time-based learning method based on new error formula

This commit is contained in:
2023-05-16 20:56:05 +07:00
parent 114161ba69
commit 70d2521c5e
5 changed files with 146 additions and 227 deletions

View File

@@ -34,14 +34,16 @@ using .interface
"""
Todo:
[7] time-based learning method based on new error formula
[*6] time-based learning method based on new error formula
(use output vt compared to vth instead of late time)
if output neuron not activate when it should, use output neuron's
(vth - vt)*100/vth as error
if output neuron activates when it should NOT, use output neuron's
(vt*100)/vth as error
[8] verify that model can complete learning cycle with no error
[5] synaptic connection strength concept. use sigmoid
[6] neuroplasticity() i.e. change connection
[7] use LinearAlgebra.normalize!(vector, 1) to adjust weight after weight merge
[9] verify that model can complete learning cycle with no error
[*5] synaptic connection strength concept. use sigmoid, turn connection offline
[8] neuroplasticity() i.e. change connection
[] using RL to control learning signal
[] consider using Dates.now() instead of timestamp because time_stamp may overflow
[] training should include adjusting α, neuron membrane potential decay factor

View File

@@ -11,7 +11,6 @@ using ..types, ..snn_utils
""" Model forward()
"""
function (m::model)(input_data::AbstractVector)
# m.global_tick += 1
m.timeStep += 1
# process all corresponding KFN
@@ -31,7 +30,6 @@ function (kfn::kfn_1)(m::model, input_data::AbstractVector)
kfn.timeStep = m.timeStep
kfn.softreset = m.softreset
kfn.learningStage = m.learningStage
kfn.error = m.error
# generate noise
noise = [GeneralUtils.randomChoiceWithProb([true, false],[0.5,0.5])
@@ -101,8 +99,8 @@ function (n::lif_neuron)(kfn::knowledgeFn)
# last only 1 timestep follow by a period of refractory.
n.recSignal = n.recSignal * 0.0
# Exponantial decay of v_t1
n.v_t1 = n.v_t * n.alpha^(n.timeStep - n.lastFiringTime) # or n.v_t1 = n.alpha * n.v_t
# decay of v_t1
n.v_t1 = n.alpha * n.v_t
else
n.recSignal = sum(n.w_rec .* n.z_i_t) # signal from other neuron that this neuron subscribed
@@ -142,8 +140,8 @@ function (n::alif_neuron)(kfn::knowledgeFn)
n.a = (n.rho * n.a) + ((1 - n.rho) * n.z_t)
n.recSignal = n.recSignal * 0.0
# Exponantial decay of v_t1
n.v_t1 = n.v_t * n.alpha^(n.timeStep - n.lastFiringTime) # or n.v_t1 = n.alpha * n.v_t
# decay of v_t1
n.v_t1 = n.alpha * n.v_t
n.phi = 0
else
n.z_t = isnothing(n.z_t) ? false : n.z_t
@@ -187,8 +185,8 @@ function (n::linear_neuron)(kfn::T) where T<:knowledgeFn
# last only 1 timestep follow by a period of refractory.
n.recSignal = n.recSignal * 0.0
# Exponantial decay of v_t1
n.v_t1 = n.v_t * n.alpha^(n.timeStep - n.lastFiringTime) # or n.v_t1 = n.alpha * n.v_t
# decay of v_t1
n.v_t1 = n.alpha * n.v_t
else
n.recSignal = sum(n.w_rec .* n.z_i_t) # signal from other neuron that this neuron subscribed

View File

@@ -4,7 +4,7 @@ using Flux.Optimise: apply!
using Statistics, Flux, Random, LinearAlgebra
using GeneralUtils
using ..types
using ..types, ..snn_utils
export learn!
@@ -12,6 +12,23 @@ export learn!
function learn!(m::model, modelRespond, correctAnswer=nothing)
m.knowledgeFn[:I].learningStage = m.learningStage
# # how many matched respond and correct answer
# matched = sum(isequal(modelRespond, correctAnswer))
if correctAnswer === nothing
correctAnswer_I = zeros(length(modelRespond))
else
correctAnswer_I = correctAnswer # correct answer for kfn I
end
learn!(m.knowledgeFn[:I], correctAnswer_I)
end
""" knowledgeFn learn()
"""
function learn!(kfn::kfn_1, correctAnswer::AbstractVector)
# ΔWeight Conn. Strength
# case 1 no no during input signal, no correct answer available, no answer
# case 2 no - during input signal, no correct answer available, wrong answer
@@ -27,38 +44,57 @@ function learn!(m::model, modelRespond, correctAnswer=nothing)
# success
# how many matched respond and correct answer
matched = sum(isequal(modelRespond, correctAnswer))
correctAnswer_I = correctAnswer # correct answer for kfn I
learn!(m.knowledgeFn[:I], correctAnswer_I)
# return model_error
end
""" knowledgeFn learn()
"""
function learn!(kfn::kfn_1, correctAnswer=nothing)
if kfn.learningStage == "start_learning"
# reset params here instead of at the end_learning so that neuron's parameter data
# don't gets wiped and can be logged for visualization later
for n in kfn.neuronsArray
# epsilonRec need to be reset because it counting how many each synaptic fires and
# use this info to calculate how much synaptic weight should be adjust
reset_learning_params!(n)
resetLearningParams!(n)
end
# clear variables
kfn.firedNeurons = Vector{Int64}()
kfn.outputs = nothing
kfn.firedNeurons_t0 = Vector{Bool}()
kfn.firedNeurons_t1 = Vector{Bool}()
kfn.learningStage = "learning"
#TODO prepare for end learning
elseif kfn.learningStage == "end_learning"
reset_learning_params!(n)
resetLearningParams!(n)
# clear variables
kfn.firedNeurons = Vector{Int64}()
kfn.firedNeurons_t0 = Vector{Bool}()
kfn.firedNeurons_t1 = Vector{Bool}()
kfn.learningStage = "inference"
end
# compute kfn error
out = [n.z_t1 for n in kfn.outputNeuronsArray]
for (i, v) in enumerate(out)
if v != correctAnswer[i] # need to adjust weight
kfnError = (kfn.outputNeuronsArray[i].v_th - kfn.outputNeuronsArray[i].v_t) *
100 / kfn.outputNeuronsArray[i].v_th
# Threads.@threads for n in kfn.neuronsArray
for n in kfn.neuronsArray
learn!(n, kfnError)
end
learn!(kfn.outputNeuronsArray[i], kfn)
end
end
#WORKING
# Threads.@threads for n in kfn.neuronsArray
for n in kfn.neuronsArray
learn!(n, kfn) # Neurons are always learning, besides error from model output
@@ -71,7 +107,7 @@ function learn!(kfn::kfn_1, correctAnswer=nothing)
# other output neurons
learn!(n, kfn)
end
#TODO: put other KFN to learn here
# for main loop user's display and training's exit condition
avgNeuronsFiringRate = 0.0
@@ -90,6 +126,25 @@ function learn!(kfn::kfn_1, correctAnswer=nothing)
end
kfn.avgNeurons_v_t1 = avgNeurons_v_t1 / kfn.kfnParams[:compute_neuron_number]
end
# wrap up learning session
if kfn.learningStage == "end_learning"
#TODO neuroplasticity
resetLearningParams!(n)
# clear variables
kfn.firedNeurons = Vector{Int64}()
kfn.firedNeurons_t0 = Vector{Bool}()
kfn.firedNeurons_t1 = Vector{Bool}()
kfn.learningStage = "inference"
end
end
""" passthrough_neuron learn()
@@ -100,71 +155,23 @@ end
""" lif learn()
"""
function learn!(n::lif_neuron, kfn::knowledgeFn)
if n.learnable_flag == true
function learn!(n::lif_neuron, error::Number)
n.decayedEpsilonRec = n.alpha * n.epsilonRec
n.epsilonRec = n.decayedEpsilonRec + n.z_i_t
n.eRec = n.phi * n.epsilonRec
n.decayedEpsilonRec = n.alpha * n.epsilonRec
n.epsilonRec = n.decayedEpsilonRec + n.z_i_t
n.eRec = n.phi * n.epsilonRec
end
ΔwRecChange = n.eta * error
n.wRecChange = (n.subExInType * n.wRecChange) + ΔwRecChange
# a piece of knowledgeFn error that belongs to this neuron
n.error = isnothing(kfn.error) ? nothing : kfn.error * n.Bn
n.learningStage = kfn.learningStage
# accumulate voltage regularization terms
Snn_utils.cal_v_reg!(n)
if n.learningStage == "doing_inference"
# no learning
elseif n.learningStage == "start_learning" ||
n.learningStage == "start_learning_no_wchange_reset"
# if error signal available then accumulates Δw
if n.error !== nothing
Snn_utils.firing_rate!(n)
Snn_utils.firing_diff!(n)
n.wRecChange = n.wRecChange +
-apply!(n.optimiser, n.w_rec,
(n.error + Snn_utils.voltage_error!(n) + n.firingRateError) * n.eRec) +
-Snn_utils.firing_rate_regulator!(n) +
-Snn_utils.voltage_regulator!(n)
end
elseif n.learningStage == "during_learning"
# if error signal available then accumulates Δw
if n.error !== nothing
Snn_utils.firing_rate!(n)
Snn_utils.firing_diff!(n)
n.wRecChange = n.wRecChange +
-apply!(n.optimiser, n.w_rec,
(n.error + Snn_utils.voltage_error!(n) + n.firingRateError) * n.eRec) +
-Snn_utils.firing_rate_regulator!(n) +
-Snn_utils.voltage_regulator!(n)
end
elseif n.learningStage == "end_learning"
# if error signal available then accumulates Δw
if n.error !== nothing
Snn_utils.firing_rate!(n)
Snn_utils.firing_diff!(n)
n.wRecChange = n.wRecChange +
-apply!(n.optimiser, n.w_rec,
(n.error + Snn_utils.voltage_error!(n) + n.firingRateError) * n.eRec) +
-Snn_utils.firing_rate_regulator!(n) +
-Snn_utils.voltage_regulator!(n)
end
not_zero = (!iszero).(n.w_rec)
# set 0 in wRecChange update according to 0 in w_rec for hard constrain connection
n.w_rec = n.w_rec + (not_zero .* n.wRecChange)
replace!(x -> x < 0 ? 0 : x, n.w_rec) # no negative weight
Snn_utils.neuroplasticity!(n, kfn.firedNeurons)
end
# check for fliped sign, 1 indicates non-fliped sign
wSign = sign.(n.wRecChange)
nonFlipedSign = isequal.(n.subExInType, wSign) # 1 not fliped, 0 fliped
n.wRecChange .*= nonFlipedSign # set weight that fliped sign to 0 for random new connection
end
""" alif_neuron learn()
"""
function learn!(n::alif_neuron, kfn::knowledgeFn)
function learn!(n::alif_neuron, error::Number)
n.decayedEpsilonRec = n.alpha * n.epsilonRec
n.epsilonRec = n.decayedEpsilonRec + n.z_i_t
n.epsilonRecA = (n.phi * n.epsilonRec) +
@@ -173,117 +180,29 @@ function learn!(n::alif_neuron, kfn::knowledgeFn)
n.eRec_a = -n.phi * n.beta * n.epsilonRecA
n.eRec = n.eRec_v + n.eRec_a
# a piece of knowledgeFn error that belongs to this neuron
n.error = isnothing(kfn.error) ? nothing : kfn.error * n.Bn
n.learningStage = kfn.learningStage
ΔwRecChange = n.eta * error
n.wRecChange = (n.subExInType * n.wRecChange) + ΔwRecChange
if n.learningStage == "doing_inference"
# no learning
elseif n.learningStage == "start_learning" ||
n.learningStage == "start_learning_no_wchange_reset"
# if error signal available then accumulates Δw
if n.error !== nothing
Snn_utils.firing_rate!(n)
Snn_utils.firing_diff!(n)
n.wRecChange = n.wRecChange +
-apply!(n.optimiser, n.w_rec,
(n.error + Snn_utils.voltage_error!(n) + n.firingRateError) * n.eRec) +
-Snn_utils.firing_rate_regulator!(n) +
-Snn_utils.voltage_regulator!(n)
end
elseif n.learningStage == "during_learning"
# if error signal available then accumulates Δw
if n.error !== nothing
Snn_utils.firing_rate!(n)
Snn_utils.firing_diff!(n)
n.wRecChange = n.wRecChange +
-apply!(n.optimiser, n.w_rec,
(n.error + Snn_utils.voltage_error!(n) + n.firingRateError) * n.eRec) +
-Snn_utils.firing_rate_regulator!(n) +
-Snn_utils.voltage_regulator!(n)
end
elseif n.learningStage == "end_learning"
# if error signal available then accumulates Δw
if n.error !== nothing
Snn_utils.firing_rate!(n)
Snn_utils.firing_diff!(n)
n.wRecChange = n.wRecChange +
-apply!(n.optimiser, n.w_rec,
(n.error + Snn_utils.voltage_error!(n) + n.firingRateError) * n.eRec) +
-Snn_utils.firing_rate_regulator!(n) +
-Snn_utils.voltage_regulator!(n)
end
not_zero = (!iszero).(n.w_rec)
# set 0 in wRecChange update according to 0 in w_rec for hard constrain connection
n.w_rec = n.w_rec + (not_zero .* n.wRecChange)
replace!(x -> x < 0 ? 0 : x, n.w_rec) # no negative weight
Snn_utils.neuroplasticity!(n, kfn.firedNeurons)
end
# check for fliped sign, 1 indicates non-fliped sign
wSign = sign.(n.wRecChange)
nonFlipedSign = isequal.(n.subExInType, wSign) # 1 not fliped, 0 fliped
n.wRecChange .*= nonFlipedSign # set weight that fliped sign to 0 for random new connection
end
""" linear_neuron learn()
"""
function learn!(n::linear_neuron, kfn::knowledgeFn)
n.error = kfn.outputError[n.id]
n.learningStage = kfn.learningStage
function learn!(n::linear_neuron, error::Number)
n.decayedEpsilonRec = n.alpha * n.epsilonRec
n.epsilonRec = n.decayedEpsilonRec + n.z_i_t
n.eRec = n.phi * n.epsilonRec
if n.learningStage == "doing_inference"
# no learning
elseif n.learningStage == "start_learning"
# if error signal available then accumulates Δw
if n.error !== nothing && n.id == 1 # NOT working w/ multithreading training
Δw = -apply!(n.optimiser, n.w_out, (n.error * n.epsilon_j))
n.w_out_change = n.w_out_change + Δw
n.eta = n.optimiser.eta
Δb = -n.eta * n.error
n.b_change = n.b_change + Δb
elseif n.error !== nothing && n.id !== 1
n.eta = kfn.outputNeuronsArray[1].eta
Δw = -n.eta * n.error * n.epsilon_j
n.w_out_change = n.w_out_change + Δw
Δb = -n.eta * n.error
n.b_change = n.b_change + Δb
end
elseif n.learningStage == "during_learning"
# if error signal available then accumulates Δw
if n.error !== nothing && n.id == 1 # NOT working w/ multithreading training
Δw = -apply!(n.optimiser, n.w_out, (n.error * n.epsilon_j))
n.w_out_change = n.w_out_change + Δw
n.eta = n.optimiser.eta
Δb = -n.eta * n.error
n.b_change = n.b_change + Δb
elseif n.error !== nothing && n.id !== 1
n.eta = kfn.outputNeuronsArray[1].eta
Δw = -n.eta * n.error * n.epsilon_j
n.w_out_change = n.w_out_change + Δw
Δb = -n.eta * n.error
n.b_change = n.b_change + Δb
end
elseif n.learningStage == "end_learning"
# if error signal available then accumulates Δw
if n.error !== nothing && n.id == 1 # NOT working w/ multithreading training
Δw = -apply!(n.optimiser, n.w_out, (n.error * n.epsilon_j))
n.w_out_change = n.w_out_change + Δw
n.eta = n.optimiser.eta
Δb = -n.eta * n.error
n.b_change = n.b_change + Δb
elseif n.error !== nothing && n.id !== 1
n.eta = kfn.outputNeuronsArray[1].eta
Δw = -n.eta * n.error * n.epsilon_j
n.w_out_change = n.w_out_change + Δw
Δb = -n.eta * n.error
n.b_change = n.b_change + Δb
end
ΔwRecChange = n.eta * error
n.wRecChange = (n.subExInType * n.wRecChange) + ΔwRecChange
n.w_out = n.w_out + n.w_out_change
n.b = n.b + n.b_change
end
# check for fliped sign, 1 indicates non-fliped sign
wSign = sign.(n.wRecChange)
nonFlipedSign = isequal.(n.subExInType, wSign) # 1 not fliped, 0 fliped
n.wRecChange .*= nonFlipedSign # set weight that fliped sign to 0 for random new connection
end

View File

@@ -3,13 +3,13 @@ module snn_utils
using Flux.Optimise: apply!
export calculate_α, calculate_ρ, calculate_k, timestep_forward!, init_neuron, no_negative!,
precision, calculate_w_change!, store_knowledgefn_error!, interneurons_adjustment!,
reset_z_t!, reset_learning_params!, reset_learning_history_params!,
reset_z_t!, resetLearningParams!, reset_learning_history_params!,
cal_v_reg!, calculate_w_change_end!,
firing_rate_error!, firing_rate_regulator!, update_Bn!, cal_firing_reg!,
neuroplasticity!, shakeup!, reset_learning_no_wchange!, adjust_internal_learning_rate!,
gradient_withloss
using Statistics, Random, LinearAlgebra, Distributions, Zygote
using Statistics, Random, LinearAlgebra, Distributions, Zygote, Flux
using ..types
@@ -98,21 +98,19 @@ reset_b_change!(n::linear_neuron) = n.b_change = n.b_change * 0.0
""" Reset all learning-related params at the END of learning session
"""
function reset_learning_params!(n::lif_neuron)
function resetLearningParams!(n::lif_neuron)
reset_epsilon_rec!(n)
reset_w_rec_change!(n)
# reset_v_t!(n)
# reset_z_t!(n)
reset_firing_counter!(n)
reset_firing_diff!(n)
reset_previous_error!(n)
reset_error!(n)
# reset refractory state at the start/end of episode. Otherwise once neuron goes into
# refractory state, it will stay in refractory state forever
reset_refractoryCounter!(n)
end
function reset_learning_params!(n::alif_neuron)
function resetLearningParams!(n::alif_neuron)
reset_epsilon_rec!(n)
reset_epsilon_rec_a!(n)
reset_w_rec_change!(n)
@@ -121,8 +119,6 @@ function reset_learning_params!(n::alif_neuron)
# reset_a!(n)
reset_firing_counter!(n)
reset_firing_diff!(n)
reset_previous_error!(n)
reset_error!(n)
# reset refractory state at the start/end of episode. Otherwise once neuron goes into
# refractory state, it will stay in refractory state forever
@@ -132,18 +128,15 @@ end
# function reset_learning_no_wchange!(n::passthrough_neuron)
# end
function reset_learning_params!(n::passthrough_neuron)
function resetLearningParams!(n::passthrough_neuron)
# skip
end
#WORKING
function reset_learning_params!(n::linear_neuron)
function resetLearningParams!(n::linear_neuron)
reset_epsilon_rec!(n)
reset_w_rec_change!(n)
reset_v_t!(n)
reset_firing_counter!(n)
reset_firing_diff!(n)
reset_previous_error!(n)
reset_error!(n)
# reset refractory state at the start/end of episode. Otherwise once neuron goes into
# refractory state, it will stay in refractory state forever
@@ -288,14 +281,19 @@ function push_epsilon_rec_a!(n::alif_neuron)
push!(n.epsilonRecA, 0)
end
""" compute synaptic connection strength. bias will shift currentStrength to fit into
sigmoid operating range which centred at 0 and range is -37 to 37.
# Example
synaptic strength range is 0 to 10
one may use bias = -5 to transform synaptic strength into range -5 to 5
the return value is shifted back to original scale
"""
function synapticConnStrength(currentStrength::AbstractFloat, bias::Number=0)
currentStrength += bias
currentStrength - (1.0 - sigmoid(currentStrength))
currentStrength -= bias
return currentStrength
end

View File

@@ -325,7 +325,8 @@ Base.@kwdef mutable struct lif_neuron <: compute_neuron
# during v_t1 calculation hence I need a variable to hold z_t1 so that I'm not replacing z_t
z_t1::Bool = false # neuron postsynaptic firing at current timestep (after neuron's calculation)
z_i_t::Union{Array{Bool},Nothing} = nothing # neuron presynaptic firing at current timestep (which is other neuron postsynaptic firing of previous timestep)
# Bn_wout_decay::Union{Float64,Nothing} = 0.01 # use to balance Bn and w_out
synapticStrength::Union{Array{Float64},Nothing} = nothing
synapticStrengthLimit::Union{NamedTuple,Nothing} = (lowerlimit=(0=>0), upperlimit=(10=>10))
gammaPd::Union{Float64,Nothing} = 0.3 # γ_pd, discount factor, value from paper
alpha::Union{Float64,Nothing} = nothing # α, neuron membrane potential decay factor
@@ -334,7 +335,6 @@ Base.@kwdef mutable struct lif_neuron <: compute_neuron
decayedEpsilonRec::Union{Array{Float64},Nothing} = nothing # α * epsilonRec
eRec::Union{Array{Float64},Nothing} = nothing # eligibility trace for neuron spike
delta::Union{Float64,Nothing} = 1.0 # δ, discreate timestep size in millisecond
lastFiringTime::Union{Float64,Nothing} = 0.0 # the last time neuron fires, use to calculate exponantial decay of v_t1
refractoryDuration::Union{Float64,Nothing} = 3 # neuron's refratory period in millisecond
# refractory_state_active::Union{Bool,Nothing} = false # if true, neuron is in refractory state and cannot process new information
refractoryCounter::Integer = 0
@@ -418,7 +418,8 @@ Base.@kwdef mutable struct alif_neuron <: compute_neuron
# during v_t1 calculation hence I need a variable to hold z_t1 so that I'm not replacing z_t
z_t1::Bool = false # neuron postsynaptic firing at current timestep (after neuron's calculation)
z_i_t::Union{Array{Bool},Nothing} = nothing # neuron presynaptic firing at current timestep (which is other neuron postsynaptic firing of previous timestep)
# Bn_wout_decay::Union{Float64,Nothing} = 0.01 # use to balance Bn and w_out
synapticStrength::Union{Array{Float64},Nothing} = nothing
synapticStrengthLimit::Union{NamedTuple,Nothing} = (lowerlimit=(-5=>0), upperlimit=(5=>5))
alpha::Union{Float64,Nothing} = nothing # α, neuron membrane potential decay factor
delta::Union{Float64,Nothing} = 1.0 # δ, discreate timestep size in millisecond
@@ -430,7 +431,6 @@ Base.@kwdef mutable struct alif_neuron <: compute_neuron
eRec::Union{Array{Float64},Nothing} = nothing # neuron's eligibility trace
eta::Union{Float64,Nothing} = 0.01 # eta, learning rate
gammaPd::Union{Float64,Nothing} = 0.3 # γ_pd, discount factor, value from paper
lastFiringTime::Union{Float64,Nothing} = 0.0 # the last time neuron fires, use to calculate exponantial decay of v_t1
phi::Union{Float64,Nothing} = nothing # ϕ, psuedo derivative
refractoryDuration::Union{Float64,Nothing} = 3 # neuron's refractory period in millisecond
# refractory_state_active::Union{Bool,Nothing} = false # if true, neuron is in refractory state and cannot process new information
@@ -528,6 +528,8 @@ Base.@kwdef mutable struct linear_neuron <: output_neuron
# neuron presynaptic firing at current timestep (which is other neuron postsynaptic firing of
# previous timestep)
z_i_t::Union{Array{Bool},Nothing} = nothing
synapticStrength::Union{Array{Float64},Nothing} = nothing
synapticStrengthLimit::Union{NamedTuple,Nothing} = (lowerlimit=(-5=>0), upperlimit=(5=>5))
gammaPd::Union{Float64,Nothing} = 0.3 # γ_pd, discount factor, value from paper
alpha::Union{Float64,Nothing} = nothing # α, neuron membrane potential decay factor
@@ -536,7 +538,6 @@ Base.@kwdef mutable struct linear_neuron <: output_neuron
decayedEpsilonRec::Union{Array{Float64},Nothing} = nothing # α * epsilonRec
eRec::Union{Array{Float64},Nothing} = nothing # eligibility trace for neuron spike
delta::Union{Float64,Nothing} = 1.0 # δ, discreate timestep size in millisecond
lastFiringTime::Union{Float64,Nothing} = 0.0 # the last time neuron fires, use to calculate exponantial decay of v_t1
refractoryDuration::Union{Float64,Nothing} = 3 # neuron's refratory period in millisecond
refractoryCounter::Integer = 0
tau_out::Union{Float64,Nothing} = nothing # τ_out, membrane time constant in millisecond
@@ -629,11 +630,11 @@ function init_neuron!(id::Int64, n::lif_neuron, n_params::Dict, kfnParams::Dict)
# prevent subscription to itself by removing this neuron id
filter!(x -> x != n.id, n.subscriptionList)
n.synapticStrength = normalize!(rand(length(n.subscriptionList)), 1)
n.epsilonRec = zeros(length(n.subscriptionList))
n.w_rec = Random.rand(length(n.subscriptionList))
n.wRecChange = zeros(length(n.subscriptionList))
# n.reg_voltage_b = zeros(length(n.subscriptionList))
n.alpha = calculate_α(n)
end
@@ -648,6 +649,7 @@ function init_neuron!(id::Int64, n::alif_neuron, n_params::Dict,
# prevent subscription to itself by removing this neuron id
filter!(x -> x != n.id, n.subscriptionList)
n.synapticStrength = normalize!(rand(length(n.subscriptionList)), 1)
n.epsilonRec = zeros(length(n.subscriptionList))
n.w_rec = Random.rand(length(n.subscriptionList))
@@ -660,7 +662,7 @@ function init_neuron!(id::Int64, n::alif_neuron, n_params::Dict,
n.epsilonRecA = zeros(length(n.subscriptionList))
end
#WORKING
function init_neuron!(id::Int64, n::linear_neuron, n_params::Dict, kfnParams::Dict)
n.id = id
n.knowledgeFnName = kfnParams[:knowledgeFnName]
@@ -669,7 +671,7 @@ function init_neuron!(id::Int64, n::linear_neuron, n_params::Dict, kfnParams::Di
subscription_numbers = Int(floor(n_params[:synaptic_connection_number] *
kfnParams[:total_compute_neuron] / 100.0))
n.subscriptionList = [pop!(subscription_options) for i = 1:subscription_numbers]
n.synapticStrength = normalize!(rand(length(n.subscriptionList)), 1)
n.epsilonRec = zeros(length(n.subscriptionList))
n.w_rec = Random.rand(length(n.subscriptionList))
n.wRecChange = zeros(length(n.subscriptionList))