time-based learning method based on new error formula
This commit is contained in:
@@ -34,14 +34,16 @@ using .interface
|
||||
|
||||
"""
|
||||
Todo:
|
||||
[7] time-based learning method based on new error formula
|
||||
[*6] time-based learning method based on new error formula
|
||||
(use output vt compared to vth instead of late time)
|
||||
if output neuron not activate when it should, use output neuron's
|
||||
(vth - vt)*100/vth as error
|
||||
if output neuron activates when it should NOT, use output neuron's
|
||||
(vt*100)/vth as error
|
||||
[8] verify that model can complete learning cycle with no error
|
||||
[5] synaptic connection strength concept. use sigmoid
|
||||
[6] neuroplasticity() i.e. change connection
|
||||
[7] use LinearAlgebra.normalize!(vector, 1) to adjust weight after weight merge
|
||||
[9] verify that model can complete learning cycle with no error
|
||||
[*5] synaptic connection strength concept. use sigmoid, turn connection offline
|
||||
[8] neuroplasticity() i.e. change connection
|
||||
[] using RL to control learning signal
|
||||
[] consider using Dates.now() instead of timestamp because time_stamp may overflow
|
||||
[] training should include adjusting α, neuron membrane potential decay factor
|
||||
|
||||
@@ -11,7 +11,6 @@ using ..types, ..snn_utils
|
||||
""" Model forward()
|
||||
"""
|
||||
function (m::model)(input_data::AbstractVector)
|
||||
# m.global_tick += 1
|
||||
m.timeStep += 1
|
||||
|
||||
# process all corresponding KFN
|
||||
@@ -31,7 +30,6 @@ function (kfn::kfn_1)(m::model, input_data::AbstractVector)
|
||||
kfn.timeStep = m.timeStep
|
||||
kfn.softreset = m.softreset
|
||||
kfn.learningStage = m.learningStage
|
||||
kfn.error = m.error
|
||||
|
||||
# generate noise
|
||||
noise = [GeneralUtils.randomChoiceWithProb([true, false],[0.5,0.5])
|
||||
@@ -101,8 +99,8 @@ function (n::lif_neuron)(kfn::knowledgeFn)
|
||||
# last only 1 timestep follow by a period of refractory.
|
||||
n.recSignal = n.recSignal * 0.0
|
||||
|
||||
# Exponantial decay of v_t1
|
||||
n.v_t1 = n.v_t * n.alpha^(n.timeStep - n.lastFiringTime) # or n.v_t1 = n.alpha * n.v_t
|
||||
# decay of v_t1
|
||||
n.v_t1 = n.alpha * n.v_t
|
||||
else
|
||||
n.recSignal = sum(n.w_rec .* n.z_i_t) # signal from other neuron that this neuron subscribed
|
||||
|
||||
@@ -142,8 +140,8 @@ function (n::alif_neuron)(kfn::knowledgeFn)
|
||||
n.a = (n.rho * n.a) + ((1 - n.rho) * n.z_t)
|
||||
n.recSignal = n.recSignal * 0.0
|
||||
|
||||
# Exponantial decay of v_t1
|
||||
n.v_t1 = n.v_t * n.alpha^(n.timeStep - n.lastFiringTime) # or n.v_t1 = n.alpha * n.v_t
|
||||
# decay of v_t1
|
||||
n.v_t1 = n.alpha * n.v_t
|
||||
n.phi = 0
|
||||
else
|
||||
n.z_t = isnothing(n.z_t) ? false : n.z_t
|
||||
@@ -187,8 +185,8 @@ function (n::linear_neuron)(kfn::T) where T<:knowledgeFn
|
||||
# last only 1 timestep follow by a period of refractory.
|
||||
n.recSignal = n.recSignal * 0.0
|
||||
|
||||
# Exponantial decay of v_t1
|
||||
n.v_t1 = n.v_t * n.alpha^(n.timeStep - n.lastFiringTime) # or n.v_t1 = n.alpha * n.v_t
|
||||
# decay of v_t1
|
||||
n.v_t1 = n.alpha * n.v_t
|
||||
else
|
||||
n.recSignal = sum(n.w_rec .* n.z_i_t) # signal from other neuron that this neuron subscribed
|
||||
|
||||
|
||||
281
src/learn.jl
281
src/learn.jl
@@ -4,7 +4,7 @@ using Flux.Optimise: apply!
|
||||
|
||||
using Statistics, Flux, Random, LinearAlgebra
|
||||
using GeneralUtils
|
||||
using ..types
|
||||
using ..types, ..snn_utils
|
||||
|
||||
export learn!
|
||||
|
||||
@@ -12,6 +12,23 @@ export learn!
|
||||
|
||||
function learn!(m::model, modelRespond, correctAnswer=nothing)
|
||||
m.knowledgeFn[:I].learningStage = m.learningStage
|
||||
|
||||
# # how many matched respond and correct answer
|
||||
# matched = sum(isequal(modelRespond, correctAnswer))
|
||||
|
||||
if correctAnswer === nothing
|
||||
correctAnswer_I = zeros(length(modelRespond))
|
||||
else
|
||||
correctAnswer_I = correctAnswer # correct answer for kfn I
|
||||
end
|
||||
|
||||
learn!(m.knowledgeFn[:I], correctAnswer_I)
|
||||
end
|
||||
|
||||
""" knowledgeFn learn()
|
||||
"""
|
||||
function learn!(kfn::kfn_1, correctAnswer::AbstractVector)
|
||||
|
||||
# ΔWeight Conn. Strength
|
||||
# case 1 no no during input signal, no correct answer available, no answer
|
||||
# case 2 no - during input signal, no correct answer available, wrong answer
|
||||
@@ -27,38 +44,57 @@ function learn!(m::model, modelRespond, correctAnswer=nothing)
|
||||
|
||||
# success
|
||||
|
||||
# how many matched respond and correct answer
|
||||
matched = sum(isequal(modelRespond, correctAnswer))
|
||||
|
||||
correctAnswer_I = correctAnswer # correct answer for kfn I
|
||||
learn!(m.knowledgeFn[:I], correctAnswer_I)
|
||||
|
||||
# return model_error
|
||||
end
|
||||
|
||||
|
||||
""" knowledgeFn learn()
|
||||
"""
|
||||
function learn!(kfn::kfn_1, correctAnswer=nothing)
|
||||
if kfn.learningStage == "start_learning"
|
||||
# reset params here instead of at the end_learning so that neuron's parameter data
|
||||
# don't gets wiped and can be logged for visualization later
|
||||
for n in kfn.neuronsArray
|
||||
# epsilonRec need to be reset because it counting how many each synaptic fires and
|
||||
# use this info to calculate how much synaptic weight should be adjust
|
||||
reset_learning_params!(n)
|
||||
resetLearningParams!(n)
|
||||
end
|
||||
|
||||
# clear variables
|
||||
kfn.firedNeurons = Vector{Int64}()
|
||||
kfn.outputs = nothing
|
||||
kfn.firedNeurons_t0 = Vector{Bool}()
|
||||
kfn.firedNeurons_t1 = Vector{Bool}()
|
||||
|
||||
kfn.learningStage = "learning"
|
||||
|
||||
#TODO prepare for end learning
|
||||
elseif kfn.learningStage == "end_learning"
|
||||
reset_learning_params!(n)
|
||||
resetLearningParams!(n)
|
||||
|
||||
# clear variables
|
||||
kfn.firedNeurons = Vector{Int64}()
|
||||
kfn.firedNeurons_t0 = Vector{Bool}()
|
||||
kfn.firedNeurons_t1 = Vector{Bool}()
|
||||
|
||||
kfn.learningStage = "inference"
|
||||
end
|
||||
|
||||
# compute kfn error
|
||||
out = [n.z_t1 for n in kfn.outputNeuronsArray]
|
||||
for (i, v) in enumerate(out)
|
||||
if v != correctAnswer[i] # need to adjust weight
|
||||
kfnError = (kfn.outputNeuronsArray[i].v_th - kfn.outputNeuronsArray[i].v_t) *
|
||||
100 / kfn.outputNeuronsArray[i].v_th
|
||||
|
||||
# Threads.@threads for n in kfn.neuronsArray
|
||||
for n in kfn.neuronsArray
|
||||
learn!(n, kfnError)
|
||||
end
|
||||
|
||||
learn!(kfn.outputNeuronsArray[i], kfn)
|
||||
end
|
||||
end
|
||||
#WORKING
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# Threads.@threads for n in kfn.neuronsArray
|
||||
for n in kfn.neuronsArray
|
||||
learn!(n, kfn) # Neurons are always learning, besides error from model output
|
||||
@@ -71,7 +107,7 @@ function learn!(kfn::kfn_1, correctAnswer=nothing)
|
||||
# other output neurons
|
||||
learn!(n, kfn)
|
||||
end
|
||||
#TODO: put other KFN to learn here
|
||||
|
||||
|
||||
# for main loop user's display and training's exit condition
|
||||
avgNeuronsFiringRate = 0.0
|
||||
@@ -90,6 +126,25 @@ function learn!(kfn::kfn_1, correctAnswer=nothing)
|
||||
end
|
||||
kfn.avgNeurons_v_t1 = avgNeurons_v_t1 / kfn.kfnParams[:compute_neuron_number]
|
||||
end
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# wrap up learning session
|
||||
if kfn.learningStage == "end_learning"
|
||||
#TODO neuroplasticity
|
||||
resetLearningParams!(n)
|
||||
|
||||
# clear variables
|
||||
kfn.firedNeurons = Vector{Int64}()
|
||||
kfn.firedNeurons_t0 = Vector{Bool}()
|
||||
kfn.firedNeurons_t1 = Vector{Bool}()
|
||||
|
||||
kfn.learningStage = "inference"
|
||||
end
|
||||
end
|
||||
|
||||
""" passthrough_neuron learn()
|
||||
@@ -100,71 +155,23 @@ end
|
||||
|
||||
""" lif learn()
|
||||
"""
|
||||
function learn!(n::lif_neuron, kfn::knowledgeFn)
|
||||
if n.learnable_flag == true
|
||||
function learn!(n::lif_neuron, error::Number)
|
||||
n.decayedEpsilonRec = n.alpha * n.epsilonRec
|
||||
n.epsilonRec = n.decayedEpsilonRec + n.z_i_t
|
||||
n.eRec = n.phi * n.epsilonRec
|
||||
|
||||
n.decayedEpsilonRec = n.alpha * n.epsilonRec
|
||||
n.epsilonRec = n.decayedEpsilonRec + n.z_i_t
|
||||
n.eRec = n.phi * n.epsilonRec
|
||||
end
|
||||
ΔwRecChange = n.eta * error
|
||||
n.wRecChange = (n.subExInType * n.wRecChange) + ΔwRecChange
|
||||
|
||||
# a piece of knowledgeFn error that belongs to this neuron
|
||||
n.error = isnothing(kfn.error) ? nothing : kfn.error * n.Bn
|
||||
n.learningStage = kfn.learningStage
|
||||
|
||||
# accumulate voltage regularization terms
|
||||
Snn_utils.cal_v_reg!(n)
|
||||
|
||||
if n.learningStage == "doing_inference"
|
||||
# no learning
|
||||
elseif n.learningStage == "start_learning" ||
|
||||
n.learningStage == "start_learning_no_wchange_reset"
|
||||
|
||||
# if error signal available then accumulates Δw
|
||||
if n.error !== nothing
|
||||
Snn_utils.firing_rate!(n)
|
||||
Snn_utils.firing_diff!(n)
|
||||
n.wRecChange = n.wRecChange +
|
||||
-apply!(n.optimiser, n.w_rec,
|
||||
(n.error + Snn_utils.voltage_error!(n) + n.firingRateError) * n.eRec) +
|
||||
-Snn_utils.firing_rate_regulator!(n) +
|
||||
-Snn_utils.voltage_regulator!(n)
|
||||
end
|
||||
elseif n.learningStage == "during_learning"
|
||||
# if error signal available then accumulates Δw
|
||||
if n.error !== nothing
|
||||
Snn_utils.firing_rate!(n)
|
||||
Snn_utils.firing_diff!(n)
|
||||
n.wRecChange = n.wRecChange +
|
||||
-apply!(n.optimiser, n.w_rec,
|
||||
(n.error + Snn_utils.voltage_error!(n) + n.firingRateError) * n.eRec) +
|
||||
-Snn_utils.firing_rate_regulator!(n) +
|
||||
-Snn_utils.voltage_regulator!(n)
|
||||
end
|
||||
elseif n.learningStage == "end_learning"
|
||||
# if error signal available then accumulates Δw
|
||||
if n.error !== nothing
|
||||
Snn_utils.firing_rate!(n)
|
||||
Snn_utils.firing_diff!(n)
|
||||
n.wRecChange = n.wRecChange +
|
||||
-apply!(n.optimiser, n.w_rec,
|
||||
(n.error + Snn_utils.voltage_error!(n) + n.firingRateError) * n.eRec) +
|
||||
-Snn_utils.firing_rate_regulator!(n) +
|
||||
-Snn_utils.voltage_regulator!(n)
|
||||
end
|
||||
|
||||
not_zero = (!iszero).(n.w_rec)
|
||||
# set 0 in wRecChange update according to 0 in w_rec for hard constrain connection
|
||||
n.w_rec = n.w_rec + (not_zero .* n.wRecChange)
|
||||
replace!(x -> x < 0 ? 0 : x, n.w_rec) # no negative weight
|
||||
|
||||
Snn_utils.neuroplasticity!(n, kfn.firedNeurons)
|
||||
end
|
||||
# check for fliped sign, 1 indicates non-fliped sign
|
||||
wSign = sign.(n.wRecChange)
|
||||
nonFlipedSign = isequal.(n.subExInType, wSign) # 1 not fliped, 0 fliped
|
||||
n.wRecChange .*= nonFlipedSign # set weight that fliped sign to 0 for random new connection
|
||||
end
|
||||
|
||||
""" alif_neuron learn()
|
||||
"""
|
||||
function learn!(n::alif_neuron, kfn::knowledgeFn)
|
||||
function learn!(n::alif_neuron, error::Number)
|
||||
n.decayedEpsilonRec = n.alpha * n.epsilonRec
|
||||
n.epsilonRec = n.decayedEpsilonRec + n.z_i_t
|
||||
n.epsilonRecA = (n.phi * n.epsilonRec) +
|
||||
@@ -173,117 +180,29 @@ function learn!(n::alif_neuron, kfn::knowledgeFn)
|
||||
n.eRec_a = -n.phi * n.beta * n.epsilonRecA
|
||||
n.eRec = n.eRec_v + n.eRec_a
|
||||
|
||||
# a piece of knowledgeFn error that belongs to this neuron
|
||||
n.error = isnothing(kfn.error) ? nothing : kfn.error * n.Bn
|
||||
n.learningStage = kfn.learningStage
|
||||
ΔwRecChange = n.eta * error
|
||||
n.wRecChange = (n.subExInType * n.wRecChange) + ΔwRecChange
|
||||
|
||||
|
||||
|
||||
if n.learningStage == "doing_inference"
|
||||
# no learning
|
||||
elseif n.learningStage == "start_learning" ||
|
||||
n.learningStage == "start_learning_no_wchange_reset"
|
||||
|
||||
# if error signal available then accumulates Δw
|
||||
if n.error !== nothing
|
||||
Snn_utils.firing_rate!(n)
|
||||
Snn_utils.firing_diff!(n)
|
||||
n.wRecChange = n.wRecChange +
|
||||
-apply!(n.optimiser, n.w_rec,
|
||||
(n.error + Snn_utils.voltage_error!(n) + n.firingRateError) * n.eRec) +
|
||||
-Snn_utils.firing_rate_regulator!(n) +
|
||||
-Snn_utils.voltage_regulator!(n)
|
||||
end
|
||||
elseif n.learningStage == "during_learning"
|
||||
# if error signal available then accumulates Δw
|
||||
if n.error !== nothing
|
||||
Snn_utils.firing_rate!(n)
|
||||
Snn_utils.firing_diff!(n)
|
||||
n.wRecChange = n.wRecChange +
|
||||
-apply!(n.optimiser, n.w_rec,
|
||||
(n.error + Snn_utils.voltage_error!(n) + n.firingRateError) * n.eRec) +
|
||||
-Snn_utils.firing_rate_regulator!(n) +
|
||||
-Snn_utils.voltage_regulator!(n)
|
||||
end
|
||||
elseif n.learningStage == "end_learning"
|
||||
|
||||
# if error signal available then accumulates Δw
|
||||
if n.error !== nothing
|
||||
Snn_utils.firing_rate!(n)
|
||||
Snn_utils.firing_diff!(n)
|
||||
n.wRecChange = n.wRecChange +
|
||||
-apply!(n.optimiser, n.w_rec,
|
||||
(n.error + Snn_utils.voltage_error!(n) + n.firingRateError) * n.eRec) +
|
||||
-Snn_utils.firing_rate_regulator!(n) +
|
||||
-Snn_utils.voltage_regulator!(n)
|
||||
end
|
||||
|
||||
not_zero = (!iszero).(n.w_rec)
|
||||
# set 0 in wRecChange update according to 0 in w_rec for hard constrain connection
|
||||
n.w_rec = n.w_rec + (not_zero .* n.wRecChange)
|
||||
replace!(x -> x < 0 ? 0 : x, n.w_rec) # no negative weight
|
||||
|
||||
Snn_utils.neuroplasticity!(n, kfn.firedNeurons)
|
||||
end
|
||||
# check for fliped sign, 1 indicates non-fliped sign
|
||||
wSign = sign.(n.wRecChange)
|
||||
nonFlipedSign = isequal.(n.subExInType, wSign) # 1 not fliped, 0 fliped
|
||||
n.wRecChange .*= nonFlipedSign # set weight that fliped sign to 0 for random new connection
|
||||
end
|
||||
|
||||
""" linear_neuron learn()
|
||||
"""
|
||||
function learn!(n::linear_neuron, kfn::knowledgeFn)
|
||||
n.error = kfn.outputError[n.id]
|
||||
n.learningStage = kfn.learningStage
|
||||
function learn!(n::linear_neuron, error::Number)
|
||||
n.decayedEpsilonRec = n.alpha * n.epsilonRec
|
||||
n.epsilonRec = n.decayedEpsilonRec + n.z_i_t
|
||||
n.eRec = n.phi * n.epsilonRec
|
||||
|
||||
if n.learningStage == "doing_inference"
|
||||
# no learning
|
||||
elseif n.learningStage == "start_learning"
|
||||
# if error signal available then accumulates Δw
|
||||
if n.error !== nothing && n.id == 1 # NOT working w/ multithreading training
|
||||
Δw = -apply!(n.optimiser, n.w_out, (n.error * n.epsilon_j))
|
||||
n.w_out_change = n.w_out_change + Δw
|
||||
n.eta = n.optimiser.eta
|
||||
Δb = -n.eta * n.error
|
||||
n.b_change = n.b_change + Δb
|
||||
elseif n.error !== nothing && n.id !== 1
|
||||
n.eta = kfn.outputNeuronsArray[1].eta
|
||||
Δw = -n.eta * n.error * n.epsilon_j
|
||||
n.w_out_change = n.w_out_change + Δw
|
||||
Δb = -n.eta * n.error
|
||||
n.b_change = n.b_change + Δb
|
||||
end
|
||||
elseif n.learningStage == "during_learning"
|
||||
# if error signal available then accumulates Δw
|
||||
if n.error !== nothing && n.id == 1 # NOT working w/ multithreading training
|
||||
Δw = -apply!(n.optimiser, n.w_out, (n.error * n.epsilon_j))
|
||||
n.w_out_change = n.w_out_change + Δw
|
||||
n.eta = n.optimiser.eta
|
||||
Δb = -n.eta * n.error
|
||||
n.b_change = n.b_change + Δb
|
||||
elseif n.error !== nothing && n.id !== 1
|
||||
n.eta = kfn.outputNeuronsArray[1].eta
|
||||
Δw = -n.eta * n.error * n.epsilon_j
|
||||
n.w_out_change = n.w_out_change + Δw
|
||||
Δb = -n.eta * n.error
|
||||
n.b_change = n.b_change + Δb
|
||||
end
|
||||
elseif n.learningStage == "end_learning"
|
||||
# if error signal available then accumulates Δw
|
||||
if n.error !== nothing && n.id == 1 # NOT working w/ multithreading training
|
||||
Δw = -apply!(n.optimiser, n.w_out, (n.error * n.epsilon_j))
|
||||
n.w_out_change = n.w_out_change + Δw
|
||||
n.eta = n.optimiser.eta
|
||||
Δb = -n.eta * n.error
|
||||
n.b_change = n.b_change + Δb
|
||||
elseif n.error !== nothing && n.id !== 1
|
||||
n.eta = kfn.outputNeuronsArray[1].eta
|
||||
Δw = -n.eta * n.error * n.epsilon_j
|
||||
n.w_out_change = n.w_out_change + Δw
|
||||
Δb = -n.eta * n.error
|
||||
n.b_change = n.b_change + Δb
|
||||
end
|
||||
ΔwRecChange = n.eta * error
|
||||
n.wRecChange = (n.subExInType * n.wRecChange) + ΔwRecChange
|
||||
|
||||
n.w_out = n.w_out + n.w_out_change
|
||||
n.b = n.b + n.b_change
|
||||
end
|
||||
# check for fliped sign, 1 indicates non-fliped sign
|
||||
wSign = sign.(n.wRecChange)
|
||||
nonFlipedSign = isequal.(n.subExInType, wSign) # 1 not fliped, 0 fliped
|
||||
n.wRecChange .*= nonFlipedSign # set weight that fliped sign to 0 for random new connection
|
||||
end
|
||||
|
||||
|
||||
|
||||
@@ -3,13 +3,13 @@ module snn_utils
|
||||
using Flux.Optimise: apply!
|
||||
export calculate_α, calculate_ρ, calculate_k, timestep_forward!, init_neuron, no_negative!,
|
||||
precision, calculate_w_change!, store_knowledgefn_error!, interneurons_adjustment!,
|
||||
reset_z_t!, reset_learning_params!, reset_learning_history_params!,
|
||||
reset_z_t!, resetLearningParams!, reset_learning_history_params!,
|
||||
cal_v_reg!, calculate_w_change_end!,
|
||||
firing_rate_error!, firing_rate_regulator!, update_Bn!, cal_firing_reg!,
|
||||
neuroplasticity!, shakeup!, reset_learning_no_wchange!, adjust_internal_learning_rate!,
|
||||
gradient_withloss
|
||||
|
||||
using Statistics, Random, LinearAlgebra, Distributions, Zygote
|
||||
using Statistics, Random, LinearAlgebra, Distributions, Zygote, Flux
|
||||
|
||||
using ..types
|
||||
|
||||
@@ -98,21 +98,19 @@ reset_b_change!(n::linear_neuron) = n.b_change = n.b_change * 0.0
|
||||
|
||||
""" Reset all learning-related params at the END of learning session
|
||||
"""
|
||||
function reset_learning_params!(n::lif_neuron)
|
||||
function resetLearningParams!(n::lif_neuron)
|
||||
reset_epsilon_rec!(n)
|
||||
reset_w_rec_change!(n)
|
||||
# reset_v_t!(n)
|
||||
# reset_z_t!(n)
|
||||
reset_firing_counter!(n)
|
||||
reset_firing_diff!(n)
|
||||
reset_previous_error!(n)
|
||||
reset_error!(n)
|
||||
|
||||
# reset refractory state at the start/end of episode. Otherwise once neuron goes into
|
||||
# refractory state, it will stay in refractory state forever
|
||||
reset_refractoryCounter!(n)
|
||||
end
|
||||
function reset_learning_params!(n::alif_neuron)
|
||||
function resetLearningParams!(n::alif_neuron)
|
||||
reset_epsilon_rec!(n)
|
||||
reset_epsilon_rec_a!(n)
|
||||
reset_w_rec_change!(n)
|
||||
@@ -121,8 +119,6 @@ function reset_learning_params!(n::alif_neuron)
|
||||
# reset_a!(n)
|
||||
reset_firing_counter!(n)
|
||||
reset_firing_diff!(n)
|
||||
reset_previous_error!(n)
|
||||
reset_error!(n)
|
||||
|
||||
# reset refractory state at the start/end of episode. Otherwise once neuron goes into
|
||||
# refractory state, it will stay in refractory state forever
|
||||
@@ -132,18 +128,15 @@ end
|
||||
# function reset_learning_no_wchange!(n::passthrough_neuron)
|
||||
# end
|
||||
|
||||
function reset_learning_params!(n::passthrough_neuron)
|
||||
function resetLearningParams!(n::passthrough_neuron)
|
||||
# skip
|
||||
end
|
||||
#WORKING
|
||||
function reset_learning_params!(n::linear_neuron)
|
||||
|
||||
function resetLearningParams!(n::linear_neuron)
|
||||
reset_epsilon_rec!(n)
|
||||
reset_w_rec_change!(n)
|
||||
reset_v_t!(n)
|
||||
reset_firing_counter!(n)
|
||||
reset_firing_diff!(n)
|
||||
reset_previous_error!(n)
|
||||
reset_error!(n)
|
||||
|
||||
# reset refractory state at the start/end of episode. Otherwise once neuron goes into
|
||||
# refractory state, it will stay in refractory state forever
|
||||
@@ -288,14 +281,19 @@ function push_epsilon_rec_a!(n::alif_neuron)
|
||||
push!(n.epsilonRecA, 0)
|
||||
end
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
""" compute synaptic connection strength. bias will shift currentStrength to fit into
|
||||
sigmoid operating range which centred at 0 and range is -37 to 37.
|
||||
# Example
|
||||
synaptic strength range is 0 to 10
|
||||
one may use bias = -5 to transform synaptic strength into range -5 to 5
|
||||
the return value is shifted back to original scale
|
||||
"""
|
||||
function synapticConnStrength(currentStrength::AbstractFloat, bias::Number=0)
|
||||
currentStrength += bias
|
||||
currentStrength - (1.0 - sigmoid(currentStrength))
|
||||
currentStrength -= bias
|
||||
return currentStrength
|
||||
end
|
||||
|
||||
|
||||
|
||||
|
||||
18
src/types.jl
18
src/types.jl
@@ -325,7 +325,8 @@ Base.@kwdef mutable struct lif_neuron <: compute_neuron
|
||||
# during v_t1 calculation hence I need a variable to hold z_t1 so that I'm not replacing z_t
|
||||
z_t1::Bool = false # neuron postsynaptic firing at current timestep (after neuron's calculation)
|
||||
z_i_t::Union{Array{Bool},Nothing} = nothing # neuron presynaptic firing at current timestep (which is other neuron postsynaptic firing of previous timestep)
|
||||
# Bn_wout_decay::Union{Float64,Nothing} = 0.01 # use to balance Bn and w_out
|
||||
synapticStrength::Union{Array{Float64},Nothing} = nothing
|
||||
synapticStrengthLimit::Union{NamedTuple,Nothing} = (lowerlimit=(0=>0), upperlimit=(10=>10))
|
||||
|
||||
gammaPd::Union{Float64,Nothing} = 0.3 # γ_pd, discount factor, value from paper
|
||||
alpha::Union{Float64,Nothing} = nothing # α, neuron membrane potential decay factor
|
||||
@@ -334,7 +335,6 @@ Base.@kwdef mutable struct lif_neuron <: compute_neuron
|
||||
decayedEpsilonRec::Union{Array{Float64},Nothing} = nothing # α * epsilonRec
|
||||
eRec::Union{Array{Float64},Nothing} = nothing # eligibility trace for neuron spike
|
||||
delta::Union{Float64,Nothing} = 1.0 # δ, discreate timestep size in millisecond
|
||||
lastFiringTime::Union{Float64,Nothing} = 0.0 # the last time neuron fires, use to calculate exponantial decay of v_t1
|
||||
refractoryDuration::Union{Float64,Nothing} = 3 # neuron's refratory period in millisecond
|
||||
# refractory_state_active::Union{Bool,Nothing} = false # if true, neuron is in refractory state and cannot process new information
|
||||
refractoryCounter::Integer = 0
|
||||
@@ -418,7 +418,8 @@ Base.@kwdef mutable struct alif_neuron <: compute_neuron
|
||||
# during v_t1 calculation hence I need a variable to hold z_t1 so that I'm not replacing z_t
|
||||
z_t1::Bool = false # neuron postsynaptic firing at current timestep (after neuron's calculation)
|
||||
z_i_t::Union{Array{Bool},Nothing} = nothing # neuron presynaptic firing at current timestep (which is other neuron postsynaptic firing of previous timestep)
|
||||
# Bn_wout_decay::Union{Float64,Nothing} = 0.01 # use to balance Bn and w_out
|
||||
synapticStrength::Union{Array{Float64},Nothing} = nothing
|
||||
synapticStrengthLimit::Union{NamedTuple,Nothing} = (lowerlimit=(-5=>0), upperlimit=(5=>5))
|
||||
|
||||
alpha::Union{Float64,Nothing} = nothing # α, neuron membrane potential decay factor
|
||||
delta::Union{Float64,Nothing} = 1.0 # δ, discreate timestep size in millisecond
|
||||
@@ -430,7 +431,6 @@ Base.@kwdef mutable struct alif_neuron <: compute_neuron
|
||||
eRec::Union{Array{Float64},Nothing} = nothing # neuron's eligibility trace
|
||||
eta::Union{Float64,Nothing} = 0.01 # eta, learning rate
|
||||
gammaPd::Union{Float64,Nothing} = 0.3 # γ_pd, discount factor, value from paper
|
||||
lastFiringTime::Union{Float64,Nothing} = 0.0 # the last time neuron fires, use to calculate exponantial decay of v_t1
|
||||
phi::Union{Float64,Nothing} = nothing # ϕ, psuedo derivative
|
||||
refractoryDuration::Union{Float64,Nothing} = 3 # neuron's refractory period in millisecond
|
||||
# refractory_state_active::Union{Bool,Nothing} = false # if true, neuron is in refractory state and cannot process new information
|
||||
@@ -528,6 +528,8 @@ Base.@kwdef mutable struct linear_neuron <: output_neuron
|
||||
# neuron presynaptic firing at current timestep (which is other neuron postsynaptic firing of
|
||||
# previous timestep)
|
||||
z_i_t::Union{Array{Bool},Nothing} = nothing
|
||||
synapticStrength::Union{Array{Float64},Nothing} = nothing
|
||||
synapticStrengthLimit::Union{NamedTuple,Nothing} = (lowerlimit=(-5=>0), upperlimit=(5=>5))
|
||||
|
||||
gammaPd::Union{Float64,Nothing} = 0.3 # γ_pd, discount factor, value from paper
|
||||
alpha::Union{Float64,Nothing} = nothing # α, neuron membrane potential decay factor
|
||||
@@ -536,7 +538,6 @@ Base.@kwdef mutable struct linear_neuron <: output_neuron
|
||||
decayedEpsilonRec::Union{Array{Float64},Nothing} = nothing # α * epsilonRec
|
||||
eRec::Union{Array{Float64},Nothing} = nothing # eligibility trace for neuron spike
|
||||
delta::Union{Float64,Nothing} = 1.0 # δ, discreate timestep size in millisecond
|
||||
lastFiringTime::Union{Float64,Nothing} = 0.0 # the last time neuron fires, use to calculate exponantial decay of v_t1
|
||||
refractoryDuration::Union{Float64,Nothing} = 3 # neuron's refratory period in millisecond
|
||||
refractoryCounter::Integer = 0
|
||||
tau_out::Union{Float64,Nothing} = nothing # τ_out, membrane time constant in millisecond
|
||||
@@ -629,11 +630,11 @@ function init_neuron!(id::Int64, n::lif_neuron, n_params::Dict, kfnParams::Dict)
|
||||
|
||||
# prevent subscription to itself by removing this neuron id
|
||||
filter!(x -> x != n.id, n.subscriptionList)
|
||||
n.synapticStrength = normalize!(rand(length(n.subscriptionList)), 1)
|
||||
|
||||
n.epsilonRec = zeros(length(n.subscriptionList))
|
||||
n.w_rec = Random.rand(length(n.subscriptionList))
|
||||
n.wRecChange = zeros(length(n.subscriptionList))
|
||||
# n.reg_voltage_b = zeros(length(n.subscriptionList))
|
||||
n.alpha = calculate_α(n)
|
||||
end
|
||||
|
||||
@@ -648,6 +649,7 @@ function init_neuron!(id::Int64, n::alif_neuron, n_params::Dict,
|
||||
|
||||
# prevent subscription to itself by removing this neuron id
|
||||
filter!(x -> x != n.id, n.subscriptionList)
|
||||
n.synapticStrength = normalize!(rand(length(n.subscriptionList)), 1)
|
||||
|
||||
n.epsilonRec = zeros(length(n.subscriptionList))
|
||||
n.w_rec = Random.rand(length(n.subscriptionList))
|
||||
@@ -660,7 +662,7 @@ function init_neuron!(id::Int64, n::alif_neuron, n_params::Dict,
|
||||
n.epsilonRecA = zeros(length(n.subscriptionList))
|
||||
end
|
||||
|
||||
#WORKING
|
||||
|
||||
function init_neuron!(id::Int64, n::linear_neuron, n_params::Dict, kfnParams::Dict)
|
||||
n.id = id
|
||||
n.knowledgeFnName = kfnParams[:knowledgeFnName]
|
||||
@@ -669,7 +671,7 @@ function init_neuron!(id::Int64, n::linear_neuron, n_params::Dict, kfnParams::Di
|
||||
subscription_numbers = Int(floor(n_params[:synaptic_connection_number] *
|
||||
kfnParams[:total_compute_neuron] / 100.0))
|
||||
n.subscriptionList = [pop!(subscription_options) for i = 1:subscription_numbers]
|
||||
|
||||
n.synapticStrength = normalize!(rand(length(n.subscriptionList)), 1)
|
||||
n.epsilonRec = zeros(length(n.subscriptionList))
|
||||
n.w_rec = Random.rand(length(n.subscriptionList))
|
||||
n.wRecChange = zeros(length(n.subscriptionList))
|
||||
|
||||
Reference in New Issue
Block a user