This commit is contained in:
ton
2023-09-19 20:46:40 +07:00
parent 05d8cabaf8
commit 821f12c86b
2 changed files with 32 additions and 937 deletions

View File

@@ -27,8 +27,10 @@ using Pkg; Pkg.activate("."); Pkg.resolve(), Pkg.instantiate()
# end
using Revise
using BenchmarkTools, Cthulhu
using BenchmarkTools, Cthulhu, REPL.TerminalMenus
using Flux, CUDA
using BSON, JSON3
using MLDatasets: MNIST
@@ -66,6 +68,9 @@ if device == gpu CUDA.device!(0) end #CHANGE
-
"""
# ----------------------------- REPL menu options ---------------------------- #
options = ["yes", "no"]
menu = RadioMenu(options)
# communication config --------------------------------------------------------------------------100
@@ -416,17 +421,22 @@ function train_snn(model, trainData, validateData, labelDict::Vector)
stop = 0
vt0 = 0.0 # store vt to compute learning progress
for epoch = 1:1000
stop == 3 ? break : false
stop == 1 ? break : false
println("epoch $epoch")
n = length(trainData)
println("n $n")
p = Progress(n, dt=1.0) # minimum update interval: 1 second
for (imgBatch, labels) in trainData # imgBatch(28, 28, 4) i.e. (row, col, batch), labels(label, batch)
for rep in 1:10
stop == 3 ? break : false
stop == 1 ? break : false
consecutiveCorrect = 0
rep = 0
# for rep in 1:20
while consecutiveCorrect < 10
rep += 1
stop == 1 ? break : false
# prepare image into input signal (10, 2, 784, 4) i.e. (row, col, timestep, batch)
signal = dualTrackSpikeGen(imgBatch, [0.05, 0.1, 0.2, 0.3, 0.5], noise=(true, 1, 0.5), copies=18)
signal = dualTrackSpikeGen(imgBatch, [0.05, 0.1, 0.2, 0.3, 0.5], noise=(true, 1, 1.0), copies=18)
if length(size(signal)) == 3
row, col, sequence = size(signal)
batch = 1
@@ -472,7 +482,7 @@ function train_snn(model, trainData, validateData, labelDict::Vector)
elseif timestep == sequence # online learning, 1-by-1 timestep
# no error calculation
#WORKING answer time windows, collect logit to get finalAnswer
# answer time windows, collect logit to get finalAnswer
elseif timestep > sequence && timestep < sequence+thinkingPeriod
logit_cpu = logit |> cpu
logit_cpu = logit_cpu[:,1]
@@ -565,17 +575,21 @@ function train_snn(model, trainData, validateData, labelDict::Vector)
# println("label $(labels[1]) finalAnswer $finalAnswer_cpu")
max = isequal.(finalAnswer_cpu[:,1], maximum(finalAnswer_cpu[:,1]))
if sum(finalAnswer_cpu) == 0
IronpenGPU.learn!(model, progress, device)
consecutiveCorrect = 0
println("modelname $modelname epoch $epoch rep $rep label $(labels[1]) finalAnswer __ LEARNING")
IronpenGPU.learn!(model, progress, device)
elseif sum(max) == 1 && findall(max)[1] -1 == labels[1]
IronpenGPU.learn!(model, progress, device)
consecutiveCorrect += 1
println("modelname $modelname epoch $epoch rep $rep label $(labels[1]) finalAnswer $finalAnswer_cpu CORRECT")
IronpenGPU.learn!(model, progress, device)
elseif sum(max) == 1 && findall(max)[1] -1 != labels[1]
println("modelname $modelname epoch $epoch rep $rep label $(labels[1]) finalAnswer $finalAnswer_cpu LEARNING")
IronpenGPU.learn!(model, progress, device)
consecutiveCorrect = 0
println("modelname $modelname epoch $epoch rep $rep label $(labels[1]) finalAnswer $finalAnswer_cpu LEARNING")
else
println("modelname $modelname epoch $epoch rep $rep label $(labels[1]) finalAnswer $finalAnswer_cpu LEARNING")
IronpenGPU.learn!(model, progress, device)
consecutiveCorrect = 0
println("modelname $modelname epoch $epoch rep $rep label $(labels[1]) finalAnswer $finalAnswer_cpu LEARNING")
end
# error("DEBUG -> main $(Dates.now())")
@@ -683,6 +697,14 @@ function train_snn(model, trainData, validateData, labelDict::Vector)
# break
# end
end
#WORKING add menu
# choice = request("continue?", menu)
# if choice == "yes"
# continue
# else
# stop = 1
# end
next!(p)
end