beta
This commit is contained in:
@@ -14,19 +14,6 @@
|
||||
# for i in condapackage CondaPkg.add(i) end
|
||||
|
||||
using Pkg; Pkg.activate("."); Pkg.resolve(), Pkg.instantiate()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------- #
|
||||
# for debugging purpose #
|
||||
# ---------------------------------------------------------------------------- #
|
||||
# https://discourse.julialang.org/t/debugging-extremely-slow/53801/3
|
||||
# using MethodAnalysis
|
||||
# visit(Base) do item
|
||||
# isa(item, Module) && push!(JuliaInterpreter.compiled_modules, item)
|
||||
# true
|
||||
# end
|
||||
|
||||
|
||||
using Revise
|
||||
using BenchmarkTools, Cthulhu
|
||||
using Flux, CUDA
|
||||
@@ -139,7 +126,7 @@ function generate_snn(filename::String, location::String)
|
||||
:type => "linearNeuron",
|
||||
:v_th => 1.0, # neuron firing threshold (this value is treated as maximum bound if I use auto generate)
|
||||
:tau_out => 100.0, # output time constant in millisecond.
|
||||
:synapticConnectionPercent => 20, # % coverage of total neurons in kfn
|
||||
:synapticConnectionPercent => 100, # % coverage of total neurons in kfn
|
||||
# Good starting value is 1/50th of tau_a
|
||||
# This is problem specific parameter.
|
||||
# It controls how leaky the neuron is.
|
||||
@@ -414,19 +401,17 @@ function train_snn(model, trainData, validateData, labelDict::Vector)
|
||||
bestAccuracy = 0.0
|
||||
finalAnswer = [0] |> device # store model prediction in (logit of choices, batch)
|
||||
stop = 0
|
||||
vt0 = 0.0 # store vt to compute learning progress
|
||||
for epoch = 1:1000
|
||||
stop == 3 ? break : false
|
||||
println("epoch $epoch")
|
||||
n = length(trainData)
|
||||
println("n $n")
|
||||
p = Progress(n, dt=1.0) # minimum update interval: 1 second
|
||||
for (imgBatch, labels) in trainData # imgBatch(28, 28, 4) i.e. (row, col, batch), labels(label, batch)
|
||||
for (imgBatch, labels) in trainData # imgBatch (28, 28, 4) i.e. (row, col, batch)
|
||||
for rep in 1:10
|
||||
stop == 3 ? break : false
|
||||
|
||||
# prepare image into input signal (10, 2, 784, 4) i.e. (row, col, timestep, batch)
|
||||
signal = dualTrackSpikeGen(imgBatch, [0.05, 0.1, 0.2, 0.3, 0.5], noise=(true, 1, 0.1), copies=18)
|
||||
# signal (10, 2, 784, 4) i.e. (row, col, timestep, batch)
|
||||
signal = spikeGenerator(imgBatch, [0.05, 0.1, 0.2, 0.3, 0.5], noise=(true, 1, 0.1), copies=18)
|
||||
if length(size(signal)) == 3
|
||||
row, col, sequence = size(signal)
|
||||
batch = 1
|
||||
@@ -435,7 +420,7 @@ function train_snn(model, trainData, validateData, labelDict::Vector)
|
||||
end
|
||||
|
||||
# encode labels
|
||||
correctAnswer = onehotbatch(labels, labelDict) # (correctAnswer, batch)
|
||||
correctAnswer = onehotbatch(labels, labelDict) # (choices, batch)
|
||||
|
||||
# insert data into model sequencially
|
||||
for timestep in 1:(sequence + thinkingPeriod) # sMNIST has 784 timestep(pixel) + thinking period = 1000 timestep
|
||||
@@ -448,7 +433,6 @@ function train_snn(model, trainData, validateData, labelDict::Vector)
|
||||
if timestep == 1 # tell a model to start learning. 1-time only
|
||||
model.learningStage = [1]
|
||||
finalAnswer = [0] |> device
|
||||
vt0 = 0.0
|
||||
elseif timestep == (sequence+thinkingPeriod)
|
||||
model.learningStage = [3]
|
||||
else
|
||||
@@ -469,20 +453,9 @@ function train_snn(model, trainData, validateData, labelDict::Vector)
|
||||
# no error calculation
|
||||
elseif timestep == sequence # online learning, 1-by-1 timestep
|
||||
# no error calculation
|
||||
|
||||
#WORKING answer time windows, collect logit to get finalAnswer
|
||||
elseif timestep > sequence && timestep < sequence+thinkingPeriod
|
||||
logit_cpu = logit |> cpu
|
||||
logit_cpu = logit_cpu[:,1]
|
||||
elseif timestep > sequence && timestep < sequence+thinkingPeriod # collect answer
|
||||
finalAnswer = length(finalAnswer) == 1 ? logit : finalAnswer .+ logit # (logit, batch)
|
||||
finalAnswer_cpu = finalAnswer |> cpu
|
||||
on_vt_cpu = model.on_vt |> cpu
|
||||
on_vt_cpu = on_vt_cpu[1,1,:,1]
|
||||
|
||||
modelError = loss(vt0, on_vt_cpu, logit_cpu, finalAnswer_cpu, labels[1])
|
||||
vt0 = on_vt_cpu # update vt0 for this timestep
|
||||
|
||||
error("DEBUG -> main $(Dates.now())")
|
||||
predict_cpu = logit |> cpu
|
||||
|
||||
modelError = (predict_cpu .- correctAnswer)
|
||||
modelError = reshape(modelError, (1,1,:, size(modelError, 2)))
|
||||
@@ -522,28 +495,14 @@ function train_snn(model, trainData, validateData, labelDict::Vector)
|
||||
# # error("DEBUG -> main $(Dates.now())")
|
||||
# end
|
||||
|
||||
elseif timestep == sequence+thinkingPeriod #TODO update code
|
||||
logit_cpu = logit |> cpu
|
||||
elseif timestep == sequence+thinkingPeriod
|
||||
finalAnswer = length(finalAnswer) == 1 ? logit : finalAnswer .+ logit # (logit, batch)
|
||||
finalAnswer_cpu = finalAnswer |> cpu
|
||||
on_vt_cpu = model.on_vt |> cpu
|
||||
on_vt_cpu = on_vt_cpu[1,1,:,1]
|
||||
|
||||
# get vt of correct neuron, julia array is 1-based index
|
||||
labelPosition = labels[1] + 1
|
||||
on_vt_cpu = on_vt_cpu[labelPosition]
|
||||
predict_cpu = logit |> cpu
|
||||
|
||||
modelError = loss(vt0, on_vt_cpu, logit_cpu, finalAnswer_cpu)
|
||||
vt0 = on_vt_cpu
|
||||
|
||||
error("DEBUG -> main $(Dates.now())")
|
||||
|
||||
|
||||
|
||||
modelError = (logit_cpu .- correctAnswer)
|
||||
modelError = (predict_cpu .- correctAnswer)
|
||||
modelError = reshape(modelError, (1,1,:, size(modelError, 2)))
|
||||
modelError = sum(modelError, dims=3) |> device
|
||||
outputError = (logit_cpu .- correctAnswer) |> device
|
||||
outputError = (predict_cpu .- correctAnswer) |> device
|
||||
|
||||
lif_epsilonRec_cpu = model.lif_epsilonRec |> cpu
|
||||
on_zt_cpu = model.on_zt |> cpu
|
||||
@@ -796,33 +755,6 @@ function validate(model, dataset, labelDict)
|
||||
return percentCorrect::Float64
|
||||
end
|
||||
|
||||
function dualTrackSpikeGen(inputsignals, thresholds=[1.0]; noise=(false, 1, 0.5), copies=0)
|
||||
rowInputSignal = nothing
|
||||
colInputSignal = nothing
|
||||
|
||||
for slice in eachslice(inputsignals, dims=3)
|
||||
srow = nothing
|
||||
scol = nothing
|
||||
for row in eachrow(slice)
|
||||
srow = srow === nothing ? row : cat(srow, row, dims=1)
|
||||
end
|
||||
|
||||
for col in eachcol(slice)
|
||||
scol = scol === nothing ? col : cat(scol, col, dims=1)
|
||||
end
|
||||
|
||||
rowInputSignal = rowInputSignal === nothing ? srow : cat(rowInputSignal, srow, dims=3)
|
||||
colInputSignal = colInputSignal === nothing ? scol : cat(colInputSignal, scol, dims=3)
|
||||
end
|
||||
rowInputSignal = reshape(rowInputSignal, (size(rowInputSignal, 1), 1, size(inputsignals, 3)))
|
||||
colInputSignal = reshape(colInputSignal, (size(colInputSignal, 1), 1, size(inputsignals, 3)))
|
||||
rowInputSignal = spikeGenerator(rowInputSignal, thresholds, noise=noise, copies=8)
|
||||
colInputSignal = spikeGenerator(colInputSignal, thresholds, noise=noise, copies=8)
|
||||
|
||||
signal = cat(rowInputSignal, colInputSignal, dims=2)
|
||||
return signal
|
||||
end
|
||||
|
||||
""" inputsignals is normal column-major julia matrix in (row, col, batch) dimension
|
||||
- each threshold scan return 2 vectors. 1 for +, 1 for -
|
||||
- noise = (true/false, row, col, probability)
|
||||
@@ -874,33 +806,6 @@ function noiseGenerator(row, col, z; prob=0.5)
|
||||
return noise
|
||||
end
|
||||
|
||||
function loss(vt0::AbstractArray, vt1::AbstractArray, logit::AbstractArray,
|
||||
finalAnswer::AbstractArray, correctAnswer::Number)
|
||||
labelPosition = correctAnswer + 1 # julia array is 1-based index
|
||||
|
||||
# get vt of correct neuron
|
||||
vt1 = vt1[labelPosition]
|
||||
|
||||
# get zt of correct neuron
|
||||
zt = logit[labelPosition]
|
||||
|
||||
modelError = nothing
|
||||
|
||||
if zt == 1
|
||||
modelError = 0.0 # already correct, no weight update
|
||||
elseif vt1 > vt0 # progress increase
|
||||
modelError = 1.0 - vt1
|
||||
elseif vt1 == vt0 # no progress
|
||||
modelError = 0.11111111 # special signal
|
||||
elseif vt1 < vt0 # setback
|
||||
modelError = vt0 - vt1
|
||||
else
|
||||
error("undefined condition line $(@__LINE__)")
|
||||
end
|
||||
|
||||
return modelError
|
||||
end
|
||||
|
||||
# function arrayMax(x)
|
||||
# if sum(GeneralUtils.isNotEqual.(x, 0)) == 0 # guard against all-zeros array
|
||||
# return GeneralUtils.isNotEqual.(x, 0)
|
||||
|
||||
Reference in New Issue
Block a user