time-based learning method based on new error formula
This commit is contained in:
287
src/learn.jl
287
src/learn.jl
@@ -4,7 +4,7 @@ using Flux.Optimise: apply!
|
||||
|
||||
using Statistics, Flux, Random, LinearAlgebra
|
||||
using GeneralUtils
|
||||
using ..types
|
||||
using ..types, ..snn_utils
|
||||
|
||||
export learn!
|
||||
|
||||
@@ -12,6 +12,23 @@ export learn!
|
||||
|
||||
function learn!(m::model, modelRespond, correctAnswer=nothing)
|
||||
m.knowledgeFn[:I].learningStage = m.learningStage
|
||||
|
||||
# # how many matched respond and correct answer
|
||||
# matched = sum(isequal(modelRespond, correctAnswer))
|
||||
|
||||
if correctAnswer === nothing
|
||||
correctAnswer_I = zeros(length(modelRespond))
|
||||
else
|
||||
correctAnswer_I = correctAnswer # correct answer for kfn I
|
||||
end
|
||||
|
||||
learn!(m.knowledgeFn[:I], correctAnswer_I)
|
||||
end
|
||||
|
||||
""" knowledgeFn learn()
|
||||
"""
|
||||
function learn!(kfn::kfn_1, correctAnswer::AbstractVector)
|
||||
|
||||
# ΔWeight Conn. Strength
|
||||
# case 1 no no during input signal, no correct answer available, no answer
|
||||
# case 2 no - during input signal, no correct answer available, wrong answer
|
||||
@@ -26,39 +43,58 @@ function learn!(m::model, modelRespond, correctAnswer=nothing)
|
||||
# case 8 no - after input signal, after correct timing (late), wrong answer
|
||||
|
||||
# success
|
||||
|
||||
# how many matched respond and correct answer
|
||||
matched = sum(isequal(modelRespond, correctAnswer))
|
||||
|
||||
correctAnswer_I = correctAnswer # correct answer for kfn I
|
||||
learn!(m.knowledgeFn[:I], correctAnswer_I)
|
||||
|
||||
# return model_error
|
||||
end
|
||||
|
||||
|
||||
""" knowledgeFn learn()
|
||||
"""
|
||||
function learn!(kfn::kfn_1, correctAnswer=nothing)
|
||||
|
||||
if kfn.learningStage == "start_learning"
|
||||
# reset params here instead of at the end_learning so that neuron's parameter data
|
||||
# don't gets wiped and can be logged for visualization later
|
||||
for n in kfn.neuronsArray
|
||||
# epsilonRec need to be reset because it counting how many each synaptic fires and
|
||||
# use this info to calculate how much synaptic weight should be adjust
|
||||
reset_learning_params!(n)
|
||||
resetLearningParams!(n)
|
||||
end
|
||||
|
||||
# clear variables
|
||||
kfn.firedNeurons = Vector{Int64}()
|
||||
kfn.outputs = nothing
|
||||
kfn.firedNeurons_t0 = Vector{Bool}()
|
||||
kfn.firedNeurons_t1 = Vector{Bool}()
|
||||
|
||||
kfn.learningStage = "learning"
|
||||
|
||||
#TODO prepare for end learning
|
||||
elseif kfn.learningStage == "end_learning"
|
||||
reset_learning_params!(n)
|
||||
resetLearningParams!(n)
|
||||
|
||||
# clear variables
|
||||
kfn.firedNeurons = Vector{Int64}()
|
||||
kfn.firedNeurons_t0 = Vector{Bool}()
|
||||
kfn.firedNeurons_t1 = Vector{Bool}()
|
||||
|
||||
kfn.learningStage = "inference"
|
||||
end
|
||||
|
||||
# compute kfn error
|
||||
out = [n.z_t1 for n in kfn.outputNeuronsArray]
|
||||
for (i, v) in enumerate(out)
|
||||
if v != correctAnswer[i] # need to adjust weight
|
||||
kfnError = (kfn.outputNeuronsArray[i].v_th - kfn.outputNeuronsArray[i].v_t) *
|
||||
100 / kfn.outputNeuronsArray[i].v_th
|
||||
|
||||
# Threads.@threads for n in kfn.neuronsArray
|
||||
for n in kfn.neuronsArray
|
||||
learn!(n, kfnError)
|
||||
end
|
||||
|
||||
learn!(kfn.outputNeuronsArray[i], kfn)
|
||||
end
|
||||
end
|
||||
#WORKING
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# Threads.@threads for n in kfn.neuronsArray
|
||||
for n in kfn.neuronsArray
|
||||
learn!(n, kfn) # Neurons are always learning, besides error from model output
|
||||
@@ -71,7 +107,7 @@ function learn!(kfn::kfn_1, correctAnswer=nothing)
|
||||
# other output neurons
|
||||
learn!(n, kfn)
|
||||
end
|
||||
#TODO: put other KFN to learn here
|
||||
|
||||
|
||||
# for main loop user's display and training's exit condition
|
||||
avgNeuronsFiringRate = 0.0
|
||||
@@ -90,6 +126,25 @@ function learn!(kfn::kfn_1, correctAnswer=nothing)
|
||||
end
|
||||
kfn.avgNeurons_v_t1 = avgNeurons_v_t1 / kfn.kfnParams[:compute_neuron_number]
|
||||
end
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# wrap up learning session
|
||||
if kfn.learningStage == "end_learning"
|
||||
#TODO neuroplasticity
|
||||
resetLearningParams!(n)
|
||||
|
||||
# clear variables
|
||||
kfn.firedNeurons = Vector{Int64}()
|
||||
kfn.firedNeurons_t0 = Vector{Bool}()
|
||||
kfn.firedNeurons_t1 = Vector{Bool}()
|
||||
|
||||
kfn.learningStage = "inference"
|
||||
end
|
||||
end
|
||||
|
||||
""" passthrough_neuron learn()
|
||||
@@ -100,71 +155,23 @@ end
|
||||
|
||||
""" lif learn()
|
||||
"""
|
||||
function learn!(n::lif_neuron, kfn::knowledgeFn)
|
||||
if n.learnable_flag == true
|
||||
|
||||
n.decayedEpsilonRec = n.alpha * n.epsilonRec
|
||||
n.epsilonRec = n.decayedEpsilonRec + n.z_i_t
|
||||
n.eRec = n.phi * n.epsilonRec
|
||||
end
|
||||
function learn!(n::lif_neuron, error::Number)
|
||||
n.decayedEpsilonRec = n.alpha * n.epsilonRec
|
||||
n.epsilonRec = n.decayedEpsilonRec + n.z_i_t
|
||||
n.eRec = n.phi * n.epsilonRec
|
||||
|
||||
ΔwRecChange = n.eta * error
|
||||
n.wRecChange = (n.subExInType * n.wRecChange) + ΔwRecChange
|
||||
|
||||
# a piece of knowledgeFn error that belongs to this neuron
|
||||
n.error = isnothing(kfn.error) ? nothing : kfn.error * n.Bn
|
||||
n.learningStage = kfn.learningStage
|
||||
|
||||
# accumulate voltage regularization terms
|
||||
Snn_utils.cal_v_reg!(n)
|
||||
|
||||
if n.learningStage == "doing_inference"
|
||||
# no learning
|
||||
elseif n.learningStage == "start_learning" ||
|
||||
n.learningStage == "start_learning_no_wchange_reset"
|
||||
|
||||
# if error signal available then accumulates Δw
|
||||
if n.error !== nothing
|
||||
Snn_utils.firing_rate!(n)
|
||||
Snn_utils.firing_diff!(n)
|
||||
n.wRecChange = n.wRecChange +
|
||||
-apply!(n.optimiser, n.w_rec,
|
||||
(n.error + Snn_utils.voltage_error!(n) + n.firingRateError) * n.eRec) +
|
||||
-Snn_utils.firing_rate_regulator!(n) +
|
||||
-Snn_utils.voltage_regulator!(n)
|
||||
end
|
||||
elseif n.learningStage == "during_learning"
|
||||
# if error signal available then accumulates Δw
|
||||
if n.error !== nothing
|
||||
Snn_utils.firing_rate!(n)
|
||||
Snn_utils.firing_diff!(n)
|
||||
n.wRecChange = n.wRecChange +
|
||||
-apply!(n.optimiser, n.w_rec,
|
||||
(n.error + Snn_utils.voltage_error!(n) + n.firingRateError) * n.eRec) +
|
||||
-Snn_utils.firing_rate_regulator!(n) +
|
||||
-Snn_utils.voltage_regulator!(n)
|
||||
end
|
||||
elseif n.learningStage == "end_learning"
|
||||
# if error signal available then accumulates Δw
|
||||
if n.error !== nothing
|
||||
Snn_utils.firing_rate!(n)
|
||||
Snn_utils.firing_diff!(n)
|
||||
n.wRecChange = n.wRecChange +
|
||||
-apply!(n.optimiser, n.w_rec,
|
||||
(n.error + Snn_utils.voltage_error!(n) + n.firingRateError) * n.eRec) +
|
||||
-Snn_utils.firing_rate_regulator!(n) +
|
||||
-Snn_utils.voltage_regulator!(n)
|
||||
end
|
||||
|
||||
not_zero = (!iszero).(n.w_rec)
|
||||
# set 0 in wRecChange update according to 0 in w_rec for hard constrain connection
|
||||
n.w_rec = n.w_rec + (not_zero .* n.wRecChange)
|
||||
replace!(x -> x < 0 ? 0 : x, n.w_rec) # no negative weight
|
||||
|
||||
Snn_utils.neuroplasticity!(n, kfn.firedNeurons)
|
||||
end
|
||||
# check for fliped sign, 1 indicates non-fliped sign
|
||||
wSign = sign.(n.wRecChange)
|
||||
nonFlipedSign = isequal.(n.subExInType, wSign) # 1 not fliped, 0 fliped
|
||||
n.wRecChange .*= nonFlipedSign # set weight that fliped sign to 0 for random new connection
|
||||
end
|
||||
|
||||
""" alif_neuron learn()
|
||||
"""
|
||||
function learn!(n::alif_neuron, kfn::knowledgeFn)
|
||||
function learn!(n::alif_neuron, error::Number)
|
||||
n.decayedEpsilonRec = n.alpha * n.epsilonRec
|
||||
n.epsilonRec = n.decayedEpsilonRec + n.z_i_t
|
||||
n.epsilonRecA = (n.phi * n.epsilonRec) +
|
||||
@@ -173,117 +180,29 @@ function learn!(n::alif_neuron, kfn::knowledgeFn)
|
||||
n.eRec_a = -n.phi * n.beta * n.epsilonRecA
|
||||
n.eRec = n.eRec_v + n.eRec_a
|
||||
|
||||
# a piece of knowledgeFn error that belongs to this neuron
|
||||
n.error = isnothing(kfn.error) ? nothing : kfn.error * n.Bn
|
||||
n.learningStage = kfn.learningStage
|
||||
ΔwRecChange = n.eta * error
|
||||
n.wRecChange = (n.subExInType * n.wRecChange) + ΔwRecChange
|
||||
|
||||
|
||||
|
||||
if n.learningStage == "doing_inference"
|
||||
# no learning
|
||||
elseif n.learningStage == "start_learning" ||
|
||||
n.learningStage == "start_learning_no_wchange_reset"
|
||||
|
||||
# if error signal available then accumulates Δw
|
||||
if n.error !== nothing
|
||||
Snn_utils.firing_rate!(n)
|
||||
Snn_utils.firing_diff!(n)
|
||||
n.wRecChange = n.wRecChange +
|
||||
-apply!(n.optimiser, n.w_rec,
|
||||
(n.error + Snn_utils.voltage_error!(n) + n.firingRateError) * n.eRec) +
|
||||
-Snn_utils.firing_rate_regulator!(n) +
|
||||
-Snn_utils.voltage_regulator!(n)
|
||||
end
|
||||
elseif n.learningStage == "during_learning"
|
||||
# if error signal available then accumulates Δw
|
||||
if n.error !== nothing
|
||||
Snn_utils.firing_rate!(n)
|
||||
Snn_utils.firing_diff!(n)
|
||||
n.wRecChange = n.wRecChange +
|
||||
-apply!(n.optimiser, n.w_rec,
|
||||
(n.error + Snn_utils.voltage_error!(n) + n.firingRateError) * n.eRec) +
|
||||
-Snn_utils.firing_rate_regulator!(n) +
|
||||
-Snn_utils.voltage_regulator!(n)
|
||||
end
|
||||
elseif n.learningStage == "end_learning"
|
||||
|
||||
# if error signal available then accumulates Δw
|
||||
if n.error !== nothing
|
||||
Snn_utils.firing_rate!(n)
|
||||
Snn_utils.firing_diff!(n)
|
||||
n.wRecChange = n.wRecChange +
|
||||
-apply!(n.optimiser, n.w_rec,
|
||||
(n.error + Snn_utils.voltage_error!(n) + n.firingRateError) * n.eRec) +
|
||||
-Snn_utils.firing_rate_regulator!(n) +
|
||||
-Snn_utils.voltage_regulator!(n)
|
||||
end
|
||||
|
||||
not_zero = (!iszero).(n.w_rec)
|
||||
# set 0 in wRecChange update according to 0 in w_rec for hard constrain connection
|
||||
n.w_rec = n.w_rec + (not_zero .* n.wRecChange)
|
||||
replace!(x -> x < 0 ? 0 : x, n.w_rec) # no negative weight
|
||||
|
||||
Snn_utils.neuroplasticity!(n, kfn.firedNeurons)
|
||||
end
|
||||
# check for fliped sign, 1 indicates non-fliped sign
|
||||
wSign = sign.(n.wRecChange)
|
||||
nonFlipedSign = isequal.(n.subExInType, wSign) # 1 not fliped, 0 fliped
|
||||
n.wRecChange .*= nonFlipedSign # set weight that fliped sign to 0 for random new connection
|
||||
end
|
||||
|
||||
""" linear_neuron learn()
|
||||
"""
|
||||
function learn!(n::linear_neuron, kfn::knowledgeFn)
|
||||
n.error = kfn.outputError[n.id]
|
||||
n.learningStage = kfn.learningStage
|
||||
function learn!(n::linear_neuron, error::Number)
|
||||
n.decayedEpsilonRec = n.alpha * n.epsilonRec
|
||||
n.epsilonRec = n.decayedEpsilonRec + n.z_i_t
|
||||
n.eRec = n.phi * n.epsilonRec
|
||||
|
||||
ΔwRecChange = n.eta * error
|
||||
n.wRecChange = (n.subExInType * n.wRecChange) + ΔwRecChange
|
||||
|
||||
if n.learningStage == "doing_inference"
|
||||
# no learning
|
||||
elseif n.learningStage == "start_learning"
|
||||
# if error signal available then accumulates Δw
|
||||
if n.error !== nothing && n.id == 1 # NOT working w/ multithreading training
|
||||
Δw = -apply!(n.optimiser, n.w_out, (n.error * n.epsilon_j))
|
||||
n.w_out_change = n.w_out_change + Δw
|
||||
n.eta = n.optimiser.eta
|
||||
Δb = -n.eta * n.error
|
||||
n.b_change = n.b_change + Δb
|
||||
elseif n.error !== nothing && n.id !== 1
|
||||
n.eta = kfn.outputNeuronsArray[1].eta
|
||||
Δw = -n.eta * n.error * n.epsilon_j
|
||||
n.w_out_change = n.w_out_change + Δw
|
||||
Δb = -n.eta * n.error
|
||||
n.b_change = n.b_change + Δb
|
||||
end
|
||||
elseif n.learningStage == "during_learning"
|
||||
# if error signal available then accumulates Δw
|
||||
if n.error !== nothing && n.id == 1 # NOT working w/ multithreading training
|
||||
Δw = -apply!(n.optimiser, n.w_out, (n.error * n.epsilon_j))
|
||||
n.w_out_change = n.w_out_change + Δw
|
||||
n.eta = n.optimiser.eta
|
||||
Δb = -n.eta * n.error
|
||||
n.b_change = n.b_change + Δb
|
||||
elseif n.error !== nothing && n.id !== 1
|
||||
n.eta = kfn.outputNeuronsArray[1].eta
|
||||
Δw = -n.eta * n.error * n.epsilon_j
|
||||
n.w_out_change = n.w_out_change + Δw
|
||||
Δb = -n.eta * n.error
|
||||
n.b_change = n.b_change + Δb
|
||||
end
|
||||
elseif n.learningStage == "end_learning"
|
||||
# if error signal available then accumulates Δw
|
||||
if n.error !== nothing && n.id == 1 # NOT working w/ multithreading training
|
||||
Δw = -apply!(n.optimiser, n.w_out, (n.error * n.epsilon_j))
|
||||
n.w_out_change = n.w_out_change + Δw
|
||||
n.eta = n.optimiser.eta
|
||||
Δb = -n.eta * n.error
|
||||
n.b_change = n.b_change + Δb
|
||||
elseif n.error !== nothing && n.id !== 1
|
||||
n.eta = kfn.outputNeuronsArray[1].eta
|
||||
Δw = -n.eta * n.error * n.epsilon_j
|
||||
n.w_out_change = n.w_out_change + Δw
|
||||
Δb = -n.eta * n.error
|
||||
n.b_change = n.b_change + Δb
|
||||
end
|
||||
|
||||
n.w_out = n.w_out + n.w_out_change
|
||||
n.b = n.b + n.b_change
|
||||
end
|
||||
# check for fliped sign, 1 indicates non-fliped sign
|
||||
wSign = sign.(n.wRecChange)
|
||||
nonFlipedSign = isequal.(n.subExInType, wSign) # 1 not fliped, 0 fliped
|
||||
n.wRecChange .*= nonFlipedSign # set weight that fliped sign to 0 for random new connection
|
||||
end
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user