add 0.0.6
This commit is contained in:
127
example_main.jl
127
example_main.jl
@@ -7,17 +7,17 @@ using MLUtils, Images, ProgressMeter, Dates, DataFrames, Random, Statistics, Lin
|
||||
BenchmarkTools, Serialization, OneHotArrays , GLMakie # ClickHouse
|
||||
|
||||
|
||||
# if one need to reinstall all python packages
|
||||
# try Pkg.rm("PythonCall") catch end # should be removed before using CondaPkg to install packages
|
||||
# # if one need to reinstall all python packages
|
||||
# # try Pkg.rm("PythonCall") catch end # should be removed before using CondaPkg to install packages
|
||||
# condapackage = ["numpy", "pytorch", "snntorch"]
|
||||
# using CondaPkg # in CondaPkg.toml file, channels = ["anaconda", "conda-forge", "pytorch"]
|
||||
# CondaPkg.add_channel("pytorch")
|
||||
# for i in condapackage
|
||||
# try CondaPkg.rm(i) catch end
|
||||
# end
|
||||
# for i in condapackage
|
||||
# CondaPkg.add(i)
|
||||
# end
|
||||
# Pkg.add("PythonCall");
|
||||
|
||||
using PythonCall;
|
||||
np = pyimport("numpy")
|
||||
@@ -31,10 +31,9 @@ sep = Sys.iswindows() ? "\\" : "/"
|
||||
rootDir = pwd()
|
||||
|
||||
# select compute device
|
||||
# device = Flux.CUDA.functional() ? gpu : cpu
|
||||
# if device == gpu
|
||||
# CUDA.device!(3)
|
||||
# end
|
||||
# device = Flux.CUDA.functional() ? gpu : cpu # Flux provide "cpu" and "gpu" keywork
|
||||
# if device == gpu CUDA.device!(3) end
|
||||
# CUDA.allowscalar(false) # turn off scalar indexing
|
||||
#------------------------------------------------------------------------------------------------100
|
||||
|
||||
|
||||
@@ -59,12 +58,12 @@ database_ip = "localhost"
|
||||
#------------------------------------------------------------------------------------------------100
|
||||
|
||||
function generate_snn(filename::String, location::String)
|
||||
expect_compute_neuron_numbers = 1024 #FIXME change to 512
|
||||
expect_compute_neuron_numbers = 1024
|
||||
signalInput_portnumbers = 50
|
||||
noise_portnumbers = signalInput_portnumbers
|
||||
noise_portnumbers = 1 #signalInput_portnumbers
|
||||
output_portnumbers = 10
|
||||
|
||||
lif_neuron_number = Int(floor(expect_compute_neuron_numbers * 0.4))
|
||||
lif_neuron_number = Int(floor(expect_compute_neuron_numbers * 0.6))
|
||||
alif_neuron_number = expect_compute_neuron_numbers - lif_neuron_number # from Allen Institute, ALIF is 20-40% of LIF
|
||||
computeNeuronNumber = lif_neuron_number + alif_neuron_number
|
||||
|
||||
@@ -80,8 +79,8 @@ function generate_snn(filename::String, location::String)
|
||||
:type => "lifNeuron",
|
||||
:v_t_default => 0.0,
|
||||
:v_th => 1.0, # neuron firing threshold (this value is treated as maximum bound if I use auto generate)
|
||||
:tau_m => 200.0, # membrane time constant in millisecond.
|
||||
:eta => 1e-2,
|
||||
:tau_m => 20.0, # membrane time constant in millisecond.
|
||||
:eta => 1e-6,
|
||||
# Good starting value is 1/10th of tau_a
|
||||
# This is problem specific parameter. It controls how leaky the neuron is.
|
||||
# Too high(less leaky) makes learning algo harder to move model into direction that reduce error
|
||||
@@ -89,7 +88,7 @@ function generate_snn(filename::String, location::String)
|
||||
# exert more force (larger w_out_change) to move neuron into direction that reduce error
|
||||
# For example, model error from 7 to 2e6.
|
||||
|
||||
:synapticConnectionPercent => 50, # % coverage of total neurons in kfn
|
||||
:synapticConnectionPercent => 20, # % coverage of total neurons in kfn
|
||||
:w_rec_generation_pattern => "random",
|
||||
)
|
||||
|
||||
@@ -97,8 +96,8 @@ function generate_snn(filename::String, location::String)
|
||||
:type => "alifNeuron",
|
||||
:v_t_default => 0.0,
|
||||
:v_th => 1.0, # neuron firing threshold (this value is treated as maximum bound if I use auto generate)
|
||||
:tau_m => 200.0, # membrane time constant in millisecond.
|
||||
:eta => 1e-2,
|
||||
:tau_m => 20.0, # membrane time constant in millisecond.
|
||||
:eta => 1e-6,
|
||||
# Good starting value is 1/10th of tau_a
|
||||
# This is problem specific parameter. It controls how leaky the neuron is.
|
||||
# Too high(less leaky) makes learning algo harder to move model into direction that reduce error
|
||||
@@ -106,35 +105,35 @@ function generate_snn(filename::String, location::String)
|
||||
# exert more force (larger w_out_change) to move neuron into direction that reduce error
|
||||
# For example, model error from 7 to 2e6.
|
||||
|
||||
:tau_a => 500.0, # adaptation time constant in millisecond. it defines neuron memory length.
|
||||
:tau_a => 800.0, # adaptation time constant in millisecond. it defines neuron memory length.
|
||||
# This is problem specific parameter
|
||||
# Good starting value is 0.5 to 2 times of info STORE-RECALL length i.e. total time SNN takes to
|
||||
# perform a task, for example, equals to episode length.
|
||||
# From "Spike frequency adaptation supports network computations on temporally dispersed
|
||||
# information"
|
||||
|
||||
:synapticConnectionPercent => 50, # % coverage of total neurons in kfn
|
||||
:synapticConnectionPercent => 20, # % coverage of total neurons in kfn
|
||||
:w_rec_generation_pattern => "random",
|
||||
)
|
||||
|
||||
# linear_neuron_params = Dict{Symbol, Any}(
|
||||
# :type => "linearNeuron",
|
||||
# :v_th => 1.0, # neuron firing threshold (this value is treated as maximum bound if I use auto generate)
|
||||
# :tau_out => 50.0, # output time constant in millisecond.
|
||||
# :synapticConnectionPercent => 100, # % coverage of total neurons in kfn
|
||||
# # Good starting value is 1/50th of tau_a
|
||||
# # This is problem specific parameter.
|
||||
# # It controls how leaky the neuron is.
|
||||
# # Too high(less leaky) makes learning algo harder to move model into direction that reduce error
|
||||
# # resulting in model's error to explode exponantially. For example, model error from 7 to 2e6
|
||||
# # One can image training output neuron is like Tetris Game.
|
||||
# )
|
||||
linear_neuron_params = Dict{Symbol, Any}(
|
||||
:type => "linearNeuron",
|
||||
:v_th => 1.0, # neuron firing threshold (this value is treated as maximum bound if I use auto generate)
|
||||
:tau_out => 50.0, # output time constant in millisecond.
|
||||
:synapticConnectionPercent => 20, # % coverage of total neurons in kfn
|
||||
# Good starting value is 1/50th of tau_a
|
||||
# This is problem specific parameter.
|
||||
# It controls how leaky the neuron is.
|
||||
# Too high(less leaky) makes learning algo harder to move model into direction that reduce error
|
||||
# resulting in model's error to explode exponantially. For example, model error from 7 to 2e6
|
||||
# One can image training output neuron is like Tetris Game.
|
||||
)
|
||||
|
||||
integrate_neuron_params = Dict{Symbol, Any}(
|
||||
:type => "integrateNeuron",
|
||||
:synapticConnectionPercent => 100, # % coverage of total neurons in kfn
|
||||
:eta => 1e-2,
|
||||
:tau_out => 100.0,
|
||||
:eta => 1e-6,
|
||||
:tau_out => 50.0,
|
||||
# Good starting value is 1/50th of tau_a
|
||||
# This is problem specific parameter.
|
||||
# It controls how leaky the neuron is.
|
||||
@@ -146,7 +145,7 @@ function generate_snn(filename::String, location::String)
|
||||
I_kfnparams = Dict{Symbol, Any}(
|
||||
:knowledgeFnName=> "I",
|
||||
:computeNeuronNumber=> computeNeuronNumber,
|
||||
:neuronFiringRateTarget=> 10.0, # Hz
|
||||
:neuronFiringRateTarget=> 20.0, # Hz
|
||||
:Bn=> "random", # error projection coefficient for EACH neuron
|
||||
:totalNeurons=> totalNeurons,
|
||||
:totalInputPort=> totalInputPort,
|
||||
@@ -315,7 +314,7 @@ function train_snn(model_name::String, filename::String, location::String,
|
||||
axislegend(subfig1, position = :lb)
|
||||
|
||||
subfig2 = GLMakie.Axis(fig1[2, 1], # define position of this subfigure inside a figure
|
||||
title = "output neurons activation",
|
||||
title = "output neurons logit",
|
||||
xlabel = "time",
|
||||
ylabel = "data"
|
||||
)
|
||||
@@ -334,7 +333,7 @@ function train_snn(model_name::String, filename::String, location::String,
|
||||
|
||||
|
||||
subfig3 = GLMakie.Axis(fig1[3, 1], # define position of this subfigure inside a figure
|
||||
title = "output neurons membrane potential v_t1",
|
||||
title = "last RSNN wRec",
|
||||
xlabel = "time",
|
||||
ylabel = "data"
|
||||
)
|
||||
@@ -351,7 +350,7 @@ function train_snn(model_name::String, filename::String, location::String,
|
||||
axislegend(subfig3, position = :lb)
|
||||
|
||||
subfig4 = GLMakie.Axis(fig1[4, 1], # define position of this subfigure inside a figure
|
||||
title = "output neuron wRec",
|
||||
title = "RSNN v_t1",
|
||||
xlabel = "time",
|
||||
ylabel = "data"
|
||||
)
|
||||
@@ -407,16 +406,18 @@ function train_snn(model_name::String, filename::String, location::String,
|
||||
|
||||
# model learning
|
||||
maxRepeatRound = 1 # repeat each image
|
||||
thinkingPeriod = 16 # 1000-784 = 216
|
||||
thinkingPeriod = 16 # 1000-784 = 216
|
||||
bestAccuracy = 0.0
|
||||
for epoch = 1:1000
|
||||
println("epoch $epoch")
|
||||
batchCounter = 0
|
||||
for (imgBatch, labelBatch) in trainData
|
||||
batchCounter += 1
|
||||
println("epoch $epoch batch $batchCounter")
|
||||
@showprogress for i in eachindex(labelBatch)
|
||||
_img = (imgBatch[:, :, i])
|
||||
img = reshape(_img, (:, 1))
|
||||
row, col = size(img)
|
||||
label = labelBatch[i]
|
||||
println("epoch $epoch training label $label")
|
||||
|
||||
img_tensor = torch.from_numpy( np.asarray(img) )
|
||||
|
||||
@@ -518,43 +519,23 @@ function train_snn(model_name::String, filename::String, location::String,
|
||||
var3 = [var3;; _var3]
|
||||
var4 = [var4;; _var4]
|
||||
|
||||
# if tick <= row # online learning, 1-by-1 timestep
|
||||
# correctAnswer = zeros(length(logit))
|
||||
# modelError = (logit - correctAnswer) * 1.0
|
||||
# Ironpen.compute_wRecChange!(model, modelError, correctAnswer)
|
||||
# elseif tick == row+1
|
||||
# correctAnswer = OneHotArrays.onehot(label, labelDict)
|
||||
# modelError = (logit - correctAnswer) * 1.0
|
||||
# Ironpen.compute_wRecChange!(model, modelError, correctAnswer)
|
||||
# elseif tick > row+1 && tick < row+thinkingPeriod
|
||||
# correctAnswer = OneHotArrays.onehot(label, labelDict)
|
||||
# modelError = (logit - correctAnswer) * 1.0
|
||||
# Ironpen.compute_wRecChange!(model, modelError, correctAnswer)
|
||||
# elseif tick == row+thinkingPeriod
|
||||
# _predict = logitLog[:, end-thinkingPeriod+1:end] # answer count during thinking period
|
||||
# _predict = Int.([sum(row) for row in eachrow(_predict)])
|
||||
# # predict = [x > 0 for x in _predict]
|
||||
# correctAnswer = OneHotArrays.onehot(label, labelDict)
|
||||
# modelError = (logit - correctAnswer) * 1.0
|
||||
# Ironpen.compute_wRecChange!(model, modelError, correctAnswer)
|
||||
# Ironpen.learn!(model)
|
||||
# println("label $label predict $(_predict) model error $(Int.(modelError))")
|
||||
# else
|
||||
# error("undefined condition line $(@__LINE__)")
|
||||
# end
|
||||
|
||||
if tick <= row # online learning, 1-by-1 timestep
|
||||
if tick < row # online learning, 1-by-1 timestep
|
||||
# no error calculation
|
||||
elseif tick == row # online learning, 1-by-1 timestep
|
||||
correctAnswer = OneHotArrays.onehot(label, labelDict)
|
||||
modelError = Flux.logitcrossentropy(logit, correctAnswer) * 1.0
|
||||
outputError = (logit - correctAnswer) * 1.0
|
||||
Ironpen.compute_paramsChange!(model, modelError, outputError)
|
||||
elseif tick > row && tick < row+thinkingPeriod
|
||||
# correctAnswer = OneHotArrays.onehot(label, labelDict)
|
||||
# modelError = (logit - correctAnswer) * 1.0
|
||||
# Ironpen.compute_wRecChange!(model, modelError, correctAnswer)
|
||||
|
||||
correctAnswer = OneHotArrays.onehot(label, labelDict)
|
||||
modelError = Flux.logitcrossentropy(logit, correctAnswer) * 1.0
|
||||
outputError = (logit - correctAnswer) * 1.0
|
||||
Ironpen.compute_paramsChange!(model, modelError, outputError)
|
||||
elseif tick == row+thinkingPeriod
|
||||
correctAnswer = OneHotArrays.onehot(label, labelDict)
|
||||
modelError = Flux.logitcrossentropy(logit, correctAnswer) * 1.0
|
||||
outputError = (logit - correctAnswer) * 1.0
|
||||
Ironpen.compute_wRecChange!(model, modelError, outputError)
|
||||
Ironpen.compute_paramsChange!(model, modelError, outputError)
|
||||
Ironpen.learn!(model)
|
||||
_logit = round.(logit; digits=2)
|
||||
predict = findall(isequal.(logit, maximum(logit)))[1] - 1
|
||||
@@ -670,7 +651,8 @@ function train_snn(model_name::String, filename::String, location::String,
|
||||
# check accuracy
|
||||
println("validating model")
|
||||
answerCorrectly = validate(model, validateData, labelDict)
|
||||
println("model accuracy is $answerCorrectly %")
|
||||
bestAccuracy = answerCorrectly > bestAccuracy ? answerCorrectly : bestAccuracy
|
||||
println("model accuracy is $answerCorrectly %, best accuracy is $bestAccuracy")
|
||||
end
|
||||
|
||||
# # check mean error and accuracy
|
||||
@@ -819,8 +801,3 @@ end
|
||||
!isinteractive() && main()
|
||||
#------------------------------------------------------------------------------------------------100
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user