time-based learning method based on new error formula

This commit is contained in:
2023-05-16 20:56:05 +07:00
parent 114161ba69
commit 70d2521c5e
5 changed files with 146 additions and 227 deletions

View File

@@ -34,14 +34,16 @@ using .interface
""" """
Todo: Todo:
[7] time-based learning method based on new error formula [*6] time-based learning method based on new error formula
(use output vt compared to vth instead of late time)
if output neuron not activate when it should, use output neuron's if output neuron not activate when it should, use output neuron's
(vth - vt)*100/vth as error (vth - vt)*100/vth as error
if output neuron activates when it should NOT, use output neuron's if output neuron activates when it should NOT, use output neuron's
(vt*100)/vth as error (vt*100)/vth as error
[8] verify that model can complete learning cycle with no error [7] use LinearAlgebra.normalize!(vector, 1) to adjust weight after weight merge
[5] synaptic connection strength concept. use sigmoid [9] verify that model can complete learning cycle with no error
[6] neuroplasticity() i.e. change connection [*5] synaptic connection strength concept. use sigmoid, turn connection offline
[8] neuroplasticity() i.e. change connection
[] using RL to control learning signal [] using RL to control learning signal
[] consider using Dates.now() instead of timestamp because time_stamp may overflow [] consider using Dates.now() instead of timestamp because time_stamp may overflow
[] training should include adjusting α, neuron membrane potential decay factor [] training should include adjusting α, neuron membrane potential decay factor

View File

@@ -11,7 +11,6 @@ using ..types, ..snn_utils
""" Model forward() """ Model forward()
""" """
function (m::model)(input_data::AbstractVector) function (m::model)(input_data::AbstractVector)
# m.global_tick += 1
m.timeStep += 1 m.timeStep += 1
# process all corresponding KFN # process all corresponding KFN
@@ -31,7 +30,6 @@ function (kfn::kfn_1)(m::model, input_data::AbstractVector)
kfn.timeStep = m.timeStep kfn.timeStep = m.timeStep
kfn.softreset = m.softreset kfn.softreset = m.softreset
kfn.learningStage = m.learningStage kfn.learningStage = m.learningStage
kfn.error = m.error
# generate noise # generate noise
noise = [GeneralUtils.randomChoiceWithProb([true, false],[0.5,0.5]) noise = [GeneralUtils.randomChoiceWithProb([true, false],[0.5,0.5])
@@ -101,8 +99,8 @@ function (n::lif_neuron)(kfn::knowledgeFn)
# last only 1 timestep follow by a period of refractory. # last only 1 timestep follow by a period of refractory.
n.recSignal = n.recSignal * 0.0 n.recSignal = n.recSignal * 0.0
# Exponantial decay of v_t1 # decay of v_t1
n.v_t1 = n.v_t * n.alpha^(n.timeStep - n.lastFiringTime) # or n.v_t1 = n.alpha * n.v_t n.v_t1 = n.alpha * n.v_t
else else
n.recSignal = sum(n.w_rec .* n.z_i_t) # signal from other neuron that this neuron subscribed n.recSignal = sum(n.w_rec .* n.z_i_t) # signal from other neuron that this neuron subscribed
@@ -142,8 +140,8 @@ function (n::alif_neuron)(kfn::knowledgeFn)
n.a = (n.rho * n.a) + ((1 - n.rho) * n.z_t) n.a = (n.rho * n.a) + ((1 - n.rho) * n.z_t)
n.recSignal = n.recSignal * 0.0 n.recSignal = n.recSignal * 0.0
# Exponantial decay of v_t1 # decay of v_t1
n.v_t1 = n.v_t * n.alpha^(n.timeStep - n.lastFiringTime) # or n.v_t1 = n.alpha * n.v_t n.v_t1 = n.alpha * n.v_t
n.phi = 0 n.phi = 0
else else
n.z_t = isnothing(n.z_t) ? false : n.z_t n.z_t = isnothing(n.z_t) ? false : n.z_t
@@ -187,8 +185,8 @@ function (n::linear_neuron)(kfn::T) where T<:knowledgeFn
# last only 1 timestep follow by a period of refractory. # last only 1 timestep follow by a period of refractory.
n.recSignal = n.recSignal * 0.0 n.recSignal = n.recSignal * 0.0
# Exponantial decay of v_t1 # decay of v_t1
n.v_t1 = n.v_t * n.alpha^(n.timeStep - n.lastFiringTime) # or n.v_t1 = n.alpha * n.v_t n.v_t1 = n.alpha * n.v_t
else else
n.recSignal = sum(n.w_rec .* n.z_i_t) # signal from other neuron that this neuron subscribed n.recSignal = sum(n.w_rec .* n.z_i_t) # signal from other neuron that this neuron subscribed

View File

@@ -4,7 +4,7 @@ using Flux.Optimise: apply!
using Statistics, Flux, Random, LinearAlgebra using Statistics, Flux, Random, LinearAlgebra
using GeneralUtils using GeneralUtils
using ..types using ..types, ..snn_utils
export learn! export learn!
@@ -12,6 +12,23 @@ export learn!
function learn!(m::model, modelRespond, correctAnswer=nothing) function learn!(m::model, modelRespond, correctAnswer=nothing)
m.knowledgeFn[:I].learningStage = m.learningStage m.knowledgeFn[:I].learningStage = m.learningStage
# # how many matched respond and correct answer
# matched = sum(isequal(modelRespond, correctAnswer))
if correctAnswer === nothing
correctAnswer_I = zeros(length(modelRespond))
else
correctAnswer_I = correctAnswer # correct answer for kfn I
end
learn!(m.knowledgeFn[:I], correctAnswer_I)
end
""" knowledgeFn learn()
"""
function learn!(kfn::kfn_1, correctAnswer::AbstractVector)
# ΔWeight Conn. Strength # ΔWeight Conn. Strength
# case 1 no no during input signal, no correct answer available, no answer # case 1 no no during input signal, no correct answer available, no answer
# case 2 no - during input signal, no correct answer available, wrong answer # case 2 no - during input signal, no correct answer available, wrong answer
@@ -27,38 +44,57 @@ function learn!(m::model, modelRespond, correctAnswer=nothing)
# success # success
# how many matched respond and correct answer
matched = sum(isequal(modelRespond, correctAnswer))
correctAnswer_I = correctAnswer # correct answer for kfn I
learn!(m.knowledgeFn[:I], correctAnswer_I)
# return model_error
end
""" knowledgeFn learn()
"""
function learn!(kfn::kfn_1, correctAnswer=nothing)
if kfn.learningStage == "start_learning" if kfn.learningStage == "start_learning"
# reset params here instead of at the end_learning so that neuron's parameter data # reset params here instead of at the end_learning so that neuron's parameter data
# don't gets wiped and can be logged for visualization later # don't gets wiped and can be logged for visualization later
for n in kfn.neuronsArray for n in kfn.neuronsArray
# epsilonRec need to be reset because it counting how many each synaptic fires and # epsilonRec need to be reset because it counting how many each synaptic fires and
# use this info to calculate how much synaptic weight should be adjust # use this info to calculate how much synaptic weight should be adjust
reset_learning_params!(n) resetLearningParams!(n)
end end
# clear variables # clear variables
kfn.firedNeurons = Vector{Int64}() kfn.firedNeurons = Vector{Int64}()
kfn.outputs = nothing kfn.firedNeurons_t0 = Vector{Bool}()
kfn.firedNeurons_t1 = Vector{Bool}()
kfn.learningStage = "learning" kfn.learningStage = "learning"
#TODO prepare for end learning
elseif kfn.learningStage == "end_learning" elseif kfn.learningStage == "end_learning"
reset_learning_params!(n) resetLearningParams!(n)
# clear variables
kfn.firedNeurons = Vector{Int64}()
kfn.firedNeurons_t0 = Vector{Bool}()
kfn.firedNeurons_t1 = Vector{Bool}()
kfn.learningStage = "inference" kfn.learningStage = "inference"
end end
# compute kfn error
out = [n.z_t1 for n in kfn.outputNeuronsArray]
for (i, v) in enumerate(out)
if v != correctAnswer[i] # need to adjust weight
kfnError = (kfn.outputNeuronsArray[i].v_th - kfn.outputNeuronsArray[i].v_t) *
100 / kfn.outputNeuronsArray[i].v_th
# Threads.@threads for n in kfn.neuronsArray
for n in kfn.neuronsArray
learn!(n, kfnError)
end
learn!(kfn.outputNeuronsArray[i], kfn)
end
end
#WORKING
# Threads.@threads for n in kfn.neuronsArray # Threads.@threads for n in kfn.neuronsArray
for n in kfn.neuronsArray for n in kfn.neuronsArray
learn!(n, kfn) # Neurons are always learning, besides error from model output learn!(n, kfn) # Neurons are always learning, besides error from model output
@@ -71,7 +107,7 @@ function learn!(kfn::kfn_1, correctAnswer=nothing)
# other output neurons # other output neurons
learn!(n, kfn) learn!(n, kfn)
end end
#TODO: put other KFN to learn here
# for main loop user's display and training's exit condition # for main loop user's display and training's exit condition
avgNeuronsFiringRate = 0.0 avgNeuronsFiringRate = 0.0
@@ -90,6 +126,25 @@ function learn!(kfn::kfn_1, correctAnswer=nothing)
end end
kfn.avgNeurons_v_t1 = avgNeurons_v_t1 / kfn.kfnParams[:compute_neuron_number] kfn.avgNeurons_v_t1 = avgNeurons_v_t1 / kfn.kfnParams[:compute_neuron_number]
end end
# wrap up learning session
if kfn.learningStage == "end_learning"
#TODO neuroplasticity
resetLearningParams!(n)
# clear variables
kfn.firedNeurons = Vector{Int64}()
kfn.firedNeurons_t0 = Vector{Bool}()
kfn.firedNeurons_t1 = Vector{Bool}()
kfn.learningStage = "inference"
end
end end
""" passthrough_neuron learn() """ passthrough_neuron learn()
@@ -100,71 +155,23 @@ end
""" lif learn() """ lif learn()
""" """
function learn!(n::lif_neuron, kfn::knowledgeFn) function learn!(n::lif_neuron, error::Number)
if n.learnable_flag == true n.decayedEpsilonRec = n.alpha * n.epsilonRec
n.epsilonRec = n.decayedEpsilonRec + n.z_i_t
n.eRec = n.phi * n.epsilonRec
n.decayedEpsilonRec = n.alpha * n.epsilonRec ΔwRecChange = n.eta * error
n.epsilonRec = n.decayedEpsilonRec + n.z_i_t n.wRecChange = (n.subExInType * n.wRecChange) + ΔwRecChange
n.eRec = n.phi * n.epsilonRec
end
# a piece of knowledgeFn error that belongs to this neuron # check for fliped sign, 1 indicates non-fliped sign
n.error = isnothing(kfn.error) ? nothing : kfn.error * n.Bn wSign = sign.(n.wRecChange)
n.learningStage = kfn.learningStage nonFlipedSign = isequal.(n.subExInType, wSign) # 1 not fliped, 0 fliped
n.wRecChange .*= nonFlipedSign # set weight that fliped sign to 0 for random new connection
# accumulate voltage regularization terms
Snn_utils.cal_v_reg!(n)
if n.learningStage == "doing_inference"
# no learning
elseif n.learningStage == "start_learning" ||
n.learningStage == "start_learning_no_wchange_reset"
# if error signal available then accumulates Δw
if n.error !== nothing
Snn_utils.firing_rate!(n)
Snn_utils.firing_diff!(n)
n.wRecChange = n.wRecChange +
-apply!(n.optimiser, n.w_rec,
(n.error + Snn_utils.voltage_error!(n) + n.firingRateError) * n.eRec) +
-Snn_utils.firing_rate_regulator!(n) +
-Snn_utils.voltage_regulator!(n)
end
elseif n.learningStage == "during_learning"
# if error signal available then accumulates Δw
if n.error !== nothing
Snn_utils.firing_rate!(n)
Snn_utils.firing_diff!(n)
n.wRecChange = n.wRecChange +
-apply!(n.optimiser, n.w_rec,
(n.error + Snn_utils.voltage_error!(n) + n.firingRateError) * n.eRec) +
-Snn_utils.firing_rate_regulator!(n) +
-Snn_utils.voltage_regulator!(n)
end
elseif n.learningStage == "end_learning"
# if error signal available then accumulates Δw
if n.error !== nothing
Snn_utils.firing_rate!(n)
Snn_utils.firing_diff!(n)
n.wRecChange = n.wRecChange +
-apply!(n.optimiser, n.w_rec,
(n.error + Snn_utils.voltage_error!(n) + n.firingRateError) * n.eRec) +
-Snn_utils.firing_rate_regulator!(n) +
-Snn_utils.voltage_regulator!(n)
end
not_zero = (!iszero).(n.w_rec)
# set 0 in wRecChange update according to 0 in w_rec for hard constrain connection
n.w_rec = n.w_rec + (not_zero .* n.wRecChange)
replace!(x -> x < 0 ? 0 : x, n.w_rec) # no negative weight
Snn_utils.neuroplasticity!(n, kfn.firedNeurons)
end
end end
""" alif_neuron learn() """ alif_neuron learn()
""" """
function learn!(n::alif_neuron, kfn::knowledgeFn) function learn!(n::alif_neuron, error::Number)
n.decayedEpsilonRec = n.alpha * n.epsilonRec n.decayedEpsilonRec = n.alpha * n.epsilonRec
n.epsilonRec = n.decayedEpsilonRec + n.z_i_t n.epsilonRec = n.decayedEpsilonRec + n.z_i_t
n.epsilonRecA = (n.phi * n.epsilonRec) + n.epsilonRecA = (n.phi * n.epsilonRec) +
@@ -173,117 +180,29 @@ function learn!(n::alif_neuron, kfn::knowledgeFn)
n.eRec_a = -n.phi * n.beta * n.epsilonRecA n.eRec_a = -n.phi * n.beta * n.epsilonRecA
n.eRec = n.eRec_v + n.eRec_a n.eRec = n.eRec_v + n.eRec_a
# a piece of knowledgeFn error that belongs to this neuron ΔwRecChange = n.eta * error
n.error = isnothing(kfn.error) ? nothing : kfn.error * n.Bn n.wRecChange = (n.subExInType * n.wRecChange) + ΔwRecChange
n.learningStage = kfn.learningStage
# check for fliped sign, 1 indicates non-fliped sign
wSign = sign.(n.wRecChange)
if n.learningStage == "doing_inference" nonFlipedSign = isequal.(n.subExInType, wSign) # 1 not fliped, 0 fliped
# no learning n.wRecChange .*= nonFlipedSign # set weight that fliped sign to 0 for random new connection
elseif n.learningStage == "start_learning" ||
n.learningStage == "start_learning_no_wchange_reset"
# if error signal available then accumulates Δw
if n.error !== nothing
Snn_utils.firing_rate!(n)
Snn_utils.firing_diff!(n)
n.wRecChange = n.wRecChange +
-apply!(n.optimiser, n.w_rec,
(n.error + Snn_utils.voltage_error!(n) + n.firingRateError) * n.eRec) +
-Snn_utils.firing_rate_regulator!(n) +
-Snn_utils.voltage_regulator!(n)
end
elseif n.learningStage == "during_learning"
# if error signal available then accumulates Δw
if n.error !== nothing
Snn_utils.firing_rate!(n)
Snn_utils.firing_diff!(n)
n.wRecChange = n.wRecChange +
-apply!(n.optimiser, n.w_rec,
(n.error + Snn_utils.voltage_error!(n) + n.firingRateError) * n.eRec) +
-Snn_utils.firing_rate_regulator!(n) +
-Snn_utils.voltage_regulator!(n)
end
elseif n.learningStage == "end_learning"
# if error signal available then accumulates Δw
if n.error !== nothing
Snn_utils.firing_rate!(n)
Snn_utils.firing_diff!(n)
n.wRecChange = n.wRecChange +
-apply!(n.optimiser, n.w_rec,
(n.error + Snn_utils.voltage_error!(n) + n.firingRateError) * n.eRec) +
-Snn_utils.firing_rate_regulator!(n) +
-Snn_utils.voltage_regulator!(n)
end
not_zero = (!iszero).(n.w_rec)
# set 0 in wRecChange update according to 0 in w_rec for hard constrain connection
n.w_rec = n.w_rec + (not_zero .* n.wRecChange)
replace!(x -> x < 0 ? 0 : x, n.w_rec) # no negative weight
Snn_utils.neuroplasticity!(n, kfn.firedNeurons)
end
end end
""" linear_neuron learn() """ linear_neuron learn()
""" """
function learn!(n::linear_neuron, kfn::knowledgeFn) function learn!(n::linear_neuron, error::Number)
n.error = kfn.outputError[n.id] n.decayedEpsilonRec = n.alpha * n.epsilonRec
n.learningStage = kfn.learningStage n.epsilonRec = n.decayedEpsilonRec + n.z_i_t
n.eRec = n.phi * n.epsilonRec
if n.learningStage == "doing_inference" ΔwRecChange = n.eta * error
# no learning n.wRecChange = (n.subExInType * n.wRecChange) + ΔwRecChange
elseif n.learningStage == "start_learning"
# if error signal available then accumulates Δw
if n.error !== nothing && n.id == 1 # NOT working w/ multithreading training
Δw = -apply!(n.optimiser, n.w_out, (n.error * n.epsilon_j))
n.w_out_change = n.w_out_change + Δw
n.eta = n.optimiser.eta
Δb = -n.eta * n.error
n.b_change = n.b_change + Δb
elseif n.error !== nothing && n.id !== 1
n.eta = kfn.outputNeuronsArray[1].eta
Δw = -n.eta * n.error * n.epsilon_j
n.w_out_change = n.w_out_change + Δw
Δb = -n.eta * n.error
n.b_change = n.b_change + Δb
end
elseif n.learningStage == "during_learning"
# if error signal available then accumulates Δw
if n.error !== nothing && n.id == 1 # NOT working w/ multithreading training
Δw = -apply!(n.optimiser, n.w_out, (n.error * n.epsilon_j))
n.w_out_change = n.w_out_change + Δw
n.eta = n.optimiser.eta
Δb = -n.eta * n.error
n.b_change = n.b_change + Δb
elseif n.error !== nothing && n.id !== 1
n.eta = kfn.outputNeuronsArray[1].eta
Δw = -n.eta * n.error * n.epsilon_j
n.w_out_change = n.w_out_change + Δw
Δb = -n.eta * n.error
n.b_change = n.b_change + Δb
end
elseif n.learningStage == "end_learning"
# if error signal available then accumulates Δw
if n.error !== nothing && n.id == 1 # NOT working w/ multithreading training
Δw = -apply!(n.optimiser, n.w_out, (n.error * n.epsilon_j))
n.w_out_change = n.w_out_change + Δw
n.eta = n.optimiser.eta
Δb = -n.eta * n.error
n.b_change = n.b_change + Δb
elseif n.error !== nothing && n.id !== 1
n.eta = kfn.outputNeuronsArray[1].eta
Δw = -n.eta * n.error * n.epsilon_j
n.w_out_change = n.w_out_change + Δw
Δb = -n.eta * n.error
n.b_change = n.b_change + Δb
end
n.w_out = n.w_out + n.w_out_change # check for fliped sign, 1 indicates non-fliped sign
n.b = n.b + n.b_change wSign = sign.(n.wRecChange)
end nonFlipedSign = isequal.(n.subExInType, wSign) # 1 not fliped, 0 fliped
n.wRecChange .*= nonFlipedSign # set weight that fliped sign to 0 for random new connection
end end

View File

@@ -3,13 +3,13 @@ module snn_utils
using Flux.Optimise: apply! using Flux.Optimise: apply!
export calculate_α, calculate_ρ, calculate_k, timestep_forward!, init_neuron, no_negative!, export calculate_α, calculate_ρ, calculate_k, timestep_forward!, init_neuron, no_negative!,
precision, calculate_w_change!, store_knowledgefn_error!, interneurons_adjustment!, precision, calculate_w_change!, store_knowledgefn_error!, interneurons_adjustment!,
reset_z_t!, reset_learning_params!, reset_learning_history_params!, reset_z_t!, resetLearningParams!, reset_learning_history_params!,
cal_v_reg!, calculate_w_change_end!, cal_v_reg!, calculate_w_change_end!,
firing_rate_error!, firing_rate_regulator!, update_Bn!, cal_firing_reg!, firing_rate_error!, firing_rate_regulator!, update_Bn!, cal_firing_reg!,
neuroplasticity!, shakeup!, reset_learning_no_wchange!, adjust_internal_learning_rate!, neuroplasticity!, shakeup!, reset_learning_no_wchange!, adjust_internal_learning_rate!,
gradient_withloss gradient_withloss
using Statistics, Random, LinearAlgebra, Distributions, Zygote using Statistics, Random, LinearAlgebra, Distributions, Zygote, Flux
using ..types using ..types
@@ -98,21 +98,19 @@ reset_b_change!(n::linear_neuron) = n.b_change = n.b_change * 0.0
""" Reset all learning-related params at the END of learning session """ Reset all learning-related params at the END of learning session
""" """
function reset_learning_params!(n::lif_neuron) function resetLearningParams!(n::lif_neuron)
reset_epsilon_rec!(n) reset_epsilon_rec!(n)
reset_w_rec_change!(n) reset_w_rec_change!(n)
# reset_v_t!(n) # reset_v_t!(n)
# reset_z_t!(n) # reset_z_t!(n)
reset_firing_counter!(n) reset_firing_counter!(n)
reset_firing_diff!(n) reset_firing_diff!(n)
reset_previous_error!(n)
reset_error!(n)
# reset refractory state at the start/end of episode. Otherwise once neuron goes into # reset refractory state at the start/end of episode. Otherwise once neuron goes into
# refractory state, it will stay in refractory state forever # refractory state, it will stay in refractory state forever
reset_refractoryCounter!(n) reset_refractoryCounter!(n)
end end
function reset_learning_params!(n::alif_neuron) function resetLearningParams!(n::alif_neuron)
reset_epsilon_rec!(n) reset_epsilon_rec!(n)
reset_epsilon_rec_a!(n) reset_epsilon_rec_a!(n)
reset_w_rec_change!(n) reset_w_rec_change!(n)
@@ -121,8 +119,6 @@ function reset_learning_params!(n::alif_neuron)
# reset_a!(n) # reset_a!(n)
reset_firing_counter!(n) reset_firing_counter!(n)
reset_firing_diff!(n) reset_firing_diff!(n)
reset_previous_error!(n)
reset_error!(n)
# reset refractory state at the start/end of episode. Otherwise once neuron goes into # reset refractory state at the start/end of episode. Otherwise once neuron goes into
# refractory state, it will stay in refractory state forever # refractory state, it will stay in refractory state forever
@@ -132,18 +128,15 @@ end
# function reset_learning_no_wchange!(n::passthrough_neuron) # function reset_learning_no_wchange!(n::passthrough_neuron)
# end # end
function reset_learning_params!(n::passthrough_neuron) function resetLearningParams!(n::passthrough_neuron)
# skip # skip
end end
#WORKING
function reset_learning_params!(n::linear_neuron) function resetLearningParams!(n::linear_neuron)
reset_epsilon_rec!(n) reset_epsilon_rec!(n)
reset_w_rec_change!(n) reset_w_rec_change!(n)
reset_v_t!(n) reset_v_t!(n)
reset_firing_counter!(n) reset_firing_counter!(n)
reset_firing_diff!(n)
reset_previous_error!(n)
reset_error!(n)
# reset refractory state at the start/end of episode. Otherwise once neuron goes into # reset refractory state at the start/end of episode. Otherwise once neuron goes into
# refractory state, it will stay in refractory state forever # refractory state, it will stay in refractory state forever
@@ -288,14 +281,19 @@ function push_epsilon_rec_a!(n::alif_neuron)
push!(n.epsilonRecA, 0) push!(n.epsilonRecA, 0)
end end
""" compute synaptic connection strength. bias will shift currentStrength to fit into
sigmoid operating range which centred at 0 and range is -37 to 37.
# Example
synaptic strength range is 0 to 10
one may use bias = -5 to transform synaptic strength into range -5 to 5
the return value is shifted back to original scale
"""
function synapticConnStrength(currentStrength::AbstractFloat, bias::Number=0)
currentStrength += bias
currentStrength - (1.0 - sigmoid(currentStrength))
currentStrength -= bias
return currentStrength
end

View File

@@ -325,7 +325,8 @@ Base.@kwdef mutable struct lif_neuron <: compute_neuron
# during v_t1 calculation hence I need a variable to hold z_t1 so that I'm not replacing z_t # during v_t1 calculation hence I need a variable to hold z_t1 so that I'm not replacing z_t
z_t1::Bool = false # neuron postsynaptic firing at current timestep (after neuron's calculation) z_t1::Bool = false # neuron postsynaptic firing at current timestep (after neuron's calculation)
z_i_t::Union{Array{Bool},Nothing} = nothing # neuron presynaptic firing at current timestep (which is other neuron postsynaptic firing of previous timestep) z_i_t::Union{Array{Bool},Nothing} = nothing # neuron presynaptic firing at current timestep (which is other neuron postsynaptic firing of previous timestep)
# Bn_wout_decay::Union{Float64,Nothing} = 0.01 # use to balance Bn and w_out synapticStrength::Union{Array{Float64},Nothing} = nothing
synapticStrengthLimit::Union{NamedTuple,Nothing} = (lowerlimit=(0=>0), upperlimit=(10=>10))
gammaPd::Union{Float64,Nothing} = 0.3 # γ_pd, discount factor, value from paper gammaPd::Union{Float64,Nothing} = 0.3 # γ_pd, discount factor, value from paper
alpha::Union{Float64,Nothing} = nothing # α, neuron membrane potential decay factor alpha::Union{Float64,Nothing} = nothing # α, neuron membrane potential decay factor
@@ -334,7 +335,6 @@ Base.@kwdef mutable struct lif_neuron <: compute_neuron
decayedEpsilonRec::Union{Array{Float64},Nothing} = nothing # α * epsilonRec decayedEpsilonRec::Union{Array{Float64},Nothing} = nothing # α * epsilonRec
eRec::Union{Array{Float64},Nothing} = nothing # eligibility trace for neuron spike eRec::Union{Array{Float64},Nothing} = nothing # eligibility trace for neuron spike
delta::Union{Float64,Nothing} = 1.0 # δ, discreate timestep size in millisecond delta::Union{Float64,Nothing} = 1.0 # δ, discreate timestep size in millisecond
lastFiringTime::Union{Float64,Nothing} = 0.0 # the last time neuron fires, use to calculate exponantial decay of v_t1
refractoryDuration::Union{Float64,Nothing} = 3 # neuron's refratory period in millisecond refractoryDuration::Union{Float64,Nothing} = 3 # neuron's refratory period in millisecond
# refractory_state_active::Union{Bool,Nothing} = false # if true, neuron is in refractory state and cannot process new information # refractory_state_active::Union{Bool,Nothing} = false # if true, neuron is in refractory state and cannot process new information
refractoryCounter::Integer = 0 refractoryCounter::Integer = 0
@@ -418,7 +418,8 @@ Base.@kwdef mutable struct alif_neuron <: compute_neuron
# during v_t1 calculation hence I need a variable to hold z_t1 so that I'm not replacing z_t # during v_t1 calculation hence I need a variable to hold z_t1 so that I'm not replacing z_t
z_t1::Bool = false # neuron postsynaptic firing at current timestep (after neuron's calculation) z_t1::Bool = false # neuron postsynaptic firing at current timestep (after neuron's calculation)
z_i_t::Union{Array{Bool},Nothing} = nothing # neuron presynaptic firing at current timestep (which is other neuron postsynaptic firing of previous timestep) z_i_t::Union{Array{Bool},Nothing} = nothing # neuron presynaptic firing at current timestep (which is other neuron postsynaptic firing of previous timestep)
# Bn_wout_decay::Union{Float64,Nothing} = 0.01 # use to balance Bn and w_out synapticStrength::Union{Array{Float64},Nothing} = nothing
synapticStrengthLimit::Union{NamedTuple,Nothing} = (lowerlimit=(-5=>0), upperlimit=(5=>5))
alpha::Union{Float64,Nothing} = nothing # α, neuron membrane potential decay factor alpha::Union{Float64,Nothing} = nothing # α, neuron membrane potential decay factor
delta::Union{Float64,Nothing} = 1.0 # δ, discreate timestep size in millisecond delta::Union{Float64,Nothing} = 1.0 # δ, discreate timestep size in millisecond
@@ -430,7 +431,6 @@ Base.@kwdef mutable struct alif_neuron <: compute_neuron
eRec::Union{Array{Float64},Nothing} = nothing # neuron's eligibility trace eRec::Union{Array{Float64},Nothing} = nothing # neuron's eligibility trace
eta::Union{Float64,Nothing} = 0.01 # eta, learning rate eta::Union{Float64,Nothing} = 0.01 # eta, learning rate
gammaPd::Union{Float64,Nothing} = 0.3 # γ_pd, discount factor, value from paper gammaPd::Union{Float64,Nothing} = 0.3 # γ_pd, discount factor, value from paper
lastFiringTime::Union{Float64,Nothing} = 0.0 # the last time neuron fires, use to calculate exponantial decay of v_t1
phi::Union{Float64,Nothing} = nothing # ϕ, psuedo derivative phi::Union{Float64,Nothing} = nothing # ϕ, psuedo derivative
refractoryDuration::Union{Float64,Nothing} = 3 # neuron's refractory period in millisecond refractoryDuration::Union{Float64,Nothing} = 3 # neuron's refractory period in millisecond
# refractory_state_active::Union{Bool,Nothing} = false # if true, neuron is in refractory state and cannot process new information # refractory_state_active::Union{Bool,Nothing} = false # if true, neuron is in refractory state and cannot process new information
@@ -528,6 +528,8 @@ Base.@kwdef mutable struct linear_neuron <: output_neuron
# neuron presynaptic firing at current timestep (which is other neuron postsynaptic firing of # neuron presynaptic firing at current timestep (which is other neuron postsynaptic firing of
# previous timestep) # previous timestep)
z_i_t::Union{Array{Bool},Nothing} = nothing z_i_t::Union{Array{Bool},Nothing} = nothing
synapticStrength::Union{Array{Float64},Nothing} = nothing
synapticStrengthLimit::Union{NamedTuple,Nothing} = (lowerlimit=(-5=>0), upperlimit=(5=>5))
gammaPd::Union{Float64,Nothing} = 0.3 # γ_pd, discount factor, value from paper gammaPd::Union{Float64,Nothing} = 0.3 # γ_pd, discount factor, value from paper
alpha::Union{Float64,Nothing} = nothing # α, neuron membrane potential decay factor alpha::Union{Float64,Nothing} = nothing # α, neuron membrane potential decay factor
@@ -536,7 +538,6 @@ Base.@kwdef mutable struct linear_neuron <: output_neuron
decayedEpsilonRec::Union{Array{Float64},Nothing} = nothing # α * epsilonRec decayedEpsilonRec::Union{Array{Float64},Nothing} = nothing # α * epsilonRec
eRec::Union{Array{Float64},Nothing} = nothing # eligibility trace for neuron spike eRec::Union{Array{Float64},Nothing} = nothing # eligibility trace for neuron spike
delta::Union{Float64,Nothing} = 1.0 # δ, discreate timestep size in millisecond delta::Union{Float64,Nothing} = 1.0 # δ, discreate timestep size in millisecond
lastFiringTime::Union{Float64,Nothing} = 0.0 # the last time neuron fires, use to calculate exponantial decay of v_t1
refractoryDuration::Union{Float64,Nothing} = 3 # neuron's refratory period in millisecond refractoryDuration::Union{Float64,Nothing} = 3 # neuron's refratory period in millisecond
refractoryCounter::Integer = 0 refractoryCounter::Integer = 0
tau_out::Union{Float64,Nothing} = nothing # τ_out, membrane time constant in millisecond tau_out::Union{Float64,Nothing} = nothing # τ_out, membrane time constant in millisecond
@@ -629,11 +630,11 @@ function init_neuron!(id::Int64, n::lif_neuron, n_params::Dict, kfnParams::Dict)
# prevent subscription to itself by removing this neuron id # prevent subscription to itself by removing this neuron id
filter!(x -> x != n.id, n.subscriptionList) filter!(x -> x != n.id, n.subscriptionList)
n.synapticStrength = normalize!(rand(length(n.subscriptionList)), 1)
n.epsilonRec = zeros(length(n.subscriptionList)) n.epsilonRec = zeros(length(n.subscriptionList))
n.w_rec = Random.rand(length(n.subscriptionList)) n.w_rec = Random.rand(length(n.subscriptionList))
n.wRecChange = zeros(length(n.subscriptionList)) n.wRecChange = zeros(length(n.subscriptionList))
# n.reg_voltage_b = zeros(length(n.subscriptionList))
n.alpha = calculate_α(n) n.alpha = calculate_α(n)
end end
@@ -648,6 +649,7 @@ function init_neuron!(id::Int64, n::alif_neuron, n_params::Dict,
# prevent subscription to itself by removing this neuron id # prevent subscription to itself by removing this neuron id
filter!(x -> x != n.id, n.subscriptionList) filter!(x -> x != n.id, n.subscriptionList)
n.synapticStrength = normalize!(rand(length(n.subscriptionList)), 1)
n.epsilonRec = zeros(length(n.subscriptionList)) n.epsilonRec = zeros(length(n.subscriptionList))
n.w_rec = Random.rand(length(n.subscriptionList)) n.w_rec = Random.rand(length(n.subscriptionList))
@@ -660,7 +662,7 @@ function init_neuron!(id::Int64, n::alif_neuron, n_params::Dict,
n.epsilonRecA = zeros(length(n.subscriptionList)) n.epsilonRecA = zeros(length(n.subscriptionList))
end end
#WORKING
function init_neuron!(id::Int64, n::linear_neuron, n_params::Dict, kfnParams::Dict) function init_neuron!(id::Int64, n::linear_neuron, n_params::Dict, kfnParams::Dict)
n.id = id n.id = id
n.knowledgeFnName = kfnParams[:knowledgeFnName] n.knowledgeFnName = kfnParams[:knowledgeFnName]
@@ -669,7 +671,7 @@ function init_neuron!(id::Int64, n::linear_neuron, n_params::Dict, kfnParams::Di
subscription_numbers = Int(floor(n_params[:synaptic_connection_number] * subscription_numbers = Int(floor(n_params[:synaptic_connection_number] *
kfnParams[:total_compute_neuron] / 100.0)) kfnParams[:total_compute_neuron] / 100.0))
n.subscriptionList = [pop!(subscription_options) for i = 1:subscription_numbers] n.subscriptionList = [pop!(subscription_options) for i = 1:subscription_numbers]
n.synapticStrength = normalize!(rand(length(n.subscriptionList)), 1)
n.epsilonRec = zeros(length(n.subscriptionList)) n.epsilonRec = zeros(length(n.subscriptionList))
n.w_rec = Random.rand(length(n.subscriptionList)) n.w_rec = Random.rand(length(n.subscriptionList))
n.wRecChange = zeros(length(n.subscriptionList)) n.wRecChange = zeros(length(n.subscriptionList))