working milestone with 25% accuracy
This commit is contained in:
826
example_main.jl
Normal file
826
example_main.jl
Normal file
@@ -0,0 +1,826 @@
|
|||||||
|
using Pkg; Pkg.activate("."); Pkg.resolve(), Pkg.instantiate()
|
||||||
|
using Revise
|
||||||
|
using Flux #, CUDA
|
||||||
|
using BSON, JSON3
|
||||||
|
using MLDatasets: MNIST
|
||||||
|
using MLUtils, Images, ProgressMeter, Dates, DataFrames, Random, Statistics, LinearAlgebra,
|
||||||
|
BenchmarkTools, Serialization, OneHotArrays , GLMakie # ClickHouse
|
||||||
|
|
||||||
|
|
||||||
|
# if one need to reinstall all python packages
|
||||||
|
# try Pkg.rm("PythonCall") catch end # should be removed before using CondaPkg to install packages
|
||||||
|
# condapackage = ["numpy", "pytorch", "snntorch"]
|
||||||
|
# using CondaPkg # in CondaPkg.toml file, channels = ["anaconda", "conda-forge", "pytorch"]
|
||||||
|
# for i in condapackage
|
||||||
|
# try CondaPkg.rm(i) catch end
|
||||||
|
# end
|
||||||
|
# for i in condapackage
|
||||||
|
# CondaPkg.add(i)
|
||||||
|
# end
|
||||||
|
# Pkg.add("PythonCall");
|
||||||
|
|
||||||
|
using PythonCall;
|
||||||
|
np = pyimport("numpy")
|
||||||
|
torch = pyimport("torch")
|
||||||
|
spikegen = pyimport("snntorch.spikegen") # https://github.com/jeshraghian/snntorch
|
||||||
|
|
||||||
|
using Ironpen
|
||||||
|
using GeneralUtils
|
||||||
|
|
||||||
|
sep = Sys.iswindows() ? "\\" : "/"
|
||||||
|
rootDir = pwd()
|
||||||
|
|
||||||
|
# select compute device
|
||||||
|
# device = Flux.CUDA.functional() ? gpu : cpu
|
||||||
|
# if device == gpu
|
||||||
|
# CUDA.device!(3)
|
||||||
|
# end
|
||||||
|
#------------------------------------------------------------------------------------------------100
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Todo:
|
||||||
|
- []
|
||||||
|
|
||||||
|
Change from version:
|
||||||
|
-
|
||||||
|
|
||||||
|
All features
|
||||||
|
-
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
# communication config --------------------------------------------------------------------------100
|
||||||
|
|
||||||
|
database_ip = "localhost"
|
||||||
|
# database_ip = "192.168.0.8"
|
||||||
|
|
||||||
|
#------------------------------------------------------------------------------------------------100
|
||||||
|
|
||||||
|
function generate_snn(filename::String, location::String)
|
||||||
|
expect_compute_neuron_numbers = 1024 #FIXME change to 512
|
||||||
|
signalInput_portnumbers = 50
|
||||||
|
noise_portnumbers = signalInput_portnumbers
|
||||||
|
output_portnumbers = 10
|
||||||
|
|
||||||
|
lif_neuron_number = Int(floor(expect_compute_neuron_numbers * 0.4))
|
||||||
|
alif_neuron_number = expect_compute_neuron_numbers - lif_neuron_number # from Allen Institute, ALIF is 20-40% of LIF
|
||||||
|
computeNeuronNumber = lif_neuron_number + alif_neuron_number
|
||||||
|
|
||||||
|
totalNeurons = computeNeuronNumber + noise_portnumbers + signalInput_portnumbers
|
||||||
|
totalInputPort = noise_portnumbers + signalInput_portnumbers
|
||||||
|
|
||||||
|
# kfn and neuron config
|
||||||
|
passthrough_neuron_params = Dict(
|
||||||
|
:type => "passthroughNeuron"
|
||||||
|
)
|
||||||
|
|
||||||
|
lif_neuron_params = Dict{Symbol, Any}(
|
||||||
|
:type => "lifNeuron",
|
||||||
|
:v_t_default => 0.0,
|
||||||
|
:v_th => 1.0, # neuron firing threshold (this value is treated as maximum bound if I use auto generate)
|
||||||
|
:tau_m => 200.0, # membrane time constant in millisecond.
|
||||||
|
:eta => 1e-2,
|
||||||
|
# Good starting value is 1/10th of tau_a
|
||||||
|
# This is problem specific parameter. It controls how leaky the neuron is.
|
||||||
|
# Too high(less leaky) makes learning algo harder to move model into direction that reduce error
|
||||||
|
# resulting in model's error to explode exponantially likely because learning algo will try to
|
||||||
|
# exert more force (larger w_out_change) to move neuron into direction that reduce error
|
||||||
|
# For example, model error from 7 to 2e6.
|
||||||
|
|
||||||
|
:synapticConnectionPercent => 50, # % coverage of total neurons in kfn
|
||||||
|
:w_rec_generation_pattern => "random",
|
||||||
|
)
|
||||||
|
|
||||||
|
alif_neuron_params = Dict{Symbol, Any}(
|
||||||
|
:type => "alifNeuron",
|
||||||
|
:v_t_default => 0.0,
|
||||||
|
:v_th => 1.0, # neuron firing threshold (this value is treated as maximum bound if I use auto generate)
|
||||||
|
:tau_m => 200.0, # membrane time constant in millisecond.
|
||||||
|
:eta => 1e-2,
|
||||||
|
# Good starting value is 1/10th of tau_a
|
||||||
|
# This is problem specific parameter. It controls how leaky the neuron is.
|
||||||
|
# Too high(less leaky) makes learning algo harder to move model into direction that reduce error
|
||||||
|
# resulting in model's error to explode exponantially likely because learning algo will try to
|
||||||
|
# exert more force (larger w_out_change) to move neuron into direction that reduce error
|
||||||
|
# For example, model error from 7 to 2e6.
|
||||||
|
|
||||||
|
:tau_a => 500.0, # adaptation time constant in millisecond. it defines neuron memory length.
|
||||||
|
# This is problem specific parameter
|
||||||
|
# Good starting value is 0.5 to 2 times of info STORE-RECALL length i.e. total time SNN takes to
|
||||||
|
# perform a task, for example, equals to episode length.
|
||||||
|
# From "Spike frequency adaptation supports network computations on temporally dispersed
|
||||||
|
# information"
|
||||||
|
|
||||||
|
:synapticConnectionPercent => 50, # % coverage of total neurons in kfn
|
||||||
|
:w_rec_generation_pattern => "random",
|
||||||
|
)
|
||||||
|
|
||||||
|
# linear_neuron_params = Dict{Symbol, Any}(
|
||||||
|
# :type => "linearNeuron",
|
||||||
|
# :v_th => 1.0, # neuron firing threshold (this value is treated as maximum bound if I use auto generate)
|
||||||
|
# :tau_out => 50.0, # output time constant in millisecond.
|
||||||
|
# :synapticConnectionPercent => 100, # % coverage of total neurons in kfn
|
||||||
|
# # Good starting value is 1/50th of tau_a
|
||||||
|
# # This is problem specific parameter.
|
||||||
|
# # It controls how leaky the neuron is.
|
||||||
|
# # Too high(less leaky) makes learning algo harder to move model into direction that reduce error
|
||||||
|
# # resulting in model's error to explode exponantially. For example, model error from 7 to 2e6
|
||||||
|
# # One can image training output neuron is like Tetris Game.
|
||||||
|
# )
|
||||||
|
|
||||||
|
integrate_neuron_params = Dict{Symbol, Any}(
|
||||||
|
:type => "integrateNeuron",
|
||||||
|
:synapticConnectionPercent => 100, # % coverage of total neurons in kfn
|
||||||
|
:eta => 1e-2,
|
||||||
|
:tau_out => 100.0,
|
||||||
|
# Good starting value is 1/50th of tau_a
|
||||||
|
# This is problem specific parameter.
|
||||||
|
# It controls how leaky the neuron is.
|
||||||
|
# Too high(less leaky) makes learning algo harder to move model into direction that reduce error
|
||||||
|
# resulting in model's error to explode exponantially. For example, model error from 7 to 2e6
|
||||||
|
# One can image training output neuron is like Tetris Game.
|
||||||
|
)
|
||||||
|
|
||||||
|
I_kfnparams = Dict{Symbol, Any}(
|
||||||
|
:knowledgeFnName=> "I",
|
||||||
|
:computeNeuronNumber=> computeNeuronNumber,
|
||||||
|
:neuronFiringRateTarget=> 10.0, # Hz
|
||||||
|
:Bn=> "random", # error projection coefficient for EACH neuron
|
||||||
|
:totalNeurons=> totalNeurons,
|
||||||
|
:totalInputPort=> totalInputPort,
|
||||||
|
:totalComputeNeuron=> computeNeuronNumber,
|
||||||
|
|
||||||
|
# group relavent info
|
||||||
|
:inputPort=> Dict(
|
||||||
|
:noise=> Dict(
|
||||||
|
:numbers=> noise_portnumbers,
|
||||||
|
:params=> passthrough_neuron_params,
|
||||||
|
),
|
||||||
|
:signal=> Dict(
|
||||||
|
:numbers=> signalInput_portnumbers, # in case of GloVe word encoding, it is 300
|
||||||
|
:params=> passthrough_neuron_params,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
:outputPort=> Dict(
|
||||||
|
:numbers=> output_portnumbers, # output neuron, this is also the output length
|
||||||
|
:params=> integrate_neuron_params,
|
||||||
|
),
|
||||||
|
:computeNeuron=> Dict(
|
||||||
|
:1=> Dict(
|
||||||
|
:numbers=> lif_neuron_number,
|
||||||
|
:params=> lif_neuron_params,
|
||||||
|
),
|
||||||
|
:2=> Dict(
|
||||||
|
:numbers=> alif_neuron_number,
|
||||||
|
:params=> alif_neuron_params,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
#------------------------------------------------------------------------------------------------100
|
||||||
|
|
||||||
|
I_kfn = Ironpen.kfn_1(I_kfnparams)
|
||||||
|
|
||||||
|
model_params_1 = Dict(:knowledgeFn => Dict(
|
||||||
|
:I => I_kfn),
|
||||||
|
)
|
||||||
|
|
||||||
|
model = Ironpen.model(model_params_1)
|
||||||
|
|
||||||
|
serialize(location * sep * filename, model)
|
||||||
|
println("SNN generated")
|
||||||
|
end
|
||||||
|
|
||||||
|
function data_loader()
|
||||||
|
# test problem
|
||||||
|
fullTrainDataset = MNIST(:train)
|
||||||
|
prototypeDataset = fullTrainDataset[1:10] # use reshape(test_dataset[1], (:, 1)) to flaten matrix
|
||||||
|
trainDataset = fullTrainDataset # total 60000
|
||||||
|
validateDataset = fullTrainDataset[1:100]
|
||||||
|
labelDict = [0:9...]
|
||||||
|
|
||||||
|
trainData = MLUtils.DataLoader(
|
||||||
|
trainDataset; # fullTrainDataset or trainDataset
|
||||||
|
batchsize=100,
|
||||||
|
collate=true,
|
||||||
|
shuffle=true,
|
||||||
|
buffer=true,
|
||||||
|
partial=false, # better for gpu memory if batchsize is fixed
|
||||||
|
# parallel=true, #BUG ?? causing dataloader into forever loop
|
||||||
|
)
|
||||||
|
|
||||||
|
validateData = MLUtils.DataLoader(
|
||||||
|
validateDataset;
|
||||||
|
batchsize=1,
|
||||||
|
collate=true,
|
||||||
|
shuffle=true,
|
||||||
|
buffer=true,
|
||||||
|
partial=false, # better for gpu memory if batchsize is fixed
|
||||||
|
# parallel=true, #BUG ?? causing dataloader into forever loop
|
||||||
|
)
|
||||||
|
|
||||||
|
#CHANGE dummy data used to debug
|
||||||
|
# trainData = [(rand(10, 10), [5]), (rand(10, 10), [2])]
|
||||||
|
# trainData = [(rand(10, 10), [5]),]
|
||||||
|
|
||||||
|
return trainData, validateData, labelDict
|
||||||
|
end
|
||||||
|
|
||||||
|
function train_snn(model_name::String, filename::String, location::String,
|
||||||
|
trainData, validateData, labelDict::Vector)
|
||||||
|
println("loading SNN model")
|
||||||
|
|
||||||
|
model = deserialize(location * sep * filename)
|
||||||
|
println("model loading completed")
|
||||||
|
|
||||||
|
# random seed
|
||||||
|
# rng = MersenneTwister(1234)
|
||||||
|
|
||||||
|
logitLog = zeros(10, 2)
|
||||||
|
firedNeurons_t1 = zeros(1)
|
||||||
|
var1 = zeros(10, 2)
|
||||||
|
var2 = zeros(10, 2)
|
||||||
|
var3 = zeros(10, 2)
|
||||||
|
var4 = zeros(10, 2)
|
||||||
|
|
||||||
|
# ----------------------------------- plot ----------------------------------- #
|
||||||
|
plot10 = Observable(firedNeurons_t1)
|
||||||
|
|
||||||
|
plot20 = Observable(logitLog[1 , :])
|
||||||
|
plot21 = Observable(logitLog[2 , :])
|
||||||
|
plot22 = Observable(logitLog[3 , :])
|
||||||
|
plot23 = Observable(logitLog[4 , :])
|
||||||
|
plot24 = Observable(logitLog[5 , :])
|
||||||
|
plot25 = Observable(logitLog[6 , :])
|
||||||
|
plot26 = Observable(logitLog[7 , :])
|
||||||
|
plot27 = Observable(logitLog[8 , :])
|
||||||
|
plot28 = Observable(logitLog[9 , :])
|
||||||
|
plot29 = Observable(logitLog[10, :])
|
||||||
|
|
||||||
|
plot30 = Observable(var1[1 , :])
|
||||||
|
plot31 = Observable(var1[2 , :])
|
||||||
|
plot32 = Observable(var1[3 , :])
|
||||||
|
plot33 = Observable(var1[4 , :])
|
||||||
|
plot34 = Observable(var1[5 , :])
|
||||||
|
plot35 = Observable(var1[6 , :])
|
||||||
|
plot36 = Observable(var1[7 , :])
|
||||||
|
plot37 = Observable(var1[8 , :])
|
||||||
|
plot38 = Observable(var1[9 , :])
|
||||||
|
plot39 = Observable(var1[10, :])
|
||||||
|
|
||||||
|
plot40 = Observable(var2[1 , :])
|
||||||
|
plot41 = Observable(var2[2 , :])
|
||||||
|
plot42 = Observable(var2[3 , :])
|
||||||
|
plot43 = Observable(var2[4 , :])
|
||||||
|
plot44 = Observable(var2[5 , :])
|
||||||
|
plot45 = Observable(var2[6 , :])
|
||||||
|
plot46 = Observable(var2[7 , :])
|
||||||
|
plot47 = Observable(var2[8 , :])
|
||||||
|
plot48 = Observable(var2[9 , :])
|
||||||
|
plot49 = Observable(var2[10, :])
|
||||||
|
|
||||||
|
plot50 = Observable(var3[1 , :])
|
||||||
|
plot51 = Observable(var3[2 , :])
|
||||||
|
plot52 = Observable(var3[3 , :])
|
||||||
|
plot53 = Observable(var3[4 , :])
|
||||||
|
plot54 = Observable(var3[5 , :])
|
||||||
|
plot55 = Observable(var3[6 , :])
|
||||||
|
plot56 = Observable(var3[7 , :])
|
||||||
|
plot57 = Observable(var3[8 , :])
|
||||||
|
plot58 = Observable(var3[9 , :])
|
||||||
|
plot59 = Observable(var3[10, :])
|
||||||
|
|
||||||
|
plot60 = Observable(var4[1 , :])
|
||||||
|
plot61 = Observable(var4[2 , :])
|
||||||
|
plot62 = Observable(var4[3 , :])
|
||||||
|
plot63 = Observable(var4[4 , :])
|
||||||
|
plot64 = Observable(var4[5 , :])
|
||||||
|
plot65 = Observable(var4[6 , :])
|
||||||
|
plot66 = Observable(var4[7 , :])
|
||||||
|
plot67 = Observable(var4[8 , :])
|
||||||
|
plot68 = Observable(var4[9 , :])
|
||||||
|
plot69 = Observable(var4[10, :])
|
||||||
|
|
||||||
|
# main figure
|
||||||
|
fig1 = Figure()
|
||||||
|
|
||||||
|
subfig1 = GLMakie.Axis(fig1[1, 1], # define position of this subfigure inside a figure
|
||||||
|
title = "RSNN firedNeurons_t1",
|
||||||
|
xlabel = "time",
|
||||||
|
ylabel = "data"
|
||||||
|
)
|
||||||
|
lines!(subfig1, plot10, label = "firedNeurons_t1")
|
||||||
|
axislegend(subfig1, position = :lb)
|
||||||
|
|
||||||
|
subfig2 = GLMakie.Axis(fig1[2, 1], # define position of this subfigure inside a figure
|
||||||
|
title = "output neurons activation",
|
||||||
|
xlabel = "time",
|
||||||
|
ylabel = "data"
|
||||||
|
)
|
||||||
|
|
||||||
|
lines!(subfig2, plot20, label = "0", color = 1, colormap = :tab10, colorrange = (1, 10) )
|
||||||
|
lines!(subfig2, plot21, label = "1", color = 2, colormap = :tab10, colorrange = (1, 10) )
|
||||||
|
lines!(subfig2, plot22, label = "2", color = 3, colormap = :tab10, colorrange = (1, 10) )
|
||||||
|
lines!(subfig2, plot23, label = "3", color = 4, colormap = :tab10, colorrange = (1, 10) )
|
||||||
|
lines!(subfig2, plot24, label = "4", color = 5, colormap = :tab10, colorrange = (1, 10) )
|
||||||
|
lines!(subfig2, plot25, label = "5", color = 6, colormap = :tab10, colorrange = (1, 10) )
|
||||||
|
lines!(subfig2, plot26, label = "6", color = 7, colormap = :tab10, colorrange = (1, 10) )
|
||||||
|
lines!(subfig2, plot27, label = "7", color = 8, colormap = :tab10, colorrange = (1, 10) )
|
||||||
|
lines!(subfig2, plot28, label = "8", color = 9, colormap = :tab10, colorrange = (1, 10) )
|
||||||
|
lines!(subfig2, plot29, label = "9", color = 10, colormap = :tab10, colorrange = (1, 10))
|
||||||
|
axislegend(subfig2, position = :lb)
|
||||||
|
|
||||||
|
|
||||||
|
subfig3 = GLMakie.Axis(fig1[3, 1], # define position of this subfigure inside a figure
|
||||||
|
title = "output neurons membrane potential v_t1",
|
||||||
|
xlabel = "time",
|
||||||
|
ylabel = "data"
|
||||||
|
)
|
||||||
|
lines!(subfig3, plot30, label = "0", color = 1, colormap = :tab10, colorrange = (1, 10) )
|
||||||
|
lines!(subfig3, plot31, label = "1", color = 2, colormap = :tab10, colorrange = (1, 10) )
|
||||||
|
lines!(subfig3, plot32, label = "2", color = 3, colormap = :tab10, colorrange = (1, 10) )
|
||||||
|
lines!(subfig3, plot33, label = "3", color = 4, colormap = :tab10, colorrange = (1, 10) )
|
||||||
|
lines!(subfig3, plot34, label = "4", color = 5, colormap = :tab10, colorrange = (1, 10) )
|
||||||
|
lines!(subfig3, plot35, label = "5", color = 6, colormap = :tab10, colorrange = (1, 10) )
|
||||||
|
lines!(subfig3, plot36, label = "6", color = 7, colormap = :tab10, colorrange = (1, 10) )
|
||||||
|
lines!(subfig3, plot37, label = "7", color = 8, colormap = :tab10, colorrange = (1, 10) )
|
||||||
|
lines!(subfig3, plot38, label = "8", color = 9, colormap = :tab10, colorrange = (1, 10) )
|
||||||
|
lines!(subfig3, plot39, label = "9", color = 10, colormap = :tab10, colorrange = (1, 10))
|
||||||
|
axislegend(subfig3, position = :lb)
|
||||||
|
|
||||||
|
subfig4 = GLMakie.Axis(fig1[4, 1], # define position of this subfigure inside a figure
|
||||||
|
title = "output neuron wRec",
|
||||||
|
xlabel = "time",
|
||||||
|
ylabel = "data"
|
||||||
|
)
|
||||||
|
lines!(subfig4, plot40, label = "0", color = 1, colormap = :tab10, colorrange = (1, 10) )
|
||||||
|
lines!(subfig4, plot41, label = "1", color = 2, colormap = :tab10, colorrange = (1, 10) )
|
||||||
|
lines!(subfig4, plot42, label = "2", color = 3, colormap = :tab10, colorrange = (1, 10) )
|
||||||
|
lines!(subfig4, plot43, label = "3", color = 4, colormap = :tab10, colorrange = (1, 10) )
|
||||||
|
lines!(subfig4, plot44, label = "4", color = 5, colormap = :tab10, colorrange = (1, 10) )
|
||||||
|
lines!(subfig4, plot45, label = "5", color = 6, colormap = :tab10, colorrange = (1, 10) )
|
||||||
|
lines!(subfig4, plot46, label = "6", color = 7, colormap = :tab10, colorrange = (1, 10) )
|
||||||
|
lines!(subfig4, plot47, label = "7", color = 8, colormap = :tab10, colorrange = (1, 10) )
|
||||||
|
lines!(subfig4, plot48, label = "8", color = 9, colormap = :tab10, colorrange = (1, 10) )
|
||||||
|
lines!(subfig4, plot49, label = "9", color = 10, colormap = :tab10, colorrange = (1, 10))
|
||||||
|
axislegend(subfig4, position = :lb)
|
||||||
|
|
||||||
|
subfig5 = GLMakie.Axis(fig1[5, 1], # define position of this subfigure inside a figure
|
||||||
|
title = "output neuron epsilonRec",
|
||||||
|
xlabel = "time",
|
||||||
|
ylabel = "data"
|
||||||
|
)
|
||||||
|
lines!(subfig5, plot50, label = "0", color = 1, colormap = :tab10, colorrange = (1, 10) )
|
||||||
|
lines!(subfig5, plot51, label = "1", color = 2, colormap = :tab10, colorrange = (1, 10) )
|
||||||
|
lines!(subfig5, plot52, label = "2", color = 3, colormap = :tab10, colorrange = (1, 10) )
|
||||||
|
lines!(subfig5, plot53, label = "3", color = 4, colormap = :tab10, colorrange = (1, 10) )
|
||||||
|
lines!(subfig5, plot54, label = "4", color = 5, colormap = :tab10, colorrange = (1, 10) )
|
||||||
|
lines!(subfig5, plot55, label = "5", color = 6, colormap = :tab10, colorrange = (1, 10) )
|
||||||
|
lines!(subfig5, plot56, label = "6", color = 7, colormap = :tab10, colorrange = (1, 10) )
|
||||||
|
lines!(subfig5, plot57, label = "7", color = 8, colormap = :tab10, colorrange = (1, 10) )
|
||||||
|
lines!(subfig5, plot58, label = "8", color = 9, colormap = :tab10, colorrange = (1, 10) )
|
||||||
|
lines!(subfig5, plot59, label = "9", color = 10, colormap = :tab10, colorrange = (1, 10))
|
||||||
|
axislegend(subfig5, position = :lb)
|
||||||
|
|
||||||
|
subfig6 = GLMakie.Axis(fig1[6, 1], # define position of this subfigure inside a figure
|
||||||
|
title = "output neuron wRecChange",
|
||||||
|
xlabel = "time",
|
||||||
|
ylabel = "data"
|
||||||
|
)
|
||||||
|
lines!(subfig6, plot60, label = "0", color = 1, colormap = :tab10, colorrange = (1, 10) )
|
||||||
|
lines!(subfig6, plot61, label = "1", color = 2, colormap = :tab10, colorrange = (1, 10) )
|
||||||
|
lines!(subfig6, plot62, label = "2", color = 3, colormap = :tab10, colorrange = (1, 10) )
|
||||||
|
lines!(subfig6, plot63, label = "3", color = 4, colormap = :tab10, colorrange = (1, 10) )
|
||||||
|
lines!(subfig6, plot64, label = "4", color = 5, colormap = :tab10, colorrange = (1, 10) )
|
||||||
|
lines!(subfig6, plot65, label = "5", color = 6, colormap = :tab10, colorrange = (1, 10) )
|
||||||
|
lines!(subfig6, plot66, label = "6", color = 7, colormap = :tab10, colorrange = (1, 10) )
|
||||||
|
lines!(subfig6, plot67, label = "7", color = 8, colormap = :tab10, colorrange = (1, 10) )
|
||||||
|
lines!(subfig6, plot68, label = "8", color = 9, colormap = :tab10, colorrange = (1, 10) )
|
||||||
|
lines!(subfig6, plot69, label = "9", color = 10, colormap = :tab10, colorrange = (1, 10))
|
||||||
|
axislegend(subfig6, position = :lb)
|
||||||
|
|
||||||
|
# wait(display(fig1))
|
||||||
|
# display(fig1)
|
||||||
|
# --------------------------------- end plot --------------------------------- #
|
||||||
|
|
||||||
|
# model learning
|
||||||
|
maxRepeatRound = 1 # repeat each image
|
||||||
|
thinkingPeriod = 16 # 1000-784 = 216
|
||||||
|
for epoch = 1:1000
|
||||||
|
println("epoch $epoch")
|
||||||
|
for (imgBatch, labelBatch) in trainData
|
||||||
|
@showprogress for i in eachindex(labelBatch)
|
||||||
|
_img = (imgBatch[:, :, i])
|
||||||
|
img = reshape(_img, (:, 1))
|
||||||
|
row, col = size(img)
|
||||||
|
label = labelBatch[i]
|
||||||
|
println("epoch $epoch training label $label")
|
||||||
|
|
||||||
|
img_tensor = torch.from_numpy( np.asarray(img) )
|
||||||
|
|
||||||
|
# create more data for RSNN
|
||||||
|
spike = spikegen.delta(img_tensor, threshold=0.1, off_spike=true)
|
||||||
|
spike1 = isequal.(pyconvert(Array, spike.data.numpy()), 1)
|
||||||
|
spike2 = isequal.(pyconvert(Array, spike.data.numpy()), 1)
|
||||||
|
spike3 = isequal.(pyconvert(Array, spike.data.numpy()), 1)
|
||||||
|
spike4 = isequal.(pyconvert(Array, spike.data.numpy()), 1)
|
||||||
|
spike5 = isequal.(pyconvert(Array, spike.data.numpy()), 1)
|
||||||
|
spike6 = isequal.(pyconvert(Array, spike.data.numpy()), -1)
|
||||||
|
spike7 = isequal.(pyconvert(Array, spike.data.numpy()), -1)
|
||||||
|
spike8 = isequal.(pyconvert(Array, spike.data.numpy()), -1)
|
||||||
|
spike9 = isequal.(pyconvert(Array, spike.data.numpy()), -1)
|
||||||
|
spike10 = isequal.(pyconvert(Array, spike.data.numpy()), -1)
|
||||||
|
|
||||||
|
spike = spikegen.delta(img_tensor, threshold=0.2, off_spike=true)
|
||||||
|
spike11 = isequal.(pyconvert(Array, spike.data.numpy()), 1)
|
||||||
|
spike12 = isequal.(pyconvert(Array, spike.data.numpy()), 1)
|
||||||
|
spike13 = isequal.(pyconvert(Array, spike.data.numpy()), 1)
|
||||||
|
spike14 = isequal.(pyconvert(Array, spike.data.numpy()), 1)
|
||||||
|
spike15 = isequal.(pyconvert(Array, spike.data.numpy()), 1)
|
||||||
|
spike16 = isequal.(pyconvert(Array, spike.data.numpy()), -1)
|
||||||
|
spike17 = isequal.(pyconvert(Array, spike.data.numpy()), -1)
|
||||||
|
spike18 = isequal.(pyconvert(Array, spike.data.numpy()), -1)
|
||||||
|
spike19 = isequal.(pyconvert(Array, spike.data.numpy()), -1)
|
||||||
|
spike20 = isequal.(pyconvert(Array, spike.data.numpy()), -1)
|
||||||
|
|
||||||
|
spike = spikegen.delta(img_tensor, threshold=0.3, off_spike=true)
|
||||||
|
spike21 = isequal.(pyconvert(Array, spike.data.numpy()), 1)
|
||||||
|
spike22 = isequal.(pyconvert(Array, spike.data.numpy()), 1)
|
||||||
|
spike23 = isequal.(pyconvert(Array, spike.data.numpy()), 1)
|
||||||
|
spike24 = isequal.(pyconvert(Array, spike.data.numpy()), 1)
|
||||||
|
spike25 = isequal.(pyconvert(Array, spike.data.numpy()), 1)
|
||||||
|
spike26 = isequal.(pyconvert(Array, spike.data.numpy()), -1)
|
||||||
|
spike27 = isequal.(pyconvert(Array, spike.data.numpy()), -1)
|
||||||
|
spike28 = isequal.(pyconvert(Array, spike.data.numpy()), -1)
|
||||||
|
spike29 = isequal.(pyconvert(Array, spike.data.numpy()), -1)
|
||||||
|
spike30 = isequal.(pyconvert(Array, spike.data.numpy()), -1)
|
||||||
|
|
||||||
|
spike = spikegen.delta(img_tensor, threshold=0.4, off_spike=true)
|
||||||
|
spike31 = isequal.(pyconvert(Array, spike.data.numpy()), 1)
|
||||||
|
spike32 = isequal.(pyconvert(Array, spike.data.numpy()), 1)
|
||||||
|
spike33 = isequal.(pyconvert(Array, spike.data.numpy()), 1)
|
||||||
|
spike34 = isequal.(pyconvert(Array, spike.data.numpy()), 1)
|
||||||
|
spike35 = isequal.(pyconvert(Array, spike.data.numpy()), 1)
|
||||||
|
spike36 = isequal.(pyconvert(Array, spike.data.numpy()), -1)
|
||||||
|
spike37 = isequal.(pyconvert(Array, spike.data.numpy()), -1)
|
||||||
|
spike38 = isequal.(pyconvert(Array, spike.data.numpy()), -1)
|
||||||
|
spike39 = isequal.(pyconvert(Array, spike.data.numpy()), -1)
|
||||||
|
spike40 = isequal.(pyconvert(Array, spike.data.numpy()), -1)
|
||||||
|
|
||||||
|
spike = spikegen.delta(img_tensor, threshold=0.5, off_spike=true)
|
||||||
|
spike41 = isequal.(pyconvert(Array, spike.data.numpy()), 1)
|
||||||
|
spike42 = isequal.(pyconvert(Array, spike.data.numpy()), 1)
|
||||||
|
spike43 = isequal.(pyconvert(Array, spike.data.numpy()), 1)
|
||||||
|
spike44 = isequal.(pyconvert(Array, spike.data.numpy()), 1)
|
||||||
|
spike45 = isequal.(pyconvert(Array, spike.data.numpy()), 1)
|
||||||
|
spike46 = isequal.(pyconvert(Array, spike.data.numpy()), -1)
|
||||||
|
spike47 = isequal.(pyconvert(Array, spike.data.numpy()), -1)
|
||||||
|
spike48 = isequal.(pyconvert(Array, spike.data.numpy()), -1)
|
||||||
|
spike49 = isequal.(pyconvert(Array, spike.data.numpy()), -1)
|
||||||
|
spike50 = isequal.(pyconvert(Array, spike.data.numpy()), -1)
|
||||||
|
|
||||||
|
input = [spike1;; spike2;; spike3;; spike4;; spike5;; spike6;; spike7;; spike8;; spike9;; spike10;;
|
||||||
|
spike11;; spike12;; spike13;; spike14;; spike15;; spike16;; spike17;; spike18;; spike19;; spike20;;
|
||||||
|
spike21;; spike22;; spike23;; spike24;; spike25;; spike26;; spike27;; spike28;; spike29;; spike30;;
|
||||||
|
spike31;; spike32;; spike33;; spike34;; spike35;; spike36;; spike37;; spike38;; spike39;; spike40;;
|
||||||
|
spike41;; spike42;; spike43;; spike44;; spike45;; spike46;; spike47;; spike48;; spike49;; spike50
|
||||||
|
]' # ' to flip 784x10 to 10x784
|
||||||
|
|
||||||
|
predict = 0
|
||||||
|
|
||||||
|
for k in 1:maxRepeatRound
|
||||||
|
|
||||||
|
# insert data into model sequencially
|
||||||
|
for i in 1:(row + thinkingPeriod) # sMNIST ihas 784 timestep(pixel) + thinking period = 1000 timestep
|
||||||
|
tick = i
|
||||||
|
if i <= row
|
||||||
|
current_pixel = input[:, i]
|
||||||
|
else
|
||||||
|
current_pixel = zeros(size(input)[1]) # dummy input in "thinking" period
|
||||||
|
end
|
||||||
|
|
||||||
|
if tick == 1 # tell a model to start learning. 1-time only
|
||||||
|
model.learningStage = "start_learning"
|
||||||
|
|
||||||
|
elseif tick == (row+thinkingPeriod)
|
||||||
|
model.learningStage = "end_learning"
|
||||||
|
else
|
||||||
|
end
|
||||||
|
|
||||||
|
_firedNeurons_t1, logit, _var1, _var2, _var3, _var4 = model(current_pixel)
|
||||||
|
# log answer of all timestep
|
||||||
|
logitLog = [logitLog;; logit]
|
||||||
|
firedNeurons_t1 = push!(firedNeurons_t1, _firedNeurons_t1)
|
||||||
|
var1 = [var1;; _var1]
|
||||||
|
var2 = [var2;; _var2]
|
||||||
|
var3 = [var3;; _var3]
|
||||||
|
var4 = [var4;; _var4]
|
||||||
|
|
||||||
|
# if tick <= row # online learning, 1-by-1 timestep
|
||||||
|
# correctAnswer = zeros(length(logit))
|
||||||
|
# modelError = (logit - correctAnswer) * 1.0
|
||||||
|
# Ironpen.compute_wRecChange!(model, modelError, correctAnswer)
|
||||||
|
# elseif tick == row+1
|
||||||
|
# correctAnswer = OneHotArrays.onehot(label, labelDict)
|
||||||
|
# modelError = (logit - correctAnswer) * 1.0
|
||||||
|
# Ironpen.compute_wRecChange!(model, modelError, correctAnswer)
|
||||||
|
# elseif tick > row+1 && tick < row+thinkingPeriod
|
||||||
|
# correctAnswer = OneHotArrays.onehot(label, labelDict)
|
||||||
|
# modelError = (logit - correctAnswer) * 1.0
|
||||||
|
# Ironpen.compute_wRecChange!(model, modelError, correctAnswer)
|
||||||
|
# elseif tick == row+thinkingPeriod
|
||||||
|
# _predict = logitLog[:, end-thinkingPeriod+1:end] # answer count during thinking period
|
||||||
|
# _predict = Int.([sum(row) for row in eachrow(_predict)])
|
||||||
|
# # predict = [x > 0 for x in _predict]
|
||||||
|
# correctAnswer = OneHotArrays.onehot(label, labelDict)
|
||||||
|
# modelError = (logit - correctAnswer) * 1.0
|
||||||
|
# Ironpen.compute_wRecChange!(model, modelError, correctAnswer)
|
||||||
|
# Ironpen.learn!(model)
|
||||||
|
# println("label $label predict $(_predict) model error $(Int.(modelError))")
|
||||||
|
# else
|
||||||
|
# error("undefined condition line $(@__LINE__)")
|
||||||
|
# end
|
||||||
|
|
||||||
|
if tick <= row # online learning, 1-by-1 timestep
|
||||||
|
# no error calculation
|
||||||
|
elseif tick > row && tick < row+thinkingPeriod
|
||||||
|
# correctAnswer = OneHotArrays.onehot(label, labelDict)
|
||||||
|
# modelError = (logit - correctAnswer) * 1.0
|
||||||
|
# Ironpen.compute_wRecChange!(model, modelError, correctAnswer)
|
||||||
|
|
||||||
|
elseif tick == row+thinkingPeriod
|
||||||
|
correctAnswer = OneHotArrays.onehot(label, labelDict)
|
||||||
|
modelError = Flux.logitcrossentropy(logit, correctAnswer) * 1.0
|
||||||
|
outputError = (logit - correctAnswer) * 1.0
|
||||||
|
Ironpen.compute_wRecChange!(model, modelError, outputError)
|
||||||
|
Ironpen.learn!(model)
|
||||||
|
_logit = round.(logit; digits=2)
|
||||||
|
predict = findall(isequal.(logit, maximum(logit)))[1] - 1
|
||||||
|
y = round.(modelError; digits=2)
|
||||||
|
println("")
|
||||||
|
println("label $label predict $predict logit $_logit model error $y")
|
||||||
|
else
|
||||||
|
error("undefined condition line $(@__LINE__)")
|
||||||
|
end
|
||||||
|
|
||||||
|
# update plot
|
||||||
|
plot10[] = firedNeurons_t1
|
||||||
|
|
||||||
|
plot20[] = view(logitLog, 1 , :)
|
||||||
|
plot21[] = view(logitLog, 2 , :)
|
||||||
|
plot22[] = view(logitLog, 3 , :)
|
||||||
|
plot23[] = view(logitLog, 4 , :)
|
||||||
|
plot24[] = view(logitLog, 5 , :)
|
||||||
|
plot25[] = view(logitLog, 6 , :)
|
||||||
|
plot26[] = view(logitLog, 7 , :)
|
||||||
|
plot27[] = view(logitLog, 8 , :)
|
||||||
|
plot28[] = view(logitLog, 9 , :)
|
||||||
|
plot29[] = view(logitLog, 10, :)
|
||||||
|
|
||||||
|
plot30[] = view(var1, 1 , :)
|
||||||
|
plot31[] = view(var1, 2 , :)
|
||||||
|
plot32[] = view(var1, 3 , :)
|
||||||
|
plot33[] = view(var1, 4 , :)
|
||||||
|
plot34[] = view(var1, 5 , :)
|
||||||
|
plot35[] = view(var1, 6 , :)
|
||||||
|
plot36[] = view(var1, 7 , :)
|
||||||
|
plot37[] = view(var1, 8 , :)
|
||||||
|
plot38[] = view(var1, 9 , :)
|
||||||
|
plot39[] = view(var1, 10, :)
|
||||||
|
|
||||||
|
plot40[] = view(var2, 1 , :)
|
||||||
|
plot41[] = view(var2, 2 , :)
|
||||||
|
plot42[] = view(var2, 3 , :)
|
||||||
|
plot43[] = view(var2, 4 , :)
|
||||||
|
plot44[] = view(var2, 5 , :)
|
||||||
|
plot45[] = view(var2, 6 , :)
|
||||||
|
plot46[] = view(var2, 7 , :)
|
||||||
|
plot47[] = view(var2, 8 , :)
|
||||||
|
plot48[] = view(var2, 9 , :)
|
||||||
|
plot49[] = view(var2, 10, :)
|
||||||
|
|
||||||
|
plot50[] = view(var3, 1 , :)
|
||||||
|
plot51[] = view(var3, 2 , :)
|
||||||
|
plot52[] = view(var3, 3 , :)
|
||||||
|
plot53[] = view(var3, 4 , :)
|
||||||
|
plot54[] = view(var3, 5 , :)
|
||||||
|
plot55[] = view(var3, 6 , :)
|
||||||
|
plot56[] = view(var3, 7 , :)
|
||||||
|
plot57[] = view(var3, 8 , :)
|
||||||
|
plot58[] = view(var3, 9 , :)
|
||||||
|
plot59[] = view(var3, 10, :)
|
||||||
|
|
||||||
|
plot60[] = view(var4, 1 , :)
|
||||||
|
plot61[] = view(var4, 2 , :)
|
||||||
|
plot62[] = view(var4, 3 , :)
|
||||||
|
plot63[] = view(var4, 4 , :)
|
||||||
|
plot64[] = view(var4, 5 , :)
|
||||||
|
plot65[] = view(var4, 6 , :)
|
||||||
|
plot66[] = view(var4, 7 , :)
|
||||||
|
plot67[] = view(var4, 8 , :)
|
||||||
|
plot68[] = view(var4, 9 , :)
|
||||||
|
plot69[] = view(var4, 10, :)
|
||||||
|
end
|
||||||
|
# end-thinkingPeriod+2; +2 because initialize logitLog = zeros(10, 2)
|
||||||
|
# _modelRespond = logitLog[:, end-thinkingPeriod+2:end] # answer count during thinking period
|
||||||
|
# _modelRespond = [sum(i) for i in eachrow(_modelRespond)]
|
||||||
|
# modelRespond = isequal.(isequal.(_modelRespond, 0), 0)
|
||||||
|
|
||||||
|
display(fig1)
|
||||||
|
# sleep(1)
|
||||||
|
if k % 3 == 0
|
||||||
|
firedNeurons_t1 = zeros(1)
|
||||||
|
logitLog = zeros(10, 2)
|
||||||
|
var1 = zeros(10, 2)
|
||||||
|
var2 = zeros(10, 2)
|
||||||
|
var3 = zeros(10, 2)
|
||||||
|
var4 = zeros(10, 2)
|
||||||
|
end
|
||||||
|
|
||||||
|
# if predict == OneHotArrays.onehot(label, labelDict)
|
||||||
|
# println("model train $label successfully, $k tries")
|
||||||
|
# # wait(display(fig1))
|
||||||
|
|
||||||
|
# firedNeurons_t1 = zeros(1)
|
||||||
|
# logitLog = zeros(10, 2)
|
||||||
|
# var1 = zeros(10, 2)
|
||||||
|
# var2 = zeros(10, 2)
|
||||||
|
# var3 = zeros(10, 2)
|
||||||
|
# var4 = zeros(10, 2)
|
||||||
|
# break
|
||||||
|
# end
|
||||||
|
|
||||||
|
if k == maxRepeatRound
|
||||||
|
# println("model train $label unsuccessfully, $maxRepeatRound tries, skip training")
|
||||||
|
# display(fig1)
|
||||||
|
firedNeurons_t1 = zeros(1)
|
||||||
|
logitLog = zeros(10, 2)
|
||||||
|
var1 = zeros(10, 2)
|
||||||
|
var2 = zeros(10, 2)
|
||||||
|
var3 = zeros(10, 2)
|
||||||
|
var4 = zeros(10, 2)
|
||||||
|
break
|
||||||
|
end
|
||||||
|
|
||||||
|
GC.gc()
|
||||||
|
end
|
||||||
|
end
|
||||||
|
# check accuracy
|
||||||
|
println("validating model")
|
||||||
|
answerCorrectly = validate(model, validateData, labelDict)
|
||||||
|
println("model accuracy is $answerCorrectly %")
|
||||||
|
end
|
||||||
|
|
||||||
|
# # check mean error and accuracy
|
||||||
|
# mean_error = round(mean(model_error_list), sigdigits = 3)
|
||||||
|
# accuracy = round(model_accuracy / batch_size * 100, sigdigits = 3)
|
||||||
|
# println("------------")
|
||||||
|
# println(model_name)
|
||||||
|
# println("mean error $mean_error accuracy $accuracy")
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
function validate(model, dataset, labelDict)
|
||||||
|
answerCorrectly = 0.0 # %
|
||||||
|
thinkingPeriod = 16 # 1000-784 = 216
|
||||||
|
@showprogress for (image, label) in dataset
|
||||||
|
img = reshape(image, (:, 1))
|
||||||
|
row, col = size(img)
|
||||||
|
label = label[1]
|
||||||
|
|
||||||
|
img_tensor = torch.from_numpy( np.asarray(img) )
|
||||||
|
|
||||||
|
# create more data for RSNN
|
||||||
|
spike = spikegen.delta(img_tensor, threshold=0.1, off_spike=true)
|
||||||
|
spike1 = isequal.(pyconvert(Array, spike.data.numpy()), 1)
|
||||||
|
spike2 = isequal.(pyconvert(Array, spike.data.numpy()), 1)
|
||||||
|
spike3 = isequal.(pyconvert(Array, spike.data.numpy()), 1)
|
||||||
|
spike4 = isequal.(pyconvert(Array, spike.data.numpy()), 1)
|
||||||
|
spike5 = isequal.(pyconvert(Array, spike.data.numpy()), 1)
|
||||||
|
spike6 = isequal.(pyconvert(Array, spike.data.numpy()), -1)
|
||||||
|
spike7 = isequal.(pyconvert(Array, spike.data.numpy()), -1)
|
||||||
|
spike8 = isequal.(pyconvert(Array, spike.data.numpy()), -1)
|
||||||
|
spike9 = isequal.(pyconvert(Array, spike.data.numpy()), -1)
|
||||||
|
spike10 = isequal.(pyconvert(Array, spike.data.numpy()), -1)
|
||||||
|
|
||||||
|
spike = spikegen.delta(img_tensor, threshold=0.2, off_spike=true)
|
||||||
|
spike11 = isequal.(pyconvert(Array, spike.data.numpy()), 1)
|
||||||
|
spike12 = isequal.(pyconvert(Array, spike.data.numpy()), 1)
|
||||||
|
spike13 = isequal.(pyconvert(Array, spike.data.numpy()), 1)
|
||||||
|
spike14 = isequal.(pyconvert(Array, spike.data.numpy()), 1)
|
||||||
|
spike15 = isequal.(pyconvert(Array, spike.data.numpy()), 1)
|
||||||
|
spike16 = isequal.(pyconvert(Array, spike.data.numpy()), -1)
|
||||||
|
spike17 = isequal.(pyconvert(Array, spike.data.numpy()), -1)
|
||||||
|
spike18 = isequal.(pyconvert(Array, spike.data.numpy()), -1)
|
||||||
|
spike19 = isequal.(pyconvert(Array, spike.data.numpy()), -1)
|
||||||
|
spike20 = isequal.(pyconvert(Array, spike.data.numpy()), -1)
|
||||||
|
|
||||||
|
spike = spikegen.delta(img_tensor, threshold=0.3, off_spike=true)
|
||||||
|
spike21 = isequal.(pyconvert(Array, spike.data.numpy()), 1)
|
||||||
|
spike22 = isequal.(pyconvert(Array, spike.data.numpy()), 1)
|
||||||
|
spike23 = isequal.(pyconvert(Array, spike.data.numpy()), 1)
|
||||||
|
spike24 = isequal.(pyconvert(Array, spike.data.numpy()), 1)
|
||||||
|
spike25 = isequal.(pyconvert(Array, spike.data.numpy()), 1)
|
||||||
|
spike26 = isequal.(pyconvert(Array, spike.data.numpy()), -1)
|
||||||
|
spike27 = isequal.(pyconvert(Array, spike.data.numpy()), -1)
|
||||||
|
spike28 = isequal.(pyconvert(Array, spike.data.numpy()), -1)
|
||||||
|
spike29 = isequal.(pyconvert(Array, spike.data.numpy()), -1)
|
||||||
|
spike30 = isequal.(pyconvert(Array, spike.data.numpy()), -1)
|
||||||
|
|
||||||
|
spike = spikegen.delta(img_tensor, threshold=0.4, off_spike=true)
|
||||||
|
spike31 = isequal.(pyconvert(Array, spike.data.numpy()), 1)
|
||||||
|
spike32 = isequal.(pyconvert(Array, spike.data.numpy()), 1)
|
||||||
|
spike33 = isequal.(pyconvert(Array, spike.data.numpy()), 1)
|
||||||
|
spike34 = isequal.(pyconvert(Array, spike.data.numpy()), 1)
|
||||||
|
spike35 = isequal.(pyconvert(Array, spike.data.numpy()), 1)
|
||||||
|
spike36 = isequal.(pyconvert(Array, spike.data.numpy()), -1)
|
||||||
|
spike37 = isequal.(pyconvert(Array, spike.data.numpy()), -1)
|
||||||
|
spike38 = isequal.(pyconvert(Array, spike.data.numpy()), -1)
|
||||||
|
spike39 = isequal.(pyconvert(Array, spike.data.numpy()), -1)
|
||||||
|
spike40 = isequal.(pyconvert(Array, spike.data.numpy()), -1)
|
||||||
|
|
||||||
|
spike = spikegen.delta(img_tensor, threshold=0.5, off_spike=true)
|
||||||
|
spike41 = isequal.(pyconvert(Array, spike.data.numpy()), 1)
|
||||||
|
spike42 = isequal.(pyconvert(Array, spike.data.numpy()), 1)
|
||||||
|
spike43 = isequal.(pyconvert(Array, spike.data.numpy()), 1)
|
||||||
|
spike44 = isequal.(pyconvert(Array, spike.data.numpy()), 1)
|
||||||
|
spike45 = isequal.(pyconvert(Array, spike.data.numpy()), 1)
|
||||||
|
spike46 = isequal.(pyconvert(Array, spike.data.numpy()), -1)
|
||||||
|
spike47 = isequal.(pyconvert(Array, spike.data.numpy()), -1)
|
||||||
|
spike48 = isequal.(pyconvert(Array, spike.data.numpy()), -1)
|
||||||
|
spike49 = isequal.(pyconvert(Array, spike.data.numpy()), -1)
|
||||||
|
spike50 = isequal.(pyconvert(Array, spike.data.numpy()), -1)
|
||||||
|
|
||||||
|
input = [spike1;; spike2;; spike3;; spike4;; spike5;; spike6;; spike7;; spike8;; spike9;; spike10;;
|
||||||
|
spike11;; spike12;; spike13;; spike14;; spike15;; spike16;; spike17;; spike18;; spike19;; spike20;;
|
||||||
|
spike21;; spike22;; spike23;; spike24;; spike25;; spike26;; spike27;; spike28;; spike29;; spike30;;
|
||||||
|
spike31;; spike32;; spike33;; spike34;; spike35;; spike36;; spike37;; spike38;; spike39;; spike40;;
|
||||||
|
spike41;; spike42;; spike43;; spike44;; spike45;; spike46;; spike47;; spike48;; spike49;; spike50
|
||||||
|
]' # ' to flip 784x10 to 10x784
|
||||||
|
|
||||||
|
# insert data into model sequencially
|
||||||
|
logit = Float64[]
|
||||||
|
for i in 1:(row + thinkingPeriod) # sMNIST ihas 784 timestep(pixel) + thinking period = 1000 timestep
|
||||||
|
if i <= row
|
||||||
|
current_pixel = input[:, i]
|
||||||
|
else
|
||||||
|
current_pixel = zeros(size(input)[1]) # dummy input in "thinking" period
|
||||||
|
end
|
||||||
|
|
||||||
|
_firedNeurons_t1, logit, _var1, _var2, _var3, _var4 = model(current_pixel)
|
||||||
|
end
|
||||||
|
|
||||||
|
predict = findall(isequal.(logit, maximum(logit)))[1] - 1
|
||||||
|
if predict == label
|
||||||
|
answerCorrectly += 1
|
||||||
|
# println("model answer $label correctly")
|
||||||
|
else
|
||||||
|
# println("img $label, model answer $predict")
|
||||||
|
end
|
||||||
|
GC.gc()
|
||||||
|
end
|
||||||
|
|
||||||
|
correctPercent = answerCorrectly * 100.0 / length(dataset)
|
||||||
|
|
||||||
|
return correctPercent::Float64
|
||||||
|
end
|
||||||
|
|
||||||
|
|
||||||
|
function main()
|
||||||
|
training_start_time = Dates.now()
|
||||||
|
println("program started ", training_start_time)
|
||||||
|
|
||||||
|
filelocation = string(@__DIR__)
|
||||||
|
|
||||||
|
# generate SNN
|
||||||
|
for i = 1:1
|
||||||
|
modelname = "v06_36"
|
||||||
|
filename = "$modelname.jl163"
|
||||||
|
generate_snn(filename, filelocation)
|
||||||
|
end
|
||||||
|
|
||||||
|
modelname = "v06_36"
|
||||||
|
filename = "$modelname.jl163"
|
||||||
|
# filename = "v06_31c.jl163"
|
||||||
|
|
||||||
|
trainDataset, validateDataset, labelDict = data_loader()
|
||||||
|
|
||||||
|
train_snn(modelname, filename, filelocation, trainDataset, validateDataset, labelDict)
|
||||||
|
|
||||||
|
finish_training_time = Dates.now()
|
||||||
|
println("training done, $training_start_time ==> $finish_training_time ")
|
||||||
|
println(" ///////////////////////////////////////////////////////////////////////")
|
||||||
|
end
|
||||||
|
|
||||||
|
# only runs main() if julia isn’t started interactively
|
||||||
|
# https://discourse.julialang.org/t/scripting-like-a-julian/50707
|
||||||
|
!isinteractive() && main()
|
||||||
|
#------------------------------------------------------------------------------------------------100
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -34,7 +34,6 @@ using .learn
|
|||||||
|
|
||||||
""" version 0.0.5
|
""" version 0.0.5
|
||||||
Todo:
|
Todo:
|
||||||
-
|
|
||||||
[4] implement dormant connection
|
[4] implement dormant connection
|
||||||
[] using RL to control learning signal
|
[] using RL to control learning signal
|
||||||
[] consider using Dates.now() instead of timestamp because time_stamp may overflow
|
[] consider using Dates.now() instead of timestamp because time_stamp may overflow
|
||||||
@@ -43,34 +42,60 @@ using .learn
|
|||||||
|
|
||||||
|
|
||||||
Change from version: 0.0.4
|
Change from version: 0.0.4
|
||||||
- compute error in main loop so one could decide how to calculate error
|
- compute model error in main loop so one could decide when to calculate error in
|
||||||
- compute model error in main loop so one could decide when to calculate error
|
training sequence and how to calculate
|
||||||
|
- fix ALIF adaptation formula, now n.a compute avery time step
|
||||||
|
- add higher input signal to noise ratio
|
||||||
|
- no noise generate
|
||||||
|
- increase input signal by adding more input neuron population
|
||||||
|
- add /100 every wRec and b
|
||||||
|
- add integrateNeuron
|
||||||
|
- ΔwRecChange during input signal ingestion will be merged at the end of learning
|
||||||
|
- use Flux.logitcrossentropy for overall error
|
||||||
|
- move timestep_forward!() in kfn's forward to the beginning so that v_t and z_t is reset
|
||||||
|
- fix n.a formula in forward() and calculate both non-firing and firing state
|
||||||
|
- RSNN use overall modelError to update while integrate neuron use error with respect to
|
||||||
|
itself (yk - yk*) to update
|
||||||
|
- all RSNN neuron connect to integrateNeuron
|
||||||
|
- integrateNeuron does NOT repect RSNN excitatory and inhabitory sign
|
||||||
|
- weaker connection should be harder to increase strength. It requires a lot of
|
||||||
|
repeat activation to get it stronger. While strong connction requires a lot of
|
||||||
|
inactivation to get it weaker. The concept is strong connection will lock
|
||||||
|
correct neural pathway through repeated use of the right connection i.e. keep training
|
||||||
|
on the correct answer -> strengthen the right neural pathway (connections) ->
|
||||||
|
this correct neural pathway resist to change.
|
||||||
|
Not used connection should dissapear (forgetting).
|
||||||
|
|
||||||
|
|
||||||
All features
|
All features
|
||||||
- ΔwRecChange during input signal ingestion will be merged at the end of learning
|
|
||||||
- all RSNN and output neuron learning associate.
|
|
||||||
- synapticStrength apply at the end of learning
|
- synapticStrength apply at the end of learning
|
||||||
- collect ΔwRecChange during online learning (0-784th) and merge with wRec at
|
- collect ΔwRecChange during online learning (0-784th) and merge with wRec at
|
||||||
the end learning (1000th).
|
the end learning (800th).
|
||||||
- compute model error at the end learning. Model error times with 5 constant for
|
|
||||||
higher learning impact than the error during online
|
|
||||||
- multidispatch + for loop as main compute method
|
- multidispatch + for loop as main compute method
|
||||||
- hard connection constrain yes
|
|
||||||
- normalize output yes
|
|
||||||
- allow -w_rec yes
|
- allow -w_rec yes
|
||||||
- voltage drop when neuron fires voltage drop equals to vth
|
- voltage drop when neuron fires voltage drop equals to vRest
|
||||||
- v_t decay during refractory
|
- v_t decay during refractory
|
||||||
duration exponantial decay
|
|
||||||
- input data population encoding, each pixel data =>
|
- input data population encoding, each pixel data =>
|
||||||
population encoding, ralative between pixel data
|
population encoding, ralative between pixel data
|
||||||
- compute neuron weight init rand()
|
- compute neuron weight init rand()
|
||||||
- output neuron weight init randn()
|
- output neuron weight init randn()
|
||||||
- each knowledgeFn should have its own noise generater
|
|
||||||
- where to put pseudo derivative (n.phi)
|
- compute pseudo derivative (n.phi) every time step
|
||||||
- add excitatory, inhabitory to neuron
|
- add excitatory, inhabitory to neuron
|
||||||
- implement "start learning", reset learning and "learning", "end_learning and
|
- implement "start learning", reset learning and "learning", "end_learning and
|
||||||
"inference"
|
"inference"
|
||||||
|
- synaptic connection strength concept. use sigmoid, turn connection offline
|
||||||
|
- neuroplasticity() i.e. change connection
|
||||||
|
- add multi threads
|
||||||
|
|
||||||
|
|
||||||
|
Removed features
|
||||||
|
- normalize output yes
|
||||||
|
<logitcrossentropy does not need normalization>
|
||||||
|
- compute model error at the end learning. Model error times with 5 constant for
|
||||||
|
higher learning impact than the error during online
|
||||||
|
<there should be no difference between error in each timestep because error
|
||||||
|
has equal importance>
|
||||||
- output neuron connect to random multiple compute neurons and overall have
|
- output neuron connect to random multiple compute neurons and overall have
|
||||||
the same structure as lif
|
the same structure as lif
|
||||||
- time-based learning method based on new error formula
|
- time-based learning method based on new error formula
|
||||||
@@ -79,28 +104,23 @@ using .learn
|
|||||||
(vth - vt)*100/vth as error
|
(vth - vt)*100/vth as error
|
||||||
if output neuron activates when it should NOT, use output neuron's
|
if output neuron activates when it should NOT, use output neuron's
|
||||||
(vt*100)/vth as error
|
(vt*100)/vth as error
|
||||||
|
<use logitcrossentropy>
|
||||||
- use LinearAlgebra.normalize!(vector, 1) to adjust weight after weight merge
|
- use LinearAlgebra.normalize!(vector, 1) to adjust weight after weight merge
|
||||||
|
<it reduce instant respond of neuron. Sometime postsynaptic neuron need to
|
||||||
|
respond quickly at differnt neural pathway. If wRec is normalized, weights that
|
||||||
|
needs to be high to allow neuron instant respond get reduced.>
|
||||||
- reset_epsilonRec after ΔwRecChange is calculated
|
- reset_epsilonRec after ΔwRecChange is calculated
|
||||||
- synaptic connection strength concept. use sigmoid, turn connection offline
|
<training example does not require intermediate respond from RSNN>
|
||||||
- wRec should not normalized whole. it should be local 5 conn normalized.
|
- add maximum weight cap of each connection
|
||||||
- neuroplasticity() i.e. change connection
|
<capping weight limit neuron ability to adjust its respond>
|
||||||
- add multi threads
|
- wRec should not normalized whole. it should be local 5 conn normalized.
|
||||||
- add maximum weight cap of each connection
|
<it makes small weight bigger>
|
||||||
|
|
||||||
|
|
||||||
Removed features
|
|
||||||
|
Ideas to try
|
||||||
|
- Δweight * connection strength
|
||||||
|
- reset_epsilonRec after ΔwRecChange is calculated
|
||||||
- ΔwRecChange that apply immediately during online learning
|
- ΔwRecChange that apply immediately during online learning
|
||||||
- error by percent of vth-v_t1
|
|
||||||
- Δweight * connection strength
|
|
||||||
- weaker connection should be harder to increase strength. It requires a lot of
|
|
||||||
repeat activation to get it stronger. While strong connction requires a lot of
|
|
||||||
inactivation to get it weaker. The concept is strong connection will lock
|
|
||||||
correct neural pathway through repeated use of the right connection i.e. keep training
|
|
||||||
on the correct answer -> strengthen the right neural pathway (connections) ->
|
|
||||||
this correct neural pathway resist to change.
|
|
||||||
Not used connection should dissapear (forgetting).
|
|
||||||
- during 0 training if 1-9 output neuron fires, adjust weight only those neurons
|
|
||||||
- use Flux.logitcrossentropy for overall error
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
module forward
|
module forward
|
||||||
|
|
||||||
using Statistics, Random, LinearAlgebra, JSON3
|
using Statistics, Random, LinearAlgebra, JSON3, Flux
|
||||||
using GeneralUtils
|
using GeneralUtils
|
||||||
using ..types, ..snn_utils
|
using ..types, ..snn_utils
|
||||||
|
|
||||||
@@ -26,7 +26,13 @@ end
|
|||||||
|
|
||||||
function (kfn::kfn_1)(m::model, input_data::AbstractVector)
|
function (kfn::kfn_1)(m::model, input_data::AbstractVector)
|
||||||
kfn.timeStep = m.timeStep
|
kfn.timeStep = m.timeStep
|
||||||
|
for n in kfn.neuronsArray
|
||||||
|
timestep_forward!(n)
|
||||||
|
end
|
||||||
|
for n in kfn.outputNeuronsArray
|
||||||
|
timestep_forward!(n)
|
||||||
|
end
|
||||||
|
|
||||||
kfn.learningStage = m.learningStage
|
kfn.learningStage = m.learningStage
|
||||||
|
|
||||||
if kfn.learningStage == "start_learning"
|
if kfn.learningStage == "start_learning"
|
||||||
@@ -54,19 +60,12 @@ function (kfn::kfn_1)(m::model, input_data::AbstractVector)
|
|||||||
end
|
end
|
||||||
|
|
||||||
# generate noise
|
# generate noise
|
||||||
noise = [GeneralUtils.randomChoiceWithProb([true, false],[0.5,0.5])
|
noise = [GeneralUtils.randomChoiceWithProb([true, false],[0.0, 1.0])
|
||||||
for i in 1:length(input_data)]
|
for i in 1:length(input_data)]
|
||||||
# noise = [rand(rng, Distributions.Binomial(1, 0.5)) for i in 1:10] # another option
|
# noise = [rand(rng, Distributions.Binomial(1, 0.5)) for i in 1:10] # another option
|
||||||
|
|
||||||
input_data = [noise; input_data] # noise must start from neuron id 1
|
input_data = [noise; input_data] # noise must start from neuron id 1
|
||||||
|
|
||||||
for n in kfn.neuronsArray
|
|
||||||
timestep_forward!(n)
|
|
||||||
end
|
|
||||||
for n in kfn.outputNeuronsArray
|
|
||||||
timestep_forward!(n)
|
|
||||||
end
|
|
||||||
|
|
||||||
# pass input_data into input neuron.
|
# pass input_data into input neuron.
|
||||||
# number of data point equals to number of input neuron starting from id 1
|
# number of data point equals to number of input neuron starting from id 1
|
||||||
for (i, data) in enumerate(input_data)
|
for (i, data) in enumerate(input_data)
|
||||||
@@ -75,8 +74,8 @@ function (kfn::kfn_1)(m::model, input_data::AbstractVector)
|
|||||||
|
|
||||||
kfn.firedNeurons_t0 = [n.z_t for n in kfn.neuronsArray]
|
kfn.firedNeurons_t0 = [n.z_t for n in kfn.neuronsArray]
|
||||||
|
|
||||||
# Threads.@threads for n in kfn.neuronsArray
|
Threads.@threads for n in kfn.neuronsArray
|
||||||
for n in kfn.neuronsArray
|
# for n in kfn.neuronsArray
|
||||||
n(kfn)
|
n(kfn)
|
||||||
end
|
end
|
||||||
|
|
||||||
@@ -84,19 +83,22 @@ function (kfn::kfn_1)(m::model, input_data::AbstractVector)
|
|||||||
append!(kfn.firedNeurons, findall(kfn.firedNeurons_t1)) # store id of neuron that fires
|
append!(kfn.firedNeurons, findall(kfn.firedNeurons_t1)) # store id of neuron that fires
|
||||||
kfn.firedNeurons |> unique! # use for random new neuron connection
|
kfn.firedNeurons |> unique! # use for random new neuron connection
|
||||||
|
|
||||||
# Threads.@threads for n in kfn.outputNeuronsArray
|
Threads.@threads for n in kfn.outputNeuronsArray
|
||||||
for n in kfn.outputNeuronsArray
|
# for n in kfn.outputNeuronsArray
|
||||||
n(kfn)
|
n(kfn)
|
||||||
end
|
end
|
||||||
|
|
||||||
out = [n.z_t1 for n in kfn.outputNeuronsArray]
|
logit = [n.v_t1 for n in kfn.outputNeuronsArray]
|
||||||
|
|
||||||
return out::Array{Bool},
|
# _predict = Flux.softmax(logit)
|
||||||
sum(kfn.firedNeurons_t1),
|
# predict = findall(isequal.(_predict, maximum(_predict)))[1]
|
||||||
|
|
||||||
|
return sum(kfn.firedNeurons_t1[kfn.kfnParams[:totalInputPort]+1:end])::Int,
|
||||||
|
logit::Array{Float64},
|
||||||
[n.v_t1 for n in kfn.outputNeuronsArray],
|
[n.v_t1 for n in kfn.outputNeuronsArray],
|
||||||
[sum(i.wRec) for i in kfn.outputNeuronsArray],
|
[sum(i.wRec) for i in kfn.outputNeuronsArray],
|
||||||
[sum(i.epsilonRec) for i in kfn.outputNeuronsArray],
|
[sum(i.epsilonRec) for i in kfn.outputNeuronsArray],
|
||||||
[i.phi for i in kfn.outputNeuronsArray]
|
[sum(i.wRecChange) for i in kfn.outputNeuronsArray]
|
||||||
end
|
end
|
||||||
|
|
||||||
#------------------------------------------------------------------------------------------------100
|
#------------------------------------------------------------------------------------------------100
|
||||||
@@ -128,11 +130,15 @@ function (n::lifNeuron)(kfn::knowledgeFn)
|
|||||||
|
|
||||||
# decay of v_t1
|
# decay of v_t1
|
||||||
n.v_t1 = n.alpha * n.v_t
|
n.v_t1 = n.alpha * n.v_t
|
||||||
|
|
||||||
|
n.phi = 0.0
|
||||||
|
n.decayedEpsilonRec = n.alpha * n.epsilonRec
|
||||||
|
n.epsilonRec = n.decayedEpsilonRec
|
||||||
else
|
else
|
||||||
n.recSignal = sum(n.wRec .* n.z_i_t) # signal from other neuron that this neuron subscribed
|
n.recSignal = sum(n.wRec .* n.z_i_t) # signal from other neuron that this neuron subscribed
|
||||||
n.alpha_v_t = n.alpha * n.v_t
|
n.alpha_v_t = n.alpha * n.v_t
|
||||||
n.v_t1 = n.alpha_v_t + n.recSignal
|
n.v_t1 = n.alpha_v_t + n.recSignal
|
||||||
n.v_t1 = no_negative!(n.v_t1)
|
# n.v_t1 = no_negative!(n.v_t1)
|
||||||
|
|
||||||
if n.v_t1 > n.v_th
|
if n.v_t1 > n.v_th
|
||||||
n.z_t1 = true
|
n.z_t1 = true
|
||||||
@@ -147,7 +153,7 @@ function (n::lifNeuron)(kfn::knowledgeFn)
|
|||||||
n.phi = (n.gammaPd / n.v_th) * max(0, 1 - (n.v_t1 - n.v_th) / n.v_th)
|
n.phi = (n.gammaPd / n.v_th) * max(0, 1 - (n.v_t1 - n.v_th) / n.v_th)
|
||||||
n.decayedEpsilonRec = n.alpha * n.epsilonRec
|
n.decayedEpsilonRec = n.alpha * n.epsilonRec
|
||||||
n.epsilonRec = n.decayedEpsilonRec + n.z_i_t
|
n.epsilonRec = n.decayedEpsilonRec + n.z_i_t
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
#------------------------------------------------------------------------------------------------100
|
#------------------------------------------------------------------------------------------------100
|
||||||
@@ -165,26 +171,30 @@ function (n::alifNeuron)(kfn::knowledgeFn)
|
|||||||
|
|
||||||
# neuron is in refractory state, skip all calculation
|
# neuron is in refractory state, skip all calculation
|
||||||
n.z_t1 = false # used by timestep_forward() in kfn. Set to zero because neuron spike last only 1 timestep follow by a period of refractory.
|
n.z_t1 = false # used by timestep_forward() in kfn. Set to zero because neuron spike last only 1 timestep follow by a period of refractory.
|
||||||
n.a = (n.rho * n.a) + ((1 - n.rho) * n.z_t)
|
n.a = (n.rho * n.a)
|
||||||
n.recSignal = n.recSignal * 0.0
|
n.recSignal = n.recSignal * 0.0
|
||||||
|
|
||||||
# decay of v_t1
|
# decay of v_t1
|
||||||
n.v_t1 = n.alpha * n.v_t
|
n.v_t1 = n.alpha * n.v_t
|
||||||
n.phi = 0
|
|
||||||
|
n.phi = 0.0
|
||||||
|
n.decayedEpsilonRec = n.alpha * n.epsilonRec
|
||||||
|
n.epsilonRec = n.decayedEpsilonRec
|
||||||
else
|
else
|
||||||
n.a = (n.rho * n.a) + ((1 - n.rho) * n.z_t)
|
|
||||||
n.av_th = n.v_th + (n.beta * n.a)
|
n.av_th = n.v_th + (n.beta * n.a)
|
||||||
n.recSignal = sum(n.wRec .* n.z_i_t) # signal from other neuron that this neuron subscribed
|
n.recSignal = sum(n.wRec .* n.z_i_t) # signal from other neuron that this neuron subscribed
|
||||||
n.alpha_v_t = n.alpha * n.v_t
|
n.alpha_v_t = n.alpha * n.v_t
|
||||||
n.v_t1 = n.alpha_v_t + n.recSignal
|
n.v_t1 = n.alpha_v_t + n.recSignal
|
||||||
n.v_t1 = no_negative!(n.v_t1)
|
# n.v_t1 = no_negative!(n.v_t1)
|
||||||
if n.v_t1 > n.av_th
|
if n.v_t1 > n.av_th
|
||||||
n.z_t1 = true
|
n.z_t1 = true
|
||||||
n.refractoryCounter = n.refractoryDuration
|
n.refractoryCounter = n.refractoryDuration
|
||||||
n.firingCounter += 1
|
n.firingCounter += 1
|
||||||
n.v_t1 = n.vRest
|
n.v_t1 = n.vRest
|
||||||
|
n.a = (n.rho * n.a) + 1.0
|
||||||
else
|
else
|
||||||
n.z_t1 = false
|
n.z_t1 = false
|
||||||
|
n.a = (n.rho * n.a)
|
||||||
end
|
end
|
||||||
|
|
||||||
# there is a difference from lif formula
|
# there is a difference from lif formula
|
||||||
@@ -219,12 +229,16 @@ function (n::linearNeuron)(kfn::T) where T<:knowledgeFn
|
|||||||
# decay of v_t1
|
# decay of v_t1
|
||||||
n.v_t1 = n.alpha * n.v_t
|
n.v_t1 = n.alpha * n.v_t
|
||||||
n.vError = n.v_t1 # store voltage that will be used to calculate error later
|
n.vError = n.v_t1 # store voltage that will be used to calculate error later
|
||||||
|
|
||||||
|
n.phi = 0.0
|
||||||
|
n.decayedEpsilonRec = n.alpha * n.epsilonRec
|
||||||
|
n.epsilonRec = n.decayedEpsilonRec
|
||||||
else
|
else
|
||||||
recSignal = n.wRec .* n.z_i_t
|
recSignal = n.wRec .* n.z_i_t
|
||||||
n.recSignal = sum(recSignal) # signal from other neuron that this neuron subscribed
|
n.recSignal = sum(recSignal) # signal from other neuron that this neuron subscribed
|
||||||
n.alpha_v_t = n.alpha * n.v_t
|
n.alpha_v_t = n.alpha * n.v_t
|
||||||
n.v_t1 = n.alpha_v_t + n.recSignal
|
n.v_t1 = n.alpha_v_t + n.recSignal
|
||||||
n.v_t1 = no_negative!(n.v_t1)
|
# n.v_t1 = no_negative!(n.v_t1)
|
||||||
n.vError = n.v_t1 # store voltage that will be used to calculate error later
|
n.vError = n.v_t1 # store voltage that will be used to calculate error later
|
||||||
if n.v_t1 > n.v_th
|
if n.v_t1 > n.v_th
|
||||||
n.z_t1 = true
|
n.z_t1 = true
|
||||||
@@ -242,6 +256,30 @@ function (n::linearNeuron)(kfn::T) where T<:knowledgeFn
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
#------------------------------------------------------------------------------------------------100
|
||||||
|
|
||||||
|
""" integrateNeuron forward()
|
||||||
|
"""
|
||||||
|
function (n::integrateNeuron)(kfn::knowledgeFn)
|
||||||
|
n.timeStep = kfn.timeStep
|
||||||
|
|
||||||
|
# pulling other neuron's firing status at time t
|
||||||
|
n.z_i_t = getindex(kfn.firedNeurons_t0, n.subscriptionList)
|
||||||
|
n.z_i_t_commulative += n.z_i_t
|
||||||
|
|
||||||
|
n.recSignal = sum(n.wRec .* n.z_i_t) # signal from other neuron that this neuron subscribed
|
||||||
|
n.alpha_v_t = n.alpha * n.v_t
|
||||||
|
if n.recSignal <= 0
|
||||||
|
n.v_t1 = n.alpha_v_t
|
||||||
|
else
|
||||||
|
n.v_t1 = n.alpha_v_t + n.recSignal + n.b
|
||||||
|
end
|
||||||
|
|
||||||
|
# there is a difference from alif formula
|
||||||
|
n.decayedEpsilonRec = n.alpha * n.epsilonRec
|
||||||
|
n.epsilonRec = n.decayedEpsilonRec + n.z_i_t
|
||||||
|
end
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
172
src/learn.jl
172
src/learn.jl
@@ -8,15 +8,9 @@ export learn!, compute_wRecChange!, computeModelError
|
|||||||
|
|
||||||
#------------------------------------------------------------------------------------------------100
|
#------------------------------------------------------------------------------------------------100
|
||||||
|
|
||||||
|
function compute_wRecChange!(m::model, modelError::Float64, outputError::Vector{Float64})
|
||||||
function computeModelError(modelRespond, correctAnswer; magnitude::Float64=1.0)
|
# normalize!(modelError)
|
||||||
error = ((correctAnswer .- modelRespond) .* magnitude)
|
compute_wRecChange!(m.knowledgeFn[:I], modelError, outputError)
|
||||||
|
|
||||||
return error::Vector{Float64}
|
|
||||||
end
|
|
||||||
|
|
||||||
function compute_wRecChange!(m::model, error::Vector{Float64}, correctAnswer::AbstractVector)
|
|
||||||
compute_wRecChange!(m.knowledgeFn[:I], error, correctAnswer)
|
|
||||||
end
|
end
|
||||||
|
|
||||||
# function compute_wRecChange!(kfn::kfn_1, errors::Vector{Float64}, correctAnswer::AbstractVector)
|
# function compute_wRecChange!(kfn::kfn_1, errors::Vector{Float64}, correctAnswer::AbstractVector)
|
||||||
@@ -47,54 +41,83 @@ end
|
|||||||
# end
|
# end
|
||||||
|
|
||||||
|
|
||||||
function compute_wRecChange!(kfn::kfn_1, errors::Vector{Float64}, correctAnswer::AbstractVector)
|
function compute_wRecChange!(kfn::kfn_1, modelError::Float64, outputError::Vector{Float64})
|
||||||
for (i, error) in enumerate(errors)
|
Threads.@threads for n in kfn.neuronsArray
|
||||||
if error < 0 # model fires too fast
|
# for n in kfn.neuronsArray
|
||||||
error = error *
|
if typeof(n)<: computeNeuron
|
||||||
abs(kfn.outputNeuronsArray[i].v_th - kfn.outputNeuronsArray[i].vError)
|
# wIndex = findall(isequal.(oN.subscriptionList, n.id))
|
||||||
elseif error == 0 # model answer correctly. maintain membrain potential ≈ 0.5
|
wOut = abs.([oN.wRec[findall(isequal.(oN.subscriptionList, n.id))[1]]
|
||||||
error = error *
|
for oN in kfn.outputNeuronsArray])
|
||||||
abs(kfn.outputNeuronsArray[i].v_th/2 - kfn.outputNeuronsArray[i].vError)
|
compute_wRecChange!(n, wOut, modelError)
|
||||||
else # model fires too slow
|
|
||||||
error = error *
|
|
||||||
abs(kfn.outputNeuronsArray[i].v_th - kfn.outputNeuronsArray[i].vError)
|
|
||||||
end
|
end
|
||||||
Threads.@threads for n in kfn.neuronsArray # multithread is not atomic and causing error
|
end
|
||||||
# for n in kfn.neuronsArray
|
|
||||||
compute_wRecChange!(n, error)
|
for oN in kfn.outputNeuronsArray
|
||||||
end
|
compute_wRecChange!(oN, outputError[oN.id])
|
||||||
compute_wRecChange!(kfn.outputNeuronsArray[i], error)
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
function compute_wRecChange!(n::passthroughNeuron, error::Float64)
|
function compute_wRecChange!(n::passthroughNeuron, wOut::AbstractVector, modelError::Vector{Float64})
|
||||||
# skip
|
# skip
|
||||||
end
|
end
|
||||||
|
|
||||||
function compute_wRecChange!(n::lifNeuron, error::Float64)
|
function compute_wRecChange!(n::lifNeuron, wOut::AbstractVector, modelError::Float64)
|
||||||
|
# how much error of this neuron 1-spike causing each output neuron's error
|
||||||
|
nError = sum(wOut * modelError)
|
||||||
|
|
||||||
n.eRec = n.phi * n.epsilonRec
|
n.eRec = n.phi * n.epsilonRec
|
||||||
ΔwRecChange = n.eta * error * n.eRec
|
ΔwRecChange = -n.eta * nError * n.eRec
|
||||||
|
# if sum(n.wRec) < 0 # prevent -sum(wRec) that causing neuron NOT fire at all
|
||||||
|
# ΔwRecChange .+= (0.2*(abs(sum(n.wRec)) / length(n.wRec)))
|
||||||
|
# end
|
||||||
n.wRecChange .+= ΔwRecChange
|
n.wRecChange .+= ΔwRecChange
|
||||||
reset_epsilonRec!(n)
|
# reset_epsilonRec!(n)
|
||||||
end
|
end
|
||||||
|
|
||||||
function compute_wRecChange!(n::alifNeuron, error::Float64)
|
function compute_wRecChange!(n::alifNeuron, wOut::AbstractVector, modelError::Float64)
|
||||||
|
# how much error of this neuron 1-spike causing each output neuron's error
|
||||||
|
# (prejected throug wOut)
|
||||||
|
nError = sum(wOut * modelError)
|
||||||
|
|
||||||
n.eRec_v = n.phi * n.epsilonRec
|
n.eRec_v = n.phi * n.epsilonRec
|
||||||
n.eRec_a = n.phi * n.beta * n.epsilonRecA
|
n.eRec_a = n.phi * n.beta * n.epsilonRecA
|
||||||
n.eRec = n.eRec_v + n.eRec_a
|
n.eRec = n.eRec_v + n.eRec_a
|
||||||
ΔwRecChange = n.eta * error * n.eRec
|
ΔwRecChange = -n.eta * nError * n.eRec
|
||||||
|
# if sum(n.wRec) < 0 # prevent -sum(wRec) that causing neuron NOT fire at all
|
||||||
|
# ΔwRecChange .+= (0.2*(abs(sum(n.wRec)) / length(n.wRec)))
|
||||||
|
# end
|
||||||
n.wRecChange .+= ΔwRecChange
|
n.wRecChange .+= ΔwRecChange
|
||||||
reset_epsilonRec!(n)
|
|
||||||
reset_epsilonRecA!(n)
|
# reset_epsilonRec!(n)
|
||||||
|
# reset_epsilonRecA!(n)
|
||||||
|
# n.alphaChange += compute_alphaChange(n.eta, nError)
|
||||||
end
|
end
|
||||||
|
|
||||||
function compute_wRecChange!(n::linearNeuron, error::Float64)
|
function compute_wRecChange!(n::integrateNeuron, error::Float64)
|
||||||
n.eRec = n.phi * n.epsilonRec
|
ΔwRecChange = -n.eta * error * n.epsilonRec
|
||||||
ΔwRecChange = n.eta * error * n.eRec
|
ΔbChange = -n.eta * error
|
||||||
|
# if sum(n.wRec) < 0 # prevent -sum(wRec) that causing neuron NOT fire at all
|
||||||
|
# ΔwRecChange .+= (abs(sum(n.wRec)) / length(n.wRec))
|
||||||
|
# end
|
||||||
n.wRecChange .+= ΔwRecChange
|
n.wRecChange .+= ΔwRecChange
|
||||||
reset_epsilonRec!(n)
|
n.bChange += ΔbChange
|
||||||
|
# reset_epsilonRec!(n)
|
||||||
|
# n.alphaChange += compute_alphaChange(n.eta, error)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
# function compute_wRecChange!(n::linearNeuron, error::Float64)
|
||||||
|
# n.eRec = n.phi * n.epsilonRec
|
||||||
|
# ΔwRecChange = -n.eta * error * n.eRec
|
||||||
|
# # if sum(n.wRec) < 0 # prevent -sum(wRec) that causing neuron NOT fire at all
|
||||||
|
# # ΔwRecChange .+= (abs(sum(n.wRec)) / length(n.wRec))
|
||||||
|
# # end
|
||||||
|
# n.wRecChange .+= ΔwRecChange
|
||||||
|
# # reset_epsilonRec!(n)
|
||||||
|
# end
|
||||||
|
|
||||||
|
# add compute_alphaChange
|
||||||
|
compute_alphaChange(learningRate::Float64, total_wRecChange) = -learningRate * total_wRecChange
|
||||||
|
|
||||||
function learn!(m::model)
|
function learn!(m::model)
|
||||||
learn!(m.knowledgeFn[:I])
|
learn!(m.knowledgeFn[:I])
|
||||||
end
|
end
|
||||||
@@ -105,8 +128,8 @@ function learn!(kfn::kfn_1)
|
|||||||
# compute kfn error for each neuron
|
# compute kfn error for each neuron
|
||||||
Threads.@threads for n in kfn.neuronsArray # multithread is not atomic and causing error
|
Threads.@threads for n in kfn.neuronsArray # multithread is not atomic and causing error
|
||||||
# for n in kfn.neuronsArray
|
# for n in kfn.neuronsArray
|
||||||
learn!(n, kfn.firedNeurons, kfn.nExInType)
|
learn!(n, kfn.firedNeurons, kfn.nExInType)
|
||||||
end
|
end
|
||||||
for n in kfn.outputNeuronsArray
|
for n in kfn.outputNeuronsArray
|
||||||
learn!(n, kfn.firedNeurons, kfn.nExInType, kfn.kfnParams[:totalInputPort])
|
learn!(n, kfn.firedNeurons, kfn.nExInType, kfn.kfnParams[:totalInputPort])
|
||||||
end
|
end
|
||||||
@@ -124,35 +147,70 @@ end
|
|||||||
function learn!(n::T, firedNeurons, nExInType) where T<:computeNeuron
|
function learn!(n::T, firedNeurons, nExInType) where T<:computeNeuron
|
||||||
wSign_0 = sign.(n.wRec) # original sign
|
wSign_0 = sign.(n.wRec) # original sign
|
||||||
# n.wRecChange .*= (connStrengthAdjust.(n.synapticStrength))
|
# n.wRecChange .*= (connStrengthAdjust.(n.synapticStrength))
|
||||||
n.wRec += n.wRecChange # merge wRecChange into wRec
|
|
||||||
reset_wRecChange!(n)
|
wRecChange_reduceCoeff = 1.0
|
||||||
|
# wRecChange_max = 0.2 * abs(sum(n.wRec)) # max change 20%
|
||||||
|
# y = abs(sum(n.wRecChange))
|
||||||
|
# if y > wRecChange_max # capping weight update
|
||||||
|
# wRecChange_reduceCoeff = wRecChange_max / y
|
||||||
|
# end
|
||||||
|
n.wRec += (wRecChange_reduceCoeff * n.wRecChange)
|
||||||
|
# n.alpha += n.alphaChange
|
||||||
|
|
||||||
wSign_1 = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign
|
wSign_1 = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign
|
||||||
nonFlipedSign = isequal.(wSign_0, wSign_1) # 1 not fliped, 0 fliped
|
nonFlipedSign = isequal.(wSign_0, wSign_1) # 1 not fliped, 0 fliped
|
||||||
# normalize wRec peak to prevent input signal overwhelming neuron
|
# normalize wRec peak to prevent input signal overwhelming neuron
|
||||||
normalizePeak!(n.wRec, n.wRecChange, 2)
|
# if sum(n.wRecChange) != 0
|
||||||
|
# normalizePeak!(n.wRec, n.wRecChange, 2)
|
||||||
|
# end
|
||||||
# set weight that fliped sign to 0 for random new connection
|
# set weight that fliped sign to 0 for random new connection
|
||||||
n.wRec .*= nonFlipedSign
|
n.wRec .*= nonFlipedSign
|
||||||
capMaxWeight!(n.wRec) # cap maximum weight
|
# capMaxWeight!(n.wRec) # cap maximum weight
|
||||||
synapticConnStrength!(n, "updown")
|
synapticConnStrength!(n, "updown")
|
||||||
neuroplasticity!(n, firedNeurons, nExInType)
|
neuroplasticity!(n, firedNeurons, nExInType)
|
||||||
end
|
end
|
||||||
|
|
||||||
function learn!(n::T, firedNeurons, nExInType, totalInputPort) where T<:outputNeuron
|
function learn!(n::integrateNeuron, firedNeurons, nExInType, totalInputPort)
|
||||||
wSign_0 = sign.(n.wRec) # original sign
|
wRecChange_reduceCoeff = 1.0
|
||||||
# n.wRecChange .*= (connStrengthAdjust.(n.synapticStrength))
|
# wRecChange_max = 0.2 * abs(sum(n.wRec)) # max change 20%
|
||||||
n.wRec += n.wRecChange
|
# y = abs(sum(n.wRecChange))
|
||||||
reset_wRecChange!(n)
|
# if y > wRecChange_max # capping weight update
|
||||||
wSign_1 = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign
|
# wRecChange_reduceCoeff = wRecChange_max / y
|
||||||
nonFlipedSign = isequal.(wSign_0, wSign_1) # 1 not fliped, 0 fliped
|
# end
|
||||||
# normalize wRec peak to prevent input signal overwhelming neuron
|
n.wRec += (wRecChange_reduceCoeff * n.wRecChange)
|
||||||
normalizePeak!(n.wRec, n.wRecChange, 2)
|
n.b += (wRecChange_reduceCoeff * n.bChange)
|
||||||
# set weight that fliped sign to 0 for random new connection
|
# n.alpha += n.alphaChange
|
||||||
n.wRec .*= nonFlipedSign
|
|
||||||
capMaxWeight!(n.wRec) # cap maximum weight
|
|
||||||
synapticConnStrength!(n, "updown")
|
|
||||||
neuroplasticity!(n,firedNeurons, nExInType, totalInputPort)
|
|
||||||
end
|
end
|
||||||
|
|
||||||
|
# function learn!(n::linearNeuron, firedNeurons, nExInType, totalInputPort)
|
||||||
|
# wSign_0 = sign.(n.wRec) # original sign
|
||||||
|
# # n.wRecChange .*= (connStrengthAdjust.(n.synapticStrength))
|
||||||
|
# wRecChange_max = 0.1 * abs(sum(n.wRec)) # max change 20%
|
||||||
|
# y = abs(sum(n.wRecChange))
|
||||||
|
# wRecChange_reduceCoeff = 1.0
|
||||||
|
# # if y > wRecChange_max # capping weight update
|
||||||
|
# # wRecChange_reduceCoeff = wRecChange_max / y
|
||||||
|
# # end
|
||||||
|
# n.wRec += (wRecChange_reduceCoeff * n.wRecChange)
|
||||||
|
# n.alpha += n.alphaChange
|
||||||
|
|
||||||
|
# wSign_1 = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign
|
||||||
|
# nonFlipedSign = isequal.(wSign_0, wSign_1) # 1 not fliped, 0 fliped
|
||||||
|
# # normalize wRec peak to prevent input signal overwhelming neuron
|
||||||
|
# if sum(n.wRecChange) != 0
|
||||||
|
# # normalizePeak!(n.wRec, n.wRecChange, 2)
|
||||||
|
# end
|
||||||
|
# # set weight that fliped sign to 0 for random new connection
|
||||||
|
# # n.wRec .*= nonFlipedSign
|
||||||
|
# # capMaxWeight!(n.wRec) # cap maximum weight
|
||||||
|
# # synapticConnStrength!(n, "updown")
|
||||||
|
# # neuroplasticity!(n,firedNeurons, nExInType, totalInputPort)
|
||||||
|
# end
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
217
src/snn_utils.jl
217
src/snn_utils.jl
@@ -13,6 +13,7 @@ using GeneralUtils
|
|||||||
using ..types
|
using ..types
|
||||||
|
|
||||||
#------------------------------------------------------------------------------------------------100
|
#------------------------------------------------------------------------------------------------100
|
||||||
|
rng = MersenneTwister(1234)
|
||||||
|
|
||||||
function timestep_forward!(x::passthroughNeuron)
|
function timestep_forward!(x::passthroughNeuron)
|
||||||
x.z_t = x.z_t1
|
x.z_t = x.z_t1
|
||||||
@@ -30,6 +31,7 @@ precision(x::Array{<:Array}) = ( std(mean.(x)) / mean(mean.(x)) ) * 100
|
|||||||
reset_last_firing_time!(n::computeNeuron) = n.lastFiringTime = 0.0
|
reset_last_firing_time!(n::computeNeuron) = n.lastFiringTime = 0.0
|
||||||
reset_refractory_state_active!(n::computeNeuron) = n.refractory_state_active = false
|
reset_refractory_state_active!(n::computeNeuron) = n.refractory_state_active = false
|
||||||
reset_v_t!(n::neuron) = n.v_t = n.vRest
|
reset_v_t!(n::neuron) = n.v_t = n.vRest
|
||||||
|
reset_v_t1!(n::neuron) = n.v_t1 = n.vRest
|
||||||
reset_z_t!(n::computeNeuron) = n.z_t = false
|
reset_z_t!(n::computeNeuron) = n.z_t = false
|
||||||
reset_epsilonRec!(n::computeNeuron) = n.epsilonRec = n.epsilonRec * 0.0
|
reset_epsilonRec!(n::computeNeuron) = n.epsilonRec = n.epsilonRec * 0.0
|
||||||
reset_epsilonRec!(n::outputNeuron) = n.epsilonRec = n.epsilonRec * 0.0
|
reset_epsilonRec!(n::outputNeuron) = n.epsilonRec = n.epsilonRec * 0.0
|
||||||
@@ -46,66 +48,23 @@ reset_firing_counter!(n::Union{computeNeuron, outputNeuron}) = n.firingCounter =
|
|||||||
reset_firing_diff!(n::Union{computeNeuron, outputNeuron}) = n.firingDiff = n.firingDiff * 0.0
|
reset_firing_diff!(n::Union{computeNeuron, outputNeuron}) = n.firingDiff = n.firingDiff * 0.0
|
||||||
reset_refractoryCounter!(n::Union{computeNeuron, outputNeuron}) = n.refractoryCounter = n.refractoryCounter * 0.0
|
reset_refractoryCounter!(n::Union{computeNeuron, outputNeuron}) = n.refractoryCounter = n.refractoryCounter * 0.0
|
||||||
reset_z_i_t_commulative!(n::Union{computeNeuron, outputNeuron}) = n.z_i_t_commulative = n.z_i_t_commulative * 0.0
|
reset_z_i_t_commulative!(n::Union{computeNeuron, outputNeuron}) = n.z_i_t_commulative = n.z_i_t_commulative * 0.0
|
||||||
|
reset_alphaChange!(n::Union{computeNeuron, outputNeuron}) = n.alphaChange = n.alphaChange * 0.0
|
||||||
|
|
||||||
# reset function for output neuron
|
# reset function for output neuron
|
||||||
reset_epsilon_j!(n::linearNeuron) = n.epsilon_j = n.epsilon_j * 0.0
|
reset_epsilon_j!(n::linearNeuron) = n.epsilon_j = n.epsilon_j * 0.0
|
||||||
reset_out_t!(n::linearNeuron) = n.out_t = n.out_t * 0.0
|
reset_out_t!(n::linearNeuron) = n.out_t = n.out_t * 0.0
|
||||||
reset_w_out_change!(n::linearNeuron) = n.w_out_change = n.w_out_change * 0.0
|
reset_bChange!(n::integrateNeuron) = n.bChange = n.bChange * 0.0
|
||||||
reset_b_change!(n::linearNeuron) = n.b_change = n.b_change * 0.0
|
|
||||||
|
|
||||||
|
|
||||||
""" Reset a part of learning-related params that used to collect learning history during learning
|
|
||||||
session
|
|
||||||
"""
|
|
||||||
# function reset_learning_no_wchange!(n::lifNeuron)
|
|
||||||
# reset_epsilonRec!(n)
|
|
||||||
# # reset_v_t!(n)
|
|
||||||
# # reset_z_t!(n)
|
|
||||||
# # reset_reg_voltage_a!(n)
|
|
||||||
# # reset_reg_voltage_b!(n)
|
|
||||||
# # reset_reg_voltage_error!(n)
|
|
||||||
# reset_firing_counter!(n)
|
|
||||||
# reset_firing_diff!(n)
|
|
||||||
# reset_previous_error!(n)
|
|
||||||
# reset_error!(n)
|
|
||||||
|
|
||||||
# # # reset refractory state at the end of episode. Otherwise once neuron goes into refractory state,
|
|
||||||
# # # it will stay in refractory state forever
|
|
||||||
# # reset_refractory_state_active!(n)
|
|
||||||
# end
|
|
||||||
# function reset_learning_no_wchange!(n::Union{alifNeuron, elif_neuron})
|
|
||||||
# reset_epsilonRec!(n)
|
|
||||||
# reset_epsilonRecA!(n)
|
|
||||||
# reset_v_t!(n)
|
|
||||||
# reset_z_t!(n)
|
|
||||||
# # reset_a!(n)
|
|
||||||
# reset_reg_voltage_a!(n)
|
|
||||||
# reset_reg_voltage_b!(n)
|
|
||||||
# reset_reg_voltage_error!(n)
|
|
||||||
# reset_firing_counter!(n)
|
|
||||||
# reset_firing_diff!(n)
|
|
||||||
# reset_previous_error!(n)
|
|
||||||
# reset_error!(n)
|
|
||||||
|
|
||||||
# # reset refractory state at the end of episode. Otherwise once neuron goes into refractory state,
|
|
||||||
# # it will stay in refractory state forever
|
|
||||||
# reset_refractory_state_active!(n)
|
|
||||||
# end
|
|
||||||
# function reset_learning_no_wchange!(n::linearNeuron)
|
|
||||||
# reset_epsilon_j!(n)
|
|
||||||
# reset_out_t!(n)
|
|
||||||
# reset_error!(n)
|
|
||||||
# end
|
|
||||||
|
|
||||||
""" Reset all learning-related params at the END of learning session
|
""" Reset all learning-related params at the END of learning session
|
||||||
"""
|
"""
|
||||||
function resetLearningParams!(n::lifNeuron)
|
function resetLearningParams!(n::lifNeuron)
|
||||||
reset_epsilonRec!(n)
|
reset_epsilonRec!(n)
|
||||||
reset_wRecChange!(n)
|
reset_wRecChange!(n)
|
||||||
# reset_v_t!(n)
|
reset_v_t!(n)
|
||||||
# reset_z_t!(n)
|
reset_z_t!(n)
|
||||||
reset_firing_counter!(n)
|
reset_firing_counter!(n)
|
||||||
reset_firing_diff!(n)
|
reset_firing_diff!(n)
|
||||||
|
reset_alphaChange!(n)
|
||||||
|
|
||||||
# reset refractory state at the start/end of episode. Otherwise once neuron goes into
|
# reset refractory state at the start/end of episode. Otherwise once neuron goes into
|
||||||
# refractory state, it will stay in refractory state forever
|
# refractory state, it will stay in refractory state forever
|
||||||
@@ -116,11 +75,12 @@ function resetLearningParams!(n::alifNeuron)
|
|||||||
reset_epsilonRec!(n)
|
reset_epsilonRec!(n)
|
||||||
reset_epsilonRecA!(n)
|
reset_epsilonRecA!(n)
|
||||||
reset_wRecChange!(n)
|
reset_wRecChange!(n)
|
||||||
# reset_v_t!(n)
|
reset_v_t!(n)
|
||||||
# reset_z_t!(n)
|
reset_z_t!(n)
|
||||||
# reset_a!(n)
|
reset_a!(n)
|
||||||
reset_firing_counter!(n)
|
reset_firing_counter!(n)
|
||||||
reset_firing_diff!(n)
|
reset_firing_diff!(n)
|
||||||
|
reset_alphaChange!(n)
|
||||||
|
|
||||||
# reset refractory state at the start/end of episode. Otherwise once neuron goes into
|
# reset refractory state at the start/end of episode. Otherwise once neuron goes into
|
||||||
# refractory state, it will stay in refractory state forever
|
# refractory state, it will stay in refractory state forever
|
||||||
@@ -135,21 +95,31 @@ function resetLearningParams!(n::passthroughNeuron)
|
|||||||
# skip
|
# skip
|
||||||
end
|
end
|
||||||
|
|
||||||
function resetLearningParams!(n::linearNeuron)
|
# function resetLearningParams!(n::linearNeuron)
|
||||||
|
# reset_epsilonRec!(n)
|
||||||
|
# reset_wRecChange!(n)
|
||||||
|
# # reset_v_t!(n)
|
||||||
|
# reset_firing_counter!(n)
|
||||||
|
|
||||||
|
# # reset refractory state at the start/end of episode. Otherwise once neuron goes into
|
||||||
|
# # refractory state, it will stay in refractory state forever
|
||||||
|
# # reset_refractoryCounter!(n)
|
||||||
|
# reset_z_i_t_commulative!(n)
|
||||||
|
# end
|
||||||
|
|
||||||
|
function resetLearningParams!(n::integrateNeuron)
|
||||||
reset_epsilonRec!(n)
|
reset_epsilonRec!(n)
|
||||||
reset_wRecChange!(n)
|
reset_wRecChange!(n)
|
||||||
# reset_v_t!(n)
|
reset_bChange!(n)
|
||||||
|
reset_v_t!(n)
|
||||||
reset_firing_counter!(n)
|
reset_firing_counter!(n)
|
||||||
|
reset_alphaChange!(n)
|
||||||
# reset refractory state at the start/end of episode. Otherwise once neuron goes into
|
|
||||||
# refractory state, it will stay in refractory state forever
|
|
||||||
# reset_refractoryCounter!(n)
|
|
||||||
reset_z_i_t_commulative!(n)
|
|
||||||
end
|
end
|
||||||
|
|
||||||
#------------------------------------------------------------------------------------------------100
|
#------------------------------------------------------------------------------------------------100
|
||||||
|
|
||||||
function store_knowledgefn_error!(kfn::knowledgeFn)
|
function store_knowledgefn_error!(kfn::knowledgeFn)
|
||||||
# condition to adjust nueron in KFN plane in addition to weight adjustment inside each neuron
|
# condition to adjust neuron in KFN plane in addition to weight adjustment inside each neuron
|
||||||
if kfn.learningStage == "start_learning"
|
if kfn.learningStage == "start_learning"
|
||||||
if kfn.recent_knowledgeFn_error === nothing && kfn.knowledgeFn_error === nothing
|
if kfn.recent_knowledgeFn_error === nothing && kfn.knowledgeFn_error === nothing
|
||||||
kfn.recent_knowledgeFn_error = [[]]
|
kfn.recent_knowledgeFn_error = [[]]
|
||||||
@@ -279,13 +249,13 @@ function synapticConnStrength(currentStrength::Float64, updown::String)
|
|||||||
|
|
||||||
if updown == "up"
|
if updown == "up"
|
||||||
if currentStrength > 4 # strong connection
|
if currentStrength > 4 # strong connection
|
||||||
updatedStrength = currentStrength + (Δstrength * 1.0)
|
updatedStrength = currentStrength + (Δstrength * 0.2)
|
||||||
else
|
else
|
||||||
updatedStrength = currentStrength + (Δstrength * 0.1)
|
updatedStrength = currentStrength + (Δstrength * 0.1)
|
||||||
end
|
end
|
||||||
elseif updown == "down"
|
elseif updown == "down"
|
||||||
if currentStrength > 4
|
if currentStrength > 4
|
||||||
updatedStrength = currentStrength - (Δstrength * 0.5)
|
updatedStrength = currentStrength - (Δstrength * 0.1)
|
||||||
else
|
else
|
||||||
updatedStrength = currentStrength - (Δstrength * 0.2)
|
updatedStrength = currentStrength - (Δstrength * 0.2)
|
||||||
end
|
end
|
||||||
@@ -358,81 +328,80 @@ function neuroplasticity!(n::computeNeuron, firedNeurons::Vector,
|
|||||||
nExInTypeList::Vector)
|
nExInTypeList::Vector)
|
||||||
# if there is 0-weight then replace it with new connection
|
# if there is 0-weight then replace it with new connection
|
||||||
zeroWeightConnIndex = findall(iszero.(n.wRec)) # connection that has 0 weight
|
zeroWeightConnIndex = findall(iszero.(n.wRec)) # connection that has 0 weight
|
||||||
|
if length(zeroWeightConnIndex) != 0
|
||||||
|
# new synaptic connection must sample fron neuron that fires
|
||||||
|
nFiredPool = filter(x -> x ∉ [n.id], firedNeurons) # exclude this neuron id from the id list
|
||||||
|
filter!(x -> x ∉ n.subscriptionList, nFiredPool) # exclude this neuron's subscriptionList from the list
|
||||||
|
|
||||||
# new synaptic connection must sample fron neuron that fires
|
nNonFiredPool = setdiff!([1:length(nExInTypeList)...], nFiredPool)
|
||||||
nFiredPool = filter(x -> x ∉ [n.id], firedNeurons) # exclude this neuron id from the id list
|
|
||||||
filter!(x -> x ∉ n.subscriptionList, nFiredPool) # exclude this neuron's subscriptionList from the list
|
|
||||||
|
|
||||||
nNonFiredPool = setdiff!([1:length(nExInTypeList)...], nFiredPool)
|
filter!(x -> x ∉ [n.id], nNonFiredPool) # exclude this neuron id from the id list
|
||||||
filter!(x -> x ∉ [n.id], nNonFiredPool) # exclude this neuron id from the id list
|
filter!(x -> x ∉ n.subscriptionList, nNonFiredPool) # exclude this neuron's subscriptionList from the list
|
||||||
filter!(x -> x ∉ n.subscriptionList, nNonFiredPool) # exclude this neuron's subscriptionList from the list
|
|
||||||
|
|
||||||
w = rand(0.01:0.01:0.2, length(zeroWeightConnIndex))
|
w = randn(length(zeroWeightConnIndex)) / 100
|
||||||
synapticStrength = rand(-5:0.01:-4, length(zeroWeightConnIndex))
|
synapticStrength = rand(-4.5:0.1:-3.5, length(zeroWeightConnIndex))
|
||||||
|
|
||||||
shuffle!(nFiredPool)
|
shuffle!(nFiredPool)
|
||||||
shuffle!(nNonFiredPool)
|
shuffle!(nNonFiredPool)
|
||||||
|
|
||||||
# add new synaptic connection to neuron
|
# add new synaptic connection to neuron
|
||||||
for (i, connIndex) in enumerate(zeroWeightConnIndex)
|
for (i, connIndex) in enumerate(zeroWeightConnIndex)
|
||||||
if length(nFiredPool) != 0
|
""" conn that is being replaced has to go into nNonFiredPool so
|
||||||
newConn = popfirst!(nFiredPool)
|
nNonFiredPool isn't empty """
|
||||||
else
|
|
||||||
newConn = popfirst!(nNonFiredPool)
|
|
||||||
end
|
|
||||||
|
|
||||||
""" conn that is being replaced has to go into nNonFiredPool so nNonFiredPool isn't empty
|
|
||||||
"""
|
|
||||||
push!(nNonFiredPool, n.subscriptionList[connIndex])
|
|
||||||
n.subscriptionList[connIndex] = newConn
|
|
||||||
n.wRec[connIndex] = w[i] * nExInTypeList[newConn]
|
|
||||||
n.synapticStrength[connIndex] = synapticStrength[i]
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
function neuroplasticity!(n::outputNeuron, firedNeurons::Vector,
|
|
||||||
nExInTypeList::Vector, totalInputNeuron::Integer)
|
|
||||||
# if there is 0-weight then replace it with new connection
|
|
||||||
zeroWeightConnIndex = findall(iszero.(n.wRec)) # connection that has 0 weight
|
|
||||||
|
|
||||||
# new synaptic connection must sample fron neuron that fires
|
|
||||||
nFiredPool = filter(x -> x ∉ [n.id], firedNeurons) # exclude this neuron id from the id list
|
|
||||||
filter!(x -> x ∉ n.subscriptionList, nFiredPool) # exclude this neuron's subscriptionList from the list
|
|
||||||
filter!(x -> x ∉ [1:totalInputNeuron...], nFiredPool) # exclude input neuron
|
|
||||||
|
|
||||||
nNonFiredPool = setdiff!([1:length(nExInTypeList)...], nFiredPool)
|
|
||||||
filter!(x -> x ∉ [n.id], nNonFiredPool) # exclude this neuron id from the id list
|
|
||||||
filter!(x -> x ∉ n.subscriptionList, nNonFiredPool) # exclude this neuron's subscriptionList from the list
|
|
||||||
filter!(x -> x ∉ [1:totalInputNeuron...], nNonFiredPool) # exclude input neuron
|
|
||||||
|
|
||||||
w = rand(0.01:0.01:0.2, length(zeroWeightConnIndex))
|
|
||||||
synapticStrength = rand(-5:0.01:-4, length(zeroWeightConnIndex))
|
|
||||||
|
|
||||||
shuffle!(nFiredPool)
|
|
||||||
shuffle!(nNonFiredPool)
|
|
||||||
|
|
||||||
# add new synaptic connection to neuron
|
|
||||||
for (i, connIndex) in enumerate(zeroWeightConnIndex)
|
|
||||||
newConn::Int64 = 0
|
|
||||||
if length(nFiredPool) != 0
|
|
||||||
newConn = popfirst!(nFiredPool)
|
|
||||||
elseif length(nNonFiredPool) != 0
|
|
||||||
newConn = popfirst!(nNonFiredPool)
|
|
||||||
else
|
|
||||||
# skip
|
|
||||||
end
|
|
||||||
|
|
||||||
if newConn != 0
|
|
||||||
""" conn that is being replaced has to go into nNonFiredPool so nNonFiredPool isn't empty
|
|
||||||
"""
|
|
||||||
push!(nNonFiredPool, n.subscriptionList[connIndex])
|
push!(nNonFiredPool, n.subscriptionList[connIndex])
|
||||||
|
|
||||||
|
if length(nFiredPool) != 0
|
||||||
|
newConn = popfirst!(nFiredPool)
|
||||||
|
else
|
||||||
|
newConn = popfirst!(nNonFiredPool)
|
||||||
|
end
|
||||||
n.subscriptionList[connIndex] = newConn
|
n.subscriptionList[connIndex] = newConn
|
||||||
n.wRec[connIndex] = w[i] * nExInTypeList[newConn]
|
n.wRec[connIndex] = abs(w[i]) * nExInTypeList[newConn]
|
||||||
n.synapticStrength[connIndex] = synapticStrength[i]
|
n.synapticStrength[connIndex] = synapticStrength[i]
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
# function neuroplasticity!(n::outputNeuron, firedNeurons::Vector,
|
||||||
|
# nExInTypeList::Vector, totalInputNeuron::Integer)
|
||||||
|
# # if there is 0-weight then replace it with new connection
|
||||||
|
# zeroWeightConnIndex = findall(iszero.(n.wRec)) # connection that has 0 weight
|
||||||
|
# if length(zeroWeightConnIndex) != 0
|
||||||
|
# # new synaptic connection must sample fron neuron that fires
|
||||||
|
# nFiredPool = filter(x -> x ∉ [n.id], firedNeurons) # exclude this neuron id from the id list
|
||||||
|
# filter!(x -> x ∉ n.subscriptionList, nFiredPool) # exclude this neuron's subscriptionList from the list
|
||||||
|
# filter!(x -> x ∉ [1:totalInputNeuron...], nFiredPool) # exclude input neuron
|
||||||
|
|
||||||
|
# nNonFiredPool = setdiff!([1:length(nExInTypeList)...], nFiredPool)
|
||||||
|
# unique!(append!(nNonFiredPool, zeroWeightConnIndex))
|
||||||
|
# filter!(x -> x ∉ [n.id], nNonFiredPool) # exclude this neuron id from the id list
|
||||||
|
# filter!(x -> x ∉ n.subscriptionList, nNonFiredPool) # exclude this neuron's subscriptionList from the list
|
||||||
|
# filter!(x -> x ∉ [1:totalInputNeuron...], nNonFiredPool) # exclude input neuron
|
||||||
|
|
||||||
|
# w = randn(length(zeroWeightConnIndex)) / 100
|
||||||
|
# synapticStrength = rand(-4.5:0.1:-3.5, length(zeroWeightConnIndex))
|
||||||
|
|
||||||
|
# shuffle!(nFiredPool)
|
||||||
|
# shuffle!(nNonFiredPool)
|
||||||
|
|
||||||
|
# # add new synaptic connection to neuron
|
||||||
|
# for (i, connIndex) in enumerate(zeroWeightConnIndex)
|
||||||
|
# """ conn that is being replaced has to go into nNonFiredPool so
|
||||||
|
# nNonFiredPool isn't empty """
|
||||||
|
# push!(nNonFiredPool, n.subscriptionList[connIndex])
|
||||||
|
|
||||||
|
# if length(nFiredPool) != 0
|
||||||
|
# newConn = popfirst!(nFiredPool)
|
||||||
|
# else
|
||||||
|
# newConn = popfirst!(nNonFiredPool)
|
||||||
|
# end
|
||||||
|
# n.subscriptionList[connIndex] = newConn
|
||||||
|
# n.wRec[connIndex] = w[i] * nExInTypeList[newConn]
|
||||||
|
# n.synapticStrength[connIndex] = synapticStrength[i]
|
||||||
|
# end
|
||||||
|
# end
|
||||||
|
# end
|
||||||
|
|
||||||
""" Cap maximum weight of each neuron connection
|
""" Cap maximum weight of each neuron connection
|
||||||
"""
|
"""
|
||||||
function capMaxWeight!(v::Vector{Float64}, max=1.0)
|
function capMaxWeight!(v::Vector{Float64}, max=1.0)
|
||||||
|
|||||||
153
src/types.jl
153
src/types.jl
@@ -4,6 +4,7 @@ export
|
|||||||
# struct
|
# struct
|
||||||
IronpenStruct, model, knowledgeFn, lifNeuron, alifNeuron, linearNeuron,
|
IronpenStruct, model, knowledgeFn, lifNeuron, alifNeuron, linearNeuron,
|
||||||
kfn_1, inputNeuron, computeNeuron, neuron, outputNeuron, passthroughNeuron,
|
kfn_1, inputNeuron, computeNeuron, neuron, outputNeuron, passthroughNeuron,
|
||||||
|
integrateNeuron,
|
||||||
|
|
||||||
# function
|
# function
|
||||||
instantiate_custom_types, init_neuron, populate_neuron,
|
instantiate_custom_types, init_neuron, populate_neuron,
|
||||||
@@ -22,6 +23,8 @@ abstract type computeNeuron <: neuron end
|
|||||||
|
|
||||||
#------------------------------------------------------------------------------------------------100
|
#------------------------------------------------------------------------------------------------100
|
||||||
|
|
||||||
|
rng = MersenneTwister(1234)
|
||||||
|
|
||||||
""" Model struct
|
""" Model struct
|
||||||
"""
|
"""
|
||||||
Base.@kwdef mutable struct model <: Ironpen
|
Base.@kwdef mutable struct model <: Ironpen
|
||||||
@@ -262,16 +265,16 @@ function kfn_1(kfnParams::Dict)
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
# add ExInType into each output neuron subExInType
|
# # add ExInType into each output neuron subExInType
|
||||||
for n in kfn.outputNeuronsArray
|
# for n in kfn.outputNeuronsArray
|
||||||
try # input neuron doest have n.subscriptionList
|
# try # input neuron doest have n.subscriptionList
|
||||||
for (i, sub_id) in enumerate(n.subscriptionList)
|
# for (i, sub_id) in enumerate(n.subscriptionList)
|
||||||
n_ExInType = kfn.neuronsArray[sub_id].ExInType
|
# n_ExInType = kfn.neuronsArray[sub_id].ExInType
|
||||||
n.wRec[i] *= n_ExInType
|
# n.wRec[i] *= n_ExInType
|
||||||
end
|
# end
|
||||||
catch
|
# catch
|
||||||
end
|
# end
|
||||||
end
|
# end
|
||||||
|
|
||||||
for n in kfn.neuronsArray
|
for n in kfn.neuronsArray
|
||||||
push!(kfn.nExInType, n.ExInType)
|
push!(kfn.nExInType, n.ExInType)
|
||||||
@@ -339,6 +342,7 @@ Base.@kwdef mutable struct lifNeuron <: computeNeuron
|
|||||||
|
|
||||||
gammaPd::Float64 = 0.3 # γ_pd, discount factor, value from paper
|
gammaPd::Float64 = 0.3 # γ_pd, discount factor, value from paper
|
||||||
alpha::Float64 = 0.0 # α, neuron membrane potential decay factor
|
alpha::Float64 = 0.0 # α, neuron membrane potential decay factor
|
||||||
|
alphaChange::Float64 = 0.0
|
||||||
phi::Float64 = 0.0 # ϕ, psuedo derivative
|
phi::Float64 = 0.0 # ϕ, psuedo derivative
|
||||||
epsilonRec::Array{Float64} = Float64[] # ϵ_rec, eligibility vector for neuron spike
|
epsilonRec::Array{Float64} = Float64[] # ϵ_rec, eligibility vector for neuron spike
|
||||||
decayedEpsilonRec::Array{Float64} = Float64[] # α * epsilonRec
|
decayedEpsilonRec::Array{Float64} = Float64[] # α * epsilonRec
|
||||||
@@ -347,7 +351,7 @@ Base.@kwdef mutable struct lifNeuron <: computeNeuron
|
|||||||
refractoryDuration::Int64 = 3 # neuron's refratory period in millisecond
|
refractoryDuration::Int64 = 3 # neuron's refratory period in millisecond
|
||||||
refractoryCounter::Int64 = 0
|
refractoryCounter::Int64 = 0
|
||||||
tau_m::Float64 = 100.0 # τ_m, membrane time constant in millisecond
|
tau_m::Float64 = 100.0 # τ_m, membrane time constant in millisecond
|
||||||
eta::Float64 = 0.01 # η, learning rate
|
eta::Float64 = 1e-3 # η, learning rate
|
||||||
wRecChange::Array{Float64} = Float64[] # Δw_rec, cumulated wRec change
|
wRecChange::Array{Float64} = Float64[] # Δw_rec, cumulated wRec change
|
||||||
recSignal::Float64 = 0.0 # incoming recurrent signal
|
recSignal::Float64 = 0.0 # incoming recurrent signal
|
||||||
alpha_v_t::Float64 = 0.0 # alpha * v_t
|
alpha_v_t::Float64 = 0.0 # alpha * v_t
|
||||||
@@ -428,6 +432,7 @@ Base.@kwdef mutable struct alifNeuron <: computeNeuron
|
|||||||
synapticStrengthLimit::NamedTuple = (lowerlimit=(-5=>0), upperlimit=(5=>5))
|
synapticStrengthLimit::NamedTuple = (lowerlimit=(-5=>0), upperlimit=(5=>5))
|
||||||
|
|
||||||
alpha::Float64 = 0.0 # α, neuron membrane potential decay factor
|
alpha::Float64 = 0.0 # α, neuron membrane potential decay factor
|
||||||
|
alphaChange::Float64 = 0.0
|
||||||
delta::Float64 = 1.0 # δ, discreate timestep size in millisecond
|
delta::Float64 = 1.0 # δ, discreate timestep size in millisecond
|
||||||
epsilonRec::Array{Float64} = Float64[] # ϵ_rec(v), eligibility vector for neuron i spike
|
epsilonRec::Array{Float64} = Float64[] # ϵ_rec(v), eligibility vector for neuron i spike
|
||||||
epsilonRecA::Array{Float64} = Float64[] # ϵ_rec(a)
|
epsilonRecA::Array{Float64} = Float64[] # ϵ_rec(a)
|
||||||
@@ -435,7 +440,7 @@ Base.@kwdef mutable struct alifNeuron <: computeNeuron
|
|||||||
eRec_v::Array{Float64} = Float64[] # a component of neuron's eligibility trace resulted from v_t
|
eRec_v::Array{Float64} = Float64[] # a component of neuron's eligibility trace resulted from v_t
|
||||||
eRec_a::Array{Float64} = Float64[] # a component of neuron's eligibility trace resulted from av_th
|
eRec_a::Array{Float64} = Float64[] # a component of neuron's eligibility trace resulted from av_th
|
||||||
eRec::Array{Float64} = Float64[] # neuron's eligibility trace
|
eRec::Array{Float64} = Float64[] # neuron's eligibility trace
|
||||||
eta::Float64 = 0.01 # eta, learning rate
|
eta::Float64 = 1e-3 # eta, learning rate
|
||||||
gammaPd::Float64 = 0.3 # γ_pd, discount factor, value from paper
|
gammaPd::Float64 = 0.3 # γ_pd, discount factor, value from paper
|
||||||
phi::Float64 = 0.0 # ϕ, psuedo derivative
|
phi::Float64 = 0.0 # ϕ, psuedo derivative
|
||||||
refractoryDuration::Int64 = 3 # neuron's refractory period in millisecond
|
refractoryDuration::Int64 = 3 # neuron's refractory period in millisecond
|
||||||
@@ -510,7 +515,7 @@ end
|
|||||||
""" linearNeuron struct
|
""" linearNeuron struct
|
||||||
"""
|
"""
|
||||||
Base.@kwdef mutable struct linearNeuron <: outputNeuron
|
Base.@kwdef mutable struct linearNeuron <: outputNeuron
|
||||||
id::Float64 = 0.0 # ID of this neuron which is it position in knowledgeFn array
|
id::Int64 = 0 # ID of this neuron which is it position in knowledgeFn array
|
||||||
type::String = "linearNeuron"
|
type::String = "linearNeuron"
|
||||||
knowledgeFnName::String = "not defined" # knowledgeFn that this neuron belongs to
|
knowledgeFnName::String = "not defined" # knowledgeFn that this neuron belongs to
|
||||||
subscriptionList::Array{Int64} = Int64[] # list of other neuron that this neuron synapse subscribed to
|
subscriptionList::Array{Int64} = Int64[] # list of other neuron that this neuron synapse subscribed to
|
||||||
@@ -545,7 +550,7 @@ Base.@kwdef mutable struct linearNeuron <: outputNeuron
|
|||||||
refractoryDuration::Int64 = 3 # neuron's refratory period in millisecond
|
refractoryDuration::Int64 = 3 # neuron's refratory period in millisecond
|
||||||
refractoryCounter::Int64 = 0
|
refractoryCounter::Int64 = 0
|
||||||
tau_out::Float64 = 50.0 # τ_out, membrane time constant in millisecond
|
tau_out::Float64 = 50.0 # τ_out, membrane time constant in millisecond
|
||||||
eta::Float64 = 0.01 # η, learning rate
|
eta::Float64 = 1e-3 # η, learning rate
|
||||||
wRecChange::Array{Float64} = Float64[] # Δw_rec, cumulated wRec change
|
wRecChange::Array{Float64} = Float64[] # Δw_rec, cumulated wRec change
|
||||||
recSignal::Float64 = 0.0 # incoming recurrent signal
|
recSignal::Float64 = 0.0 # incoming recurrent signal
|
||||||
alpha_v_t::Float64 = 0.0 # alpha * v_t
|
alpha_v_t::Float64 = 0.0 # alpha * v_t
|
||||||
@@ -584,6 +589,87 @@ function linearNeuron(params::Dict)
|
|||||||
return n
|
return n
|
||||||
end
|
end
|
||||||
|
|
||||||
|
#------------------------------------------------------------------------------------------------100
|
||||||
|
""" integrateNeuron struct
|
||||||
|
"""
|
||||||
|
Base.@kwdef mutable struct integrateNeuron <: outputNeuron
|
||||||
|
id::Int64 = 0 # ID of this neuron which is it position in knowledgeFn array
|
||||||
|
type::String = "integrateNeuron"
|
||||||
|
knowledgeFnName::String = "not defined" # knowledgeFn that this neuron belongs to
|
||||||
|
subscriptionList::Array{Int64} = Int64[] # list of other neuron that this neuron synapse subscribed to
|
||||||
|
timeStep::Int64 = 0 # current time
|
||||||
|
wRec::Array{Float64} = Float64[] # synaptic weight (for receiving signal from other neuron)
|
||||||
|
v_t::Float64 = randn() # vᵗ, postsynaptic neuron membrane potential of previous timestep
|
||||||
|
v_t1::Float64 = 0.0 # vᵗ⁺¹, postsynaptic neuron membrane potential at current timestep
|
||||||
|
v_th::Float64 = 1.0 # vᵗʰ, neuron firing threshold
|
||||||
|
vRest::Float64 = 0.0 # resting potential after neuron fired
|
||||||
|
vError::Float64 = 0.0 # used to compute model error
|
||||||
|
z_t::Bool = false # zᵗ, neuron postsynaptic firing of previous timestep
|
||||||
|
# zᵗ⁺¹, neuron firing status at time = t+1. I need this because the way I calculate all
|
||||||
|
# neurons forward function at each timestep-by-timestep is to do every neuron
|
||||||
|
# forward calculation. Each neuron requires access to other neuron's firing status
|
||||||
|
# during v_t1 calculation hence I need a variable to hold z_t1 so that I'm not replacing z_t
|
||||||
|
z_t1::Bool = false # neuron postsynaptic firing at current timestep (after neuron's calculation)
|
||||||
|
b::Float64 = 0.0
|
||||||
|
bChange::Float64 = 0.0
|
||||||
|
|
||||||
|
# neuron presynaptic firing at current timestep (which is other neuron postsynaptic firing of
|
||||||
|
# previous timestep)
|
||||||
|
z_i_t::Array{Bool} = Bool[]
|
||||||
|
z_i_t_commulative::Array{Int64} = Int64[] # used to compute connection strength
|
||||||
|
synapticStrength::Array{Float64} = Float64[]
|
||||||
|
synapticStrengthLimit::NamedTuple = (lowerlimit=(-5=>-5), upperlimit=(5=>5))
|
||||||
|
|
||||||
|
gammaPd::Float64 = 0.3 # γ_pd, discount factor, value from paper
|
||||||
|
alpha::Float64 = 0.0 # α, neuron membrane potential decay factor
|
||||||
|
alphaChange::Float64 = 0.0
|
||||||
|
phi::Float64 = 0.0 # ϕ, psuedo derivative
|
||||||
|
epsilonRec::Array{Float64} = Float64[] # ϵ_rec, eligibility vector for neuron spike
|
||||||
|
decayedEpsilonRec::Array{Float64} = Float64[] # α * epsilonRec
|
||||||
|
eRec::Array{Float64} = Float64[] # eligibility trace for neuron spike
|
||||||
|
delta::Float64 = 1.0 # δ, discreate timestep size in millisecond
|
||||||
|
refractoryDuration::Int64 = 3 # neuron's refratory period in millisecond
|
||||||
|
refractoryCounter::Int64 = 0
|
||||||
|
tau_out::Float64 = 50.0 # τ_out, membrane time constant in millisecond
|
||||||
|
eta::Float64 = 1e-3 # η, learning rate
|
||||||
|
wRecChange::Array{Float64} = Float64[] # Δw_rec, cumulated wRec change
|
||||||
|
recSignal::Float64 = 0.0 # incoming recurrent signal
|
||||||
|
alpha_v_t::Float64 = 0.0 # alpha * v_t
|
||||||
|
|
||||||
|
firingCounter::Int64 = 0 # store how many times neuron fires
|
||||||
|
ExInSignalSum::Float64 = 0.0
|
||||||
|
end
|
||||||
|
|
||||||
|
""" linear neuron outer constructor
|
||||||
|
|
||||||
|
# Example
|
||||||
|
|
||||||
|
linear_neuron_params = Dict(
|
||||||
|
:type => "linearNeuron",
|
||||||
|
:k => 0.9, # output leakink coefficient
|
||||||
|
:tau_out => 5.0, # output time constant in millisecond. It should equals to time use for 1 sequence
|
||||||
|
:out => 0.0, # neuron's output value store here
|
||||||
|
)
|
||||||
|
|
||||||
|
neuron1 = linearNeuron(linear_neuron_params)
|
||||||
|
"""
|
||||||
|
function integrateNeuron(params::Dict)
|
||||||
|
n = integrateNeuron()
|
||||||
|
field_names = fieldnames(typeof(n))
|
||||||
|
for i in field_names
|
||||||
|
if i in keys(params)
|
||||||
|
if i == :optimiser
|
||||||
|
opt_type = string(split(params[i], ".")[end])
|
||||||
|
n.:($i) = load_optimiser(opt_type)
|
||||||
|
else
|
||||||
|
n.:($i) = params[i] # assign params to n struct fields
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
return n
|
||||||
|
end
|
||||||
|
|
||||||
#------------------------------------------------------------------------------------------------100
|
#------------------------------------------------------------------------------------------------100
|
||||||
|
|
||||||
# function load_optimiser(optimiser_name::String; params::Union{Dict,Nothing} = nothing)
|
# function load_optimiser(optimiser_name::String; params::Union{Dict,Nothing} = nothing)
|
||||||
@@ -634,10 +720,11 @@ function init_neuron!(id::Int64, n::lifNeuron, n_params::Dict, kfnParams::Dict)
|
|||||||
|
|
||||||
# prevent subscription to itself by removing this neuron id
|
# prevent subscription to itself by removing this neuron id
|
||||||
filter!(x -> x != n.id, n.subscriptionList)
|
filter!(x -> x != n.id, n.subscriptionList)
|
||||||
n.synapticStrength = rand(-5:0.01:-4, length(n.subscriptionList))
|
n.synapticStrength = rand(-4.5:0.01:-4, length(n.subscriptionList))
|
||||||
|
|
||||||
n.epsilonRec = zeros(length(n.subscriptionList))
|
n.epsilonRec = zeros(length(n.subscriptionList))
|
||||||
n.wRec = randn(length(n.subscriptionList))
|
# n.wRec = randn(length(n.subscriptionList))
|
||||||
|
n.wRec = randn(rng, length(n.subscriptionList)) / 100
|
||||||
n.wRecChange = zeros(length(n.subscriptionList))
|
n.wRecChange = zeros(length(n.subscriptionList))
|
||||||
n.alpha = calculate_α(n)
|
n.alpha = calculate_α(n)
|
||||||
n.z_i_t_commulative = zeros(length(n.subscriptionList))
|
n.z_i_t_commulative = zeros(length(n.subscriptionList))
|
||||||
@@ -654,10 +741,10 @@ function init_neuron!(id::Int64, n::alifNeuron, n_params::Dict,
|
|||||||
|
|
||||||
# prevent subscription to itself by removing this neuron id
|
# prevent subscription to itself by removing this neuron id
|
||||||
filter!(x -> x != n.id, n.subscriptionList)
|
filter!(x -> x != n.id, n.subscriptionList)
|
||||||
n.synapticStrength = rand(-5:0.01:-4, length(n.subscriptionList))
|
n.synapticStrength = rand(-4.5:0.01:-4, length(n.subscriptionList))
|
||||||
|
|
||||||
n.epsilonRec = zeros(length(n.subscriptionList))
|
n.epsilonRec = zeros(length(n.subscriptionList))
|
||||||
n.wRec = randn(length(n.subscriptionList))
|
n.wRec = randn(rng, length(n.subscriptionList)) / 100 # TODO use abs()
|
||||||
n.wRecChange = zeros(length(n.subscriptionList))
|
n.wRecChange = zeros(length(n.subscriptionList))
|
||||||
|
|
||||||
# the more time has passed from the last time neuron was activated, the more
|
# the more time has passed from the last time neuron was activated, the more
|
||||||
@@ -668,8 +755,7 @@ function init_neuron!(id::Int64, n::alifNeuron, n_params::Dict,
|
|||||||
n.z_i_t_commulative = zeros(length(n.subscriptionList))
|
n.z_i_t_commulative = zeros(length(n.subscriptionList))
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function init_neuron!(id::Int64, n::integrateNeuron, n_params::Dict, kfnParams::Dict)
|
||||||
function init_neuron!(id::Int64, n::linearNeuron, n_params::Dict, kfnParams::Dict)
|
|
||||||
n.id = id
|
n.id = id
|
||||||
n.knowledgeFnName = kfnParams[:knowledgeFnName]
|
n.knowledgeFnName = kfnParams[:knowledgeFnName]
|
||||||
|
|
||||||
@@ -677,15 +763,33 @@ function init_neuron!(id::Int64, n::linearNeuron, n_params::Dict, kfnParams::Dic
|
|||||||
subscription_numbers = Int(floor((n_params[:synapticConnectionPercent] / 100.0) *
|
subscription_numbers = Int(floor((n_params[:synapticConnectionPercent] / 100.0) *
|
||||||
kfnParams[:totalNeurons] - kfnParams[:totalInputPort]))
|
kfnParams[:totalNeurons] - kfnParams[:totalInputPort]))
|
||||||
n.subscriptionList = [pop!(subscription_options) for i = 1:subscription_numbers]
|
n.subscriptionList = [pop!(subscription_options) for i = 1:subscription_numbers]
|
||||||
n.synapticStrength = rand(-5:0.01:-4, length(n.subscriptionList))
|
n.synapticStrength = rand(-4.5:0.01:-4, length(n.subscriptionList))
|
||||||
|
|
||||||
n.epsilonRec = zeros(length(n.subscriptionList))
|
n.epsilonRec = zeros(length(n.subscriptionList))
|
||||||
n.wRec = randn(length(n.subscriptionList))
|
n.wRec = randn(rng, length(n.subscriptionList)) / 100
|
||||||
n.wRecChange = zeros(length(n.subscriptionList))
|
n.wRecChange = zeros(length(n.subscriptionList))
|
||||||
n.alpha = calculate_k(n)
|
n.alpha = calculate_k(n)
|
||||||
n.z_i_t_commulative = zeros(length(n.subscriptionList))
|
n.z_i_t_commulative = zeros(length(n.subscriptionList))
|
||||||
|
n.b = randn(rng) / 100
|
||||||
end
|
end
|
||||||
|
|
||||||
|
# function init_neuron!(id::Int64, n::linearNeuron, n_params::Dict, kfnParams::Dict)
|
||||||
|
# n.id = id
|
||||||
|
# n.knowledgeFnName = kfnParams[:knowledgeFnName]
|
||||||
|
|
||||||
|
# subscription_options = shuffle!([kfnParams[:totalInputPort]+1 : kfnParams[:totalNeurons]...])
|
||||||
|
# subscription_numbers = Int(floor((n_params[:synapticConnectionPercent] / 100.0) *
|
||||||
|
# kfnParams[:totalNeurons] - kfnParams[:totalInputPort]))
|
||||||
|
# n.subscriptionList = [pop!(subscription_options) for i = 1:subscription_numbers]
|
||||||
|
# n.synapticStrength = rand(-4.5:0.01:-4, length(n.subscriptionList))
|
||||||
|
|
||||||
|
# n.epsilonRec = zeros(length(n.subscriptionList))
|
||||||
|
# n.wRec = randn(rng, length(n.subscriptionList)) / 100
|
||||||
|
# n.wRecChange = zeros(length(n.subscriptionList))
|
||||||
|
# n.alpha = calculate_k(n)
|
||||||
|
# n.z_i_t_commulative = zeros(length(n.subscriptionList))
|
||||||
|
# end
|
||||||
|
|
||||||
""" Make a neuron intended for use with knowledgeFn
|
""" Make a neuron intended for use with knowledgeFn
|
||||||
"""
|
"""
|
||||||
function init_neuron(id::Int64, n_params::Dict, kfnParams::Dict)
|
function init_neuron(id::Int64, n_params::Dict, kfnParams::Dict)
|
||||||
@@ -715,7 +819,9 @@ function instantiate_custom_types(params::Union{Dict,Nothing} = nothing)
|
|||||||
elseif type == "alifNeuron"
|
elseif type == "alifNeuron"
|
||||||
return alifNeuron(params)
|
return alifNeuron(params)
|
||||||
elseif type == "linearNeuron"
|
elseif type == "linearNeuron"
|
||||||
return linearNeuron(params)
|
return linearNeuron(params)
|
||||||
|
elseif type == "integrateNeuron"
|
||||||
|
return integrateNeuron(params)
|
||||||
else
|
else
|
||||||
return nothing
|
return nothing
|
||||||
end
|
end
|
||||||
@@ -740,6 +846,7 @@ calculate_α(neuron::lifNeuron) = exp(-neuron.delta / neuron.tau_m)
|
|||||||
calculate_α(neuron::alifNeuron) = exp(-neuron.delta / neuron.tau_m)
|
calculate_α(neuron::alifNeuron) = exp(-neuron.delta / neuron.tau_m)
|
||||||
calculate_ρ(neuron::alifNeuron) = exp(-neuron.delta / neuron.tau_a)
|
calculate_ρ(neuron::alifNeuron) = exp(-neuron.delta / neuron.tau_a)
|
||||||
calculate_k(neuron::linearNeuron) = exp(-neuron.delta / neuron.tau_out)
|
calculate_k(neuron::linearNeuron) = exp(-neuron.delta / neuron.tau_out)
|
||||||
|
calculate_k(neuron::integrateNeuron) = exp(-neuron.delta / neuron.tau_out)
|
||||||
|
|
||||||
#------------------------------------------------------------------------------------------------100
|
#------------------------------------------------------------------------------------------------100
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user