dev
This commit is contained in:
@@ -27,8 +27,10 @@ using Pkg; Pkg.activate("."); Pkg.resolve(), Pkg.instantiate()
|
||||
# end
|
||||
|
||||
|
||||
|
||||
|
||||
using Revise
|
||||
using BenchmarkTools, Cthulhu
|
||||
using BenchmarkTools, Cthulhu, REPL.TerminalMenus
|
||||
using Flux, CUDA
|
||||
using BSON, JSON3
|
||||
using MLDatasets: MNIST
|
||||
@@ -66,6 +68,9 @@ if device == gpu CUDA.device!(0) end #CHANGE
|
||||
-
|
||||
"""
|
||||
|
||||
# ----------------------------- REPL menu options ---------------------------- #
|
||||
options = ["yes", "no"]
|
||||
menu = RadioMenu(options)
|
||||
|
||||
# communication config --------------------------------------------------------------------------100
|
||||
|
||||
@@ -416,17 +421,22 @@ function train_snn(model, trainData, validateData, labelDict::Vector)
|
||||
stop = 0
|
||||
vt0 = 0.0 # store vt to compute learning progress
|
||||
for epoch = 1:1000
|
||||
stop == 3 ? break : false
|
||||
stop == 1 ? break : false
|
||||
println("epoch $epoch")
|
||||
n = length(trainData)
|
||||
println("n $n")
|
||||
p = Progress(n, dt=1.0) # minimum update interval: 1 second
|
||||
for (imgBatch, labels) in trainData # imgBatch(28, 28, 4) i.e. (row, col, batch), labels(label, batch)
|
||||
for rep in 1:10
|
||||
stop == 3 ? break : false
|
||||
stop == 1 ? break : false
|
||||
consecutiveCorrect = 0
|
||||
rep = 0
|
||||
# for rep in 1:20
|
||||
while consecutiveCorrect < 10
|
||||
rep += 1
|
||||
stop == 1 ? break : false
|
||||
|
||||
# prepare image into input signal (10, 2, 784, 4) i.e. (row, col, timestep, batch)
|
||||
signal = dualTrackSpikeGen(imgBatch, [0.05, 0.1, 0.2, 0.3, 0.5], noise=(true, 1, 0.5), copies=18)
|
||||
signal = dualTrackSpikeGen(imgBatch, [0.05, 0.1, 0.2, 0.3, 0.5], noise=(true, 1, 1.0), copies=18)
|
||||
if length(size(signal)) == 3
|
||||
row, col, sequence = size(signal)
|
||||
batch = 1
|
||||
@@ -472,7 +482,7 @@ function train_snn(model, trainData, validateData, labelDict::Vector)
|
||||
elseif timestep == sequence # online learning, 1-by-1 timestep
|
||||
# no error calculation
|
||||
|
||||
#WORKING answer time windows, collect logit to get finalAnswer
|
||||
# answer time windows, collect logit to get finalAnswer
|
||||
elseif timestep > sequence && timestep < sequence+thinkingPeriod
|
||||
logit_cpu = logit |> cpu
|
||||
logit_cpu = logit_cpu[:,1]
|
||||
@@ -565,17 +575,21 @@ function train_snn(model, trainData, validateData, labelDict::Vector)
|
||||
# println("label $(labels[1]) finalAnswer $finalAnswer_cpu")
|
||||
max = isequal.(finalAnswer_cpu[:,1], maximum(finalAnswer_cpu[:,1]))
|
||||
if sum(finalAnswer_cpu) == 0
|
||||
IronpenGPU.learn!(model, progress, device)
|
||||
consecutiveCorrect = 0
|
||||
println("modelname $modelname epoch $epoch rep $rep label $(labels[1]) finalAnswer __ LEARNING")
|
||||
IronpenGPU.learn!(model, progress, device)
|
||||
elseif sum(max) == 1 && findall(max)[1] -1 == labels[1]
|
||||
IronpenGPU.learn!(model, progress, device)
|
||||
consecutiveCorrect += 1
|
||||
println("modelname $modelname epoch $epoch rep $rep label $(labels[1]) finalAnswer $finalAnswer_cpu CORRECT")
|
||||
IronpenGPU.learn!(model, progress, device)
|
||||
elseif sum(max) == 1 && findall(max)[1] -1 != labels[1]
|
||||
println("modelname $modelname epoch $epoch rep $rep label $(labels[1]) finalAnswer $finalAnswer_cpu LEARNING")
|
||||
IronpenGPU.learn!(model, progress, device)
|
||||
consecutiveCorrect = 0
|
||||
println("modelname $modelname epoch $epoch rep $rep label $(labels[1]) finalAnswer $finalAnswer_cpu LEARNING")
|
||||
else
|
||||
println("modelname $modelname epoch $epoch rep $rep label $(labels[1]) finalAnswer $finalAnswer_cpu LEARNING")
|
||||
IronpenGPU.learn!(model, progress, device)
|
||||
consecutiveCorrect = 0
|
||||
println("modelname $modelname epoch $epoch rep $rep label $(labels[1]) finalAnswer $finalAnswer_cpu LEARNING")
|
||||
end
|
||||
|
||||
# error("DEBUG -> main $(Dates.now())")
|
||||
@@ -683,6 +697,14 @@ function train_snn(model, trainData, validateData, labelDict::Vector)
|
||||
# break
|
||||
# end
|
||||
end
|
||||
#WORKING add menu
|
||||
# choice = request("continue?", menu)
|
||||
# if choice == "yes"
|
||||
# continue
|
||||
# else
|
||||
# stop = 1
|
||||
# end
|
||||
|
||||
|
||||
next!(p)
|
||||
end
|
||||
|
||||
Reference in New Issue
Block a user