Files
Ironpen/src/learn.jl
2023-05-12 19:50:02 +07:00

334 lines
11 KiB
Julia

module learn
using Flux.Optimise: apply!
using Statistics, Flux, Random, LinearAlgebra
using GeneralUtils
using ..types
export learn!
#------------------------------------------------------------------------------------------------100
function learn!(m::model, modelRespond, correctAnswer=nothing, correctTiming=nothing)
# set all KFN
if m.learningStage == "start_learning"
m.knowledgeFn[:I].learningStage = "start_learning"
elseif m.learningStage == "end_learning"
m.knowledgeFn[:I].learningStage = "end_learning"
else
end
#WORKING compute error
# timingError =
too_early = m.modelParams[:perfect_timing] - m.timeStep
model_error = (model_respond .- correct_answer) * too_early
model_error = Flux.logitcrossentropy(model_respond, correct_answer)
output_elements_error = model_respond - correct_answer
learn!(m.knowledgeFn[:I], model_error, output_elements_error)
return model_error
end
# function learn!(m::model, raw_model_respond, correct_answer=nothing)
# if m.learningStage != "doing_inference"
# model_error = Flux.logitcrossentropy(raw_model_respond, correct_answer)
# output_elements_error = raw_model_respond - correct_answer
# learn!(m.knowledgeFn[:I], model_error, output_elements_error)
# else
# model_error = nothing
# end
# return model_error
# end
""" knowledgeFn learn()
"""
function learn!(kfn::knowledgeFn, error::Union{Float64,Nothing}=nothing,
outputError::Union{Vector,Nothing}=nothing)
kfn.error = error
kfn.outputError = outputError
kfn.learningStage = m.learningStage
if m.learningStage == "start_learning"
# reset params here instead of at the end_learning so that neuron's parameter data
# don't gets wiped and can be logged for visualization later
for n in kfn.neuronsArray
# epsilonRec need to be reset because it counting how many each synaptic fires and
# use this info to calculate how much synaptic weight should be adjust
reset_learning_params!(n)
end
# clear variables
kfn.firedNeurons = Vector{Int64}()
kfn.outputs = nothing
end
# Threads.@threads for n in kfn.neuronsArray
for n in kfn.neuronsArray
learn!(n, kfn) # Neurons are always learning, besides error from model output
end
if kfn.outputError !== nothing
# Threads.@threads for n in kfn.outputNeuronsArray
for n in kfn.outputNeuronsArray # not use multithreading because 1st output neuron
# will set learning rate that will be used by
# other output neurons
learn!(n, kfn)
end
#TODO: put other KFN to learn here
# for main loop user's display and training's exit condition
avgNeuronsFiringRate = 0.0
for n in kfn.neuronsArray
if typeof(n) <: compute_neuron
avgNeuronsFiringRate += n.firingRate
end
end
kfn.avgNeuronsFiringRate = avgNeuronsFiringRate /
kfn.kfnParams[:compute_neuron_number]
avgNeurons_v_t1 = 0.0
for n in kfn.neuronsArray
if typeof(n) <: compute_neuron
avgNeurons_v_t1 += n.v_t1
end
end
kfn.avgNeurons_v_t1 = avgNeurons_v_t1 / kfn.kfnParams[:compute_neuron_number]
end
end
""" passthrough_neuron learn()
"""
function learn!(n::passthrough_neuron, kfn::knowledgeFn)
# skip
end
""" lif learn()
"""
function learn!(n::lif_neuron, kfn::knowledgeFn)
if n.learnable_flag == true
n.decayedEpsilonRec = n.alpha * n.epsilonRec
n.epsilonRec = n.decayedEpsilonRec + n.z_i_t
n.eRec = n.phi * n.epsilonRec
end
# a piece of knowledgeFn error that belongs to this neuron
n.error = isnothing(kfn.error) ? nothing : kfn.error * n.Bn
n.learningStage = kfn.learningStage
# accumulate voltage regularization terms
Snn_utils.cal_v_reg!(n)
if n.learningStage == "doing_inference"
# no learning
elseif n.learningStage == "start_learning" ||
n.learningStage == "start_learning_no_wchange_reset"
# if error signal available then accumulates Δw
if n.error !== nothing
Snn_utils.firing_rate!(n)
Snn_utils.firing_diff!(n)
n.wRecChange = n.wRecChange +
-apply!(n.optimiser, n.w_rec,
(n.error + Snn_utils.voltage_error!(n) + n.firingRateError) * n.eRec) +
-Snn_utils.firing_rate_regulator!(n) +
-Snn_utils.voltage_regulator!(n)
end
elseif n.learningStage == "during_learning"
# if error signal available then accumulates Δw
if n.error !== nothing
Snn_utils.firing_rate!(n)
Snn_utils.firing_diff!(n)
n.wRecChange = n.wRecChange +
-apply!(n.optimiser, n.w_rec,
(n.error + Snn_utils.voltage_error!(n) + n.firingRateError) * n.eRec) +
-Snn_utils.firing_rate_regulator!(n) +
-Snn_utils.voltage_regulator!(n)
end
elseif n.learningStage == "end_learning"
# if error signal available then accumulates Δw
if n.error !== nothing
Snn_utils.firing_rate!(n)
Snn_utils.firing_diff!(n)
n.wRecChange = n.wRecChange +
-apply!(n.optimiser, n.w_rec,
(n.error + Snn_utils.voltage_error!(n) + n.firingRateError) * n.eRec) +
-Snn_utils.firing_rate_regulator!(n) +
-Snn_utils.voltage_regulator!(n)
end
not_zero = (!iszero).(n.w_rec)
# set 0 in wRecChange update according to 0 in w_rec for hard constrain connection
n.w_rec = n.w_rec + (not_zero .* n.wRecChange)
replace!(x -> x < 0 ? 0 : x, n.w_rec) # no negative weight
Snn_utils.neuroplasticity!(n, kfn.firedNeurons)
end
end
""" alif_neuron learn()
"""
function learn!(n::alif_neuron, kfn::knowledgeFn)
n.decayedEpsilonRec = n.alpha * n.epsilonRec
n.epsilonRec = n.decayedEpsilonRec + n.z_i_t
n.epsilonRecA = (n.phi * n.epsilonRec) +
((n.rho - (n.phi * n.beta)) * n.epsilonRecA)
n.eRec_v = n.phi * n.epsilonRec
n.eRec_a = -n.phi * n.beta * n.epsilonRecA
n.eRec = n.eRec_v + n.eRec_a
# a piece of knowledgeFn error that belongs to this neuron
n.error = isnothing(kfn.error) ? nothing : kfn.error * n.Bn
n.learningStage = kfn.learningStage
if n.learningStage == "doing_inference"
# no learning
elseif n.learningStage == "start_learning" ||
n.learningStage == "start_learning_no_wchange_reset"
# if error signal available then accumulates Δw
if n.error !== nothing
Snn_utils.firing_rate!(n)
Snn_utils.firing_diff!(n)
n.wRecChange = n.wRecChange +
-apply!(n.optimiser, n.w_rec,
(n.error + Snn_utils.voltage_error!(n) + n.firingRateError) * n.eRec) +
-Snn_utils.firing_rate_regulator!(n) +
-Snn_utils.voltage_regulator!(n)
end
elseif n.learningStage == "during_learning"
# if error signal available then accumulates Δw
if n.error !== nothing
Snn_utils.firing_rate!(n)
Snn_utils.firing_diff!(n)
n.wRecChange = n.wRecChange +
-apply!(n.optimiser, n.w_rec,
(n.error + Snn_utils.voltage_error!(n) + n.firingRateError) * n.eRec) +
-Snn_utils.firing_rate_regulator!(n) +
-Snn_utils.voltage_regulator!(n)
end
elseif n.learningStage == "end_learning"
# if error signal available then accumulates Δw
if n.error !== nothing
Snn_utils.firing_rate!(n)
Snn_utils.firing_diff!(n)
n.wRecChange = n.wRecChange +
-apply!(n.optimiser, n.w_rec,
(n.error + Snn_utils.voltage_error!(n) + n.firingRateError) * n.eRec) +
-Snn_utils.firing_rate_regulator!(n) +
-Snn_utils.voltage_regulator!(n)
end
not_zero = (!iszero).(n.w_rec)
# set 0 in wRecChange update according to 0 in w_rec for hard constrain connection
n.w_rec = n.w_rec + (not_zero .* n.wRecChange)
replace!(x -> x < 0 ? 0 : x, n.w_rec) # no negative weight
Snn_utils.neuroplasticity!(n, kfn.firedNeurons)
end
end
""" linear_neuron learn()
"""
function learn!(n::linear_neuron, kfn::knowledgeFn)
n.error = kfn.outputError[n.id]
n.learningStage = kfn.learningStage
if n.learningStage == "doing_inference"
# no learning
elseif n.learningStage == "start_learning"
# if error signal available then accumulates Δw
if n.error !== nothing && n.id == 1 # NOT working w/ multithreading training
Δw = -apply!(n.optimiser, n.w_out, (n.error * n.epsilon_j))
n.w_out_change = n.w_out_change + Δw
n.eta = n.optimiser.eta
Δb = -n.eta * n.error
n.b_change = n.b_change + Δb
elseif n.error !== nothing && n.id !== 1
n.eta = kfn.outputNeuronsArray[1].eta
Δw = -n.eta * n.error * n.epsilon_j
n.w_out_change = n.w_out_change + Δw
Δb = -n.eta * n.error
n.b_change = n.b_change + Δb
end
elseif n.learningStage == "during_learning"
# if error signal available then accumulates Δw
if n.error !== nothing && n.id == 1 # NOT working w/ multithreading training
Δw = -apply!(n.optimiser, n.w_out, (n.error * n.epsilon_j))
n.w_out_change = n.w_out_change + Δw
n.eta = n.optimiser.eta
Δb = -n.eta * n.error
n.b_change = n.b_change + Δb
elseif n.error !== nothing && n.id !== 1
n.eta = kfn.outputNeuronsArray[1].eta
Δw = -n.eta * n.error * n.epsilon_j
n.w_out_change = n.w_out_change + Δw
Δb = -n.eta * n.error
n.b_change = n.b_change + Δb
end
elseif n.learningStage == "end_learning"
# if error signal available then accumulates Δw
if n.error !== nothing && n.id == 1 # NOT working w/ multithreading training
Δw = -apply!(n.optimiser, n.w_out, (n.error * n.epsilon_j))
n.w_out_change = n.w_out_change + Δw
n.eta = n.optimiser.eta
Δb = -n.eta * n.error
n.b_change = n.b_change + Δb
elseif n.error !== nothing && n.id !== 1
n.eta = kfn.outputNeuronsArray[1].eta
Δw = -n.eta * n.error * n.epsilon_j
n.w_out_change = n.w_out_change + Δw
Δb = -n.eta * n.error
n.b_change = n.b_change + Δb
end
n.w_out = n.w_out + n.w_out_change
n.b = n.b + n.b_change
end
end
end # module end