time-based learning method based on new error formula
This commit is contained in:
@@ -34,14 +34,16 @@ using .interface
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
Todo:
|
Todo:
|
||||||
[7] time-based learning method based on new error formula
|
[*6] time-based learning method based on new error formula
|
||||||
|
(use output vt compared to vth instead of late time)
|
||||||
if output neuron not activate when it should, use output neuron's
|
if output neuron not activate when it should, use output neuron's
|
||||||
(vth - vt)*100/vth as error
|
(vth - vt)*100/vth as error
|
||||||
if output neuron activates when it should NOT, use output neuron's
|
if output neuron activates when it should NOT, use output neuron's
|
||||||
(vt*100)/vth as error
|
(vt*100)/vth as error
|
||||||
[8] verify that model can complete learning cycle with no error
|
[7] use LinearAlgebra.normalize!(vector, 1) to adjust weight after weight merge
|
||||||
[5] synaptic connection strength concept. use sigmoid
|
[9] verify that model can complete learning cycle with no error
|
||||||
[6] neuroplasticity() i.e. change connection
|
[*5] synaptic connection strength concept. use sigmoid, turn connection offline
|
||||||
|
[8] neuroplasticity() i.e. change connection
|
||||||
[] using RL to control learning signal
|
[] using RL to control learning signal
|
||||||
[] consider using Dates.now() instead of timestamp because time_stamp may overflow
|
[] consider using Dates.now() instead of timestamp because time_stamp may overflow
|
||||||
[] training should include adjusting α, neuron membrane potential decay factor
|
[] training should include adjusting α, neuron membrane potential decay factor
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ using ..types, ..snn_utils
|
|||||||
""" Model forward()
|
""" Model forward()
|
||||||
"""
|
"""
|
||||||
function (m::model)(input_data::AbstractVector)
|
function (m::model)(input_data::AbstractVector)
|
||||||
# m.global_tick += 1
|
|
||||||
m.timeStep += 1
|
m.timeStep += 1
|
||||||
|
|
||||||
# process all corresponding KFN
|
# process all corresponding KFN
|
||||||
@@ -31,7 +30,6 @@ function (kfn::kfn_1)(m::model, input_data::AbstractVector)
|
|||||||
kfn.timeStep = m.timeStep
|
kfn.timeStep = m.timeStep
|
||||||
kfn.softreset = m.softreset
|
kfn.softreset = m.softreset
|
||||||
kfn.learningStage = m.learningStage
|
kfn.learningStage = m.learningStage
|
||||||
kfn.error = m.error
|
|
||||||
|
|
||||||
# generate noise
|
# generate noise
|
||||||
noise = [GeneralUtils.randomChoiceWithProb([true, false],[0.5,0.5])
|
noise = [GeneralUtils.randomChoiceWithProb([true, false],[0.5,0.5])
|
||||||
@@ -101,8 +99,8 @@ function (n::lif_neuron)(kfn::knowledgeFn)
|
|||||||
# last only 1 timestep follow by a period of refractory.
|
# last only 1 timestep follow by a period of refractory.
|
||||||
n.recSignal = n.recSignal * 0.0
|
n.recSignal = n.recSignal * 0.0
|
||||||
|
|
||||||
# Exponantial decay of v_t1
|
# decay of v_t1
|
||||||
n.v_t1 = n.v_t * n.alpha^(n.timeStep - n.lastFiringTime) # or n.v_t1 = n.alpha * n.v_t
|
n.v_t1 = n.alpha * n.v_t
|
||||||
else
|
else
|
||||||
n.recSignal = sum(n.w_rec .* n.z_i_t) # signal from other neuron that this neuron subscribed
|
n.recSignal = sum(n.w_rec .* n.z_i_t) # signal from other neuron that this neuron subscribed
|
||||||
|
|
||||||
@@ -142,8 +140,8 @@ function (n::alif_neuron)(kfn::knowledgeFn)
|
|||||||
n.a = (n.rho * n.a) + ((1 - n.rho) * n.z_t)
|
n.a = (n.rho * n.a) + ((1 - n.rho) * n.z_t)
|
||||||
n.recSignal = n.recSignal * 0.0
|
n.recSignal = n.recSignal * 0.0
|
||||||
|
|
||||||
# Exponantial decay of v_t1
|
# decay of v_t1
|
||||||
n.v_t1 = n.v_t * n.alpha^(n.timeStep - n.lastFiringTime) # or n.v_t1 = n.alpha * n.v_t
|
n.v_t1 = n.alpha * n.v_t
|
||||||
n.phi = 0
|
n.phi = 0
|
||||||
else
|
else
|
||||||
n.z_t = isnothing(n.z_t) ? false : n.z_t
|
n.z_t = isnothing(n.z_t) ? false : n.z_t
|
||||||
@@ -187,8 +185,8 @@ function (n::linear_neuron)(kfn::T) where T<:knowledgeFn
|
|||||||
# last only 1 timestep follow by a period of refractory.
|
# last only 1 timestep follow by a period of refractory.
|
||||||
n.recSignal = n.recSignal * 0.0
|
n.recSignal = n.recSignal * 0.0
|
||||||
|
|
||||||
# Exponantial decay of v_t1
|
# decay of v_t1
|
||||||
n.v_t1 = n.v_t * n.alpha^(n.timeStep - n.lastFiringTime) # or n.v_t1 = n.alpha * n.v_t
|
n.v_t1 = n.alpha * n.v_t
|
||||||
else
|
else
|
||||||
n.recSignal = sum(n.w_rec .* n.z_i_t) # signal from other neuron that this neuron subscribed
|
n.recSignal = sum(n.w_rec .* n.z_i_t) # signal from other neuron that this neuron subscribed
|
||||||
|
|
||||||
|
|||||||
275
src/learn.jl
275
src/learn.jl
@@ -4,7 +4,7 @@ using Flux.Optimise: apply!
|
|||||||
|
|
||||||
using Statistics, Flux, Random, LinearAlgebra
|
using Statistics, Flux, Random, LinearAlgebra
|
||||||
using GeneralUtils
|
using GeneralUtils
|
||||||
using ..types
|
using ..types, ..snn_utils
|
||||||
|
|
||||||
export learn!
|
export learn!
|
||||||
|
|
||||||
@@ -12,6 +12,23 @@ export learn!
|
|||||||
|
|
||||||
function learn!(m::model, modelRespond, correctAnswer=nothing)
|
function learn!(m::model, modelRespond, correctAnswer=nothing)
|
||||||
m.knowledgeFn[:I].learningStage = m.learningStage
|
m.knowledgeFn[:I].learningStage = m.learningStage
|
||||||
|
|
||||||
|
# # how many matched respond and correct answer
|
||||||
|
# matched = sum(isequal(modelRespond, correctAnswer))
|
||||||
|
|
||||||
|
if correctAnswer === nothing
|
||||||
|
correctAnswer_I = zeros(length(modelRespond))
|
||||||
|
else
|
||||||
|
correctAnswer_I = correctAnswer # correct answer for kfn I
|
||||||
|
end
|
||||||
|
|
||||||
|
learn!(m.knowledgeFn[:I], correctAnswer_I)
|
||||||
|
end
|
||||||
|
|
||||||
|
""" knowledgeFn learn()
|
||||||
|
"""
|
||||||
|
function learn!(kfn::kfn_1, correctAnswer::AbstractVector)
|
||||||
|
|
||||||
# ΔWeight Conn. Strength
|
# ΔWeight Conn. Strength
|
||||||
# case 1 no no during input signal, no correct answer available, no answer
|
# case 1 no no during input signal, no correct answer available, no answer
|
||||||
# case 2 no - during input signal, no correct answer available, wrong answer
|
# case 2 no - during input signal, no correct answer available, wrong answer
|
||||||
@@ -27,38 +44,57 @@ function learn!(m::model, modelRespond, correctAnswer=nothing)
|
|||||||
|
|
||||||
# success
|
# success
|
||||||
|
|
||||||
# how many matched respond and correct answer
|
|
||||||
matched = sum(isequal(modelRespond, correctAnswer))
|
|
||||||
|
|
||||||
correctAnswer_I = correctAnswer # correct answer for kfn I
|
|
||||||
learn!(m.knowledgeFn[:I], correctAnswer_I)
|
|
||||||
|
|
||||||
# return model_error
|
|
||||||
end
|
|
||||||
|
|
||||||
|
|
||||||
""" knowledgeFn learn()
|
|
||||||
"""
|
|
||||||
function learn!(kfn::kfn_1, correctAnswer=nothing)
|
|
||||||
if kfn.learningStage == "start_learning"
|
if kfn.learningStage == "start_learning"
|
||||||
# reset params here instead of at the end_learning so that neuron's parameter data
|
# reset params here instead of at the end_learning so that neuron's parameter data
|
||||||
# don't gets wiped and can be logged for visualization later
|
# don't gets wiped and can be logged for visualization later
|
||||||
for n in kfn.neuronsArray
|
for n in kfn.neuronsArray
|
||||||
# epsilonRec need to be reset because it counting how many each synaptic fires and
|
# epsilonRec need to be reset because it counting how many each synaptic fires and
|
||||||
# use this info to calculate how much synaptic weight should be adjust
|
# use this info to calculate how much synaptic weight should be adjust
|
||||||
reset_learning_params!(n)
|
resetLearningParams!(n)
|
||||||
end
|
end
|
||||||
|
|
||||||
# clear variables
|
# clear variables
|
||||||
kfn.firedNeurons = Vector{Int64}()
|
kfn.firedNeurons = Vector{Int64}()
|
||||||
kfn.outputs = nothing
|
kfn.firedNeurons_t0 = Vector{Bool}()
|
||||||
|
kfn.firedNeurons_t1 = Vector{Bool}()
|
||||||
|
|
||||||
kfn.learningStage = "learning"
|
kfn.learningStage = "learning"
|
||||||
|
|
||||||
|
#TODO prepare for end learning
|
||||||
elseif kfn.learningStage == "end_learning"
|
elseif kfn.learningStage == "end_learning"
|
||||||
reset_learning_params!(n)
|
resetLearningParams!(n)
|
||||||
|
|
||||||
|
# clear variables
|
||||||
|
kfn.firedNeurons = Vector{Int64}()
|
||||||
|
kfn.firedNeurons_t0 = Vector{Bool}()
|
||||||
|
kfn.firedNeurons_t1 = Vector{Bool}()
|
||||||
|
|
||||||
kfn.learningStage = "inference"
|
kfn.learningStage = "inference"
|
||||||
end
|
end
|
||||||
|
|
||||||
|
# compute kfn error
|
||||||
|
out = [n.z_t1 for n in kfn.outputNeuronsArray]
|
||||||
|
for (i, v) in enumerate(out)
|
||||||
|
if v != correctAnswer[i] # need to adjust weight
|
||||||
|
kfnError = (kfn.outputNeuronsArray[i].v_th - kfn.outputNeuronsArray[i].v_t) *
|
||||||
|
100 / kfn.outputNeuronsArray[i].v_th
|
||||||
|
|
||||||
|
# Threads.@threads for n in kfn.neuronsArray
|
||||||
|
for n in kfn.neuronsArray
|
||||||
|
learn!(n, kfnError)
|
||||||
|
end
|
||||||
|
|
||||||
|
learn!(kfn.outputNeuronsArray[i], kfn)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
#WORKING
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Threads.@threads for n in kfn.neuronsArray
|
# Threads.@threads for n in kfn.neuronsArray
|
||||||
for n in kfn.neuronsArray
|
for n in kfn.neuronsArray
|
||||||
learn!(n, kfn) # Neurons are always learning, besides error from model output
|
learn!(n, kfn) # Neurons are always learning, besides error from model output
|
||||||
@@ -71,7 +107,7 @@ function learn!(kfn::kfn_1, correctAnswer=nothing)
|
|||||||
# other output neurons
|
# other output neurons
|
||||||
learn!(n, kfn)
|
learn!(n, kfn)
|
||||||
end
|
end
|
||||||
#TODO: put other KFN to learn here
|
|
||||||
|
|
||||||
# for main loop user's display and training's exit condition
|
# for main loop user's display and training's exit condition
|
||||||
avgNeuronsFiringRate = 0.0
|
avgNeuronsFiringRate = 0.0
|
||||||
@@ -90,6 +126,25 @@ function learn!(kfn::kfn_1, correctAnswer=nothing)
|
|||||||
end
|
end
|
||||||
kfn.avgNeurons_v_t1 = avgNeurons_v_t1 / kfn.kfnParams[:compute_neuron_number]
|
kfn.avgNeurons_v_t1 = avgNeurons_v_t1 / kfn.kfnParams[:compute_neuron_number]
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# wrap up learning session
|
||||||
|
if kfn.learningStage == "end_learning"
|
||||||
|
#TODO neuroplasticity
|
||||||
|
resetLearningParams!(n)
|
||||||
|
|
||||||
|
# clear variables
|
||||||
|
kfn.firedNeurons = Vector{Int64}()
|
||||||
|
kfn.firedNeurons_t0 = Vector{Bool}()
|
||||||
|
kfn.firedNeurons_t1 = Vector{Bool}()
|
||||||
|
|
||||||
|
kfn.learningStage = "inference"
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
""" passthrough_neuron learn()
|
""" passthrough_neuron learn()
|
||||||
@@ -100,71 +155,23 @@ end
|
|||||||
|
|
||||||
""" lif learn()
|
""" lif learn()
|
||||||
"""
|
"""
|
||||||
function learn!(n::lif_neuron, kfn::knowledgeFn)
|
function learn!(n::lif_neuron, error::Number)
|
||||||
if n.learnable_flag == true
|
|
||||||
|
|
||||||
n.decayedEpsilonRec = n.alpha * n.epsilonRec
|
n.decayedEpsilonRec = n.alpha * n.epsilonRec
|
||||||
n.epsilonRec = n.decayedEpsilonRec + n.z_i_t
|
n.epsilonRec = n.decayedEpsilonRec + n.z_i_t
|
||||||
n.eRec = n.phi * n.epsilonRec
|
n.eRec = n.phi * n.epsilonRec
|
||||||
end
|
|
||||||
|
|
||||||
# a piece of knowledgeFn error that belongs to this neuron
|
ΔwRecChange = n.eta * error
|
||||||
n.error = isnothing(kfn.error) ? nothing : kfn.error * n.Bn
|
n.wRecChange = (n.subExInType * n.wRecChange) + ΔwRecChange
|
||||||
n.learningStage = kfn.learningStage
|
|
||||||
|
|
||||||
# accumulate voltage regularization terms
|
# check for fliped sign, 1 indicates non-fliped sign
|
||||||
Snn_utils.cal_v_reg!(n)
|
wSign = sign.(n.wRecChange)
|
||||||
|
nonFlipedSign = isequal.(n.subExInType, wSign) # 1 not fliped, 0 fliped
|
||||||
if n.learningStage == "doing_inference"
|
n.wRecChange .*= nonFlipedSign # set weight that fliped sign to 0 for random new connection
|
||||||
# no learning
|
|
||||||
elseif n.learningStage == "start_learning" ||
|
|
||||||
n.learningStage == "start_learning_no_wchange_reset"
|
|
||||||
|
|
||||||
# if error signal available then accumulates Δw
|
|
||||||
if n.error !== nothing
|
|
||||||
Snn_utils.firing_rate!(n)
|
|
||||||
Snn_utils.firing_diff!(n)
|
|
||||||
n.wRecChange = n.wRecChange +
|
|
||||||
-apply!(n.optimiser, n.w_rec,
|
|
||||||
(n.error + Snn_utils.voltage_error!(n) + n.firingRateError) * n.eRec) +
|
|
||||||
-Snn_utils.firing_rate_regulator!(n) +
|
|
||||||
-Snn_utils.voltage_regulator!(n)
|
|
||||||
end
|
|
||||||
elseif n.learningStage == "during_learning"
|
|
||||||
# if error signal available then accumulates Δw
|
|
||||||
if n.error !== nothing
|
|
||||||
Snn_utils.firing_rate!(n)
|
|
||||||
Snn_utils.firing_diff!(n)
|
|
||||||
n.wRecChange = n.wRecChange +
|
|
||||||
-apply!(n.optimiser, n.w_rec,
|
|
||||||
(n.error + Snn_utils.voltage_error!(n) + n.firingRateError) * n.eRec) +
|
|
||||||
-Snn_utils.firing_rate_regulator!(n) +
|
|
||||||
-Snn_utils.voltage_regulator!(n)
|
|
||||||
end
|
|
||||||
elseif n.learningStage == "end_learning"
|
|
||||||
# if error signal available then accumulates Δw
|
|
||||||
if n.error !== nothing
|
|
||||||
Snn_utils.firing_rate!(n)
|
|
||||||
Snn_utils.firing_diff!(n)
|
|
||||||
n.wRecChange = n.wRecChange +
|
|
||||||
-apply!(n.optimiser, n.w_rec,
|
|
||||||
(n.error + Snn_utils.voltage_error!(n) + n.firingRateError) * n.eRec) +
|
|
||||||
-Snn_utils.firing_rate_regulator!(n) +
|
|
||||||
-Snn_utils.voltage_regulator!(n)
|
|
||||||
end
|
|
||||||
|
|
||||||
not_zero = (!iszero).(n.w_rec)
|
|
||||||
# set 0 in wRecChange update according to 0 in w_rec for hard constrain connection
|
|
||||||
n.w_rec = n.w_rec + (not_zero .* n.wRecChange)
|
|
||||||
replace!(x -> x < 0 ? 0 : x, n.w_rec) # no negative weight
|
|
||||||
|
|
||||||
Snn_utils.neuroplasticity!(n, kfn.firedNeurons)
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
|
||||||
""" alif_neuron learn()
|
""" alif_neuron learn()
|
||||||
"""
|
"""
|
||||||
function learn!(n::alif_neuron, kfn::knowledgeFn)
|
function learn!(n::alif_neuron, error::Number)
|
||||||
n.decayedEpsilonRec = n.alpha * n.epsilonRec
|
n.decayedEpsilonRec = n.alpha * n.epsilonRec
|
||||||
n.epsilonRec = n.decayedEpsilonRec + n.z_i_t
|
n.epsilonRec = n.decayedEpsilonRec + n.z_i_t
|
||||||
n.epsilonRecA = (n.phi * n.epsilonRec) +
|
n.epsilonRecA = (n.phi * n.epsilonRec) +
|
||||||
@@ -173,117 +180,29 @@ function learn!(n::alif_neuron, kfn::knowledgeFn)
|
|||||||
n.eRec_a = -n.phi * n.beta * n.epsilonRecA
|
n.eRec_a = -n.phi * n.beta * n.epsilonRecA
|
||||||
n.eRec = n.eRec_v + n.eRec_a
|
n.eRec = n.eRec_v + n.eRec_a
|
||||||
|
|
||||||
# a piece of knowledgeFn error that belongs to this neuron
|
ΔwRecChange = n.eta * error
|
||||||
n.error = isnothing(kfn.error) ? nothing : kfn.error * n.Bn
|
n.wRecChange = (n.subExInType * n.wRecChange) + ΔwRecChange
|
||||||
n.learningStage = kfn.learningStage
|
|
||||||
|
|
||||||
|
# check for fliped sign, 1 indicates non-fliped sign
|
||||||
|
wSign = sign.(n.wRecChange)
|
||||||
if n.learningStage == "doing_inference"
|
nonFlipedSign = isequal.(n.subExInType, wSign) # 1 not fliped, 0 fliped
|
||||||
# no learning
|
n.wRecChange .*= nonFlipedSign # set weight that fliped sign to 0 for random new connection
|
||||||
elseif n.learningStage == "start_learning" ||
|
|
||||||
n.learningStage == "start_learning_no_wchange_reset"
|
|
||||||
|
|
||||||
# if error signal available then accumulates Δw
|
|
||||||
if n.error !== nothing
|
|
||||||
Snn_utils.firing_rate!(n)
|
|
||||||
Snn_utils.firing_diff!(n)
|
|
||||||
n.wRecChange = n.wRecChange +
|
|
||||||
-apply!(n.optimiser, n.w_rec,
|
|
||||||
(n.error + Snn_utils.voltage_error!(n) + n.firingRateError) * n.eRec) +
|
|
||||||
-Snn_utils.firing_rate_regulator!(n) +
|
|
||||||
-Snn_utils.voltage_regulator!(n)
|
|
||||||
end
|
|
||||||
elseif n.learningStage == "during_learning"
|
|
||||||
# if error signal available then accumulates Δw
|
|
||||||
if n.error !== nothing
|
|
||||||
Snn_utils.firing_rate!(n)
|
|
||||||
Snn_utils.firing_diff!(n)
|
|
||||||
n.wRecChange = n.wRecChange +
|
|
||||||
-apply!(n.optimiser, n.w_rec,
|
|
||||||
(n.error + Snn_utils.voltage_error!(n) + n.firingRateError) * n.eRec) +
|
|
||||||
-Snn_utils.firing_rate_regulator!(n) +
|
|
||||||
-Snn_utils.voltage_regulator!(n)
|
|
||||||
end
|
|
||||||
elseif n.learningStage == "end_learning"
|
|
||||||
|
|
||||||
# if error signal available then accumulates Δw
|
|
||||||
if n.error !== nothing
|
|
||||||
Snn_utils.firing_rate!(n)
|
|
||||||
Snn_utils.firing_diff!(n)
|
|
||||||
n.wRecChange = n.wRecChange +
|
|
||||||
-apply!(n.optimiser, n.w_rec,
|
|
||||||
(n.error + Snn_utils.voltage_error!(n) + n.firingRateError) * n.eRec) +
|
|
||||||
-Snn_utils.firing_rate_regulator!(n) +
|
|
||||||
-Snn_utils.voltage_regulator!(n)
|
|
||||||
end
|
|
||||||
|
|
||||||
not_zero = (!iszero).(n.w_rec)
|
|
||||||
# set 0 in wRecChange update according to 0 in w_rec for hard constrain connection
|
|
||||||
n.w_rec = n.w_rec + (not_zero .* n.wRecChange)
|
|
||||||
replace!(x -> x < 0 ? 0 : x, n.w_rec) # no negative weight
|
|
||||||
|
|
||||||
Snn_utils.neuroplasticity!(n, kfn.firedNeurons)
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
|
||||||
""" linear_neuron learn()
|
""" linear_neuron learn()
|
||||||
"""
|
"""
|
||||||
function learn!(n::linear_neuron, kfn::knowledgeFn)
|
function learn!(n::linear_neuron, error::Number)
|
||||||
n.error = kfn.outputError[n.id]
|
n.decayedEpsilonRec = n.alpha * n.epsilonRec
|
||||||
n.learningStage = kfn.learningStage
|
n.epsilonRec = n.decayedEpsilonRec + n.z_i_t
|
||||||
|
n.eRec = n.phi * n.epsilonRec
|
||||||
|
|
||||||
if n.learningStage == "doing_inference"
|
ΔwRecChange = n.eta * error
|
||||||
# no learning
|
n.wRecChange = (n.subExInType * n.wRecChange) + ΔwRecChange
|
||||||
elseif n.learningStage == "start_learning"
|
|
||||||
# if error signal available then accumulates Δw
|
|
||||||
if n.error !== nothing && n.id == 1 # NOT working w/ multithreading training
|
|
||||||
Δw = -apply!(n.optimiser, n.w_out, (n.error * n.epsilon_j))
|
|
||||||
n.w_out_change = n.w_out_change + Δw
|
|
||||||
n.eta = n.optimiser.eta
|
|
||||||
Δb = -n.eta * n.error
|
|
||||||
n.b_change = n.b_change + Δb
|
|
||||||
elseif n.error !== nothing && n.id !== 1
|
|
||||||
n.eta = kfn.outputNeuronsArray[1].eta
|
|
||||||
Δw = -n.eta * n.error * n.epsilon_j
|
|
||||||
n.w_out_change = n.w_out_change + Δw
|
|
||||||
Δb = -n.eta * n.error
|
|
||||||
n.b_change = n.b_change + Δb
|
|
||||||
end
|
|
||||||
elseif n.learningStage == "during_learning"
|
|
||||||
# if error signal available then accumulates Δw
|
|
||||||
if n.error !== nothing && n.id == 1 # NOT working w/ multithreading training
|
|
||||||
Δw = -apply!(n.optimiser, n.w_out, (n.error * n.epsilon_j))
|
|
||||||
n.w_out_change = n.w_out_change + Δw
|
|
||||||
n.eta = n.optimiser.eta
|
|
||||||
Δb = -n.eta * n.error
|
|
||||||
n.b_change = n.b_change + Δb
|
|
||||||
elseif n.error !== nothing && n.id !== 1
|
|
||||||
n.eta = kfn.outputNeuronsArray[1].eta
|
|
||||||
Δw = -n.eta * n.error * n.epsilon_j
|
|
||||||
n.w_out_change = n.w_out_change + Δw
|
|
||||||
Δb = -n.eta * n.error
|
|
||||||
n.b_change = n.b_change + Δb
|
|
||||||
end
|
|
||||||
elseif n.learningStage == "end_learning"
|
|
||||||
# if error signal available then accumulates Δw
|
|
||||||
if n.error !== nothing && n.id == 1 # NOT working w/ multithreading training
|
|
||||||
Δw = -apply!(n.optimiser, n.w_out, (n.error * n.epsilon_j))
|
|
||||||
n.w_out_change = n.w_out_change + Δw
|
|
||||||
n.eta = n.optimiser.eta
|
|
||||||
Δb = -n.eta * n.error
|
|
||||||
n.b_change = n.b_change + Δb
|
|
||||||
elseif n.error !== nothing && n.id !== 1
|
|
||||||
n.eta = kfn.outputNeuronsArray[1].eta
|
|
||||||
Δw = -n.eta * n.error * n.epsilon_j
|
|
||||||
n.w_out_change = n.w_out_change + Δw
|
|
||||||
Δb = -n.eta * n.error
|
|
||||||
n.b_change = n.b_change + Δb
|
|
||||||
end
|
|
||||||
|
|
||||||
n.w_out = n.w_out + n.w_out_change
|
# check for fliped sign, 1 indicates non-fliped sign
|
||||||
n.b = n.b + n.b_change
|
wSign = sign.(n.wRecChange)
|
||||||
end
|
nonFlipedSign = isequal.(n.subExInType, wSign) # 1 not fliped, 0 fliped
|
||||||
|
n.wRecChange .*= nonFlipedSign # set weight that fliped sign to 0 for random new connection
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -3,13 +3,13 @@ module snn_utils
|
|||||||
using Flux.Optimise: apply!
|
using Flux.Optimise: apply!
|
||||||
export calculate_α, calculate_ρ, calculate_k, timestep_forward!, init_neuron, no_negative!,
|
export calculate_α, calculate_ρ, calculate_k, timestep_forward!, init_neuron, no_negative!,
|
||||||
precision, calculate_w_change!, store_knowledgefn_error!, interneurons_adjustment!,
|
precision, calculate_w_change!, store_knowledgefn_error!, interneurons_adjustment!,
|
||||||
reset_z_t!, reset_learning_params!, reset_learning_history_params!,
|
reset_z_t!, resetLearningParams!, reset_learning_history_params!,
|
||||||
cal_v_reg!, calculate_w_change_end!,
|
cal_v_reg!, calculate_w_change_end!,
|
||||||
firing_rate_error!, firing_rate_regulator!, update_Bn!, cal_firing_reg!,
|
firing_rate_error!, firing_rate_regulator!, update_Bn!, cal_firing_reg!,
|
||||||
neuroplasticity!, shakeup!, reset_learning_no_wchange!, adjust_internal_learning_rate!,
|
neuroplasticity!, shakeup!, reset_learning_no_wchange!, adjust_internal_learning_rate!,
|
||||||
gradient_withloss
|
gradient_withloss
|
||||||
|
|
||||||
using Statistics, Random, LinearAlgebra, Distributions, Zygote
|
using Statistics, Random, LinearAlgebra, Distributions, Zygote, Flux
|
||||||
|
|
||||||
using ..types
|
using ..types
|
||||||
|
|
||||||
@@ -98,21 +98,19 @@ reset_b_change!(n::linear_neuron) = n.b_change = n.b_change * 0.0
|
|||||||
|
|
||||||
""" Reset all learning-related params at the END of learning session
|
""" Reset all learning-related params at the END of learning session
|
||||||
"""
|
"""
|
||||||
function reset_learning_params!(n::lif_neuron)
|
function resetLearningParams!(n::lif_neuron)
|
||||||
reset_epsilon_rec!(n)
|
reset_epsilon_rec!(n)
|
||||||
reset_w_rec_change!(n)
|
reset_w_rec_change!(n)
|
||||||
# reset_v_t!(n)
|
# reset_v_t!(n)
|
||||||
# reset_z_t!(n)
|
# reset_z_t!(n)
|
||||||
reset_firing_counter!(n)
|
reset_firing_counter!(n)
|
||||||
reset_firing_diff!(n)
|
reset_firing_diff!(n)
|
||||||
reset_previous_error!(n)
|
|
||||||
reset_error!(n)
|
|
||||||
|
|
||||||
# reset refractory state at the start/end of episode. Otherwise once neuron goes into
|
# reset refractory state at the start/end of episode. Otherwise once neuron goes into
|
||||||
# refractory state, it will stay in refractory state forever
|
# refractory state, it will stay in refractory state forever
|
||||||
reset_refractoryCounter!(n)
|
reset_refractoryCounter!(n)
|
||||||
end
|
end
|
||||||
function reset_learning_params!(n::alif_neuron)
|
function resetLearningParams!(n::alif_neuron)
|
||||||
reset_epsilon_rec!(n)
|
reset_epsilon_rec!(n)
|
||||||
reset_epsilon_rec_a!(n)
|
reset_epsilon_rec_a!(n)
|
||||||
reset_w_rec_change!(n)
|
reset_w_rec_change!(n)
|
||||||
@@ -121,8 +119,6 @@ function reset_learning_params!(n::alif_neuron)
|
|||||||
# reset_a!(n)
|
# reset_a!(n)
|
||||||
reset_firing_counter!(n)
|
reset_firing_counter!(n)
|
||||||
reset_firing_diff!(n)
|
reset_firing_diff!(n)
|
||||||
reset_previous_error!(n)
|
|
||||||
reset_error!(n)
|
|
||||||
|
|
||||||
# reset refractory state at the start/end of episode. Otherwise once neuron goes into
|
# reset refractory state at the start/end of episode. Otherwise once neuron goes into
|
||||||
# refractory state, it will stay in refractory state forever
|
# refractory state, it will stay in refractory state forever
|
||||||
@@ -132,18 +128,15 @@ end
|
|||||||
# function reset_learning_no_wchange!(n::passthrough_neuron)
|
# function reset_learning_no_wchange!(n::passthrough_neuron)
|
||||||
# end
|
# end
|
||||||
|
|
||||||
function reset_learning_params!(n::passthrough_neuron)
|
function resetLearningParams!(n::passthrough_neuron)
|
||||||
# skip
|
# skip
|
||||||
end
|
end
|
||||||
#WORKING
|
|
||||||
function reset_learning_params!(n::linear_neuron)
|
function resetLearningParams!(n::linear_neuron)
|
||||||
reset_epsilon_rec!(n)
|
reset_epsilon_rec!(n)
|
||||||
reset_w_rec_change!(n)
|
reset_w_rec_change!(n)
|
||||||
reset_v_t!(n)
|
reset_v_t!(n)
|
||||||
reset_firing_counter!(n)
|
reset_firing_counter!(n)
|
||||||
reset_firing_diff!(n)
|
|
||||||
reset_previous_error!(n)
|
|
||||||
reset_error!(n)
|
|
||||||
|
|
||||||
# reset refractory state at the start/end of episode. Otherwise once neuron goes into
|
# reset refractory state at the start/end of episode. Otherwise once neuron goes into
|
||||||
# refractory state, it will stay in refractory state forever
|
# refractory state, it will stay in refractory state forever
|
||||||
@@ -288,14 +281,19 @@ function push_epsilon_rec_a!(n::alif_neuron)
|
|||||||
push!(n.epsilonRecA, 0)
|
push!(n.epsilonRecA, 0)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
""" compute synaptic connection strength. bias will shift currentStrength to fit into
|
||||||
|
sigmoid operating range which centred at 0 and range is -37 to 37.
|
||||||
|
# Example
|
||||||
|
synaptic strength range is 0 to 10
|
||||||
|
one may use bias = -5 to transform synaptic strength into range -5 to 5
|
||||||
|
the return value is shifted back to original scale
|
||||||
|
"""
|
||||||
|
function synapticConnStrength(currentStrength::AbstractFloat, bias::Number=0)
|
||||||
|
currentStrength += bias
|
||||||
|
currentStrength - (1.0 - sigmoid(currentStrength))
|
||||||
|
currentStrength -= bias
|
||||||
|
return currentStrength
|
||||||
|
end
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
18
src/types.jl
18
src/types.jl
@@ -325,7 +325,8 @@ Base.@kwdef mutable struct lif_neuron <: compute_neuron
|
|||||||
# during v_t1 calculation hence I need a variable to hold z_t1 so that I'm not replacing z_t
|
# during v_t1 calculation hence I need a variable to hold z_t1 so that I'm not replacing z_t
|
||||||
z_t1::Bool = false # neuron postsynaptic firing at current timestep (after neuron's calculation)
|
z_t1::Bool = false # neuron postsynaptic firing at current timestep (after neuron's calculation)
|
||||||
z_i_t::Union{Array{Bool},Nothing} = nothing # neuron presynaptic firing at current timestep (which is other neuron postsynaptic firing of previous timestep)
|
z_i_t::Union{Array{Bool},Nothing} = nothing # neuron presynaptic firing at current timestep (which is other neuron postsynaptic firing of previous timestep)
|
||||||
# Bn_wout_decay::Union{Float64,Nothing} = 0.01 # use to balance Bn and w_out
|
synapticStrength::Union{Array{Float64},Nothing} = nothing
|
||||||
|
synapticStrengthLimit::Union{NamedTuple,Nothing} = (lowerlimit=(0=>0), upperlimit=(10=>10))
|
||||||
|
|
||||||
gammaPd::Union{Float64,Nothing} = 0.3 # γ_pd, discount factor, value from paper
|
gammaPd::Union{Float64,Nothing} = 0.3 # γ_pd, discount factor, value from paper
|
||||||
alpha::Union{Float64,Nothing} = nothing # α, neuron membrane potential decay factor
|
alpha::Union{Float64,Nothing} = nothing # α, neuron membrane potential decay factor
|
||||||
@@ -334,7 +335,6 @@ Base.@kwdef mutable struct lif_neuron <: compute_neuron
|
|||||||
decayedEpsilonRec::Union{Array{Float64},Nothing} = nothing # α * epsilonRec
|
decayedEpsilonRec::Union{Array{Float64},Nothing} = nothing # α * epsilonRec
|
||||||
eRec::Union{Array{Float64},Nothing} = nothing # eligibility trace for neuron spike
|
eRec::Union{Array{Float64},Nothing} = nothing # eligibility trace for neuron spike
|
||||||
delta::Union{Float64,Nothing} = 1.0 # δ, discreate timestep size in millisecond
|
delta::Union{Float64,Nothing} = 1.0 # δ, discreate timestep size in millisecond
|
||||||
lastFiringTime::Union{Float64,Nothing} = 0.0 # the last time neuron fires, use to calculate exponantial decay of v_t1
|
|
||||||
refractoryDuration::Union{Float64,Nothing} = 3 # neuron's refratory period in millisecond
|
refractoryDuration::Union{Float64,Nothing} = 3 # neuron's refratory period in millisecond
|
||||||
# refractory_state_active::Union{Bool,Nothing} = false # if true, neuron is in refractory state and cannot process new information
|
# refractory_state_active::Union{Bool,Nothing} = false # if true, neuron is in refractory state and cannot process new information
|
||||||
refractoryCounter::Integer = 0
|
refractoryCounter::Integer = 0
|
||||||
@@ -418,7 +418,8 @@ Base.@kwdef mutable struct alif_neuron <: compute_neuron
|
|||||||
# during v_t1 calculation hence I need a variable to hold z_t1 so that I'm not replacing z_t
|
# during v_t1 calculation hence I need a variable to hold z_t1 so that I'm not replacing z_t
|
||||||
z_t1::Bool = false # neuron postsynaptic firing at current timestep (after neuron's calculation)
|
z_t1::Bool = false # neuron postsynaptic firing at current timestep (after neuron's calculation)
|
||||||
z_i_t::Union{Array{Bool},Nothing} = nothing # neuron presynaptic firing at current timestep (which is other neuron postsynaptic firing of previous timestep)
|
z_i_t::Union{Array{Bool},Nothing} = nothing # neuron presynaptic firing at current timestep (which is other neuron postsynaptic firing of previous timestep)
|
||||||
# Bn_wout_decay::Union{Float64,Nothing} = 0.01 # use to balance Bn and w_out
|
synapticStrength::Union{Array{Float64},Nothing} = nothing
|
||||||
|
synapticStrengthLimit::Union{NamedTuple,Nothing} = (lowerlimit=(-5=>0), upperlimit=(5=>5))
|
||||||
|
|
||||||
alpha::Union{Float64,Nothing} = nothing # α, neuron membrane potential decay factor
|
alpha::Union{Float64,Nothing} = nothing # α, neuron membrane potential decay factor
|
||||||
delta::Union{Float64,Nothing} = 1.0 # δ, discreate timestep size in millisecond
|
delta::Union{Float64,Nothing} = 1.0 # δ, discreate timestep size in millisecond
|
||||||
@@ -430,7 +431,6 @@ Base.@kwdef mutable struct alif_neuron <: compute_neuron
|
|||||||
eRec::Union{Array{Float64},Nothing} = nothing # neuron's eligibility trace
|
eRec::Union{Array{Float64},Nothing} = nothing # neuron's eligibility trace
|
||||||
eta::Union{Float64,Nothing} = 0.01 # eta, learning rate
|
eta::Union{Float64,Nothing} = 0.01 # eta, learning rate
|
||||||
gammaPd::Union{Float64,Nothing} = 0.3 # γ_pd, discount factor, value from paper
|
gammaPd::Union{Float64,Nothing} = 0.3 # γ_pd, discount factor, value from paper
|
||||||
lastFiringTime::Union{Float64,Nothing} = 0.0 # the last time neuron fires, use to calculate exponantial decay of v_t1
|
|
||||||
phi::Union{Float64,Nothing} = nothing # ϕ, psuedo derivative
|
phi::Union{Float64,Nothing} = nothing # ϕ, psuedo derivative
|
||||||
refractoryDuration::Union{Float64,Nothing} = 3 # neuron's refractory period in millisecond
|
refractoryDuration::Union{Float64,Nothing} = 3 # neuron's refractory period in millisecond
|
||||||
# refractory_state_active::Union{Bool,Nothing} = false # if true, neuron is in refractory state and cannot process new information
|
# refractory_state_active::Union{Bool,Nothing} = false # if true, neuron is in refractory state and cannot process new information
|
||||||
@@ -528,6 +528,8 @@ Base.@kwdef mutable struct linear_neuron <: output_neuron
|
|||||||
# neuron presynaptic firing at current timestep (which is other neuron postsynaptic firing of
|
# neuron presynaptic firing at current timestep (which is other neuron postsynaptic firing of
|
||||||
# previous timestep)
|
# previous timestep)
|
||||||
z_i_t::Union{Array{Bool},Nothing} = nothing
|
z_i_t::Union{Array{Bool},Nothing} = nothing
|
||||||
|
synapticStrength::Union{Array{Float64},Nothing} = nothing
|
||||||
|
synapticStrengthLimit::Union{NamedTuple,Nothing} = (lowerlimit=(-5=>0), upperlimit=(5=>5))
|
||||||
|
|
||||||
gammaPd::Union{Float64,Nothing} = 0.3 # γ_pd, discount factor, value from paper
|
gammaPd::Union{Float64,Nothing} = 0.3 # γ_pd, discount factor, value from paper
|
||||||
alpha::Union{Float64,Nothing} = nothing # α, neuron membrane potential decay factor
|
alpha::Union{Float64,Nothing} = nothing # α, neuron membrane potential decay factor
|
||||||
@@ -536,7 +538,6 @@ Base.@kwdef mutable struct linear_neuron <: output_neuron
|
|||||||
decayedEpsilonRec::Union{Array{Float64},Nothing} = nothing # α * epsilonRec
|
decayedEpsilonRec::Union{Array{Float64},Nothing} = nothing # α * epsilonRec
|
||||||
eRec::Union{Array{Float64},Nothing} = nothing # eligibility trace for neuron spike
|
eRec::Union{Array{Float64},Nothing} = nothing # eligibility trace for neuron spike
|
||||||
delta::Union{Float64,Nothing} = 1.0 # δ, discreate timestep size in millisecond
|
delta::Union{Float64,Nothing} = 1.0 # δ, discreate timestep size in millisecond
|
||||||
lastFiringTime::Union{Float64,Nothing} = 0.0 # the last time neuron fires, use to calculate exponantial decay of v_t1
|
|
||||||
refractoryDuration::Union{Float64,Nothing} = 3 # neuron's refratory period in millisecond
|
refractoryDuration::Union{Float64,Nothing} = 3 # neuron's refratory period in millisecond
|
||||||
refractoryCounter::Integer = 0
|
refractoryCounter::Integer = 0
|
||||||
tau_out::Union{Float64,Nothing} = nothing # τ_out, membrane time constant in millisecond
|
tau_out::Union{Float64,Nothing} = nothing # τ_out, membrane time constant in millisecond
|
||||||
@@ -629,11 +630,11 @@ function init_neuron!(id::Int64, n::lif_neuron, n_params::Dict, kfnParams::Dict)
|
|||||||
|
|
||||||
# prevent subscription to itself by removing this neuron id
|
# prevent subscription to itself by removing this neuron id
|
||||||
filter!(x -> x != n.id, n.subscriptionList)
|
filter!(x -> x != n.id, n.subscriptionList)
|
||||||
|
n.synapticStrength = normalize!(rand(length(n.subscriptionList)), 1)
|
||||||
|
|
||||||
n.epsilonRec = zeros(length(n.subscriptionList))
|
n.epsilonRec = zeros(length(n.subscriptionList))
|
||||||
n.w_rec = Random.rand(length(n.subscriptionList))
|
n.w_rec = Random.rand(length(n.subscriptionList))
|
||||||
n.wRecChange = zeros(length(n.subscriptionList))
|
n.wRecChange = zeros(length(n.subscriptionList))
|
||||||
# n.reg_voltage_b = zeros(length(n.subscriptionList))
|
|
||||||
n.alpha = calculate_α(n)
|
n.alpha = calculate_α(n)
|
||||||
end
|
end
|
||||||
|
|
||||||
@@ -648,6 +649,7 @@ function init_neuron!(id::Int64, n::alif_neuron, n_params::Dict,
|
|||||||
|
|
||||||
# prevent subscription to itself by removing this neuron id
|
# prevent subscription to itself by removing this neuron id
|
||||||
filter!(x -> x != n.id, n.subscriptionList)
|
filter!(x -> x != n.id, n.subscriptionList)
|
||||||
|
n.synapticStrength = normalize!(rand(length(n.subscriptionList)), 1)
|
||||||
|
|
||||||
n.epsilonRec = zeros(length(n.subscriptionList))
|
n.epsilonRec = zeros(length(n.subscriptionList))
|
||||||
n.w_rec = Random.rand(length(n.subscriptionList))
|
n.w_rec = Random.rand(length(n.subscriptionList))
|
||||||
@@ -660,7 +662,7 @@ function init_neuron!(id::Int64, n::alif_neuron, n_params::Dict,
|
|||||||
n.epsilonRecA = zeros(length(n.subscriptionList))
|
n.epsilonRecA = zeros(length(n.subscriptionList))
|
||||||
end
|
end
|
||||||
|
|
||||||
#WORKING
|
|
||||||
function init_neuron!(id::Int64, n::linear_neuron, n_params::Dict, kfnParams::Dict)
|
function init_neuron!(id::Int64, n::linear_neuron, n_params::Dict, kfnParams::Dict)
|
||||||
n.id = id
|
n.id = id
|
||||||
n.knowledgeFnName = kfnParams[:knowledgeFnName]
|
n.knowledgeFnName = kfnParams[:knowledgeFnName]
|
||||||
@@ -669,7 +671,7 @@ function init_neuron!(id::Int64, n::linear_neuron, n_params::Dict, kfnParams::Di
|
|||||||
subscription_numbers = Int(floor(n_params[:synaptic_connection_number] *
|
subscription_numbers = Int(floor(n_params[:synaptic_connection_number] *
|
||||||
kfnParams[:total_compute_neuron] / 100.0))
|
kfnParams[:total_compute_neuron] / 100.0))
|
||||||
n.subscriptionList = [pop!(subscription_options) for i = 1:subscription_numbers]
|
n.subscriptionList = [pop!(subscription_options) for i = 1:subscription_numbers]
|
||||||
|
n.synapticStrength = normalize!(rand(length(n.subscriptionList)), 1)
|
||||||
n.epsilonRec = zeros(length(n.subscriptionList))
|
n.epsilonRec = zeros(length(n.subscriptionList))
|
||||||
n.w_rec = Random.rand(length(n.subscriptionList))
|
n.w_rec = Random.rand(length(n.subscriptionList))
|
||||||
n.wRecChange = zeros(length(n.subscriptionList))
|
n.wRecChange = zeros(length(n.subscriptionList))
|
||||||
|
|||||||
Reference in New Issue
Block a user