847 lines
37 KiB
Julia
847 lines
37 KiB
Julia
# ---------------------------------------------------------------------------- #
|
||
# if one need to reinstall all python packages #
|
||
# ---------------------------------------------------------------------------- #
|
||
# 1. delete .CondaPkg folder in working folder
|
||
# 2. delete CondaPkg.toml file in working folder
|
||
# using Pkg; Pkg.activate(".");
|
||
# pythonPkg = ["CondaPkg", "PythonCall"]
|
||
# for i in pythonPkg try Pkg.rm(i) catch end end
|
||
# for i in pythonPkg Pkg.add(i) end
|
||
# using CondaPkg, PythonCall
|
||
# channels = ["anaconda", "conda-forge", "pytorch"]
|
||
# for i in channels CondaPkg.add_channel(i) end
|
||
# condapackage = ["numpy", "pytorch", "snntorch"]
|
||
# for i in condapackage CondaPkg.add(i) end
|
||
|
||
using Pkg; Pkg.activate("."); Pkg.resolve(), Pkg.instantiate()
|
||
using Revise
|
||
using BenchmarkTools, Cthulhu
|
||
using Flux, CUDA
|
||
using BSON, JSON3
|
||
using MLDatasets: MNIST
|
||
using MLUtils, ProgressMeter, Dates, Random,
|
||
Serialization, OneHotArrays , GLMakie
|
||
|
||
using CondaPkg, PythonCall
|
||
np = pyimport("numpy")
|
||
torch = pyimport("torch")
|
||
spikegen = pyimport("snntorch.spikegen") # https://github.com/jeshraghian/snntorch
|
||
|
||
using IronpenGPU
|
||
using GeneralUtils
|
||
|
||
sep = Sys.iswindows() ? "\\" : "/"
|
||
rootDir = pwd()
|
||
|
||
# select compute device
|
||
# device = Flux.CUDA.functional() ? gpu : cpu # Flux provide "cpu" and "gpu" keywork
|
||
device = gpu
|
||
if device == gpu CUDA.device!(0) end #CHANGE
|
||
# CUDA.allowscalar(false) # turn off scalar indexing in CPU to make it easier when moving to GPU
|
||
#------------------------------------------------------------------------------------------------100
|
||
|
||
|
||
|
||
"""
|
||
Todo:
|
||
- []
|
||
|
||
Change from version:
|
||
-
|
||
|
||
All features
|
||
-
|
||
"""
|
||
|
||
|
||
# communication config --------------------------------------------------------------------------100
|
||
|
||
database_ip = "localhost"
|
||
# database_ip = "192.168.0.8"
|
||
|
||
#------------------------------------------------------------------------------------------------100
|
||
modelname = "runOn_gpu_0" #CHANGE
|
||
imageBatch = 1
|
||
|
||
|
||
function generate_snn(filename::String, location::String)
|
||
signalInput_portnumbers = (10, 20, imageBatch) # 2nd dim needs to match
|
||
# input signal + copied input signal + noise.
|
||
# 3rd dim is input batch size
|
||
noise_portnumbers = (signalInput_portnumbers[1], 1)
|
||
output_portnumbers = (10, 1)
|
||
|
||
# 5000 neurons are maximum for 64GB memory i.e. 300 LIF : 200 ALIF
|
||
lif_neuron_number = (signalInput_portnumbers[1], 3) # CHANGE
|
||
alif_neuron_number = (signalInput_portnumbers[1], 2) # CHANGE from Allen Institute, ALIF is 20-40% of LIF
|
||
|
||
# totalNeurons = computeNeuronNumber + noise_portnumbers + signalInput_portnumbers
|
||
# totalInputPort = noise_portnumbers + signalInput_portnumbers
|
||
|
||
# kfn and neuron config
|
||
passthrough_neuron_params = Dict(
|
||
:type => "passthroughNeuron"
|
||
)
|
||
|
||
lif_neuron_params = Dict{Symbol, Any}(
|
||
:type => "lifNeuron",
|
||
:v_t_default => 0.0,
|
||
:v_th => 1.0, # neuron firing threshold (this value is treated as maximum bound if I use auto generate)
|
||
:tau_m => 50.0, # membrane time constant in millisecond.
|
||
:eta => 1e-6,
|
||
# Good starting value is 1/10th of tau_a
|
||
# This is problem specific parameter. It controls how leaky the neuron is.
|
||
# Too high(less leaky) makes learning algo harder to move model into direction that reduce error
|
||
# resulting in model's error to explode exponantially likely because learning algo will try to
|
||
# exert more force (larger w_out_change) to move neuron into direction that reduce error
|
||
# For example, model error from 7 to 2e6.
|
||
|
||
:synapticConnectionPercent => 20, # % coverage of total neurons in kfn
|
||
)
|
||
|
||
alif_neuron_params = Dict{Symbol, Any}(
|
||
:type => "alifNeuron",
|
||
:v_t_default => 0.0,
|
||
:v_th => 1.0, # neuron firing threshold (this value is treated as maximum bound if I use auto generate)
|
||
:tau_m => 50.0, # membrane time constant in millisecond.
|
||
:eta => 1e-6,
|
||
# Good starting value is 1/10th of tau_a
|
||
# This is problem specific parameter. It controls how leaky the neuron is.
|
||
# Too high(less leaky) makes learning algo harder to move model into direction that reduce error
|
||
# resulting in model's error to explode exponantially likely because learning algo will try to
|
||
# exert more force (larger w_out_change) to move neuron into direction that reduce error
|
||
# For example, model error from 7 to 2e6.
|
||
|
||
:tau_a => 800.0, # adaptation time constant in millisecond. it defines neuron memory length.
|
||
# This is problem specific parameter
|
||
# Good starting value is 0.5 to 2 times of info STORE-RECALL length i.e. total time SNN takes to
|
||
# perform a task, for example, equals to episode length.
|
||
# From "Spike frequency adaptation supports network computations on temporally dispersed
|
||
# information"
|
||
|
||
:synapticConnectionPercent => 20, # % coverage of total neurons in kfn
|
||
)
|
||
|
||
linear_neuron_params = Dict{Symbol, Any}(
|
||
:type => "linearNeuron",
|
||
:v_th => 1.0, # neuron firing threshold (this value is treated as maximum bound if I use auto generate)
|
||
:tau_out => 100.0, # output time constant in millisecond.
|
||
:synapticConnectionPercent => 100, # % coverage of total neurons in kfn
|
||
# Good starting value is 1/50th of tau_a
|
||
# This is problem specific parameter.
|
||
# It controls how leaky the neuron is.
|
||
# Too high(less leaky) makes learning algo harder to move model into direction that reduce error
|
||
# resulting in model's error to explode exponantially. For example, model error from 7 to 2e6
|
||
# One can image training output neuron is like Tetris Game.
|
||
)
|
||
|
||
integrate_neuron_params = Dict{Symbol, Any}(
|
||
:type => "integrateNeuron",
|
||
:synapticConnectionPercent => 100, # % coverage of total neurons in kfn
|
||
:eta => 1e-6,
|
||
:tau_out => 100.0,
|
||
# Good starting value is 1/50th of tau_a
|
||
# This is problem specific parameter.
|
||
# It controls how leaky the neuron is.
|
||
# Too high(less leaky) makes learning algo harder to move model into direction that reduce error
|
||
# resulting in model's error to explode exponantially. For example, model error from 7 to 2e6
|
||
# One can image training output neuron is like Tetris Game.
|
||
)
|
||
|
||
I_kfnparams = Dict{Symbol, Any}(
|
||
:knowledgeFnName=> "I",
|
||
:neuronFiringRateTarget=> 20.0, # Hz
|
||
|
||
# group relavent info
|
||
:inputPort=> Dict(
|
||
:noise=> Dict(
|
||
:numbers=> noise_portnumbers,
|
||
:params=> passthrough_neuron_params,
|
||
),
|
||
:signal=> Dict(
|
||
:numbers=> signalInput_portnumbers, # in case of GloVe word encoding, it is 300
|
||
:params=> passthrough_neuron_params,
|
||
),
|
||
),
|
||
:outputPort=> Dict(
|
||
:numbers=> output_portnumbers, # output neuron, this is also the output length
|
||
:params=> linear_neuron_params,
|
||
),
|
||
:computeNeuron=> Dict(
|
||
:lif=> Dict(
|
||
:numbers=> lif_neuron_number, # number in (row, col) tuple format
|
||
:params=> lif_neuron_params,
|
||
),
|
||
:alif=> Dict(
|
||
:numbers=> alif_neuron_number, # number in (row, col) tuple format
|
||
:params=> alif_neuron_params,
|
||
),
|
||
),
|
||
)
|
||
|
||
#------------------------------------------------------------------------------------------------100
|
||
|
||
model = IronpenGPU.kfn_1(I_kfnparams, device=device);
|
||
|
||
|
||
# serialize(location * sep * filename, model)
|
||
println("SNN generated")
|
||
|
||
return model
|
||
end
|
||
|
||
function data_loader()
|
||
# test problem
|
||
trainDataset = MNIST(:train)[1:3] # total 60000
|
||
# validateDataset = MNIST(:test)
|
||
validateDataset = MNIST(:train)[1:3]
|
||
labelDict = [0:9...]
|
||
|
||
trainData = MLUtils.DataLoader(
|
||
trainDataset; # fullTrainDataset or trainDataset
|
||
batchsize=imageBatch,
|
||
collate=true,
|
||
shuffle=true,
|
||
buffer=true,
|
||
partial=false, # better for gpu memory if batchsize is fixed
|
||
# parallel=true, #BUG ?? causing dataloader into forever loop
|
||
)
|
||
|
||
validateData = MLUtils.DataLoader(
|
||
validateDataset;
|
||
batchsize=imageBatch,
|
||
collate=true,
|
||
shuffle=true,
|
||
buffer=true,
|
||
partial=false, # better for gpu memory if batchsize is fixed
|
||
# parallel=true, #BUG ?? causing dataloader into forever loop
|
||
)
|
||
|
||
# dummy data used to debug
|
||
# trainData = [(rand(10, 10), [5]), (rand(10, 10), [2])]
|
||
# trainData = [(rand(10, 10), [5]),]
|
||
|
||
return trainData, validateData, labelDict
|
||
end
|
||
|
||
function train_snn(model, trainData, validateData, labelDict::Vector)
|
||
|
||
# random seed
|
||
# rng = MersenneTwister(1234)
|
||
|
||
logitLog = zeros(10, 2)
|
||
firedNeurons_t1 = zeros(1)
|
||
var1 = zeros(10, 2)
|
||
var2 = zeros(10, 2)
|
||
var3 = zeros(10, 2)
|
||
var4 = zeros(10, 2)
|
||
|
||
# ----------------------------------- plot ----------------------------------- #
|
||
plot10 = Observable(firedNeurons_t1)
|
||
|
||
plot20 = Observable(logitLog[1 , :])
|
||
plot21 = Observable(logitLog[2 , :])
|
||
plot22 = Observable(logitLog[3 , :])
|
||
plot23 = Observable(logitLog[4 , :])
|
||
plot24 = Observable(logitLog[5 , :])
|
||
plot25 = Observable(logitLog[6 , :])
|
||
plot26 = Observable(logitLog[7 , :])
|
||
plot27 = Observable(logitLog[8 , :])
|
||
plot28 = Observable(logitLog[9 , :])
|
||
plot29 = Observable(logitLog[10, :])
|
||
|
||
plot30 = Observable(var1[1 , :])
|
||
plot31 = Observable(var1[2 , :])
|
||
plot32 = Observable(var1[3 , :])
|
||
plot33 = Observable(var1[4 , :])
|
||
plot34 = Observable(var1[5 , :])
|
||
plot35 = Observable(var1[6 , :])
|
||
plot36 = Observable(var1[7 , :])
|
||
plot37 = Observable(var1[8 , :])
|
||
plot38 = Observable(var1[9 , :])
|
||
plot39 = Observable(var1[10, :])
|
||
|
||
plot40 = Observable(var2[1 , :])
|
||
plot41 = Observable(var2[2 , :])
|
||
plot42 = Observable(var2[3 , :])
|
||
plot43 = Observable(var2[4 , :])
|
||
plot44 = Observable(var2[5 , :])
|
||
plot45 = Observable(var2[6 , :])
|
||
plot46 = Observable(var2[7 , :])
|
||
plot47 = Observable(var2[8 , :])
|
||
plot48 = Observable(var2[9 , :])
|
||
plot49 = Observable(var2[10, :])
|
||
|
||
plot50 = Observable(var3[1 , :])
|
||
plot51 = Observable(var3[2 , :])
|
||
plot52 = Observable(var3[3 , :])
|
||
plot53 = Observable(var3[4 , :])
|
||
plot54 = Observable(var3[5 , :])
|
||
plot55 = Observable(var3[6 , :])
|
||
plot56 = Observable(var3[7 , :])
|
||
plot57 = Observable(var3[8 , :])
|
||
plot58 = Observable(var3[9 , :])
|
||
plot59 = Observable(var3[10, :])
|
||
|
||
plot60 = Observable(var4[1 , :])
|
||
plot61 = Observable(var4[2 , :])
|
||
plot62 = Observable(var4[3 , :])
|
||
plot63 = Observable(var4[4 , :])
|
||
plot64 = Observable(var4[5 , :])
|
||
plot65 = Observable(var4[6 , :])
|
||
plot66 = Observable(var4[7 , :])
|
||
plot67 = Observable(var4[8 , :])
|
||
plot68 = Observable(var4[9 , :])
|
||
plot69 = Observable(var4[10, :])
|
||
|
||
# main figure
|
||
fig1 = Figure()
|
||
|
||
subfig1 = GLMakie.Axis(fig1[1, 1], # define position of this subfigure inside a figure
|
||
title = "RSNN firedNeurons_t1",
|
||
xlabel = "time",
|
||
ylabel = "data"
|
||
)
|
||
lines!(subfig1, plot10, label = "firedNeurons_t1")
|
||
# axislegend(subfig1, position = :lb)
|
||
|
||
subfig2 = GLMakie.Axis(fig1[2, 1], # define position of this subfigure inside a figure
|
||
title = "output neurons logit",
|
||
xlabel = "time",
|
||
ylabel = "data"
|
||
)
|
||
|
||
lines!(subfig2, plot20, label = "0", color = 1, colormap = :tab10, colorrange = (1, 10) )
|
||
lines!(subfig2, plot21, label = "1", color = 2, colormap = :tab10, colorrange = (1, 10) )
|
||
lines!(subfig2, plot22, label = "2", color = 3, colormap = :tab10, colorrange = (1, 10) )
|
||
lines!(subfig2, plot23, label = "3", color = 4, colormap = :tab10, colorrange = (1, 10) )
|
||
lines!(subfig2, plot24, label = "4", color = 5, colormap = :tab10, colorrange = (1, 10) )
|
||
lines!(subfig2, plot25, label = "5", color = 6, colormap = :tab10, colorrange = (1, 10) )
|
||
lines!(subfig2, plot26, label = "6", color = 7, colormap = :tab10, colorrange = (1, 10) )
|
||
lines!(subfig2, plot27, label = "7", color = 8, colormap = :tab10, colorrange = (1, 10) )
|
||
lines!(subfig2, plot28, label = "8", color = 9, colormap = :tab10, colorrange = (1, 10) )
|
||
lines!(subfig2, plot29, label = "9", color = 10, colormap = :tab10, colorrange = (1, 10))
|
||
# axislegend(subfig2, position = :lb)
|
||
|
||
|
||
subfig3 = GLMakie.Axis(fig1[3, 1], # define position of this subfigure inside a figure
|
||
title = "last RSNN wRec",
|
||
xlabel = "time",
|
||
ylabel = "data"
|
||
)
|
||
lines!(subfig3, plot30, label = "0", color = 1, colormap = :tab10, colorrange = (1, 10) )
|
||
lines!(subfig3, plot31, label = "1", color = 2, colormap = :tab10, colorrange = (1, 10) )
|
||
lines!(subfig3, plot32, label = "2", color = 3, colormap = :tab10, colorrange = (1, 10) )
|
||
lines!(subfig3, plot33, label = "3", color = 4, colormap = :tab10, colorrange = (1, 10) )
|
||
lines!(subfig3, plot34, label = "4", color = 5, colormap = :tab10, colorrange = (1, 10) )
|
||
lines!(subfig3, plot35, label = "5", color = 6, colormap = :tab10, colorrange = (1, 10) )
|
||
lines!(subfig3, plot36, label = "6", color = 7, colormap = :tab10, colorrange = (1, 10) )
|
||
lines!(subfig3, plot37, label = "7", color = 8, colormap = :tab10, colorrange = (1, 10) )
|
||
lines!(subfig3, plot38, label = "8", color = 9, colormap = :tab10, colorrange = (1, 10) )
|
||
lines!(subfig3, plot39, label = "9", color = 10, colormap = :tab10, colorrange = (1, 10))
|
||
# axislegend(subfig3, position = :lb)
|
||
|
||
subfig4 = GLMakie.Axis(fig1[4, 1], # define position of this subfigure inside a figure
|
||
title = "RSNN v_t1",
|
||
xlabel = "time",
|
||
ylabel = "data"
|
||
)
|
||
lines!(subfig4, plot40, label = "0", color = 1, colormap = :tab10, colorrange = (1, 10) )
|
||
lines!(subfig4, plot41, label = "1", color = 2, colormap = :tab10, colorrange = (1, 10) )
|
||
lines!(subfig4, plot42, label = "2", color = 3, colormap = :tab10, colorrange = (1, 10) )
|
||
lines!(subfig4, plot43, label = "3", color = 4, colormap = :tab10, colorrange = (1, 10) )
|
||
lines!(subfig4, plot44, label = "4", color = 5, colormap = :tab10, colorrange = (1, 10) )
|
||
lines!(subfig4, plot45, label = "5", color = 6, colormap = :tab10, colorrange = (1, 10) )
|
||
lines!(subfig4, plot46, label = "6", color = 7, colormap = :tab10, colorrange = (1, 10) )
|
||
lines!(subfig4, plot47, label = "7", color = 8, colormap = :tab10, colorrange = (1, 10) )
|
||
lines!(subfig4, plot48, label = "8", color = 9, colormap = :tab10, colorrange = (1, 10) )
|
||
lines!(subfig4, plot49, label = "9", color = 10, colormap = :tab10, colorrange = (1, 10))
|
||
# axislegend(subfig4, position = :lb)
|
||
|
||
subfig5 = GLMakie.Axis(fig1[5, 1], # define position of this subfigure inside a figure
|
||
title = "output neuron epsilonRec",
|
||
xlabel = "time",
|
||
ylabel = "data"
|
||
)
|
||
lines!(subfig5, plot50, label = "0", color = 1, colormap = :tab10, colorrange = (1, 10) )
|
||
lines!(subfig5, plot51, label = "1", color = 2, colormap = :tab10, colorrange = (1, 10) )
|
||
lines!(subfig5, plot52, label = "2", color = 3, colormap = :tab10, colorrange = (1, 10) )
|
||
lines!(subfig5, plot53, label = "3", color = 4, colormap = :tab10, colorrange = (1, 10) )
|
||
lines!(subfig5, plot54, label = "4", color = 5, colormap = :tab10, colorrange = (1, 10) )
|
||
lines!(subfig5, plot55, label = "5", color = 6, colormap = :tab10, colorrange = (1, 10) )
|
||
lines!(subfig5, plot56, label = "6", color = 7, colormap = :tab10, colorrange = (1, 10) )
|
||
lines!(subfig5, plot57, label = "7", color = 8, colormap = :tab10, colorrange = (1, 10) )
|
||
lines!(subfig5, plot58, label = "8", color = 9, colormap = :tab10, colorrange = (1, 10) )
|
||
lines!(subfig5, plot59, label = "9", color = 10, colormap = :tab10, colorrange = (1, 10))
|
||
# axislegend(subfig5, position = :lb)
|
||
|
||
subfig6 = GLMakie.Axis(fig1[6, 1], # define position of this subfigure inside a figure
|
||
title = "output neuron wRecChange",
|
||
xlabel = "time",
|
||
ylabel = "data"
|
||
)
|
||
lines!(subfig6, plot60, label = "0", color = 1, colormap = :tab10, colorrange = (1, 10) )
|
||
lines!(subfig6, plot61, label = "1", color = 2, colormap = :tab10, colorrange = (1, 10) )
|
||
lines!(subfig6, plot62, label = "2", color = 3, colormap = :tab10, colorrange = (1, 10) )
|
||
lines!(subfig6, plot63, label = "3", color = 4, colormap = :tab10, colorrange = (1, 10) )
|
||
lines!(subfig6, plot64, label = "4", color = 5, colormap = :tab10, colorrange = (1, 10) )
|
||
lines!(subfig6, plot65, label = "5", color = 6, colormap = :tab10, colorrange = (1, 10) )
|
||
lines!(subfig6, plot66, label = "6", color = 7, colormap = :tab10, colorrange = (1, 10) )
|
||
lines!(subfig6, plot67, label = "7", color = 8, colormap = :tab10, colorrange = (1, 10) )
|
||
lines!(subfig6, plot68, label = "8", color = 9, colormap = :tab10, colorrange = (1, 10) )
|
||
lines!(subfig6, plot69, label = "9", color = 10, colormap = :tab10, colorrange = (1, 10))
|
||
# axislegend(subfig6, position = :lb)
|
||
|
||
# wait(display(fig1))
|
||
# display(fig1)
|
||
# --------------------------------- end plot --------------------------------- #
|
||
|
||
# model learning
|
||
thinkingPeriod = 16 # 1000-784 = 216
|
||
bestAccuracy = 0.0
|
||
finalAnswer = [0] |> device # store model prediction in (logit of choices, batch)
|
||
stop = 0
|
||
for epoch = 1:1000
|
||
stop == 3 ? break : false
|
||
println("epoch $epoch")
|
||
n = length(trainData)
|
||
println("n $n")
|
||
p = Progress(n, dt=1.0) # minimum update interval: 1 second
|
||
for (imgBatch, labels) in trainData # imgBatch (28, 28, 4) i.e. (row, col, batch)
|
||
for rep in 1:10
|
||
stop == 3 ? break : false
|
||
# signal (10, 2, 784, 4) i.e. (row, col, timestep, batch)
|
||
signal = spikeGenerator(imgBatch, [0.05, 0.1, 0.2, 0.3, 0.5], noise=(true, 1, 0.1), copies=18)
|
||
if length(size(signal)) == 3
|
||
row, col, sequence = size(signal)
|
||
batch = 1
|
||
else
|
||
row, col, sequence, batch = size(signal)
|
||
end
|
||
|
||
# encode labels
|
||
correctAnswer = onehotbatch(labels, labelDict) # (choices, batch)
|
||
|
||
# insert data into model sequencially
|
||
for timestep in 1:(sequence + thinkingPeriod) # sMNIST has 784 timestep(pixel) + thinking period = 1000 timestep
|
||
if timestep <= sequence
|
||
current_pixel = view(signal, :, :, timestep, :) |> device
|
||
else
|
||
current_pixel = zeros(row, col, batch) |> device # dummy input in "thinking" period
|
||
end
|
||
|
||
if timestep == 1 # tell a model to start learning. 1-time only
|
||
model.learningStage = [1]
|
||
finalAnswer = [0] |> device
|
||
elseif timestep == (sequence+thinkingPeriod)
|
||
model.learningStage = [3]
|
||
else
|
||
end
|
||
|
||
# predict
|
||
logit, _firedNeurons_t1 = model(current_pixel)
|
||
|
||
# # log answer of all timestep
|
||
# logitLog = [logitLog;; logit]
|
||
# firedNeurons_t1 = push!(firedNeurons_t1, _firedNeurons_t1)
|
||
# var1 = [var1;; _var1]
|
||
# var2 = [var2;; _var2]
|
||
# var3 = [var3;; _var3]
|
||
# var4 = [var4;; _var4]
|
||
|
||
if timestep < sequence # online learning, 1-by-1 timestep
|
||
# no error calculation
|
||
elseif timestep == sequence # online learning, 1-by-1 timestep
|
||
# no error calculation
|
||
elseif timestep > sequence && timestep < sequence+thinkingPeriod # collect answer
|
||
finalAnswer = length(finalAnswer) == 1 ? logit : finalAnswer .+ logit # (logit, batch)
|
||
predict_cpu = logit |> cpu
|
||
|
||
modelError = (predict_cpu .- correctAnswer)
|
||
modelError = reshape(modelError, (1,1,:, size(modelError, 2)))
|
||
modelError = sum(modelError, dims=3) |> device
|
||
outputError = (predict_cpu .- correctAnswer) |> device
|
||
|
||
lif_epsilonRec_cpu = model.lif_epsilonRec |> cpu
|
||
on_zt_cpu = model.on_zt |> cpu
|
||
|
||
IronpenGPU.compute_paramsChange!(model, modelError, outputError)
|
||
|
||
lif_wRecChange_cpu = model.lif_wRecChange |> cpu
|
||
|
||
# if sum(lif_wRecChange_cpu) != 0
|
||
# println("")
|
||
# lif_vt_cpu = model.lif_vt |> cpu
|
||
# lif_zt_cpu = model.lif_zt |> cpu
|
||
|
||
# lif_recSignal = model.lif_recSignal |> cpu
|
||
|
||
# on_vt_cpu = model.on_vt |> cpu
|
||
# on_vt_cpu = on_vt_cpu[1,1,:,1]
|
||
|
||
# on_zt_cpu = on_zt_cpu[1,1,:,1]
|
||
# on_wOutChange_cpu = model.on_wOutChange |> cpu
|
||
# on_wOutChange_cpu = sum(on_wOutChange_cpu, dims=(1,2))
|
||
# println("lif vt $(lif_vt_cpu[1,1,5,1]) lif zt $(lif_zt_cpu[1,1,5,1]) on_vt $on_vt_cpu on_zt $on_zt_cpu on_wOutChange_cpu $on_wOutChange_cpu")
|
||
# println("lif_recSignal ", lif_recSignal)
|
||
# println("")
|
||
# println("lif_epsilonRec_cpu ", lif_epsilonRec_cpu)
|
||
# println("")
|
||
# println("lif_wRecChange ", lif_wRecChange_cpu)
|
||
# println("")
|
||
# zit_cumulative = model.zit_cumulative |> cpu
|
||
# println("zit_cumulative ", zit_cumulative)
|
||
|
||
# # error("DEBUG -> main $(Dates.now())")
|
||
# end
|
||
|
||
elseif timestep == sequence+thinkingPeriod
|
||
finalAnswer = length(finalAnswer) == 1 ? logit : finalAnswer .+ logit # (logit, batch)
|
||
predict_cpu = logit |> cpu
|
||
|
||
modelError = (predict_cpu .- correctAnswer)
|
||
modelError = reshape(modelError, (1,1,:, size(modelError, 2)))
|
||
modelError = sum(modelError, dims=3) |> device
|
||
outputError = (predict_cpu .- correctAnswer) |> device
|
||
|
||
lif_epsilonRec_cpu = model.lif_epsilonRec |> cpu
|
||
on_zt_cpu = model.on_zt |> cpu
|
||
|
||
IronpenGPU.compute_paramsChange!(model, modelError, outputError)
|
||
|
||
lif_wRecChange_cpu = model.lif_wRecChange |> cpu
|
||
|
||
|
||
println("")
|
||
lif_recSignal_cpu = model.lif_recSignal |> cpu
|
||
lif_recSignal_cpu = sum(lif_recSignal_cpu[:,:,5,1])
|
||
lif_vt_cpu = model.lif_vt |> cpu
|
||
lif_vt_cpu = lif_vt_cpu[1,1,5,1]
|
||
lif_zt_cpu = model.lif_zt |> cpu
|
||
lif_zt_cpu = lif_zt_cpu[1,1,5,1]
|
||
lif_epsilonRec_cpu = model.lif_epsilonRec |> cpu
|
||
lif_epsilonRec_cpu = sum(lif_epsilonRec_cpu[:,:,5,1])
|
||
lif_wRecChange_cpu = sum(lif_wRecChange_cpu[:,:,5,1])
|
||
on_vt_cpu = model.on_vt |> cpu
|
||
on_vt_cpu = on_vt_cpu[1,1,:,1]
|
||
|
||
on_zt_cpu = on_zt_cpu[1,1,:,1]
|
||
on_wOutChange_cpu = model.on_wOutChange |> cpu
|
||
on_wOutChange_cpu = sum(on_wOutChange_cpu, dims=(1,2))
|
||
println("lif recSignal $lif_recSignal_cpu lif vt $lif_vt_cpu lif zt $lif_zt_cpu lif_epsilonRec_cpu $lif_epsilonRec_cpu lif_wRecChange_cpu $lif_wRecChange_cpu on_vt $on_vt_cpu on_zt $on_zt_cpu on_wOutChange_cpu $on_wOutChange_cpu")
|
||
# println("lif_recSignal ", lif_recSignal)
|
||
# println("")
|
||
# println("lif_epsilonRec_cpu ", lif_epsilonRec_cpu)
|
||
# println("")
|
||
# println("lif_wRecChange ", lif_wRecChange_cpu)
|
||
# println("")
|
||
# zit_cumulative = model.zit_cumulative |> cpu
|
||
# println("zit_cumulative ", zit_cumulative)
|
||
|
||
# error("DEBUG -> main $(Dates.now())")
|
||
|
||
# commit learned weight only if the model answer incorrectly
|
||
finalAnswer_cpu = finalAnswer |> cpu
|
||
# println("label $(labels[1]) finalAnswer $finalAnswer_cpu")
|
||
max = isequal.(finalAnswer_cpu[:,1], maximum(finalAnswer_cpu[:,1]))
|
||
if sum(finalAnswer_cpu) == 0
|
||
println("modelname $modelname epoch $epoch label $(labels[1]) finalAnswer ZERO answer LEARNING")
|
||
IronpenGPU.learn!(model, device)
|
||
elseif sum(max) == 1 && findall(max)[1] -1 == labels[1]
|
||
finalAnswer_cpu = findall(max)[1] - 1
|
||
println("modelname $modelname epoch $epoch label $(labels[1]) finalAnswer $finalAnswer_cpu CORRECT")
|
||
elseif sum(max) == 1 && findall(max)[1] -1 != labels[1]
|
||
finalAnswer = findall(max)[1] - 1
|
||
println("modelname $modelname epoch $epoch label $(labels[1]) finalAnswer $finalAnswer_cpu LEARNING")
|
||
IronpenGPU.learn!(model, device)
|
||
else
|
||
println("modelname $modelname epoch $epoch label $(labels[1]) finalAnswer $finalAnswer_cpu LEARNING")
|
||
IronpenGPU.learn!(model, device)
|
||
end
|
||
|
||
# error("DEBUG -> main $(Dates.now())")
|
||
else
|
||
error("undefined condition line $(@__LINE__)")
|
||
# error("DEBUG -> main $(Dates.now())")
|
||
end
|
||
|
||
# update plot
|
||
# plot10[] = firedNeurons_t1
|
||
|
||
# plot20[] = view(logitLog, 1 , :)
|
||
# plot21[] = view(logitLog, 2 , :)
|
||
# plot22[] = view(logitLog, 3 , :)
|
||
# plot23[] = view(logitLog, 4 , :)
|
||
# plot24[] = view(logitLog, 5 , :)
|
||
# plot25[] = view(logitLog, 6 , :)
|
||
# plot26[] = view(logitLog, 7 , :)
|
||
# plot27[] = view(logitLog, 8 , :)
|
||
# plot28[] = view(logitLog, 9 , :)
|
||
# plot29[] = view(logitLog, 10, :)
|
||
|
||
# plot30[] = view(var1, 1 , :)
|
||
# plot31[] = view(var1, 2 , :)
|
||
# plot32[] = view(var1, 3 , :)
|
||
# plot33[] = view(var1, 4 , :)
|
||
# plot34[] = view(var1, 5 , :)
|
||
# plot35[] = view(var1, 6 , :)
|
||
# plot36[] = view(var1, 7 , :)
|
||
# plot37[] = view(var1, 8 , :)
|
||
# plot38[] = view(var1, 9 , :)
|
||
# plot39[] = view(var1, 10, :)
|
||
|
||
# plot40[] = view(var2, 1 , :)
|
||
# plot41[] = view(var2, 2 , :)
|
||
# plot42[] = view(var2, 3 , :)
|
||
# plot43[] = view(var2, 4 , :)
|
||
# plot44[] = view(var2, 5 , :)
|
||
# plot45[] = view(var2, 6 , :)
|
||
# plot46[] = view(var2, 7 , :)
|
||
# plot47[] = view(var2, 8 , :)
|
||
# plot48[] = view(var2, 9 , :)
|
||
# plot49[] = view(var2, 10, :)
|
||
|
||
# plot50[] = view(var3, 1 , :)
|
||
# plot51[] = view(var3, 2 , :)
|
||
# plot52[] = view(var3, 3 , :)
|
||
# plot53[] = view(var3, 4 , :)
|
||
# plot54[] = view(var3, 5 , :)
|
||
# plot55[] = view(var3, 6 , :)
|
||
# plot56[] = view(var3, 7 , :)
|
||
# plot57[] = view(var3, 8 , :)
|
||
# plot58[] = view(var3, 9 , :)
|
||
# plot59[] = view(var3, 10, :)
|
||
|
||
# plot60[] = view(var4, 1 , :)
|
||
# plot61[] = view(var4, 2 , :)
|
||
# plot62[] = view(var4, 3 , :)
|
||
# plot63[] = view(var4, 4 , :)
|
||
# plot64[] = view(var4, 5 , :)
|
||
# plot65[] = view(var4, 6 , :)
|
||
# plot66[] = view(var4, 7 , :)
|
||
# plot67[] = view(var4, 8 , :)
|
||
# plot68[] = view(var4, 9 , :)
|
||
# plot69[] = view(var4, 10, :)
|
||
end
|
||
# end-thinkingPeriod+2; +2 because initialize logitLog = zeros(10, 2)
|
||
# _modelRespond = logitLog[:, end-thinkingPeriod+2:end] # answer count during thinking period
|
||
# _modelRespond = [sum(i) for i in eachrow(_modelRespond)]
|
||
# modelRespond = isequal.(isequal.(_modelRespond, 0), 0)
|
||
|
||
# display(fig1)
|
||
# sleep(1)
|
||
# if k % 3 == 0
|
||
# firedNeurons_t1 = zeros(1)
|
||
# logitLog = zeros(10, 2)
|
||
# var1 = zeros(10, 2)
|
||
# var2 = zeros(10, 2)
|
||
# var3 = zeros(10, 2)
|
||
# var4 = zeros(10, 2)
|
||
# end
|
||
|
||
# # if predict == OneHotArrays.onehot(label, labelDict)
|
||
# # println("model train $label successfully, $k tries")
|
||
# # # wait(display(fig1))
|
||
|
||
# # firedNeurons_t1 = zeros(1)
|
||
# # logitLog = zeros(10, 2)
|
||
# # var1 = zeros(10, 2)
|
||
# # var2 = zeros(10, 2)
|
||
# # var3 = zeros(10, 2)
|
||
# # var4 = zeros(10, 2)
|
||
# # break
|
||
# # end
|
||
|
||
# if k == maxRepeatRound
|
||
# # println("model train $label unsuccessfully, $maxRepeatRound tries, skip training")
|
||
# # display(fig1)
|
||
# firedNeurons_t1 = zeros(1)
|
||
# logitLog = zeros(10, 2)
|
||
# var1 = zeros(10, 2)
|
||
# var2 = zeros(10, 2)
|
||
# var3 = zeros(10, 2)
|
||
# var4 = zeros(10, 2)
|
||
# break
|
||
# end
|
||
end
|
||
|
||
next!(p)
|
||
end
|
||
|
||
if epoch > 200
|
||
# check accuracy
|
||
println("validating model")
|
||
percentCorrect = validate(model, validateData, labelDict)
|
||
bestAccuracy = percentCorrect > bestAccuracy ? percentCorrect : bestAccuracy
|
||
println("$modelname model accuracy is $percentCorrect %, best accuracy is $bestAccuracy")
|
||
end
|
||
end
|
||
end
|
||
|
||
function validate(model, dataset, labelDict)
|
||
totalAnswerCorrectly = 0 # score
|
||
totalSignal = 0
|
||
thinkingPeriod = 16 # 1000-784 = 216
|
||
predict = [0] |> device
|
||
|
||
n = length(dataset)
|
||
println("n $n")
|
||
p = Progress(n, dt=1.0) # minimum update interval: 1 second
|
||
for (imgBatch, labels) in dataset
|
||
signal = spikeGenerator(imgBatch, [0.05, 0.1, 0.2, 0.3, 0.5], noise=(true, 1, 0.5), copies=18)
|
||
if length(size(signal)) == 3
|
||
row, col, sequence = size(signal)
|
||
batch = 1
|
||
else
|
||
row, col, sequence, batch = size(signal)
|
||
end
|
||
|
||
# encode labels
|
||
correctAnswer = onehotbatch(labels, labelDict) # (choices, batch)
|
||
|
||
# insert data into model sequencially
|
||
for timestep in 1:(sequence + thinkingPeriod) # sMNIST has 784 timestep(pixel) + thinking period = 1000 timestep
|
||
if timestep <= sequence
|
||
current_pixel = view(signal, :, :, timestep, :) |> device
|
||
else
|
||
current_pixel = zeros(row, col, batch) |> device # dummy input in "thinking" period
|
||
end
|
||
|
||
if timestep == 1 # tell a model to start learning. 1-time only
|
||
predict = [0] |> device
|
||
elseif timestep == (sequence+thinkingPeriod)
|
||
else
|
||
end
|
||
|
||
# predict
|
||
logit, _ = model(current_pixel)
|
||
|
||
if timestep < sequence # online learning, 1-by-1 timestep
|
||
# no error calculation
|
||
elseif timestep == sequence # online learning, 1-by-1 timestep
|
||
# no error calculation
|
||
elseif timestep > sequence && timestep < sequence+thinkingPeriod # collect answer
|
||
predict = length(predict) == 1 ? logit : predict .+ logit # (logit, batch)
|
||
elseif timestep == sequence+thinkingPeriod
|
||
predict = length(predict) == 1 ? logit : predict .+ logit # (logit, batch)
|
||
else
|
||
error("undefined condition line $(@__LINE__)")
|
||
end
|
||
end
|
||
|
||
predict_cpu = predict |> cpu
|
||
_predict_label = mapslices(GeneralUtils.vectorMax, predict_cpu; dims=1)
|
||
s = sum(_predict_label, dims=1)
|
||
if 0 ∉ s
|
||
predict_label = []
|
||
for i in eachcol(_predict_label)
|
||
_label = findall(i) .- 1
|
||
if length(_label) == 1
|
||
append!(predict_label, _label)
|
||
else
|
||
push!(predict_label, -1) # predict more than 1 label. add non-count label.
|
||
end
|
||
end
|
||
answerCorrectly = sum([x == y for (x,y) in zip(predict_label, labels)])
|
||
totalAnswerCorrectly += answerCorrectly
|
||
totalSignal += batch
|
||
end
|
||
|
||
next!(p)
|
||
end
|
||
|
||
percentCorrect = totalAnswerCorrectly * 100.0 / totalSignal
|
||
|
||
return percentCorrect::Float64
|
||
end
|
||
|
||
""" inputsignals is normal column-major julia matrix in (row, col, batch) dimension
|
||
- each threshold scan return 2 vectors. 1 for +, 1 for -
|
||
- noise = (true/false, row, col, probability)
|
||
"""
|
||
function spikeGenerator(inputsignals, thresholds=[1.0]; noise=(false, 1, 0.5), copies=0)
|
||
s = length(size(inputsignals))
|
||
ar = [] # holding all signals that are scanned
|
||
for slice in eachslice(inputsignals, dims=s)
|
||
signal_jl = reshape(slice, (:, 1)) # python array is row-major
|
||
signal_pytensor = torch.from_numpy( np.asarray(signal_jl) )
|
||
|
||
arr = [] # holding signal that is scanned by several thresholds
|
||
for threshold in thresholds
|
||
spike_py = spikegen.delta(signal_pytensor, threshold=threshold, off_spike=true)
|
||
_spike_jl = pyconvert(Array, spike_py.data.numpy())
|
||
spike_jl = reshape(_spike_jl, (1, :)) # reshape back to julia's column-major
|
||
spike_jl1 = isequal.(spike_jl, 1)
|
||
spike_jl2 = isequal.(spike_jl, -1)
|
||
arr = length(arr) == 0 ? [spike_jl1; spike_jl2] : [arr; spike_jl1; spike_jl2]
|
||
end
|
||
arrSize = [size(arr)...]
|
||
arr = reshape(arr, (arrSize[1], 1, arrSize[2])) # reshape into (row, 1, timestep)
|
||
|
||
# multiply col
|
||
if copies > 0
|
||
a = deepcopy(arr)
|
||
for i in 1:copies
|
||
arr = cat(arr, a, dims=2)
|
||
end
|
||
end
|
||
|
||
if noise[1] == true
|
||
arrSize = [size(arr)...]
|
||
n = noiseGenerator(arrSize[1], noise[2], arrSize[3], prob=noise[3])
|
||
arr = cat(arr, n, dims=2) # concatenate into (row, signal:noise, timestep)
|
||
end
|
||
|
||
# concatenate into (row, signal:noise, timestep, batch)
|
||
ar = length(ar) == 0 ? arr : [ar;;;;arr]
|
||
end
|
||
return ar
|
||
end
|
||
|
||
function noiseGenerator(row, col, z; prob=0.5)
|
||
spike_prob = torch.rand(row, col, z) * prob
|
||
spike_rand = spikegen.rate_conv(spike_prob)
|
||
noise = isequal.(pyconvert(Array, spike_rand.data.numpy()), 1)
|
||
|
||
return noise
|
||
end
|
||
|
||
# function arrayMax(x)
|
||
# if sum(GeneralUtils.isNotEqual.(x, 0)) == 0 # guard against all-zeros array
|
||
# return GeneralUtils.isNotEqual.(x, 0)
|
||
# else
|
||
# return isequal.(x, maximum(x))
|
||
# end
|
||
# end
|
||
# arraySliceMax(x) = mapslices(arrayMax, x; dims=1)
|
||
|
||
function main()
|
||
filelocation = string(@__DIR__)
|
||
|
||
filename = "$modelname.jl163"
|
||
|
||
training_start_time = Dates.now()
|
||
println("$modelname program started $training_start_time")
|
||
|
||
model = generate_snn(filename, filelocation)
|
||
|
||
trainDataset, validateDataset, labelDict = data_loader()
|
||
|
||
train_snn(model, trainDataset, validateDataset, labelDict)
|
||
|
||
finish_training_time = Dates.now()
|
||
println("training done, $training_start_time ==> $finish_training_time ")
|
||
println(" ///////////////////////////////////////////////////////////////////////")
|
||
end
|
||
|
||
# only runs main() if julia isn’t started interactively
|
||
# https://discourse.julialang.org/t/scripting-like-a-julian/50707
|
||
!isinteractive() && main()
|
||
#------------------------------------------------------------------------------------------------100
|
||
|
||
|
||
|
||
|
||
|
||
|