diff --git a/src/Ironpen.jl b/src/Ironpen.jl index 19638a8..f386398 100644 --- a/src/Ironpen.jl +++ b/src/Ironpen.jl @@ -32,8 +32,9 @@ using .learn # using .interface #------------------------------------------------------------------------------------------------100 -""" version 0.0.4 +""" version 0.0.5 Todo: + - [4] implement dormant connection [] using RL to control learning signal [] consider using Dates.now() instead of timestamp because time_stamp may overflow @@ -41,8 +42,9 @@ using .learn which defined by neuron.tau_m formula in type.jl - Change from version: 0.0.3 - + Change from version: 0.0.4 + - compute error in main loop so one could decide how to calculate error + - compute model error in main loop so one could decide when to calculate error All features diff --git a/src/forward.jl b/src/forward.jl index f23a386..05dab9c 100644 --- a/src/forward.jl +++ b/src/forward.jl @@ -90,10 +90,13 @@ function (kfn::kfn_1)(m::model, input_data::AbstractVector) end out = [n.z_t1 for n in kfn.outputNeuronsArray] - outputNeuron_v_t1 = [n.v_t1 for n in kfn.outputNeuronsArray] - return out::Array{Bool}, outputNeuron_v_t1::Array{Float64}, sum(kfn.firedNeurons_t1), - [i.ExInSignalSum for i in kfn.outputNeuronsArray] + return out::Array{Bool}, + sum(kfn.firedNeurons_t1), + [n.v_t1 for n in kfn.outputNeuronsArray], + [sum(i.wRec) for i in kfn.outputNeuronsArray], + [sum(i.epsilonRec) for i in kfn.outputNeuronsArray], + [i.phi for i in kfn.outputNeuronsArray] end #------------------------------------------------------------------------------------------------100 @@ -218,16 +221,6 @@ function (n::linearNeuron)(kfn::T) where T<:knowledgeFn n.vError = n.v_t1 # store voltage that will be used to calculate error later else recSignal = n.wRec .* n.z_i_t - n.ExInSignalSum = 0.0 #CHANGE for ploting - for i in recSignal - # if i > 0 - # kfn.exSignalSum += i - # elseif i < 0 - # kfn.inSignalsum += i - # else - # end - n.ExInSignalSum += i - end n.recSignal = sum(recSignal) # signal from other neuron that this neuron subscribed n.alpha_v_t = n.alpha * n.v_t n.v_t1 = n.alpha_v_t + n.recSignal diff --git a/src/learn.jl b/src/learn.jl index 6598236..4c9ed5c 100644 --- a/src/learn.jl +++ b/src/learn.jl @@ -19,31 +19,51 @@ function compute_wRecChange!(m::model, error::Vector{Float64}, correctAnswer::Ab compute_wRecChange!(m.knowledgeFn[:I], error, correctAnswer) end +# function compute_wRecChange!(kfn::kfn_1, errors::Vector{Float64}, correctAnswer::AbstractVector) +# for (i, error) in enumerate(errors) +# if error == 0 # output is correct +# # Threads.@threads for n in kfn.neuronsArray # multithread is not atomic and causing error +# # # for n in kfn.neuronsArray +# # synapticConnStrength!(n, true) +# # end +# # synapticConnStrength!(kfn.outputNeuronsArray[i], true) +# else # output is wrong, error occurs +# if correctAnswer[i] == 1 # high priority answer +# error = error * +# abs(kfn.outputNeuronsArray[i].v_th - kfn.outputNeuronsArray[i].vError) +# else # low priority answer +# error = error * +# abs(kfn.outputNeuronsArray[i].v_th - kfn.outputNeuronsArray[i].vError) +# end + +# Threads.@threads for n in kfn.neuronsArray # multithread is not atomic and causing error +# # for n in kfn.neuronsArray +# compute_wRecChange!(n, error) +# # synapticConnStrength!(n, false) +# end +# compute_wRecChange!(kfn.outputNeuronsArray[i], error) +# # synapticConnStrength!(kfn.outputNeuronsArray[i], false) +# end +# end + + function compute_wRecChange!(kfn::kfn_1, errors::Vector{Float64}, correctAnswer::AbstractVector) for (i, error) in enumerate(errors) - if error == 0 # output is correct - # Threads.@threads for n in kfn.neuronsArray # multithread is not atomic and causing error - # # for n in kfn.neuronsArray - # synapticConnStrength!(n, true) - # end - # synapticConnStrength!(kfn.outputNeuronsArray[i], true) - else # output is wrong, error occurs - if correctAnswer[i] == 1 # high priority answer - error = error * + if error < 0 # model fires too fast + error = error * abs(kfn.outputNeuronsArray[i].v_th - kfn.outputNeuronsArray[i].vError) - else # low priority answer - error = error * + elseif error == 0 # model answer correctly. maintain membrain potential ≈ 0.5 + error = error * + abs(kfn.outputNeuronsArray[i].v_th/2 - kfn.outputNeuronsArray[i].vError) + else # model fires too slow + error = error * abs(kfn.outputNeuronsArray[i].v_th - kfn.outputNeuronsArray[i].vError) - end - - Threads.@threads for n in kfn.neuronsArray # multithread is not atomic and causing error - # for n in kfn.neuronsArray - compute_wRecChange!(n, error) - # synapticConnStrength!(n, false) - end - compute_wRecChange!(kfn.outputNeuronsArray[i], error) - # synapticConnStrength!(kfn.outputNeuronsArray[i], false) end + Threads.@threads for n in kfn.neuronsArray # multithread is not atomic and causing error + # for n in kfn.neuronsArray + compute_wRecChange!(n, error) + end + compute_wRecChange!(kfn.outputNeuronsArray[i], error) end end diff --git a/src/types.jl b/src/types.jl index 729eef9..7aa9dfc 100644 --- a/src/types.jl +++ b/src/types.jl @@ -346,7 +346,7 @@ Base.@kwdef mutable struct lifNeuron <: computeNeuron delta::Float64 = 1.0 # δ, discreate timestep size in millisecond refractoryDuration::Int64 = 3 # neuron's refratory period in millisecond refractoryCounter::Int64 = 0 - tau_m::Float64 = 50.0 # τ_m, membrane time constant in millisecond + tau_m::Float64 = 100.0 # τ_m, membrane time constant in millisecond eta::Float64 = 0.01 # η, learning rate wRecChange::Array{Float64} = Float64[] # Δw_rec, cumulated wRec change recSignal::Float64 = 0.0 # incoming recurrent signal @@ -440,7 +440,7 @@ Base.@kwdef mutable struct alifNeuron <: computeNeuron phi::Float64 = 0.0 # ϕ, psuedo derivative refractoryDuration::Int64 = 3 # neuron's refractory period in millisecond refractoryCounter::Int64 = 0 - tau_m::Float64 = 50.0 # τ_m, membrane time constant in millisecond + tau_m::Float64 = 100.0 # τ_m, membrane time constant in millisecond wRecChange::Array{Float64} = Float64[] # Δw_rec, cumulated wRec change recSignal::Float64 = 0.0 # incoming recurrent signal alpha_v_t::Float64 = 0.0 # alpha * v_t @@ -453,7 +453,7 @@ Base.@kwdef mutable struct alifNeuron <: computeNeuron firingRateError::Float64 = 0.0 # local neuron error w.r.t. firing regularization firingRate::Float64 = 0.0 # running average of firing rate, Hz - tau_a::Float64 = 50.0 # τ_a, adaption time constant in millisecond + tau_a::Float64 = 100.0 # τ_a, adaption time constant in millisecond beta::Float64 = 0.15 # β, constant, value from paper rho::Float64 = 0.0 # ρ, threshold adaptation decay factor a::Float64 = 0.0 # threshold adaptation @@ -544,7 +544,7 @@ Base.@kwdef mutable struct linearNeuron <: outputNeuron delta::Float64 = 1.0 # δ, discreate timestep size in millisecond refractoryDuration::Int64 = 3 # neuron's refratory period in millisecond refractoryCounter::Int64 = 0 - tau_out::Float64 = 25.0 # τ_out, membrane time constant in millisecond + tau_out::Float64 = 50.0 # τ_out, membrane time constant in millisecond eta::Float64 = 0.01 # η, learning rate wRecChange::Array{Float64} = Float64[] # Δw_rec, cumulated wRec change recSignal::Float64 = 0.0 # incoming recurrent signal @@ -637,7 +637,7 @@ function init_neuron!(id::Int64, n::lifNeuron, n_params::Dict, kfnParams::Dict) n.synapticStrength = rand(-5:0.01:-4, length(n.subscriptionList)) n.epsilonRec = zeros(length(n.subscriptionList)) - n.wRec = rand(-0.2:0.01:0.2, length(n.subscriptionList)) + n.wRec = randn(length(n.subscriptionList)) n.wRecChange = zeros(length(n.subscriptionList)) n.alpha = calculate_α(n) n.z_i_t_commulative = zeros(length(n.subscriptionList)) @@ -657,7 +657,7 @@ function init_neuron!(id::Int64, n::alifNeuron, n_params::Dict, n.synapticStrength = rand(-5:0.01:-4, length(n.subscriptionList)) n.epsilonRec = zeros(length(n.subscriptionList)) - n.wRec = rand(-0.2:0.01:0.2, length(n.subscriptionList)) + n.wRec = randn(length(n.subscriptionList)) n.wRecChange = zeros(length(n.subscriptionList)) # the more time has passed from the last time neuron was activated, the more @@ -680,7 +680,7 @@ function init_neuron!(id::Int64, n::linearNeuron, n_params::Dict, kfnParams::Dic n.synapticStrength = rand(-5:0.01:-4, length(n.subscriptionList)) n.epsilonRec = zeros(length(n.subscriptionList)) - n.wRec = rand(-0.2:0.01:0.2, length(n.subscriptionList)) + n.wRec = randn(length(n.subscriptionList)) n.wRecChange = zeros(length(n.subscriptionList)) n.alpha = calculate_k(n) n.z_i_t_commulative = zeros(length(n.subscriptionList))