This commit is contained in:
ton
2023-09-10 11:28:40 +07:00
parent fb3e59a414
commit e7c0228313
5 changed files with 384 additions and 155 deletions

View File

@@ -414,17 +414,18 @@ function train_snn(model, trainData, validateData, labelDict::Vector)
bestAccuracy = 0.0
finalAnswer = [0] |> device # store model prediction in (logit of choices, batch)
stop = 0
vt0 = 0.0 # store vt to compute learning progress
for epoch = 1:1000
stop == 3 ? break : false
println("epoch $epoch")
n = length(trainData)
println("n $n")
p = Progress(n, dt=1.0) # minimum update interval: 1 second
for (imgBatch, labels) in trainData # imgBatch (28, 28, 4) i.e. (row, col, batch)
for (imgBatch, labels) in trainData # imgBatch(28, 28, 4) i.e. (row, col, batch), labels(label, batch)
for rep in 1:10
stop == 3 ? break : false
#WORKING prepare image into input signal (10, 2, 784, 4) i.e. (row, col, timestep, batch)
# prepare image into input signal (10, 2, 784, 4) i.e. (row, col, timestep, batch)
signal = dualTrackSpikeGen(imgBatch, [0.05, 0.1, 0.2, 0.3, 0.5], noise=(true, 1, 0.1), copies=18)
if length(size(signal)) == 3
row, col, sequence = size(signal)
@@ -434,7 +435,7 @@ function train_snn(model, trainData, validateData, labelDict::Vector)
end
# encode labels
correctAnswer = onehotbatch(labels, labelDict) # (choices, batch)
correctAnswer = onehotbatch(labels, labelDict) # (correctAnswer, batch)
# insert data into model sequencially
for timestep in 1:(sequence + thinkingPeriod) # sMNIST has 784 timestep(pixel) + thinking period = 1000 timestep
@@ -447,6 +448,7 @@ function train_snn(model, trainData, validateData, labelDict::Vector)
if timestep == 1 # tell a model to start learning. 1-time only
model.learningStage = [1]
finalAnswer = [0] |> device
vt0 = 0.0
elseif timestep == (sequence+thinkingPeriod)
model.learningStage = [3]
else
@@ -467,9 +469,20 @@ function train_snn(model, trainData, validateData, labelDict::Vector)
# no error calculation
elseif timestep == sequence # online learning, 1-by-1 timestep
# no error calculation
elseif timestep > sequence && timestep < sequence+thinkingPeriod # collect answer
#WORKING answer time windows, collect logit to get finalAnswer
elseif timestep > sequence && timestep < sequence+thinkingPeriod
logit_cpu = logit |> cpu
logit_cpu = logit_cpu[:,1]
finalAnswer = length(finalAnswer) == 1 ? logit : finalAnswer .+ logit # (logit, batch)
predict_cpu = logit |> cpu
finalAnswer_cpu = finalAnswer |> cpu
on_vt_cpu = model.on_vt |> cpu
on_vt_cpu = on_vt_cpu[1,1,:,1]
modelError = loss(vt0, on_vt_cpu, logit_cpu, finalAnswer_cpu, labels[1])
vt0 = on_vt_cpu # update vt0 for this timestep
error("DEBUG -> main $(Dates.now())")
modelError = (predict_cpu .- correctAnswer)
modelError = reshape(modelError, (1,1,:, size(modelError, 2)))
@@ -509,14 +522,28 @@ function train_snn(model, trainData, validateData, labelDict::Vector)
# # error("DEBUG -> main $(Dates.now())")
# end
elseif timestep == sequence+thinkingPeriod
elseif timestep == sequence+thinkingPeriod #TODO update code
logit_cpu = logit |> cpu
finalAnswer = length(finalAnswer) == 1 ? logit : finalAnswer .+ logit # (logit, batch)
predict_cpu = logit |> cpu
finalAnswer_cpu = finalAnswer |> cpu
on_vt_cpu = model.on_vt |> cpu
on_vt_cpu = on_vt_cpu[1,1,:,1]
# get vt of correct neuron, julia array is 1-based index
labelPosition = labels[1] + 1
on_vt_cpu = on_vt_cpu[labelPosition]
modelError = (predict_cpu .- correctAnswer)
modelError = loss(vt0, on_vt_cpu, logit_cpu, finalAnswer_cpu)
vt0 = on_vt_cpu
error("DEBUG -> main $(Dates.now())")
modelError = (logit_cpu .- correctAnswer)
modelError = reshape(modelError, (1,1,:, size(modelError, 2)))
modelError = sum(modelError, dims=3) |> device
outputError = (predict_cpu .- correctAnswer) |> device
outputError = (logit_cpu .- correctAnswer) |> device
lif_epsilonRec_cpu = model.lif_epsilonRec |> cpu
on_zt_cpu = model.on_zt |> cpu
@@ -847,6 +874,33 @@ function noiseGenerator(row, col, z; prob=0.5)
return noise
end
function loss(vt0::AbstractArray, vt1::AbstractArray, logit::AbstractArray,
finalAnswer::AbstractArray, correctAnswer::Number)
labelPosition = correctAnswer + 1 # julia array is 1-based index
# get vt of correct neuron
vt1 = vt1[labelPosition]
# get zt of correct neuron
zt = logit[labelPosition]
modelError = nothing
if zt == 1
modelError = 0.0 # already correct, no weight update
elseif vt1 > vt0 # progress increase
modelError = 1.0 - vt1
elseif vt1 == vt0 # no progress
modelError = 0.11111111 # special signal
elseif vt1 < vt0 # setback
modelError = vt0 - vt1
else
error("undefined condition line $(@__LINE__)")
end
return modelError
end
# function arrayMax(x)
# if sum(GeneralUtils.isNotEqual.(x, 0)) == 0 # guard against all-zeros array
# return GeneralUtils.isNotEqual.(x, 0)