time-based learning method based on new error formula

This commit is contained in:
2023-05-16 20:56:05 +07:00
parent 114161ba69
commit 70d2521c5e
5 changed files with 146 additions and 227 deletions

View File

@@ -4,7 +4,7 @@ using Flux.Optimise: apply!
using Statistics, Flux, Random, LinearAlgebra
using GeneralUtils
using ..types
using ..types, ..snn_utils
export learn!
@@ -12,6 +12,23 @@ export learn!
function learn!(m::model, modelRespond, correctAnswer=nothing)
m.knowledgeFn[:I].learningStage = m.learningStage
# # how many matched respond and correct answer
# matched = sum(isequal(modelRespond, correctAnswer))
if correctAnswer === nothing
correctAnswer_I = zeros(length(modelRespond))
else
correctAnswer_I = correctAnswer # correct answer for kfn I
end
learn!(m.knowledgeFn[:I], correctAnswer_I)
end
""" knowledgeFn learn()
"""
function learn!(kfn::kfn_1, correctAnswer::AbstractVector)
# ΔWeight Conn. Strength
# case 1 no no during input signal, no correct answer available, no answer
# case 2 no - during input signal, no correct answer available, wrong answer
@@ -26,39 +43,58 @@ function learn!(m::model, modelRespond, correctAnswer=nothing)
# case 8 no - after input signal, after correct timing (late), wrong answer
# success
# how many matched respond and correct answer
matched = sum(isequal(modelRespond, correctAnswer))
correctAnswer_I = correctAnswer # correct answer for kfn I
learn!(m.knowledgeFn[:I], correctAnswer_I)
# return model_error
end
""" knowledgeFn learn()
"""
function learn!(kfn::kfn_1, correctAnswer=nothing)
if kfn.learningStage == "start_learning"
# reset params here instead of at the end_learning so that neuron's parameter data
# don't gets wiped and can be logged for visualization later
for n in kfn.neuronsArray
# epsilonRec need to be reset because it counting how many each synaptic fires and
# use this info to calculate how much synaptic weight should be adjust
reset_learning_params!(n)
resetLearningParams!(n)
end
# clear variables
kfn.firedNeurons = Vector{Int64}()
kfn.outputs = nothing
kfn.firedNeurons_t0 = Vector{Bool}()
kfn.firedNeurons_t1 = Vector{Bool}()
kfn.learningStage = "learning"
#TODO prepare for end learning
elseif kfn.learningStage == "end_learning"
reset_learning_params!(n)
resetLearningParams!(n)
# clear variables
kfn.firedNeurons = Vector{Int64}()
kfn.firedNeurons_t0 = Vector{Bool}()
kfn.firedNeurons_t1 = Vector{Bool}()
kfn.learningStage = "inference"
end
# compute kfn error
out = [n.z_t1 for n in kfn.outputNeuronsArray]
for (i, v) in enumerate(out)
if v != correctAnswer[i] # need to adjust weight
kfnError = (kfn.outputNeuronsArray[i].v_th - kfn.outputNeuronsArray[i].v_t) *
100 / kfn.outputNeuronsArray[i].v_th
# Threads.@threads for n in kfn.neuronsArray
for n in kfn.neuronsArray
learn!(n, kfnError)
end
learn!(kfn.outputNeuronsArray[i], kfn)
end
end
#WORKING
# Threads.@threads for n in kfn.neuronsArray
for n in kfn.neuronsArray
learn!(n, kfn) # Neurons are always learning, besides error from model output
@@ -71,7 +107,7 @@ function learn!(kfn::kfn_1, correctAnswer=nothing)
# other output neurons
learn!(n, kfn)
end
#TODO: put other KFN to learn here
# for main loop user's display and training's exit condition
avgNeuronsFiringRate = 0.0
@@ -90,6 +126,25 @@ function learn!(kfn::kfn_1, correctAnswer=nothing)
end
kfn.avgNeurons_v_t1 = avgNeurons_v_t1 / kfn.kfnParams[:compute_neuron_number]
end
# wrap up learning session
if kfn.learningStage == "end_learning"
#TODO neuroplasticity
resetLearningParams!(n)
# clear variables
kfn.firedNeurons = Vector{Int64}()
kfn.firedNeurons_t0 = Vector{Bool}()
kfn.firedNeurons_t1 = Vector{Bool}()
kfn.learningStage = "inference"
end
end
""" passthrough_neuron learn()
@@ -100,71 +155,23 @@ end
""" lif learn()
"""
function learn!(n::lif_neuron, kfn::knowledgeFn)
if n.learnable_flag == true
n.decayedEpsilonRec = n.alpha * n.epsilonRec
n.epsilonRec = n.decayedEpsilonRec + n.z_i_t
n.eRec = n.phi * n.epsilonRec
end
function learn!(n::lif_neuron, error::Number)
n.decayedEpsilonRec = n.alpha * n.epsilonRec
n.epsilonRec = n.decayedEpsilonRec + n.z_i_t
n.eRec = n.phi * n.epsilonRec
ΔwRecChange = n.eta * error
n.wRecChange = (n.subExInType * n.wRecChange) + ΔwRecChange
# a piece of knowledgeFn error that belongs to this neuron
n.error = isnothing(kfn.error) ? nothing : kfn.error * n.Bn
n.learningStage = kfn.learningStage
# accumulate voltage regularization terms
Snn_utils.cal_v_reg!(n)
if n.learningStage == "doing_inference"
# no learning
elseif n.learningStage == "start_learning" ||
n.learningStage == "start_learning_no_wchange_reset"
# if error signal available then accumulates Δw
if n.error !== nothing
Snn_utils.firing_rate!(n)
Snn_utils.firing_diff!(n)
n.wRecChange = n.wRecChange +
-apply!(n.optimiser, n.w_rec,
(n.error + Snn_utils.voltage_error!(n) + n.firingRateError) * n.eRec) +
-Snn_utils.firing_rate_regulator!(n) +
-Snn_utils.voltage_regulator!(n)
end
elseif n.learningStage == "during_learning"
# if error signal available then accumulates Δw
if n.error !== nothing
Snn_utils.firing_rate!(n)
Snn_utils.firing_diff!(n)
n.wRecChange = n.wRecChange +
-apply!(n.optimiser, n.w_rec,
(n.error + Snn_utils.voltage_error!(n) + n.firingRateError) * n.eRec) +
-Snn_utils.firing_rate_regulator!(n) +
-Snn_utils.voltage_regulator!(n)
end
elseif n.learningStage == "end_learning"
# if error signal available then accumulates Δw
if n.error !== nothing
Snn_utils.firing_rate!(n)
Snn_utils.firing_diff!(n)
n.wRecChange = n.wRecChange +
-apply!(n.optimiser, n.w_rec,
(n.error + Snn_utils.voltage_error!(n) + n.firingRateError) * n.eRec) +
-Snn_utils.firing_rate_regulator!(n) +
-Snn_utils.voltage_regulator!(n)
end
not_zero = (!iszero).(n.w_rec)
# set 0 in wRecChange update according to 0 in w_rec for hard constrain connection
n.w_rec = n.w_rec + (not_zero .* n.wRecChange)
replace!(x -> x < 0 ? 0 : x, n.w_rec) # no negative weight
Snn_utils.neuroplasticity!(n, kfn.firedNeurons)
end
# check for fliped sign, 1 indicates non-fliped sign
wSign = sign.(n.wRecChange)
nonFlipedSign = isequal.(n.subExInType, wSign) # 1 not fliped, 0 fliped
n.wRecChange .*= nonFlipedSign # set weight that fliped sign to 0 for random new connection
end
""" alif_neuron learn()
"""
function learn!(n::alif_neuron, kfn::knowledgeFn)
function learn!(n::alif_neuron, error::Number)
n.decayedEpsilonRec = n.alpha * n.epsilonRec
n.epsilonRec = n.decayedEpsilonRec + n.z_i_t
n.epsilonRecA = (n.phi * n.epsilonRec) +
@@ -173,117 +180,29 @@ function learn!(n::alif_neuron, kfn::knowledgeFn)
n.eRec_a = -n.phi * n.beta * n.epsilonRecA
n.eRec = n.eRec_v + n.eRec_a
# a piece of knowledgeFn error that belongs to this neuron
n.error = isnothing(kfn.error) ? nothing : kfn.error * n.Bn
n.learningStage = kfn.learningStage
ΔwRecChange = n.eta * error
n.wRecChange = (n.subExInType * n.wRecChange) + ΔwRecChange
if n.learningStage == "doing_inference"
# no learning
elseif n.learningStage == "start_learning" ||
n.learningStage == "start_learning_no_wchange_reset"
# if error signal available then accumulates Δw
if n.error !== nothing
Snn_utils.firing_rate!(n)
Snn_utils.firing_diff!(n)
n.wRecChange = n.wRecChange +
-apply!(n.optimiser, n.w_rec,
(n.error + Snn_utils.voltage_error!(n) + n.firingRateError) * n.eRec) +
-Snn_utils.firing_rate_regulator!(n) +
-Snn_utils.voltage_regulator!(n)
end
elseif n.learningStage == "during_learning"
# if error signal available then accumulates Δw
if n.error !== nothing
Snn_utils.firing_rate!(n)
Snn_utils.firing_diff!(n)
n.wRecChange = n.wRecChange +
-apply!(n.optimiser, n.w_rec,
(n.error + Snn_utils.voltage_error!(n) + n.firingRateError) * n.eRec) +
-Snn_utils.firing_rate_regulator!(n) +
-Snn_utils.voltage_regulator!(n)
end
elseif n.learningStage == "end_learning"
# if error signal available then accumulates Δw
if n.error !== nothing
Snn_utils.firing_rate!(n)
Snn_utils.firing_diff!(n)
n.wRecChange = n.wRecChange +
-apply!(n.optimiser, n.w_rec,
(n.error + Snn_utils.voltage_error!(n) + n.firingRateError) * n.eRec) +
-Snn_utils.firing_rate_regulator!(n) +
-Snn_utils.voltage_regulator!(n)
end
not_zero = (!iszero).(n.w_rec)
# set 0 in wRecChange update according to 0 in w_rec for hard constrain connection
n.w_rec = n.w_rec + (not_zero .* n.wRecChange)
replace!(x -> x < 0 ? 0 : x, n.w_rec) # no negative weight
Snn_utils.neuroplasticity!(n, kfn.firedNeurons)
end
# check for fliped sign, 1 indicates non-fliped sign
wSign = sign.(n.wRecChange)
nonFlipedSign = isequal.(n.subExInType, wSign) # 1 not fliped, 0 fliped
n.wRecChange .*= nonFlipedSign # set weight that fliped sign to 0 for random new connection
end
""" linear_neuron learn()
"""
function learn!(n::linear_neuron, kfn::knowledgeFn)
n.error = kfn.outputError[n.id]
n.learningStage = kfn.learningStage
function learn!(n::linear_neuron, error::Number)
n.decayedEpsilonRec = n.alpha * n.epsilonRec
n.epsilonRec = n.decayedEpsilonRec + n.z_i_t
n.eRec = n.phi * n.epsilonRec
ΔwRecChange = n.eta * error
n.wRecChange = (n.subExInType * n.wRecChange) + ΔwRecChange
if n.learningStage == "doing_inference"
# no learning
elseif n.learningStage == "start_learning"
# if error signal available then accumulates Δw
if n.error !== nothing && n.id == 1 # NOT working w/ multithreading training
Δw = -apply!(n.optimiser, n.w_out, (n.error * n.epsilon_j))
n.w_out_change = n.w_out_change + Δw
n.eta = n.optimiser.eta
Δb = -n.eta * n.error
n.b_change = n.b_change + Δb
elseif n.error !== nothing && n.id !== 1
n.eta = kfn.outputNeuronsArray[1].eta
Δw = -n.eta * n.error * n.epsilon_j
n.w_out_change = n.w_out_change + Δw
Δb = -n.eta * n.error
n.b_change = n.b_change + Δb
end
elseif n.learningStage == "during_learning"
# if error signal available then accumulates Δw
if n.error !== nothing && n.id == 1 # NOT working w/ multithreading training
Δw = -apply!(n.optimiser, n.w_out, (n.error * n.epsilon_j))
n.w_out_change = n.w_out_change + Δw
n.eta = n.optimiser.eta
Δb = -n.eta * n.error
n.b_change = n.b_change + Δb
elseif n.error !== nothing && n.id !== 1
n.eta = kfn.outputNeuronsArray[1].eta
Δw = -n.eta * n.error * n.epsilon_j
n.w_out_change = n.w_out_change + Δw
Δb = -n.eta * n.error
n.b_change = n.b_change + Δb
end
elseif n.learningStage == "end_learning"
# if error signal available then accumulates Δw
if n.error !== nothing && n.id == 1 # NOT working w/ multithreading training
Δw = -apply!(n.optimiser, n.w_out, (n.error * n.epsilon_j))
n.w_out_change = n.w_out_change + Δw
n.eta = n.optimiser.eta
Δb = -n.eta * n.error
n.b_change = n.b_change + Δb
elseif n.error !== nothing && n.id !== 1
n.eta = kfn.outputNeuronsArray[1].eta
Δw = -n.eta * n.error * n.epsilon_j
n.w_out_change = n.w_out_change + Δw
Δb = -n.eta * n.error
n.b_change = n.b_change + Δb
end
n.w_out = n.w_out + n.w_out_change
n.b = n.b + n.b_change
end
# check for fliped sign, 1 indicates non-fliped sign
wSign = sign.(n.wRecChange)
nonFlipedSign = isequal.(n.subExInType, wSign) # 1 not fliped, 0 fliped
n.wRecChange .*= nonFlipedSign # set weight that fliped sign to 0 for random new connection
end