dev
This commit is contained in:
@@ -414,17 +414,18 @@ function train_snn(model, trainData, validateData, labelDict::Vector)
|
||||
bestAccuracy = 0.0
|
||||
finalAnswer = [0] |> device # store model prediction in (logit of choices, batch)
|
||||
stop = 0
|
||||
vt0 = 0.0 # store vt to compute learning progress
|
||||
for epoch = 1:1000
|
||||
stop == 3 ? break : false
|
||||
println("epoch $epoch")
|
||||
n = length(trainData)
|
||||
println("n $n")
|
||||
p = Progress(n, dt=1.0) # minimum update interval: 1 second
|
||||
for (imgBatch, labels) in trainData # imgBatch (28, 28, 4) i.e. (row, col, batch)
|
||||
for (imgBatch, labels) in trainData # imgBatch(28, 28, 4) i.e. (row, col, batch), labels(label, batch)
|
||||
for rep in 1:10
|
||||
stop == 3 ? break : false
|
||||
|
||||
#WORKING prepare image into input signal (10, 2, 784, 4) i.e. (row, col, timestep, batch)
|
||||
# prepare image into input signal (10, 2, 784, 4) i.e. (row, col, timestep, batch)
|
||||
signal = dualTrackSpikeGen(imgBatch, [0.05, 0.1, 0.2, 0.3, 0.5], noise=(true, 1, 0.1), copies=18)
|
||||
if length(size(signal)) == 3
|
||||
row, col, sequence = size(signal)
|
||||
@@ -434,7 +435,7 @@ function train_snn(model, trainData, validateData, labelDict::Vector)
|
||||
end
|
||||
|
||||
# encode labels
|
||||
correctAnswer = onehotbatch(labels, labelDict) # (choices, batch)
|
||||
correctAnswer = onehotbatch(labels, labelDict) # (correctAnswer, batch)
|
||||
|
||||
# insert data into model sequencially
|
||||
for timestep in 1:(sequence + thinkingPeriod) # sMNIST has 784 timestep(pixel) + thinking period = 1000 timestep
|
||||
@@ -447,6 +448,7 @@ function train_snn(model, trainData, validateData, labelDict::Vector)
|
||||
if timestep == 1 # tell a model to start learning. 1-time only
|
||||
model.learningStage = [1]
|
||||
finalAnswer = [0] |> device
|
||||
vt0 = 0.0
|
||||
elseif timestep == (sequence+thinkingPeriod)
|
||||
model.learningStage = [3]
|
||||
else
|
||||
@@ -467,9 +469,20 @@ function train_snn(model, trainData, validateData, labelDict::Vector)
|
||||
# no error calculation
|
||||
elseif timestep == sequence # online learning, 1-by-1 timestep
|
||||
# no error calculation
|
||||
elseif timestep > sequence && timestep < sequence+thinkingPeriod # collect answer
|
||||
|
||||
#WORKING answer time windows, collect logit to get finalAnswer
|
||||
elseif timestep > sequence && timestep < sequence+thinkingPeriod
|
||||
logit_cpu = logit |> cpu
|
||||
logit_cpu = logit_cpu[:,1]
|
||||
finalAnswer = length(finalAnswer) == 1 ? logit : finalAnswer .+ logit # (logit, batch)
|
||||
predict_cpu = logit |> cpu
|
||||
finalAnswer_cpu = finalAnswer |> cpu
|
||||
on_vt_cpu = model.on_vt |> cpu
|
||||
on_vt_cpu = on_vt_cpu[1,1,:,1]
|
||||
|
||||
modelError = loss(vt0, on_vt_cpu, logit_cpu, finalAnswer_cpu, labels[1])
|
||||
vt0 = on_vt_cpu # update vt0 for this timestep
|
||||
|
||||
error("DEBUG -> main $(Dates.now())")
|
||||
|
||||
modelError = (predict_cpu .- correctAnswer)
|
||||
modelError = reshape(modelError, (1,1,:, size(modelError, 2)))
|
||||
@@ -509,14 +522,28 @@ function train_snn(model, trainData, validateData, labelDict::Vector)
|
||||
# # error("DEBUG -> main $(Dates.now())")
|
||||
# end
|
||||
|
||||
elseif timestep == sequence+thinkingPeriod
|
||||
elseif timestep == sequence+thinkingPeriod #TODO update code
|
||||
logit_cpu = logit |> cpu
|
||||
finalAnswer = length(finalAnswer) == 1 ? logit : finalAnswer .+ logit # (logit, batch)
|
||||
predict_cpu = logit |> cpu
|
||||
finalAnswer_cpu = finalAnswer |> cpu
|
||||
on_vt_cpu = model.on_vt |> cpu
|
||||
on_vt_cpu = on_vt_cpu[1,1,:,1]
|
||||
|
||||
# get vt of correct neuron, julia array is 1-based index
|
||||
labelPosition = labels[1] + 1
|
||||
on_vt_cpu = on_vt_cpu[labelPosition]
|
||||
|
||||
modelError = (predict_cpu .- correctAnswer)
|
||||
modelError = loss(vt0, on_vt_cpu, logit_cpu, finalAnswer_cpu)
|
||||
vt0 = on_vt_cpu
|
||||
|
||||
error("DEBUG -> main $(Dates.now())")
|
||||
|
||||
|
||||
|
||||
modelError = (logit_cpu .- correctAnswer)
|
||||
modelError = reshape(modelError, (1,1,:, size(modelError, 2)))
|
||||
modelError = sum(modelError, dims=3) |> device
|
||||
outputError = (predict_cpu .- correctAnswer) |> device
|
||||
outputError = (logit_cpu .- correctAnswer) |> device
|
||||
|
||||
lif_epsilonRec_cpu = model.lif_epsilonRec |> cpu
|
||||
on_zt_cpu = model.on_zt |> cpu
|
||||
@@ -847,6 +874,33 @@ function noiseGenerator(row, col, z; prob=0.5)
|
||||
return noise
|
||||
end
|
||||
|
||||
function loss(vt0::AbstractArray, vt1::AbstractArray, logit::AbstractArray,
|
||||
finalAnswer::AbstractArray, correctAnswer::Number)
|
||||
labelPosition = correctAnswer + 1 # julia array is 1-based index
|
||||
|
||||
# get vt of correct neuron
|
||||
vt1 = vt1[labelPosition]
|
||||
|
||||
# get zt of correct neuron
|
||||
zt = logit[labelPosition]
|
||||
|
||||
modelError = nothing
|
||||
|
||||
if zt == 1
|
||||
modelError = 0.0 # already correct, no weight update
|
||||
elseif vt1 > vt0 # progress increase
|
||||
modelError = 1.0 - vt1
|
||||
elseif vt1 == vt0 # no progress
|
||||
modelError = 0.11111111 # special signal
|
||||
elseif vt1 < vt0 # setback
|
||||
modelError = vt0 - vt1
|
||||
else
|
||||
error("undefined condition line $(@__LINE__)")
|
||||
end
|
||||
|
||||
return modelError
|
||||
end
|
||||
|
||||
# function arrayMax(x)
|
||||
# if sum(GeneralUtils.isNotEqual.(x, 0)) == 0 # guard against all-zeros array
|
||||
# return GeneralUtils.isNotEqual.(x, 0)
|
||||
|
||||
Reference in New Issue
Block a user