diff --git a/example_main.jl b/example_main.jl new file mode 100644 index 0000000..12e4924 --- /dev/null +++ b/example_main.jl @@ -0,0 +1,826 @@ +using Pkg; Pkg.activate("."); Pkg.resolve(), Pkg.instantiate() +using Revise +using Flux #, CUDA +using BSON, JSON3 +using MLDatasets: MNIST +using MLUtils, Images, ProgressMeter, Dates, DataFrames, Random, Statistics, LinearAlgebra, + BenchmarkTools, Serialization, OneHotArrays , GLMakie # ClickHouse + + +# if one need to reinstall all python packages +# try Pkg.rm("PythonCall") catch end # should be removed before using CondaPkg to install packages +# condapackage = ["numpy", "pytorch", "snntorch"] +# using CondaPkg # in CondaPkg.toml file, channels = ["anaconda", "conda-forge", "pytorch"] +# for i in condapackage +# try CondaPkg.rm(i) catch end +# end +# for i in condapackage +# CondaPkg.add(i) +# end +# Pkg.add("PythonCall"); + +using PythonCall; +np = pyimport("numpy") +torch = pyimport("torch") +spikegen = pyimport("snntorch.spikegen") # https://github.com/jeshraghian/snntorch + +using Ironpen +using GeneralUtils + +sep = Sys.iswindows() ? "\\" : "/" +rootDir = pwd() + +# select compute device +# device = Flux.CUDA.functional() ? gpu : cpu +# if device == gpu +# CUDA.device!(3) +# end +#------------------------------------------------------------------------------------------------100 + + + +""" + Todo: + - [] + + Change from version: + - + + All features + - +""" + + +# communication config --------------------------------------------------------------------------100 + +database_ip = "localhost" +# database_ip = "192.168.0.8" + +#------------------------------------------------------------------------------------------------100 + +function generate_snn(filename::String, location::String) + expect_compute_neuron_numbers = 1024 #FIXME change to 512 + signalInput_portnumbers = 50 + noise_portnumbers = signalInput_portnumbers + output_portnumbers = 10 + + lif_neuron_number = Int(floor(expect_compute_neuron_numbers * 0.4)) + alif_neuron_number = expect_compute_neuron_numbers - lif_neuron_number # from Allen Institute, ALIF is 20-40% of LIF + computeNeuronNumber = lif_neuron_number + alif_neuron_number + + totalNeurons = computeNeuronNumber + noise_portnumbers + signalInput_portnumbers + totalInputPort = noise_portnumbers + signalInput_portnumbers + + # kfn and neuron config + passthrough_neuron_params = Dict( + :type => "passthroughNeuron" + ) + + lif_neuron_params = Dict{Symbol, Any}( + :type => "lifNeuron", + :v_t_default => 0.0, + :v_th => 1.0, # neuron firing threshold (this value is treated as maximum bound if I use auto generate) + :tau_m => 200.0, # membrane time constant in millisecond. + :eta => 1e-2, + # Good starting value is 1/10th of tau_a + # This is problem specific parameter. It controls how leaky the neuron is. + # Too high(less leaky) makes learning algo harder to move model into direction that reduce error + # resulting in model's error to explode exponantially likely because learning algo will try to + # exert more force (larger w_out_change) to move neuron into direction that reduce error + # For example, model error from 7 to 2e6. + + :synapticConnectionPercent => 50, # % coverage of total neurons in kfn + :w_rec_generation_pattern => "random", + ) + + alif_neuron_params = Dict{Symbol, Any}( + :type => "alifNeuron", + :v_t_default => 0.0, + :v_th => 1.0, # neuron firing threshold (this value is treated as maximum bound if I use auto generate) + :tau_m => 200.0, # membrane time constant in millisecond. + :eta => 1e-2, + # Good starting value is 1/10th of tau_a + # This is problem specific parameter. It controls how leaky the neuron is. + # Too high(less leaky) makes learning algo harder to move model into direction that reduce error + # resulting in model's error to explode exponantially likely because learning algo will try to + # exert more force (larger w_out_change) to move neuron into direction that reduce error + # For example, model error from 7 to 2e6. + + :tau_a => 500.0, # adaptation time constant in millisecond. it defines neuron memory length. + # This is problem specific parameter + # Good starting value is 0.5 to 2 times of info STORE-RECALL length i.e. total time SNN takes to + # perform a task, for example, equals to episode length. + # From "Spike frequency adaptation supports network computations on temporally dispersed + # information" + + :synapticConnectionPercent => 50, # % coverage of total neurons in kfn + :w_rec_generation_pattern => "random", + ) + + # linear_neuron_params = Dict{Symbol, Any}( + # :type => "linearNeuron", + # :v_th => 1.0, # neuron firing threshold (this value is treated as maximum bound if I use auto generate) + # :tau_out => 50.0, # output time constant in millisecond. + # :synapticConnectionPercent => 100, # % coverage of total neurons in kfn + # # Good starting value is 1/50th of tau_a + # # This is problem specific parameter. + # # It controls how leaky the neuron is. + # # Too high(less leaky) makes learning algo harder to move model into direction that reduce error + # # resulting in model's error to explode exponantially. For example, model error from 7 to 2e6 + # # One can image training output neuron is like Tetris Game. + # ) + + integrate_neuron_params = Dict{Symbol, Any}( + :type => "integrateNeuron", + :synapticConnectionPercent => 100, # % coverage of total neurons in kfn + :eta => 1e-2, + :tau_out => 100.0, + # Good starting value is 1/50th of tau_a + # This is problem specific parameter. + # It controls how leaky the neuron is. + # Too high(less leaky) makes learning algo harder to move model into direction that reduce error + # resulting in model's error to explode exponantially. For example, model error from 7 to 2e6 + # One can image training output neuron is like Tetris Game. + ) + + I_kfnparams = Dict{Symbol, Any}( + :knowledgeFnName=> "I", + :computeNeuronNumber=> computeNeuronNumber, + :neuronFiringRateTarget=> 10.0, # Hz + :Bn=> "random", # error projection coefficient for EACH neuron + :totalNeurons=> totalNeurons, + :totalInputPort=> totalInputPort, + :totalComputeNeuron=> computeNeuronNumber, + + # group relavent info + :inputPort=> Dict( + :noise=> Dict( + :numbers=> noise_portnumbers, + :params=> passthrough_neuron_params, + ), + :signal=> Dict( + :numbers=> signalInput_portnumbers, # in case of GloVe word encoding, it is 300 + :params=> passthrough_neuron_params, + ), + ), + :outputPort=> Dict( + :numbers=> output_portnumbers, # output neuron, this is also the output length + :params=> integrate_neuron_params, + ), + :computeNeuron=> Dict( + :1=> Dict( + :numbers=> lif_neuron_number, + :params=> lif_neuron_params, + ), + :2=> Dict( + :numbers=> alif_neuron_number, + :params=> alif_neuron_params, + ), + ), + ) + + #------------------------------------------------------------------------------------------------100 + + I_kfn = Ironpen.kfn_1(I_kfnparams) + + model_params_1 = Dict(:knowledgeFn => Dict( + :I => I_kfn), + ) + + model = Ironpen.model(model_params_1) + + serialize(location * sep * filename, model) + println("SNN generated") +end + +function data_loader() + # test problem + fullTrainDataset = MNIST(:train) + prototypeDataset = fullTrainDataset[1:10] # use reshape(test_dataset[1], (:, 1)) to flaten matrix + trainDataset = fullTrainDataset # total 60000 + validateDataset = fullTrainDataset[1:100] + labelDict = [0:9...] + + trainData = MLUtils.DataLoader( + trainDataset; # fullTrainDataset or trainDataset + batchsize=100, + collate=true, + shuffle=true, + buffer=true, + partial=false, # better for gpu memory if batchsize is fixed + # parallel=true, #BUG ?? causing dataloader into forever loop + ) + + validateData = MLUtils.DataLoader( + validateDataset; + batchsize=1, + collate=true, + shuffle=true, + buffer=true, + partial=false, # better for gpu memory if batchsize is fixed + # parallel=true, #BUG ?? causing dataloader into forever loop + ) + + #CHANGE dummy data used to debug + # trainData = [(rand(10, 10), [5]), (rand(10, 10), [2])] + # trainData = [(rand(10, 10), [5]),] + + return trainData, validateData, labelDict +end + +function train_snn(model_name::String, filename::String, location::String, + trainData, validateData, labelDict::Vector) + println("loading SNN model") + + model = deserialize(location * sep * filename) + println("model loading completed") + + # random seed + # rng = MersenneTwister(1234) + + logitLog = zeros(10, 2) + firedNeurons_t1 = zeros(1) + var1 = zeros(10, 2) + var2 = zeros(10, 2) + var3 = zeros(10, 2) + var4 = zeros(10, 2) + + # ----------------------------------- plot ----------------------------------- # + plot10 = Observable(firedNeurons_t1) + + plot20 = Observable(logitLog[1 , :]) + plot21 = Observable(logitLog[2 , :]) + plot22 = Observable(logitLog[3 , :]) + plot23 = Observable(logitLog[4 , :]) + plot24 = Observable(logitLog[5 , :]) + plot25 = Observable(logitLog[6 , :]) + plot26 = Observable(logitLog[7 , :]) + plot27 = Observable(logitLog[8 , :]) + plot28 = Observable(logitLog[9 , :]) + plot29 = Observable(logitLog[10, :]) + + plot30 = Observable(var1[1 , :]) + plot31 = Observable(var1[2 , :]) + plot32 = Observable(var1[3 , :]) + plot33 = Observable(var1[4 , :]) + plot34 = Observable(var1[5 , :]) + plot35 = Observable(var1[6 , :]) + plot36 = Observable(var1[7 , :]) + plot37 = Observable(var1[8 , :]) + plot38 = Observable(var1[9 , :]) + plot39 = Observable(var1[10, :]) + + plot40 = Observable(var2[1 , :]) + plot41 = Observable(var2[2 , :]) + plot42 = Observable(var2[3 , :]) + plot43 = Observable(var2[4 , :]) + plot44 = Observable(var2[5 , :]) + plot45 = Observable(var2[6 , :]) + plot46 = Observable(var2[7 , :]) + plot47 = Observable(var2[8 , :]) + plot48 = Observable(var2[9 , :]) + plot49 = Observable(var2[10, :]) + + plot50 = Observable(var3[1 , :]) + plot51 = Observable(var3[2 , :]) + plot52 = Observable(var3[3 , :]) + plot53 = Observable(var3[4 , :]) + plot54 = Observable(var3[5 , :]) + plot55 = Observable(var3[6 , :]) + plot56 = Observable(var3[7 , :]) + plot57 = Observable(var3[8 , :]) + plot58 = Observable(var3[9 , :]) + plot59 = Observable(var3[10, :]) + + plot60 = Observable(var4[1 , :]) + plot61 = Observable(var4[2 , :]) + plot62 = Observable(var4[3 , :]) + plot63 = Observable(var4[4 , :]) + plot64 = Observable(var4[5 , :]) + plot65 = Observable(var4[6 , :]) + plot66 = Observable(var4[7 , :]) + plot67 = Observable(var4[8 , :]) + plot68 = Observable(var4[9 , :]) + plot69 = Observable(var4[10, :]) + + # main figure + fig1 = Figure() + + subfig1 = GLMakie.Axis(fig1[1, 1], # define position of this subfigure inside a figure + title = "RSNN firedNeurons_t1", + xlabel = "time", + ylabel = "data" + ) + lines!(subfig1, plot10, label = "firedNeurons_t1") + axislegend(subfig1, position = :lb) + + subfig2 = GLMakie.Axis(fig1[2, 1], # define position of this subfigure inside a figure + title = "output neurons activation", + xlabel = "time", + ylabel = "data" + ) + + lines!(subfig2, plot20, label = "0", color = 1, colormap = :tab10, colorrange = (1, 10) ) + lines!(subfig2, plot21, label = "1", color = 2, colormap = :tab10, colorrange = (1, 10) ) + lines!(subfig2, plot22, label = "2", color = 3, colormap = :tab10, colorrange = (1, 10) ) + lines!(subfig2, plot23, label = "3", color = 4, colormap = :tab10, colorrange = (1, 10) ) + lines!(subfig2, plot24, label = "4", color = 5, colormap = :tab10, colorrange = (1, 10) ) + lines!(subfig2, plot25, label = "5", color = 6, colormap = :tab10, colorrange = (1, 10) ) + lines!(subfig2, plot26, label = "6", color = 7, colormap = :tab10, colorrange = (1, 10) ) + lines!(subfig2, plot27, label = "7", color = 8, colormap = :tab10, colorrange = (1, 10) ) + lines!(subfig2, plot28, label = "8", color = 9, colormap = :tab10, colorrange = (1, 10) ) + lines!(subfig2, plot29, label = "9", color = 10, colormap = :tab10, colorrange = (1, 10)) + axislegend(subfig2, position = :lb) + + + subfig3 = GLMakie.Axis(fig1[3, 1], # define position of this subfigure inside a figure + title = "output neurons membrane potential v_t1", + xlabel = "time", + ylabel = "data" + ) + lines!(subfig3, plot30, label = "0", color = 1, colormap = :tab10, colorrange = (1, 10) ) + lines!(subfig3, plot31, label = "1", color = 2, colormap = :tab10, colorrange = (1, 10) ) + lines!(subfig3, plot32, label = "2", color = 3, colormap = :tab10, colorrange = (1, 10) ) + lines!(subfig3, plot33, label = "3", color = 4, colormap = :tab10, colorrange = (1, 10) ) + lines!(subfig3, plot34, label = "4", color = 5, colormap = :tab10, colorrange = (1, 10) ) + lines!(subfig3, plot35, label = "5", color = 6, colormap = :tab10, colorrange = (1, 10) ) + lines!(subfig3, plot36, label = "6", color = 7, colormap = :tab10, colorrange = (1, 10) ) + lines!(subfig3, plot37, label = "7", color = 8, colormap = :tab10, colorrange = (1, 10) ) + lines!(subfig3, plot38, label = "8", color = 9, colormap = :tab10, colorrange = (1, 10) ) + lines!(subfig3, plot39, label = "9", color = 10, colormap = :tab10, colorrange = (1, 10)) + axislegend(subfig3, position = :lb) + + subfig4 = GLMakie.Axis(fig1[4, 1], # define position of this subfigure inside a figure + title = "output neuron wRec", + xlabel = "time", + ylabel = "data" + ) + lines!(subfig4, plot40, label = "0", color = 1, colormap = :tab10, colorrange = (1, 10) ) + lines!(subfig4, plot41, label = "1", color = 2, colormap = :tab10, colorrange = (1, 10) ) + lines!(subfig4, plot42, label = "2", color = 3, colormap = :tab10, colorrange = (1, 10) ) + lines!(subfig4, plot43, label = "3", color = 4, colormap = :tab10, colorrange = (1, 10) ) + lines!(subfig4, plot44, label = "4", color = 5, colormap = :tab10, colorrange = (1, 10) ) + lines!(subfig4, plot45, label = "5", color = 6, colormap = :tab10, colorrange = (1, 10) ) + lines!(subfig4, plot46, label = "6", color = 7, colormap = :tab10, colorrange = (1, 10) ) + lines!(subfig4, plot47, label = "7", color = 8, colormap = :tab10, colorrange = (1, 10) ) + lines!(subfig4, plot48, label = "8", color = 9, colormap = :tab10, colorrange = (1, 10) ) + lines!(subfig4, plot49, label = "9", color = 10, colormap = :tab10, colorrange = (1, 10)) + axislegend(subfig4, position = :lb) + + subfig5 = GLMakie.Axis(fig1[5, 1], # define position of this subfigure inside a figure + title = "output neuron epsilonRec", + xlabel = "time", + ylabel = "data" + ) + lines!(subfig5, plot50, label = "0", color = 1, colormap = :tab10, colorrange = (1, 10) ) + lines!(subfig5, plot51, label = "1", color = 2, colormap = :tab10, colorrange = (1, 10) ) + lines!(subfig5, plot52, label = "2", color = 3, colormap = :tab10, colorrange = (1, 10) ) + lines!(subfig5, plot53, label = "3", color = 4, colormap = :tab10, colorrange = (1, 10) ) + lines!(subfig5, plot54, label = "4", color = 5, colormap = :tab10, colorrange = (1, 10) ) + lines!(subfig5, plot55, label = "5", color = 6, colormap = :tab10, colorrange = (1, 10) ) + lines!(subfig5, plot56, label = "6", color = 7, colormap = :tab10, colorrange = (1, 10) ) + lines!(subfig5, plot57, label = "7", color = 8, colormap = :tab10, colorrange = (1, 10) ) + lines!(subfig5, plot58, label = "8", color = 9, colormap = :tab10, colorrange = (1, 10) ) + lines!(subfig5, plot59, label = "9", color = 10, colormap = :tab10, colorrange = (1, 10)) + axislegend(subfig5, position = :lb) + + subfig6 = GLMakie.Axis(fig1[6, 1], # define position of this subfigure inside a figure + title = "output neuron wRecChange", + xlabel = "time", + ylabel = "data" + ) + lines!(subfig6, plot60, label = "0", color = 1, colormap = :tab10, colorrange = (1, 10) ) + lines!(subfig6, plot61, label = "1", color = 2, colormap = :tab10, colorrange = (1, 10) ) + lines!(subfig6, plot62, label = "2", color = 3, colormap = :tab10, colorrange = (1, 10) ) + lines!(subfig6, plot63, label = "3", color = 4, colormap = :tab10, colorrange = (1, 10) ) + lines!(subfig6, plot64, label = "4", color = 5, colormap = :tab10, colorrange = (1, 10) ) + lines!(subfig6, plot65, label = "5", color = 6, colormap = :tab10, colorrange = (1, 10) ) + lines!(subfig6, plot66, label = "6", color = 7, colormap = :tab10, colorrange = (1, 10) ) + lines!(subfig6, plot67, label = "7", color = 8, colormap = :tab10, colorrange = (1, 10) ) + lines!(subfig6, plot68, label = "8", color = 9, colormap = :tab10, colorrange = (1, 10) ) + lines!(subfig6, plot69, label = "9", color = 10, colormap = :tab10, colorrange = (1, 10)) + axislegend(subfig6, position = :lb) + + # wait(display(fig1)) + # display(fig1) + # --------------------------------- end plot --------------------------------- # + + # model learning + maxRepeatRound = 1 # repeat each image + thinkingPeriod = 16 # 1000-784 = 216 + for epoch = 1:1000 + println("epoch $epoch") + for (imgBatch, labelBatch) in trainData + @showprogress for i in eachindex(labelBatch) + _img = (imgBatch[:, :, i]) + img = reshape(_img, (:, 1)) + row, col = size(img) + label = labelBatch[i] + println("epoch $epoch training label $label") + + img_tensor = torch.from_numpy( np.asarray(img) ) + + # create more data for RSNN + spike = spikegen.delta(img_tensor, threshold=0.1, off_spike=true) + spike1 = isequal.(pyconvert(Array, spike.data.numpy()), 1) + spike2 = isequal.(pyconvert(Array, spike.data.numpy()), 1) + spike3 = isequal.(pyconvert(Array, spike.data.numpy()), 1) + spike4 = isequal.(pyconvert(Array, spike.data.numpy()), 1) + spike5 = isequal.(pyconvert(Array, spike.data.numpy()), 1) + spike6 = isequal.(pyconvert(Array, spike.data.numpy()), -1) + spike7 = isequal.(pyconvert(Array, spike.data.numpy()), -1) + spike8 = isequal.(pyconvert(Array, spike.data.numpy()), -1) + spike9 = isequal.(pyconvert(Array, spike.data.numpy()), -1) + spike10 = isequal.(pyconvert(Array, spike.data.numpy()), -1) + + spike = spikegen.delta(img_tensor, threshold=0.2, off_spike=true) + spike11 = isequal.(pyconvert(Array, spike.data.numpy()), 1) + spike12 = isequal.(pyconvert(Array, spike.data.numpy()), 1) + spike13 = isequal.(pyconvert(Array, spike.data.numpy()), 1) + spike14 = isequal.(pyconvert(Array, spike.data.numpy()), 1) + spike15 = isequal.(pyconvert(Array, spike.data.numpy()), 1) + spike16 = isequal.(pyconvert(Array, spike.data.numpy()), -1) + spike17 = isequal.(pyconvert(Array, spike.data.numpy()), -1) + spike18 = isequal.(pyconvert(Array, spike.data.numpy()), -1) + spike19 = isequal.(pyconvert(Array, spike.data.numpy()), -1) + spike20 = isequal.(pyconvert(Array, spike.data.numpy()), -1) + + spike = spikegen.delta(img_tensor, threshold=0.3, off_spike=true) + spike21 = isequal.(pyconvert(Array, spike.data.numpy()), 1) + spike22 = isequal.(pyconvert(Array, spike.data.numpy()), 1) + spike23 = isequal.(pyconvert(Array, spike.data.numpy()), 1) + spike24 = isequal.(pyconvert(Array, spike.data.numpy()), 1) + spike25 = isequal.(pyconvert(Array, spike.data.numpy()), 1) + spike26 = isequal.(pyconvert(Array, spike.data.numpy()), -1) + spike27 = isequal.(pyconvert(Array, spike.data.numpy()), -1) + spike28 = isequal.(pyconvert(Array, spike.data.numpy()), -1) + spike29 = isequal.(pyconvert(Array, spike.data.numpy()), -1) + spike30 = isequal.(pyconvert(Array, spike.data.numpy()), -1) + + spike = spikegen.delta(img_tensor, threshold=0.4, off_spike=true) + spike31 = isequal.(pyconvert(Array, spike.data.numpy()), 1) + spike32 = isequal.(pyconvert(Array, spike.data.numpy()), 1) + spike33 = isequal.(pyconvert(Array, spike.data.numpy()), 1) + spike34 = isequal.(pyconvert(Array, spike.data.numpy()), 1) + spike35 = isequal.(pyconvert(Array, spike.data.numpy()), 1) + spike36 = isequal.(pyconvert(Array, spike.data.numpy()), -1) + spike37 = isequal.(pyconvert(Array, spike.data.numpy()), -1) + spike38 = isequal.(pyconvert(Array, spike.data.numpy()), -1) + spike39 = isequal.(pyconvert(Array, spike.data.numpy()), -1) + spike40 = isequal.(pyconvert(Array, spike.data.numpy()), -1) + + spike = spikegen.delta(img_tensor, threshold=0.5, off_spike=true) + spike41 = isequal.(pyconvert(Array, spike.data.numpy()), 1) + spike42 = isequal.(pyconvert(Array, spike.data.numpy()), 1) + spike43 = isequal.(pyconvert(Array, spike.data.numpy()), 1) + spike44 = isequal.(pyconvert(Array, spike.data.numpy()), 1) + spike45 = isequal.(pyconvert(Array, spike.data.numpy()), 1) + spike46 = isequal.(pyconvert(Array, spike.data.numpy()), -1) + spike47 = isequal.(pyconvert(Array, spike.data.numpy()), -1) + spike48 = isequal.(pyconvert(Array, spike.data.numpy()), -1) + spike49 = isequal.(pyconvert(Array, spike.data.numpy()), -1) + spike50 = isequal.(pyconvert(Array, spike.data.numpy()), -1) + + input = [spike1;; spike2;; spike3;; spike4;; spike5;; spike6;; spike7;; spike8;; spike9;; spike10;; + spike11;; spike12;; spike13;; spike14;; spike15;; spike16;; spike17;; spike18;; spike19;; spike20;; + spike21;; spike22;; spike23;; spike24;; spike25;; spike26;; spike27;; spike28;; spike29;; spike30;; + spike31;; spike32;; spike33;; spike34;; spike35;; spike36;; spike37;; spike38;; spike39;; spike40;; + spike41;; spike42;; spike43;; spike44;; spike45;; spike46;; spike47;; spike48;; spike49;; spike50 + ]' # ' to flip 784x10 to 10x784 + + predict = 0 + + for k in 1:maxRepeatRound + + # insert data into model sequencially + for i in 1:(row + thinkingPeriod) # sMNIST ihas 784 timestep(pixel) + thinking period = 1000 timestep + tick = i + if i <= row + current_pixel = input[:, i] + else + current_pixel = zeros(size(input)[1]) # dummy input in "thinking" period + end + + if tick == 1 # tell a model to start learning. 1-time only + model.learningStage = "start_learning" + + elseif tick == (row+thinkingPeriod) + model.learningStage = "end_learning" + else + end + + _firedNeurons_t1, logit, _var1, _var2, _var3, _var4 = model(current_pixel) + # log answer of all timestep + logitLog = [logitLog;; logit] + firedNeurons_t1 = push!(firedNeurons_t1, _firedNeurons_t1) + var1 = [var1;; _var1] + var2 = [var2;; _var2] + var3 = [var3;; _var3] + var4 = [var4;; _var4] + + # if tick <= row # online learning, 1-by-1 timestep + # correctAnswer = zeros(length(logit)) + # modelError = (logit - correctAnswer) * 1.0 + # Ironpen.compute_wRecChange!(model, modelError, correctAnswer) + # elseif tick == row+1 + # correctAnswer = OneHotArrays.onehot(label, labelDict) + # modelError = (logit - correctAnswer) * 1.0 + # Ironpen.compute_wRecChange!(model, modelError, correctAnswer) + # elseif tick > row+1 && tick < row+thinkingPeriod + # correctAnswer = OneHotArrays.onehot(label, labelDict) + # modelError = (logit - correctAnswer) * 1.0 + # Ironpen.compute_wRecChange!(model, modelError, correctAnswer) + # elseif tick == row+thinkingPeriod + # _predict = logitLog[:, end-thinkingPeriod+1:end] # answer count during thinking period + # _predict = Int.([sum(row) for row in eachrow(_predict)]) + # # predict = [x > 0 for x in _predict] + # correctAnswer = OneHotArrays.onehot(label, labelDict) + # modelError = (logit - correctAnswer) * 1.0 + # Ironpen.compute_wRecChange!(model, modelError, correctAnswer) + # Ironpen.learn!(model) + # println("label $label predict $(_predict) model error $(Int.(modelError))") + # else + # error("undefined condition line $(@__LINE__)") + # end + + if tick <= row # online learning, 1-by-1 timestep + # no error calculation + elseif tick > row && tick < row+thinkingPeriod + # correctAnswer = OneHotArrays.onehot(label, labelDict) + # modelError = (logit - correctAnswer) * 1.0 + # Ironpen.compute_wRecChange!(model, modelError, correctAnswer) + + elseif tick == row+thinkingPeriod + correctAnswer = OneHotArrays.onehot(label, labelDict) + modelError = Flux.logitcrossentropy(logit, correctAnswer) * 1.0 + outputError = (logit - correctAnswer) * 1.0 + Ironpen.compute_wRecChange!(model, modelError, outputError) + Ironpen.learn!(model) + _logit = round.(logit; digits=2) + predict = findall(isequal.(logit, maximum(logit)))[1] - 1 + y = round.(modelError; digits=2) + println("") + println("label $label predict $predict logit $_logit model error $y") + else + error("undefined condition line $(@__LINE__)") + end + + # update plot + plot10[] = firedNeurons_t1 + + plot20[] = view(logitLog, 1 , :) + plot21[] = view(logitLog, 2 , :) + plot22[] = view(logitLog, 3 , :) + plot23[] = view(logitLog, 4 , :) + plot24[] = view(logitLog, 5 , :) + plot25[] = view(logitLog, 6 , :) + plot26[] = view(logitLog, 7 , :) + plot27[] = view(logitLog, 8 , :) + plot28[] = view(logitLog, 9 , :) + plot29[] = view(logitLog, 10, :) + + plot30[] = view(var1, 1 , :) + plot31[] = view(var1, 2 , :) + plot32[] = view(var1, 3 , :) + plot33[] = view(var1, 4 , :) + plot34[] = view(var1, 5 , :) + plot35[] = view(var1, 6 , :) + plot36[] = view(var1, 7 , :) + plot37[] = view(var1, 8 , :) + plot38[] = view(var1, 9 , :) + plot39[] = view(var1, 10, :) + + plot40[] = view(var2, 1 , :) + plot41[] = view(var2, 2 , :) + plot42[] = view(var2, 3 , :) + plot43[] = view(var2, 4 , :) + plot44[] = view(var2, 5 , :) + plot45[] = view(var2, 6 , :) + plot46[] = view(var2, 7 , :) + plot47[] = view(var2, 8 , :) + plot48[] = view(var2, 9 , :) + plot49[] = view(var2, 10, :) + + plot50[] = view(var3, 1 , :) + plot51[] = view(var3, 2 , :) + plot52[] = view(var3, 3 , :) + plot53[] = view(var3, 4 , :) + plot54[] = view(var3, 5 , :) + plot55[] = view(var3, 6 , :) + plot56[] = view(var3, 7 , :) + plot57[] = view(var3, 8 , :) + plot58[] = view(var3, 9 , :) + plot59[] = view(var3, 10, :) + + plot60[] = view(var4, 1 , :) + plot61[] = view(var4, 2 , :) + plot62[] = view(var4, 3 , :) + plot63[] = view(var4, 4 , :) + plot64[] = view(var4, 5 , :) + plot65[] = view(var4, 6 , :) + plot66[] = view(var4, 7 , :) + plot67[] = view(var4, 8 , :) + plot68[] = view(var4, 9 , :) + plot69[] = view(var4, 10, :) + end + # end-thinkingPeriod+2; +2 because initialize logitLog = zeros(10, 2) + # _modelRespond = logitLog[:, end-thinkingPeriod+2:end] # answer count during thinking period + # _modelRespond = [sum(i) for i in eachrow(_modelRespond)] + # modelRespond = isequal.(isequal.(_modelRespond, 0), 0) + + display(fig1) + # sleep(1) + if k % 3 == 0 + firedNeurons_t1 = zeros(1) + logitLog = zeros(10, 2) + var1 = zeros(10, 2) + var2 = zeros(10, 2) + var3 = zeros(10, 2) + var4 = zeros(10, 2) + end + + # if predict == OneHotArrays.onehot(label, labelDict) + # println("model train $label successfully, $k tries") + # # wait(display(fig1)) + + # firedNeurons_t1 = zeros(1) + # logitLog = zeros(10, 2) + # var1 = zeros(10, 2) + # var2 = zeros(10, 2) + # var3 = zeros(10, 2) + # var4 = zeros(10, 2) + # break + # end + + if k == maxRepeatRound + # println("model train $label unsuccessfully, $maxRepeatRound tries, skip training") + # display(fig1) + firedNeurons_t1 = zeros(1) + logitLog = zeros(10, 2) + var1 = zeros(10, 2) + var2 = zeros(10, 2) + var3 = zeros(10, 2) + var4 = zeros(10, 2) + break + end + + GC.gc() + end + end + # check accuracy + println("validating model") + answerCorrectly = validate(model, validateData, labelDict) + println("model accuracy is $answerCorrectly %") + end + + # # check mean error and accuracy + # mean_error = round(mean(model_error_list), sigdigits = 3) + # accuracy = round(model_accuracy / batch_size * 100, sigdigits = 3) + # println("------------") + # println(model_name) + # println("mean error $mean_error accuracy $accuracy") + end +end + +function validate(model, dataset, labelDict) + answerCorrectly = 0.0 # % + thinkingPeriod = 16 # 1000-784 = 216 + @showprogress for (image, label) in dataset + img = reshape(image, (:, 1)) + row, col = size(img) + label = label[1] + + img_tensor = torch.from_numpy( np.asarray(img) ) + + # create more data for RSNN + spike = spikegen.delta(img_tensor, threshold=0.1, off_spike=true) + spike1 = isequal.(pyconvert(Array, spike.data.numpy()), 1) + spike2 = isequal.(pyconvert(Array, spike.data.numpy()), 1) + spike3 = isequal.(pyconvert(Array, spike.data.numpy()), 1) + spike4 = isequal.(pyconvert(Array, spike.data.numpy()), 1) + spike5 = isequal.(pyconvert(Array, spike.data.numpy()), 1) + spike6 = isequal.(pyconvert(Array, spike.data.numpy()), -1) + spike7 = isequal.(pyconvert(Array, spike.data.numpy()), -1) + spike8 = isequal.(pyconvert(Array, spike.data.numpy()), -1) + spike9 = isequal.(pyconvert(Array, spike.data.numpy()), -1) + spike10 = isequal.(pyconvert(Array, spike.data.numpy()), -1) + + spike = spikegen.delta(img_tensor, threshold=0.2, off_spike=true) + spike11 = isequal.(pyconvert(Array, spike.data.numpy()), 1) + spike12 = isequal.(pyconvert(Array, spike.data.numpy()), 1) + spike13 = isequal.(pyconvert(Array, spike.data.numpy()), 1) + spike14 = isequal.(pyconvert(Array, spike.data.numpy()), 1) + spike15 = isequal.(pyconvert(Array, spike.data.numpy()), 1) + spike16 = isequal.(pyconvert(Array, spike.data.numpy()), -1) + spike17 = isequal.(pyconvert(Array, spike.data.numpy()), -1) + spike18 = isequal.(pyconvert(Array, spike.data.numpy()), -1) + spike19 = isequal.(pyconvert(Array, spike.data.numpy()), -1) + spike20 = isequal.(pyconvert(Array, spike.data.numpy()), -1) + + spike = spikegen.delta(img_tensor, threshold=0.3, off_spike=true) + spike21 = isequal.(pyconvert(Array, spike.data.numpy()), 1) + spike22 = isequal.(pyconvert(Array, spike.data.numpy()), 1) + spike23 = isequal.(pyconvert(Array, spike.data.numpy()), 1) + spike24 = isequal.(pyconvert(Array, spike.data.numpy()), 1) + spike25 = isequal.(pyconvert(Array, spike.data.numpy()), 1) + spike26 = isequal.(pyconvert(Array, spike.data.numpy()), -1) + spike27 = isequal.(pyconvert(Array, spike.data.numpy()), -1) + spike28 = isequal.(pyconvert(Array, spike.data.numpy()), -1) + spike29 = isequal.(pyconvert(Array, spike.data.numpy()), -1) + spike30 = isequal.(pyconvert(Array, spike.data.numpy()), -1) + + spike = spikegen.delta(img_tensor, threshold=0.4, off_spike=true) + spike31 = isequal.(pyconvert(Array, spike.data.numpy()), 1) + spike32 = isequal.(pyconvert(Array, spike.data.numpy()), 1) + spike33 = isequal.(pyconvert(Array, spike.data.numpy()), 1) + spike34 = isequal.(pyconvert(Array, spike.data.numpy()), 1) + spike35 = isequal.(pyconvert(Array, spike.data.numpy()), 1) + spike36 = isequal.(pyconvert(Array, spike.data.numpy()), -1) + spike37 = isequal.(pyconvert(Array, spike.data.numpy()), -1) + spike38 = isequal.(pyconvert(Array, spike.data.numpy()), -1) + spike39 = isequal.(pyconvert(Array, spike.data.numpy()), -1) + spike40 = isequal.(pyconvert(Array, spike.data.numpy()), -1) + + spike = spikegen.delta(img_tensor, threshold=0.5, off_spike=true) + spike41 = isequal.(pyconvert(Array, spike.data.numpy()), 1) + spike42 = isequal.(pyconvert(Array, spike.data.numpy()), 1) + spike43 = isequal.(pyconvert(Array, spike.data.numpy()), 1) + spike44 = isequal.(pyconvert(Array, spike.data.numpy()), 1) + spike45 = isequal.(pyconvert(Array, spike.data.numpy()), 1) + spike46 = isequal.(pyconvert(Array, spike.data.numpy()), -1) + spike47 = isequal.(pyconvert(Array, spike.data.numpy()), -1) + spike48 = isequal.(pyconvert(Array, spike.data.numpy()), -1) + spike49 = isequal.(pyconvert(Array, spike.data.numpy()), -1) + spike50 = isequal.(pyconvert(Array, spike.data.numpy()), -1) + + input = [spike1;; spike2;; spike3;; spike4;; spike5;; spike6;; spike7;; spike8;; spike9;; spike10;; + spike11;; spike12;; spike13;; spike14;; spike15;; spike16;; spike17;; spike18;; spike19;; spike20;; + spike21;; spike22;; spike23;; spike24;; spike25;; spike26;; spike27;; spike28;; spike29;; spike30;; + spike31;; spike32;; spike33;; spike34;; spike35;; spike36;; spike37;; spike38;; spike39;; spike40;; + spike41;; spike42;; spike43;; spike44;; spike45;; spike46;; spike47;; spike48;; spike49;; spike50 + ]' # ' to flip 784x10 to 10x784 + + # insert data into model sequencially + logit = Float64[] + for i in 1:(row + thinkingPeriod) # sMNIST ihas 784 timestep(pixel) + thinking period = 1000 timestep + if i <= row + current_pixel = input[:, i] + else + current_pixel = zeros(size(input)[1]) # dummy input in "thinking" period + end + + _firedNeurons_t1, logit, _var1, _var2, _var3, _var4 = model(current_pixel) + end + + predict = findall(isequal.(logit, maximum(logit)))[1] - 1 + if predict == label + answerCorrectly += 1 + # println("model answer $label correctly") + else + # println("img $label, model answer $predict") + end + GC.gc() + end + + correctPercent = answerCorrectly * 100.0 / length(dataset) + + return correctPercent::Float64 +end + + +function main() + training_start_time = Dates.now() + println("program started ", training_start_time) + + filelocation = string(@__DIR__) + + # generate SNN + for i = 1:1 + modelname = "v06_36" + filename = "$modelname.jl163" + generate_snn(filename, filelocation) + end + + modelname = "v06_36" + filename = "$modelname.jl163" + # filename = "v06_31c.jl163" + + trainDataset, validateDataset, labelDict = data_loader() + + train_snn(modelname, filename, filelocation, trainDataset, validateDataset, labelDict) + + finish_training_time = Dates.now() + println("training done, $training_start_time ==> $finish_training_time ") + println(" ///////////////////////////////////////////////////////////////////////") +end + +# only runs main() if julia isn’t started interactively +# https://discourse.julialang.org/t/scripting-like-a-julian/50707 +!isinteractive() && main() +#------------------------------------------------------------------------------------------------100 + + + + + + diff --git a/src/Ironpen.jl b/src/Ironpen.jl index f386398..ae9b815 100644 --- a/src/Ironpen.jl +++ b/src/Ironpen.jl @@ -34,7 +34,6 @@ using .learn """ version 0.0.5 Todo: - - [4] implement dormant connection [] using RL to control learning signal [] consider using Dates.now() instead of timestamp because time_stamp may overflow @@ -43,34 +42,60 @@ using .learn Change from version: 0.0.4 - - compute error in main loop so one could decide how to calculate error - - compute model error in main loop so one could decide when to calculate error + - compute model error in main loop so one could decide when to calculate error in + training sequence and how to calculate + - fix ALIF adaptation formula, now n.a compute avery time step + - add higher input signal to noise ratio + - no noise generate + - increase input signal by adding more input neuron population + - add /100 every wRec and b + - add integrateNeuron + - ΔwRecChange during input signal ingestion will be merged at the end of learning + - use Flux.logitcrossentropy for overall error + - move timestep_forward!() in kfn's forward to the beginning so that v_t and z_t is reset + - fix n.a formula in forward() and calculate both non-firing and firing state + - RSNN use overall modelError to update while integrate neuron use error with respect to + itself (yk - yk*) to update + - all RSNN neuron connect to integrateNeuron + - integrateNeuron does NOT repect RSNN excitatory and inhabitory sign + - weaker connection should be harder to increase strength. It requires a lot of + repeat activation to get it stronger. While strong connction requires a lot of + inactivation to get it weaker. The concept is strong connection will lock + correct neural pathway through repeated use of the right connection i.e. keep training + on the correct answer -> strengthen the right neural pathway (connections) -> + this correct neural pathway resist to change. + Not used connection should dissapear (forgetting). All features - - ΔwRecChange during input signal ingestion will be merged at the end of learning - - all RSNN and output neuron learning associate. - synapticStrength apply at the end of learning - collect ΔwRecChange during online learning (0-784th) and merge with wRec at - the end learning (1000th). - - compute model error at the end learning. Model error times with 5 constant for - higher learning impact than the error during online + the end learning (800th). - multidispatch + for loop as main compute method - - hard connection constrain yes - - normalize output yes - allow -w_rec yes - - voltage drop when neuron fires voltage drop equals to vth - - v_t decay during refractory - duration exponantial decay + - voltage drop when neuron fires voltage drop equals to vRest + - v_t decay during refractory - input data population encoding, each pixel data => population encoding, ralative between pixel data - compute neuron weight init rand() - output neuron weight init randn() - - each knowledgeFn should have its own noise generater - - where to put pseudo derivative (n.phi) + + - compute pseudo derivative (n.phi) every time step - add excitatory, inhabitory to neuron - implement "start learning", reset learning and "learning", "end_learning and "inference" + - synaptic connection strength concept. use sigmoid, turn connection offline + - neuroplasticity() i.e. change connection + - add multi threads + + + Removed features + - normalize output yes + + - compute model error at the end learning. Model error times with 5 constant for + higher learning impact than the error during online + - output neuron connect to random multiple compute neurons and overall have the same structure as lif - time-based learning method based on new error formula @@ -79,28 +104,23 @@ using .learn (vth - vt)*100/vth as error if output neuron activates when it should NOT, use output neuron's (vt*100)/vth as error + - use LinearAlgebra.normalize!(vector, 1) to adjust weight after weight merge + - reset_epsilonRec after ΔwRecChange is calculated - - synaptic connection strength concept. use sigmoid, turn connection offline - - wRec should not normalized whole. it should be local 5 conn normalized. - - neuroplasticity() i.e. change connection - - add multi threads - - add maximum weight cap of each connection - + + - add maximum weight cap of each connection + + - wRec should not normalized whole. it should be local 5 conn normalized. + - Removed features + + Ideas to try + - Δweight * connection strength + - reset_epsilonRec after ΔwRecChange is calculated - ΔwRecChange that apply immediately during online learning - - error by percent of vth-v_t1 - - Δweight * connection strength - - weaker connection should be harder to increase strength. It requires a lot of - repeat activation to get it stronger. While strong connction requires a lot of - inactivation to get it weaker. The concept is strong connection will lock - correct neural pathway through repeated use of the right connection i.e. keep training - on the correct answer -> strengthen the right neural pathway (connections) -> - this correct neural pathway resist to change. - Not used connection should dissapear (forgetting). - - during 0 training if 1-9 output neuron fires, adjust weight only those neurons - - use Flux.logitcrossentropy for overall error """ diff --git a/src/forward.jl b/src/forward.jl index 05dab9c..b7bf914 100644 --- a/src/forward.jl +++ b/src/forward.jl @@ -1,6 +1,6 @@ module forward -using Statistics, Random, LinearAlgebra, JSON3 +using Statistics, Random, LinearAlgebra, JSON3, Flux using GeneralUtils using ..types, ..snn_utils @@ -26,7 +26,13 @@ end function (kfn::kfn_1)(m::model, input_data::AbstractVector) kfn.timeStep = m.timeStep - + for n in kfn.neuronsArray + timestep_forward!(n) + end + for n in kfn.outputNeuronsArray + timestep_forward!(n) + end + kfn.learningStage = m.learningStage if kfn.learningStage == "start_learning" @@ -54,19 +60,12 @@ function (kfn::kfn_1)(m::model, input_data::AbstractVector) end # generate noise - noise = [GeneralUtils.randomChoiceWithProb([true, false],[0.5,0.5]) + noise = [GeneralUtils.randomChoiceWithProb([true, false],[0.0, 1.0]) for i in 1:length(input_data)] # noise = [rand(rng, Distributions.Binomial(1, 0.5)) for i in 1:10] # another option input_data = [noise; input_data] # noise must start from neuron id 1 - for n in kfn.neuronsArray - timestep_forward!(n) - end - for n in kfn.outputNeuronsArray - timestep_forward!(n) - end - # pass input_data into input neuron. # number of data point equals to number of input neuron starting from id 1 for (i, data) in enumerate(input_data) @@ -75,8 +74,8 @@ function (kfn::kfn_1)(m::model, input_data::AbstractVector) kfn.firedNeurons_t0 = [n.z_t for n in kfn.neuronsArray] - # Threads.@threads for n in kfn.neuronsArray - for n in kfn.neuronsArray + Threads.@threads for n in kfn.neuronsArray + # for n in kfn.neuronsArray n(kfn) end @@ -84,19 +83,22 @@ function (kfn::kfn_1)(m::model, input_data::AbstractVector) append!(kfn.firedNeurons, findall(kfn.firedNeurons_t1)) # store id of neuron that fires kfn.firedNeurons |> unique! # use for random new neuron connection - # Threads.@threads for n in kfn.outputNeuronsArray - for n in kfn.outputNeuronsArray + Threads.@threads for n in kfn.outputNeuronsArray + # for n in kfn.outputNeuronsArray n(kfn) end - out = [n.z_t1 for n in kfn.outputNeuronsArray] + logit = [n.v_t1 for n in kfn.outputNeuronsArray] - return out::Array{Bool}, - sum(kfn.firedNeurons_t1), + # _predict = Flux.softmax(logit) + # predict = findall(isequal.(_predict, maximum(_predict)))[1] + + return sum(kfn.firedNeurons_t1[kfn.kfnParams[:totalInputPort]+1:end])::Int, + logit::Array{Float64}, [n.v_t1 for n in kfn.outputNeuronsArray], [sum(i.wRec) for i in kfn.outputNeuronsArray], [sum(i.epsilonRec) for i in kfn.outputNeuronsArray], - [i.phi for i in kfn.outputNeuronsArray] + [sum(i.wRecChange) for i in kfn.outputNeuronsArray] end #------------------------------------------------------------------------------------------------100 @@ -128,11 +130,15 @@ function (n::lifNeuron)(kfn::knowledgeFn) # decay of v_t1 n.v_t1 = n.alpha * n.v_t + + n.phi = 0.0 + n.decayedEpsilonRec = n.alpha * n.epsilonRec + n.epsilonRec = n.decayedEpsilonRec else n.recSignal = sum(n.wRec .* n.z_i_t) # signal from other neuron that this neuron subscribed n.alpha_v_t = n.alpha * n.v_t n.v_t1 = n.alpha_v_t + n.recSignal - n.v_t1 = no_negative!(n.v_t1) + # n.v_t1 = no_negative!(n.v_t1) if n.v_t1 > n.v_th n.z_t1 = true @@ -147,7 +153,7 @@ function (n::lifNeuron)(kfn::knowledgeFn) n.phi = (n.gammaPd / n.v_th) * max(0, 1 - (n.v_t1 - n.v_th) / n.v_th) n.decayedEpsilonRec = n.alpha * n.epsilonRec n.epsilonRec = n.decayedEpsilonRec + n.z_i_t - end + end end #------------------------------------------------------------------------------------------------100 @@ -165,26 +171,30 @@ function (n::alifNeuron)(kfn::knowledgeFn) # neuron is in refractory state, skip all calculation n.z_t1 = false # used by timestep_forward() in kfn. Set to zero because neuron spike last only 1 timestep follow by a period of refractory. - n.a = (n.rho * n.a) + ((1 - n.rho) * n.z_t) + n.a = (n.rho * n.a) n.recSignal = n.recSignal * 0.0 # decay of v_t1 n.v_t1 = n.alpha * n.v_t - n.phi = 0 + + n.phi = 0.0 + n.decayedEpsilonRec = n.alpha * n.epsilonRec + n.epsilonRec = n.decayedEpsilonRec else - n.a = (n.rho * n.a) + ((1 - n.rho) * n.z_t) n.av_th = n.v_th + (n.beta * n.a) n.recSignal = sum(n.wRec .* n.z_i_t) # signal from other neuron that this neuron subscribed n.alpha_v_t = n.alpha * n.v_t n.v_t1 = n.alpha_v_t + n.recSignal - n.v_t1 = no_negative!(n.v_t1) + # n.v_t1 = no_negative!(n.v_t1) if n.v_t1 > n.av_th n.z_t1 = true n.refractoryCounter = n.refractoryDuration n.firingCounter += 1 n.v_t1 = n.vRest + n.a = (n.rho * n.a) + 1.0 else n.z_t1 = false + n.a = (n.rho * n.a) end # there is a difference from lif formula @@ -219,12 +229,16 @@ function (n::linearNeuron)(kfn::T) where T<:knowledgeFn # decay of v_t1 n.v_t1 = n.alpha * n.v_t n.vError = n.v_t1 # store voltage that will be used to calculate error later + + n.phi = 0.0 + n.decayedEpsilonRec = n.alpha * n.epsilonRec + n.epsilonRec = n.decayedEpsilonRec else recSignal = n.wRec .* n.z_i_t n.recSignal = sum(recSignal) # signal from other neuron that this neuron subscribed n.alpha_v_t = n.alpha * n.v_t n.v_t1 = n.alpha_v_t + n.recSignal - n.v_t1 = no_negative!(n.v_t1) + # n.v_t1 = no_negative!(n.v_t1) n.vError = n.v_t1 # store voltage that will be used to calculate error later if n.v_t1 > n.v_th n.z_t1 = true @@ -242,6 +256,30 @@ function (n::linearNeuron)(kfn::T) where T<:knowledgeFn end end +#------------------------------------------------------------------------------------------------100 + +""" integrateNeuron forward() +""" +function (n::integrateNeuron)(kfn::knowledgeFn) + n.timeStep = kfn.timeStep + + # pulling other neuron's firing status at time t + n.z_i_t = getindex(kfn.firedNeurons_t0, n.subscriptionList) + n.z_i_t_commulative += n.z_i_t + + n.recSignal = sum(n.wRec .* n.z_i_t) # signal from other neuron that this neuron subscribed + n.alpha_v_t = n.alpha * n.v_t + if n.recSignal <= 0 + n.v_t1 = n.alpha_v_t + else + n.v_t1 = n.alpha_v_t + n.recSignal + n.b + end + + # there is a difference from alif formula + n.decayedEpsilonRec = n.alpha * n.epsilonRec + n.epsilonRec = n.decayedEpsilonRec + n.z_i_t +end + diff --git a/src/learn.jl b/src/learn.jl index 4c9ed5c..022415f 100644 --- a/src/learn.jl +++ b/src/learn.jl @@ -8,15 +8,9 @@ export learn!, compute_wRecChange!, computeModelError #------------------------------------------------------------------------------------------------100 - -function computeModelError(modelRespond, correctAnswer; magnitude::Float64=1.0) - error = ((correctAnswer .- modelRespond) .* magnitude) - - return error::Vector{Float64} -end - -function compute_wRecChange!(m::model, error::Vector{Float64}, correctAnswer::AbstractVector) - compute_wRecChange!(m.knowledgeFn[:I], error, correctAnswer) +function compute_wRecChange!(m::model, modelError::Float64, outputError::Vector{Float64}) + # normalize!(modelError) + compute_wRecChange!(m.knowledgeFn[:I], modelError, outputError) end # function compute_wRecChange!(kfn::kfn_1, errors::Vector{Float64}, correctAnswer::AbstractVector) @@ -47,54 +41,83 @@ end # end -function compute_wRecChange!(kfn::kfn_1, errors::Vector{Float64}, correctAnswer::AbstractVector) - for (i, error) in enumerate(errors) - if error < 0 # model fires too fast - error = error * - abs(kfn.outputNeuronsArray[i].v_th - kfn.outputNeuronsArray[i].vError) - elseif error == 0 # model answer correctly. maintain membrain potential ≈ 0.5 - error = error * - abs(kfn.outputNeuronsArray[i].v_th/2 - kfn.outputNeuronsArray[i].vError) - else # model fires too slow - error = error * - abs(kfn.outputNeuronsArray[i].v_th - kfn.outputNeuronsArray[i].vError) +function compute_wRecChange!(kfn::kfn_1, modelError::Float64, outputError::Vector{Float64}) + Threads.@threads for n in kfn.neuronsArray + # for n in kfn.neuronsArray + if typeof(n)<: computeNeuron + # wIndex = findall(isequal.(oN.subscriptionList, n.id)) + wOut = abs.([oN.wRec[findall(isequal.(oN.subscriptionList, n.id))[1]] + for oN in kfn.outputNeuronsArray]) + compute_wRecChange!(n, wOut, modelError) end - Threads.@threads for n in kfn.neuronsArray # multithread is not atomic and causing error - # for n in kfn.neuronsArray - compute_wRecChange!(n, error) - end - compute_wRecChange!(kfn.outputNeuronsArray[i], error) + end + + for oN in kfn.outputNeuronsArray + compute_wRecChange!(oN, outputError[oN.id]) end end -function compute_wRecChange!(n::passthroughNeuron, error::Float64) +function compute_wRecChange!(n::passthroughNeuron, wOut::AbstractVector, modelError::Vector{Float64}) # skip end -function compute_wRecChange!(n::lifNeuron, error::Float64) +function compute_wRecChange!(n::lifNeuron, wOut::AbstractVector, modelError::Float64) + # how much error of this neuron 1-spike causing each output neuron's error + nError = sum(wOut * modelError) + n.eRec = n.phi * n.epsilonRec - ΔwRecChange = n.eta * error * n.eRec + ΔwRecChange = -n.eta * nError * n.eRec + # if sum(n.wRec) < 0 # prevent -sum(wRec) that causing neuron NOT fire at all + # ΔwRecChange .+= (0.2*(abs(sum(n.wRec)) / length(n.wRec))) + # end n.wRecChange .+= ΔwRecChange - reset_epsilonRec!(n) + # reset_epsilonRec!(n) end -function compute_wRecChange!(n::alifNeuron, error::Float64) +function compute_wRecChange!(n::alifNeuron, wOut::AbstractVector, modelError::Float64) + # how much error of this neuron 1-spike causing each output neuron's error + # (prejected throug wOut) + nError = sum(wOut * modelError) + n.eRec_v = n.phi * n.epsilonRec n.eRec_a = n.phi * n.beta * n.epsilonRecA n.eRec = n.eRec_v + n.eRec_a - ΔwRecChange = n.eta * error * n.eRec + ΔwRecChange = -n.eta * nError * n.eRec + # if sum(n.wRec) < 0 # prevent -sum(wRec) that causing neuron NOT fire at all + # ΔwRecChange .+= (0.2*(abs(sum(n.wRec)) / length(n.wRec))) + # end n.wRecChange .+= ΔwRecChange - reset_epsilonRec!(n) - reset_epsilonRecA!(n) + + # reset_epsilonRec!(n) + # reset_epsilonRecA!(n) + # n.alphaChange += compute_alphaChange(n.eta, nError) end -function compute_wRecChange!(n::linearNeuron, error::Float64) - n.eRec = n.phi * n.epsilonRec - ΔwRecChange = n.eta * error * n.eRec +function compute_wRecChange!(n::integrateNeuron, error::Float64) + ΔwRecChange = -n.eta * error * n.epsilonRec + ΔbChange = -n.eta * error + # if sum(n.wRec) < 0 # prevent -sum(wRec) that causing neuron NOT fire at all + # ΔwRecChange .+= (abs(sum(n.wRec)) / length(n.wRec)) + # end n.wRecChange .+= ΔwRecChange - reset_epsilonRec!(n) + n.bChange += ΔbChange + # reset_epsilonRec!(n) + # n.alphaChange += compute_alphaChange(n.eta, error) end +# function compute_wRecChange!(n::linearNeuron, error::Float64) +# n.eRec = n.phi * n.epsilonRec +# ΔwRecChange = -n.eta * error * n.eRec +# # if sum(n.wRec) < 0 # prevent -sum(wRec) that causing neuron NOT fire at all +# # ΔwRecChange .+= (abs(sum(n.wRec)) / length(n.wRec)) +# # end +# n.wRecChange .+= ΔwRecChange +# # reset_epsilonRec!(n) +# end + +# add compute_alphaChange +compute_alphaChange(learningRate::Float64, total_wRecChange) = -learningRate * total_wRecChange + function learn!(m::model) learn!(m.knowledgeFn[:I]) end @@ -105,8 +128,8 @@ function learn!(kfn::kfn_1) # compute kfn error for each neuron Threads.@threads for n in kfn.neuronsArray # multithread is not atomic and causing error # for n in kfn.neuronsArray - learn!(n, kfn.firedNeurons, kfn.nExInType) - end + learn!(n, kfn.firedNeurons, kfn.nExInType) + end for n in kfn.outputNeuronsArray learn!(n, kfn.firedNeurons, kfn.nExInType, kfn.kfnParams[:totalInputPort]) end @@ -124,35 +147,70 @@ end function learn!(n::T, firedNeurons, nExInType) where T<:computeNeuron wSign_0 = sign.(n.wRec) # original sign # n.wRecChange .*= (connStrengthAdjust.(n.synapticStrength)) - n.wRec += n.wRecChange # merge wRecChange into wRec - reset_wRecChange!(n) + + wRecChange_reduceCoeff = 1.0 + # wRecChange_max = 0.2 * abs(sum(n.wRec)) # max change 20% + # y = abs(sum(n.wRecChange)) + # if y > wRecChange_max # capping weight update + # wRecChange_reduceCoeff = wRecChange_max / y + # end + n.wRec += (wRecChange_reduceCoeff * n.wRecChange) + # n.alpha += n.alphaChange + wSign_1 = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign nonFlipedSign = isequal.(wSign_0, wSign_1) # 1 not fliped, 0 fliped # normalize wRec peak to prevent input signal overwhelming neuron - normalizePeak!(n.wRec, n.wRecChange, 2) + # if sum(n.wRecChange) != 0 + # normalizePeak!(n.wRec, n.wRecChange, 2) + # end # set weight that fliped sign to 0 for random new connection n.wRec .*= nonFlipedSign - capMaxWeight!(n.wRec) # cap maximum weight + # capMaxWeight!(n.wRec) # cap maximum weight synapticConnStrength!(n, "updown") neuroplasticity!(n, firedNeurons, nExInType) end -function learn!(n::T, firedNeurons, nExInType, totalInputPort) where T<:outputNeuron - wSign_0 = sign.(n.wRec) # original sign - # n.wRecChange .*= (connStrengthAdjust.(n.synapticStrength)) - n.wRec += n.wRecChange - reset_wRecChange!(n) - wSign_1 = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign - nonFlipedSign = isequal.(wSign_0, wSign_1) # 1 not fliped, 0 fliped - # normalize wRec peak to prevent input signal overwhelming neuron - normalizePeak!(n.wRec, n.wRecChange, 2) - # set weight that fliped sign to 0 for random new connection - n.wRec .*= nonFlipedSign - capMaxWeight!(n.wRec) # cap maximum weight - synapticConnStrength!(n, "updown") - neuroplasticity!(n,firedNeurons, nExInType, totalInputPort) +function learn!(n::integrateNeuron, firedNeurons, nExInType, totalInputPort) + wRecChange_reduceCoeff = 1.0 + # wRecChange_max = 0.2 * abs(sum(n.wRec)) # max change 20% + # y = abs(sum(n.wRecChange)) + # if y > wRecChange_max # capping weight update + # wRecChange_reduceCoeff = wRecChange_max / y + # end + n.wRec += (wRecChange_reduceCoeff * n.wRecChange) + n.b += (wRecChange_reduceCoeff * n.bChange) + # n.alpha += n.alphaChange end +# function learn!(n::linearNeuron, firedNeurons, nExInType, totalInputPort) +# wSign_0 = sign.(n.wRec) # original sign +# # n.wRecChange .*= (connStrengthAdjust.(n.synapticStrength)) +# wRecChange_max = 0.1 * abs(sum(n.wRec)) # max change 20% +# y = abs(sum(n.wRecChange)) +# wRecChange_reduceCoeff = 1.0 +# # if y > wRecChange_max # capping weight update +# # wRecChange_reduceCoeff = wRecChange_max / y +# # end +# n.wRec += (wRecChange_reduceCoeff * n.wRecChange) +# n.alpha += n.alphaChange + +# wSign_1 = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign +# nonFlipedSign = isequal.(wSign_0, wSign_1) # 1 not fliped, 0 fliped +# # normalize wRec peak to prevent input signal overwhelming neuron +# if sum(n.wRecChange) != 0 +# # normalizePeak!(n.wRec, n.wRecChange, 2) +# end +# # set weight that fliped sign to 0 for random new connection +# # n.wRec .*= nonFlipedSign +# # capMaxWeight!(n.wRec) # cap maximum weight +# # synapticConnStrength!(n, "updown") +# # neuroplasticity!(n,firedNeurons, nExInType, totalInputPort) +# end + + + + + diff --git a/src/snn_utils.jl b/src/snn_utils.jl index 5c5aae1..24c1377 100644 --- a/src/snn_utils.jl +++ b/src/snn_utils.jl @@ -13,6 +13,7 @@ using GeneralUtils using ..types #------------------------------------------------------------------------------------------------100 +rng = MersenneTwister(1234) function timestep_forward!(x::passthroughNeuron) x.z_t = x.z_t1 @@ -30,6 +31,7 @@ precision(x::Array{<:Array}) = ( std(mean.(x)) / mean(mean.(x)) ) * 100 reset_last_firing_time!(n::computeNeuron) = n.lastFiringTime = 0.0 reset_refractory_state_active!(n::computeNeuron) = n.refractory_state_active = false reset_v_t!(n::neuron) = n.v_t = n.vRest +reset_v_t1!(n::neuron) = n.v_t1 = n.vRest reset_z_t!(n::computeNeuron) = n.z_t = false reset_epsilonRec!(n::computeNeuron) = n.epsilonRec = n.epsilonRec * 0.0 reset_epsilonRec!(n::outputNeuron) = n.epsilonRec = n.epsilonRec * 0.0 @@ -46,66 +48,23 @@ reset_firing_counter!(n::Union{computeNeuron, outputNeuron}) = n.firingCounter = reset_firing_diff!(n::Union{computeNeuron, outputNeuron}) = n.firingDiff = n.firingDiff * 0.0 reset_refractoryCounter!(n::Union{computeNeuron, outputNeuron}) = n.refractoryCounter = n.refractoryCounter * 0.0 reset_z_i_t_commulative!(n::Union{computeNeuron, outputNeuron}) = n.z_i_t_commulative = n.z_i_t_commulative * 0.0 +reset_alphaChange!(n::Union{computeNeuron, outputNeuron}) = n.alphaChange = n.alphaChange * 0.0 # reset function for output neuron reset_epsilon_j!(n::linearNeuron) = n.epsilon_j = n.epsilon_j * 0.0 reset_out_t!(n::linearNeuron) = n.out_t = n.out_t * 0.0 -reset_w_out_change!(n::linearNeuron) = n.w_out_change = n.w_out_change * 0.0 -reset_b_change!(n::linearNeuron) = n.b_change = n.b_change * 0.0 - - -""" Reset a part of learning-related params that used to collect learning history during learning - session -""" -# function reset_learning_no_wchange!(n::lifNeuron) -# reset_epsilonRec!(n) -# # reset_v_t!(n) -# # reset_z_t!(n) -# # reset_reg_voltage_a!(n) -# # reset_reg_voltage_b!(n) -# # reset_reg_voltage_error!(n) -# reset_firing_counter!(n) -# reset_firing_diff!(n) -# reset_previous_error!(n) -# reset_error!(n) - -# # # reset refractory state at the end of episode. Otherwise once neuron goes into refractory state, -# # # it will stay in refractory state forever -# # reset_refractory_state_active!(n) -# end -# function reset_learning_no_wchange!(n::Union{alifNeuron, elif_neuron}) -# reset_epsilonRec!(n) -# reset_epsilonRecA!(n) -# reset_v_t!(n) -# reset_z_t!(n) -# # reset_a!(n) -# reset_reg_voltage_a!(n) -# reset_reg_voltage_b!(n) -# reset_reg_voltage_error!(n) -# reset_firing_counter!(n) -# reset_firing_diff!(n) -# reset_previous_error!(n) -# reset_error!(n) - -# # reset refractory state at the end of episode. Otherwise once neuron goes into refractory state, -# # it will stay in refractory state forever -# reset_refractory_state_active!(n) -# end -# function reset_learning_no_wchange!(n::linearNeuron) -# reset_epsilon_j!(n) -# reset_out_t!(n) -# reset_error!(n) -# end +reset_bChange!(n::integrateNeuron) = n.bChange = n.bChange * 0.0 """ Reset all learning-related params at the END of learning session """ function resetLearningParams!(n::lifNeuron) reset_epsilonRec!(n) reset_wRecChange!(n) - # reset_v_t!(n) - # reset_z_t!(n) + reset_v_t!(n) + reset_z_t!(n) reset_firing_counter!(n) reset_firing_diff!(n) + reset_alphaChange!(n) # reset refractory state at the start/end of episode. Otherwise once neuron goes into # refractory state, it will stay in refractory state forever @@ -116,11 +75,12 @@ function resetLearningParams!(n::alifNeuron) reset_epsilonRec!(n) reset_epsilonRecA!(n) reset_wRecChange!(n) - # reset_v_t!(n) - # reset_z_t!(n) - # reset_a!(n) + reset_v_t!(n) + reset_z_t!(n) + reset_a!(n) reset_firing_counter!(n) reset_firing_diff!(n) + reset_alphaChange!(n) # reset refractory state at the start/end of episode. Otherwise once neuron goes into # refractory state, it will stay in refractory state forever @@ -135,21 +95,31 @@ function resetLearningParams!(n::passthroughNeuron) # skip end -function resetLearningParams!(n::linearNeuron) +# function resetLearningParams!(n::linearNeuron) +# reset_epsilonRec!(n) +# reset_wRecChange!(n) +# # reset_v_t!(n) +# reset_firing_counter!(n) + +# # reset refractory state at the start/end of episode. Otherwise once neuron goes into +# # refractory state, it will stay in refractory state forever +# # reset_refractoryCounter!(n) +# reset_z_i_t_commulative!(n) +# end + +function resetLearningParams!(n::integrateNeuron) reset_epsilonRec!(n) reset_wRecChange!(n) - # reset_v_t!(n) + reset_bChange!(n) + reset_v_t!(n) reset_firing_counter!(n) - - # reset refractory state at the start/end of episode. Otherwise once neuron goes into - # refractory state, it will stay in refractory state forever - # reset_refractoryCounter!(n) - reset_z_i_t_commulative!(n) + reset_alphaChange!(n) end + #------------------------------------------------------------------------------------------------100 function store_knowledgefn_error!(kfn::knowledgeFn) - # condition to adjust nueron in KFN plane in addition to weight adjustment inside each neuron + # condition to adjust neuron in KFN plane in addition to weight adjustment inside each neuron if kfn.learningStage == "start_learning" if kfn.recent_knowledgeFn_error === nothing && kfn.knowledgeFn_error === nothing kfn.recent_knowledgeFn_error = [[]] @@ -279,13 +249,13 @@ function synapticConnStrength(currentStrength::Float64, updown::String) if updown == "up" if currentStrength > 4 # strong connection - updatedStrength = currentStrength + (Δstrength * 1.0) + updatedStrength = currentStrength + (Δstrength * 0.2) else updatedStrength = currentStrength + (Δstrength * 0.1) end elseif updown == "down" if currentStrength > 4 - updatedStrength = currentStrength - (Δstrength * 0.5) + updatedStrength = currentStrength - (Δstrength * 0.1) else updatedStrength = currentStrength - (Δstrength * 0.2) end @@ -358,81 +328,80 @@ function neuroplasticity!(n::computeNeuron, firedNeurons::Vector, nExInTypeList::Vector) # if there is 0-weight then replace it with new connection zeroWeightConnIndex = findall(iszero.(n.wRec)) # connection that has 0 weight + if length(zeroWeightConnIndex) != 0 + # new synaptic connection must sample fron neuron that fires + nFiredPool = filter(x -> x ∉ [n.id], firedNeurons) # exclude this neuron id from the id list + filter!(x -> x ∉ n.subscriptionList, nFiredPool) # exclude this neuron's subscriptionList from the list - # new synaptic connection must sample fron neuron that fires - nFiredPool = filter(x -> x ∉ [n.id], firedNeurons) # exclude this neuron id from the id list - filter!(x -> x ∉ n.subscriptionList, nFiredPool) # exclude this neuron's subscriptionList from the list + nNonFiredPool = setdiff!([1:length(nExInTypeList)...], nFiredPool) - nNonFiredPool = setdiff!([1:length(nExInTypeList)...], nFiredPool) - filter!(x -> x ∉ [n.id], nNonFiredPool) # exclude this neuron id from the id list - filter!(x -> x ∉ n.subscriptionList, nNonFiredPool) # exclude this neuron's subscriptionList from the list + filter!(x -> x ∉ [n.id], nNonFiredPool) # exclude this neuron id from the id list + filter!(x -> x ∉ n.subscriptionList, nNonFiredPool) # exclude this neuron's subscriptionList from the list - w = rand(0.01:0.01:0.2, length(zeroWeightConnIndex)) - synapticStrength = rand(-5:0.01:-4, length(zeroWeightConnIndex)) + w = randn(length(zeroWeightConnIndex)) / 100 + synapticStrength = rand(-4.5:0.1:-3.5, length(zeroWeightConnIndex)) - shuffle!(nFiredPool) - shuffle!(nNonFiredPool) + shuffle!(nFiredPool) + shuffle!(nNonFiredPool) - # add new synaptic connection to neuron - for (i, connIndex) in enumerate(zeroWeightConnIndex) - if length(nFiredPool) != 0 - newConn = popfirst!(nFiredPool) - else - newConn = popfirst!(nNonFiredPool) - end - - """ conn that is being replaced has to go into nNonFiredPool so nNonFiredPool isn't empty - """ - push!(nNonFiredPool, n.subscriptionList[connIndex]) - n.subscriptionList[connIndex] = newConn - n.wRec[connIndex] = w[i] * nExInTypeList[newConn] - n.synapticStrength[connIndex] = synapticStrength[i] - end -end - -function neuroplasticity!(n::outputNeuron, firedNeurons::Vector, - nExInTypeList::Vector, totalInputNeuron::Integer) - # if there is 0-weight then replace it with new connection - zeroWeightConnIndex = findall(iszero.(n.wRec)) # connection that has 0 weight - - # new synaptic connection must sample fron neuron that fires - nFiredPool = filter(x -> x ∉ [n.id], firedNeurons) # exclude this neuron id from the id list - filter!(x -> x ∉ n.subscriptionList, nFiredPool) # exclude this neuron's subscriptionList from the list - filter!(x -> x ∉ [1:totalInputNeuron...], nFiredPool) # exclude input neuron - - nNonFiredPool = setdiff!([1:length(nExInTypeList)...], nFiredPool) - filter!(x -> x ∉ [n.id], nNonFiredPool) # exclude this neuron id from the id list - filter!(x -> x ∉ n.subscriptionList, nNonFiredPool) # exclude this neuron's subscriptionList from the list - filter!(x -> x ∉ [1:totalInputNeuron...], nNonFiredPool) # exclude input neuron - - w = rand(0.01:0.01:0.2, length(zeroWeightConnIndex)) - synapticStrength = rand(-5:0.01:-4, length(zeroWeightConnIndex)) - - shuffle!(nFiredPool) - shuffle!(nNonFiredPool) - - # add new synaptic connection to neuron - for (i, connIndex) in enumerate(zeroWeightConnIndex) - newConn::Int64 = 0 - if length(nFiredPool) != 0 - newConn = popfirst!(nFiredPool) - elseif length(nNonFiredPool) != 0 - newConn = popfirst!(nNonFiredPool) - else - # skip - end - - if newConn != 0 - """ conn that is being replaced has to go into nNonFiredPool so nNonFiredPool isn't empty - """ + # add new synaptic connection to neuron + for (i, connIndex) in enumerate(zeroWeightConnIndex) + """ conn that is being replaced has to go into nNonFiredPool so + nNonFiredPool isn't empty """ push!(nNonFiredPool, n.subscriptionList[connIndex]) + + if length(nFiredPool) != 0 + newConn = popfirst!(nFiredPool) + else + newConn = popfirst!(nNonFiredPool) + end n.subscriptionList[connIndex] = newConn - n.wRec[connIndex] = w[i] * nExInTypeList[newConn] + n.wRec[connIndex] = abs(w[i]) * nExInTypeList[newConn] n.synapticStrength[connIndex] = synapticStrength[i] end end end +# function neuroplasticity!(n::outputNeuron, firedNeurons::Vector, +# nExInTypeList::Vector, totalInputNeuron::Integer) +# # if there is 0-weight then replace it with new connection +# zeroWeightConnIndex = findall(iszero.(n.wRec)) # connection that has 0 weight +# if length(zeroWeightConnIndex) != 0 +# # new synaptic connection must sample fron neuron that fires +# nFiredPool = filter(x -> x ∉ [n.id], firedNeurons) # exclude this neuron id from the id list +# filter!(x -> x ∉ n.subscriptionList, nFiredPool) # exclude this neuron's subscriptionList from the list +# filter!(x -> x ∉ [1:totalInputNeuron...], nFiredPool) # exclude input neuron + +# nNonFiredPool = setdiff!([1:length(nExInTypeList)...], nFiredPool) +# unique!(append!(nNonFiredPool, zeroWeightConnIndex)) +# filter!(x -> x ∉ [n.id], nNonFiredPool) # exclude this neuron id from the id list +# filter!(x -> x ∉ n.subscriptionList, nNonFiredPool) # exclude this neuron's subscriptionList from the list +# filter!(x -> x ∉ [1:totalInputNeuron...], nNonFiredPool) # exclude input neuron + +# w = randn(length(zeroWeightConnIndex)) / 100 +# synapticStrength = rand(-4.5:0.1:-3.5, length(zeroWeightConnIndex)) + +# shuffle!(nFiredPool) +# shuffle!(nNonFiredPool) + +# # add new synaptic connection to neuron +# for (i, connIndex) in enumerate(zeroWeightConnIndex) +# """ conn that is being replaced has to go into nNonFiredPool so +# nNonFiredPool isn't empty """ +# push!(nNonFiredPool, n.subscriptionList[connIndex]) + +# if length(nFiredPool) != 0 +# newConn = popfirst!(nFiredPool) +# else +# newConn = popfirst!(nNonFiredPool) +# end +# n.subscriptionList[connIndex] = newConn +# n.wRec[connIndex] = w[i] * nExInTypeList[newConn] +# n.synapticStrength[connIndex] = synapticStrength[i] +# end +# end +# end + """ Cap maximum weight of each neuron connection """ function capMaxWeight!(v::Vector{Float64}, max=1.0) diff --git a/src/types.jl b/src/types.jl index 7aa9dfc..51a25cc 100644 --- a/src/types.jl +++ b/src/types.jl @@ -4,6 +4,7 @@ export # struct IronpenStruct, model, knowledgeFn, lifNeuron, alifNeuron, linearNeuron, kfn_1, inputNeuron, computeNeuron, neuron, outputNeuron, passthroughNeuron, + integrateNeuron, # function instantiate_custom_types, init_neuron, populate_neuron, @@ -22,6 +23,8 @@ abstract type computeNeuron <: neuron end #------------------------------------------------------------------------------------------------100 +rng = MersenneTwister(1234) + """ Model struct """ Base.@kwdef mutable struct model <: Ironpen @@ -262,16 +265,16 @@ function kfn_1(kfnParams::Dict) end end - # add ExInType into each output neuron subExInType - for n in kfn.outputNeuronsArray - try # input neuron doest have n.subscriptionList - for (i, sub_id) in enumerate(n.subscriptionList) - n_ExInType = kfn.neuronsArray[sub_id].ExInType - n.wRec[i] *= n_ExInType - end - catch - end - end + # # add ExInType into each output neuron subExInType + # for n in kfn.outputNeuronsArray + # try # input neuron doest have n.subscriptionList + # for (i, sub_id) in enumerate(n.subscriptionList) + # n_ExInType = kfn.neuronsArray[sub_id].ExInType + # n.wRec[i] *= n_ExInType + # end + # catch + # end + # end for n in kfn.neuronsArray push!(kfn.nExInType, n.ExInType) @@ -339,6 +342,7 @@ Base.@kwdef mutable struct lifNeuron <: computeNeuron gammaPd::Float64 = 0.3 # γ_pd, discount factor, value from paper alpha::Float64 = 0.0 # α, neuron membrane potential decay factor + alphaChange::Float64 = 0.0 phi::Float64 = 0.0 # ϕ, psuedo derivative epsilonRec::Array{Float64} = Float64[] # ϵ_rec, eligibility vector for neuron spike decayedEpsilonRec::Array{Float64} = Float64[] # α * epsilonRec @@ -347,7 +351,7 @@ Base.@kwdef mutable struct lifNeuron <: computeNeuron refractoryDuration::Int64 = 3 # neuron's refratory period in millisecond refractoryCounter::Int64 = 0 tau_m::Float64 = 100.0 # τ_m, membrane time constant in millisecond - eta::Float64 = 0.01 # η, learning rate + eta::Float64 = 1e-3 # η, learning rate wRecChange::Array{Float64} = Float64[] # Δw_rec, cumulated wRec change recSignal::Float64 = 0.0 # incoming recurrent signal alpha_v_t::Float64 = 0.0 # alpha * v_t @@ -428,6 +432,7 @@ Base.@kwdef mutable struct alifNeuron <: computeNeuron synapticStrengthLimit::NamedTuple = (lowerlimit=(-5=>0), upperlimit=(5=>5)) alpha::Float64 = 0.0 # α, neuron membrane potential decay factor + alphaChange::Float64 = 0.0 delta::Float64 = 1.0 # δ, discreate timestep size in millisecond epsilonRec::Array{Float64} = Float64[] # ϵ_rec(v), eligibility vector for neuron i spike epsilonRecA::Array{Float64} = Float64[] # ϵ_rec(a) @@ -435,7 +440,7 @@ Base.@kwdef mutable struct alifNeuron <: computeNeuron eRec_v::Array{Float64} = Float64[] # a component of neuron's eligibility trace resulted from v_t eRec_a::Array{Float64} = Float64[] # a component of neuron's eligibility trace resulted from av_th eRec::Array{Float64} = Float64[] # neuron's eligibility trace - eta::Float64 = 0.01 # eta, learning rate + eta::Float64 = 1e-3 # eta, learning rate gammaPd::Float64 = 0.3 # γ_pd, discount factor, value from paper phi::Float64 = 0.0 # ϕ, psuedo derivative refractoryDuration::Int64 = 3 # neuron's refractory period in millisecond @@ -510,7 +515,7 @@ end """ linearNeuron struct """ Base.@kwdef mutable struct linearNeuron <: outputNeuron - id::Float64 = 0.0 # ID of this neuron which is it position in knowledgeFn array + id::Int64 = 0 # ID of this neuron which is it position in knowledgeFn array type::String = "linearNeuron" knowledgeFnName::String = "not defined" # knowledgeFn that this neuron belongs to subscriptionList::Array{Int64} = Int64[] # list of other neuron that this neuron synapse subscribed to @@ -545,7 +550,7 @@ Base.@kwdef mutable struct linearNeuron <: outputNeuron refractoryDuration::Int64 = 3 # neuron's refratory period in millisecond refractoryCounter::Int64 = 0 tau_out::Float64 = 50.0 # τ_out, membrane time constant in millisecond - eta::Float64 = 0.01 # η, learning rate + eta::Float64 = 1e-3 # η, learning rate wRecChange::Array{Float64} = Float64[] # Δw_rec, cumulated wRec change recSignal::Float64 = 0.0 # incoming recurrent signal alpha_v_t::Float64 = 0.0 # alpha * v_t @@ -584,6 +589,87 @@ function linearNeuron(params::Dict) return n end +#------------------------------------------------------------------------------------------------100 +""" integrateNeuron struct +""" +Base.@kwdef mutable struct integrateNeuron <: outputNeuron + id::Int64 = 0 # ID of this neuron which is it position in knowledgeFn array + type::String = "integrateNeuron" + knowledgeFnName::String = "not defined" # knowledgeFn that this neuron belongs to + subscriptionList::Array{Int64} = Int64[] # list of other neuron that this neuron synapse subscribed to + timeStep::Int64 = 0 # current time + wRec::Array{Float64} = Float64[] # synaptic weight (for receiving signal from other neuron) + v_t::Float64 = randn() # vᵗ, postsynaptic neuron membrane potential of previous timestep + v_t1::Float64 = 0.0 # vᵗ⁺¹, postsynaptic neuron membrane potential at current timestep + v_th::Float64 = 1.0 # vᵗʰ, neuron firing threshold + vRest::Float64 = 0.0 # resting potential after neuron fired + vError::Float64 = 0.0 # used to compute model error + z_t::Bool = false # zᵗ, neuron postsynaptic firing of previous timestep + # zᵗ⁺¹, neuron firing status at time = t+1. I need this because the way I calculate all + # neurons forward function at each timestep-by-timestep is to do every neuron + # forward calculation. Each neuron requires access to other neuron's firing status + # during v_t1 calculation hence I need a variable to hold z_t1 so that I'm not replacing z_t + z_t1::Bool = false # neuron postsynaptic firing at current timestep (after neuron's calculation) + b::Float64 = 0.0 + bChange::Float64 = 0.0 + + # neuron presynaptic firing at current timestep (which is other neuron postsynaptic firing of + # previous timestep) + z_i_t::Array{Bool} = Bool[] + z_i_t_commulative::Array{Int64} = Int64[] # used to compute connection strength + synapticStrength::Array{Float64} = Float64[] + synapticStrengthLimit::NamedTuple = (lowerlimit=(-5=>-5), upperlimit=(5=>5)) + + gammaPd::Float64 = 0.3 # γ_pd, discount factor, value from paper + alpha::Float64 = 0.0 # α, neuron membrane potential decay factor + alphaChange::Float64 = 0.0 + phi::Float64 = 0.0 # ϕ, psuedo derivative + epsilonRec::Array{Float64} = Float64[] # ϵ_rec, eligibility vector for neuron spike + decayedEpsilonRec::Array{Float64} = Float64[] # α * epsilonRec + eRec::Array{Float64} = Float64[] # eligibility trace for neuron spike + delta::Float64 = 1.0 # δ, discreate timestep size in millisecond + refractoryDuration::Int64 = 3 # neuron's refratory period in millisecond + refractoryCounter::Int64 = 0 + tau_out::Float64 = 50.0 # τ_out, membrane time constant in millisecond + eta::Float64 = 1e-3 # η, learning rate + wRecChange::Array{Float64} = Float64[] # Δw_rec, cumulated wRec change + recSignal::Float64 = 0.0 # incoming recurrent signal + alpha_v_t::Float64 = 0.0 # alpha * v_t + + firingCounter::Int64 = 0 # store how many times neuron fires + ExInSignalSum::Float64 = 0.0 +end + +""" linear neuron outer constructor + + # Example + + linear_neuron_params = Dict( + :type => "linearNeuron", + :k => 0.9, # output leakink coefficient + :tau_out => 5.0, # output time constant in millisecond. It should equals to time use for 1 sequence + :out => 0.0, # neuron's output value store here + ) + + neuron1 = linearNeuron(linear_neuron_params) +""" +function integrateNeuron(params::Dict) + n = integrateNeuron() + field_names = fieldnames(typeof(n)) + for i in field_names + if i in keys(params) + if i == :optimiser + opt_type = string(split(params[i], ".")[end]) + n.:($i) = load_optimiser(opt_type) + else + n.:($i) = params[i] # assign params to n struct fields + end + end + end + + return n +end + #------------------------------------------------------------------------------------------------100 # function load_optimiser(optimiser_name::String; params::Union{Dict,Nothing} = nothing) @@ -634,10 +720,11 @@ function init_neuron!(id::Int64, n::lifNeuron, n_params::Dict, kfnParams::Dict) # prevent subscription to itself by removing this neuron id filter!(x -> x != n.id, n.subscriptionList) - n.synapticStrength = rand(-5:0.01:-4, length(n.subscriptionList)) + n.synapticStrength = rand(-4.5:0.01:-4, length(n.subscriptionList)) n.epsilonRec = zeros(length(n.subscriptionList)) - n.wRec = randn(length(n.subscriptionList)) + # n.wRec = randn(length(n.subscriptionList)) + n.wRec = randn(rng, length(n.subscriptionList)) / 100 n.wRecChange = zeros(length(n.subscriptionList)) n.alpha = calculate_α(n) n.z_i_t_commulative = zeros(length(n.subscriptionList)) @@ -654,10 +741,10 @@ function init_neuron!(id::Int64, n::alifNeuron, n_params::Dict, # prevent subscription to itself by removing this neuron id filter!(x -> x != n.id, n.subscriptionList) - n.synapticStrength = rand(-5:0.01:-4, length(n.subscriptionList)) + n.synapticStrength = rand(-4.5:0.01:-4, length(n.subscriptionList)) n.epsilonRec = zeros(length(n.subscriptionList)) - n.wRec = randn(length(n.subscriptionList)) + n.wRec = randn(rng, length(n.subscriptionList)) / 100 # TODO use abs() n.wRecChange = zeros(length(n.subscriptionList)) # the more time has passed from the last time neuron was activated, the more @@ -668,8 +755,7 @@ function init_neuron!(id::Int64, n::alifNeuron, n_params::Dict, n.z_i_t_commulative = zeros(length(n.subscriptionList)) end - -function init_neuron!(id::Int64, n::linearNeuron, n_params::Dict, kfnParams::Dict) +function init_neuron!(id::Int64, n::integrateNeuron, n_params::Dict, kfnParams::Dict) n.id = id n.knowledgeFnName = kfnParams[:knowledgeFnName] @@ -677,15 +763,33 @@ function init_neuron!(id::Int64, n::linearNeuron, n_params::Dict, kfnParams::Dic subscription_numbers = Int(floor((n_params[:synapticConnectionPercent] / 100.0) * kfnParams[:totalNeurons] - kfnParams[:totalInputPort])) n.subscriptionList = [pop!(subscription_options) for i = 1:subscription_numbers] - n.synapticStrength = rand(-5:0.01:-4, length(n.subscriptionList)) + n.synapticStrength = rand(-4.5:0.01:-4, length(n.subscriptionList)) n.epsilonRec = zeros(length(n.subscriptionList)) - n.wRec = randn(length(n.subscriptionList)) + n.wRec = randn(rng, length(n.subscriptionList)) / 100 n.wRecChange = zeros(length(n.subscriptionList)) n.alpha = calculate_k(n) n.z_i_t_commulative = zeros(length(n.subscriptionList)) + n.b = randn(rng) / 100 end +# function init_neuron!(id::Int64, n::linearNeuron, n_params::Dict, kfnParams::Dict) +# n.id = id +# n.knowledgeFnName = kfnParams[:knowledgeFnName] + +# subscription_options = shuffle!([kfnParams[:totalInputPort]+1 : kfnParams[:totalNeurons]...]) +# subscription_numbers = Int(floor((n_params[:synapticConnectionPercent] / 100.0) * +# kfnParams[:totalNeurons] - kfnParams[:totalInputPort])) +# n.subscriptionList = [pop!(subscription_options) for i = 1:subscription_numbers] +# n.synapticStrength = rand(-4.5:0.01:-4, length(n.subscriptionList)) + +# n.epsilonRec = zeros(length(n.subscriptionList)) +# n.wRec = randn(rng, length(n.subscriptionList)) / 100 +# n.wRecChange = zeros(length(n.subscriptionList)) +# n.alpha = calculate_k(n) +# n.z_i_t_commulative = zeros(length(n.subscriptionList)) +# end + """ Make a neuron intended for use with knowledgeFn """ function init_neuron(id::Int64, n_params::Dict, kfnParams::Dict) @@ -715,7 +819,9 @@ function instantiate_custom_types(params::Union{Dict,Nothing} = nothing) elseif type == "alifNeuron" return alifNeuron(params) elseif type == "linearNeuron" - return linearNeuron(params) + return linearNeuron(params) + elseif type == "integrateNeuron" + return integrateNeuron(params) else return nothing end @@ -740,6 +846,7 @@ calculate_α(neuron::lifNeuron) = exp(-neuron.delta / neuron.tau_m) calculate_α(neuron::alifNeuron) = exp(-neuron.delta / neuron.tau_m) calculate_ρ(neuron::alifNeuron) = exp(-neuron.delta / neuron.tau_a) calculate_k(neuron::linearNeuron) = exp(-neuron.delta / neuron.tau_out) +calculate_k(neuron::integrateNeuron) = exp(-neuron.delta / neuron.tau_out) #------------------------------------------------------------------------------------------------100