334 lines
11 KiB
Julia
334 lines
11 KiB
Julia
module learn
|
|
|
|
using Flux.Optimise: apply!
|
|
|
|
using Statistics, Flux, Random, LinearAlgebra
|
|
using GeneralUtils
|
|
using ..types
|
|
|
|
export learn!
|
|
|
|
#------------------------------------------------------------------------------------------------100
|
|
|
|
function learn!(m::model, modelRespond, correctAnswer=nothing, correctTiming=nothing)
|
|
|
|
# set all KFN
|
|
if m.learningStage == "start_learning"
|
|
m.knowledgeFn[:I].learningStage = "start_learning"
|
|
elseif m.learningStage == "end_learning"
|
|
m.knowledgeFn[:I].learningStage = "end_learning"
|
|
else
|
|
end
|
|
|
|
#WORKING compute error
|
|
# timingError =
|
|
|
|
|
|
too_early = m.modelParams[:perfect_timing] - m.timeStep
|
|
model_error = (model_respond .- correct_answer) * too_early
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_error = Flux.logitcrossentropy(model_respond, correct_answer)
|
|
output_elements_error = model_respond - correct_answer
|
|
|
|
learn!(m.knowledgeFn[:I], model_error, output_elements_error)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return model_error
|
|
end
|
|
|
|
|
|
# function learn!(m::model, raw_model_respond, correct_answer=nothing)
|
|
# if m.learningStage != "doing_inference"
|
|
# model_error = Flux.logitcrossentropy(raw_model_respond, correct_answer)
|
|
# output_elements_error = raw_model_respond - correct_answer
|
|
|
|
# learn!(m.knowledgeFn[:I], model_error, output_elements_error)
|
|
# else
|
|
# model_error = nothing
|
|
# end
|
|
|
|
# return model_error
|
|
# end
|
|
|
|
|
|
|
|
|
|
""" knowledgeFn learn()
|
|
"""
|
|
function learn!(kfn::knowledgeFn, error::Union{Float64,Nothing}=nothing,
|
|
outputError::Union{Vector,Nothing}=nothing)
|
|
kfn.error = error
|
|
kfn.outputError = outputError
|
|
|
|
kfn.learningStage = m.learningStage
|
|
if m.learningStage == "start_learning"
|
|
# reset params here instead of at the end_learning so that neuron's parameter data
|
|
# don't gets wiped and can be logged for visualization later
|
|
for n in kfn.neuronsArray
|
|
# epsilonRec need to be reset because it counting how many each synaptic fires and
|
|
# use this info to calculate how much synaptic weight should be adjust
|
|
reset_learning_params!(n)
|
|
end
|
|
|
|
# clear variables
|
|
kfn.firedNeurons = Vector{Int64}()
|
|
kfn.outputs = nothing
|
|
end
|
|
|
|
# Threads.@threads for n in kfn.neuronsArray
|
|
for n in kfn.neuronsArray
|
|
learn!(n, kfn) # Neurons are always learning, besides error from model output
|
|
end
|
|
|
|
if kfn.outputError !== nothing
|
|
# Threads.@threads for n in kfn.outputNeuronsArray
|
|
for n in kfn.outputNeuronsArray # not use multithreading because 1st output neuron
|
|
# will set learning rate that will be used by
|
|
# other output neurons
|
|
learn!(n, kfn)
|
|
end
|
|
#TODO: put other KFN to learn here
|
|
|
|
# for main loop user's display and training's exit condition
|
|
avgNeuronsFiringRate = 0.0
|
|
for n in kfn.neuronsArray
|
|
if typeof(n) <: compute_neuron
|
|
avgNeuronsFiringRate += n.firingRate
|
|
end
|
|
end
|
|
kfn.avgNeuronsFiringRate = avgNeuronsFiringRate /
|
|
kfn.kfnParams[:compute_neuron_number]
|
|
avgNeurons_v_t1 = 0.0
|
|
for n in kfn.neuronsArray
|
|
if typeof(n) <: compute_neuron
|
|
avgNeurons_v_t1 += n.v_t1
|
|
end
|
|
end
|
|
kfn.avgNeurons_v_t1 = avgNeurons_v_t1 / kfn.kfnParams[:compute_neuron_number]
|
|
end
|
|
end
|
|
|
|
""" passthrough_neuron learn()
|
|
"""
|
|
function learn!(n::passthrough_neuron, kfn::knowledgeFn)
|
|
# skip
|
|
end
|
|
|
|
""" lif learn()
|
|
"""
|
|
function learn!(n::lif_neuron, kfn::knowledgeFn)
|
|
if n.learnable_flag == true
|
|
|
|
n.decayedEpsilonRec = n.alpha * n.epsilonRec
|
|
n.epsilonRec = n.decayedEpsilonRec + n.z_i_t
|
|
n.eRec = n.phi * n.epsilonRec
|
|
end
|
|
|
|
# a piece of knowledgeFn error that belongs to this neuron
|
|
n.error = isnothing(kfn.error) ? nothing : kfn.error * n.Bn
|
|
n.learningStage = kfn.learningStage
|
|
|
|
# accumulate voltage regularization terms
|
|
Snn_utils.cal_v_reg!(n)
|
|
|
|
if n.learningStage == "doing_inference"
|
|
# no learning
|
|
elseif n.learningStage == "start_learning" ||
|
|
n.learningStage == "start_learning_no_wchange_reset"
|
|
|
|
# if error signal available then accumulates Δw
|
|
if n.error !== nothing
|
|
Snn_utils.firing_rate!(n)
|
|
Snn_utils.firing_diff!(n)
|
|
n.wRecChange = n.wRecChange +
|
|
-apply!(n.optimiser, n.w_rec,
|
|
(n.error + Snn_utils.voltage_error!(n) + n.firingRateError) * n.eRec) +
|
|
-Snn_utils.firing_rate_regulator!(n) +
|
|
-Snn_utils.voltage_regulator!(n)
|
|
end
|
|
elseif n.learningStage == "during_learning"
|
|
# if error signal available then accumulates Δw
|
|
if n.error !== nothing
|
|
Snn_utils.firing_rate!(n)
|
|
Snn_utils.firing_diff!(n)
|
|
n.wRecChange = n.wRecChange +
|
|
-apply!(n.optimiser, n.w_rec,
|
|
(n.error + Snn_utils.voltage_error!(n) + n.firingRateError) * n.eRec) +
|
|
-Snn_utils.firing_rate_regulator!(n) +
|
|
-Snn_utils.voltage_regulator!(n)
|
|
end
|
|
elseif n.learningStage == "end_learning"
|
|
# if error signal available then accumulates Δw
|
|
if n.error !== nothing
|
|
Snn_utils.firing_rate!(n)
|
|
Snn_utils.firing_diff!(n)
|
|
n.wRecChange = n.wRecChange +
|
|
-apply!(n.optimiser, n.w_rec,
|
|
(n.error + Snn_utils.voltage_error!(n) + n.firingRateError) * n.eRec) +
|
|
-Snn_utils.firing_rate_regulator!(n) +
|
|
-Snn_utils.voltage_regulator!(n)
|
|
end
|
|
|
|
not_zero = (!iszero).(n.w_rec)
|
|
# set 0 in wRecChange update according to 0 in w_rec for hard constrain connection
|
|
n.w_rec = n.w_rec + (not_zero .* n.wRecChange)
|
|
replace!(x -> x < 0 ? 0 : x, n.w_rec) # no negative weight
|
|
|
|
Snn_utils.neuroplasticity!(n, kfn.firedNeurons)
|
|
end
|
|
end
|
|
|
|
""" alif_neuron learn()
|
|
"""
|
|
function learn!(n::alif_neuron, kfn::knowledgeFn)
|
|
n.decayedEpsilonRec = n.alpha * n.epsilonRec
|
|
n.epsilonRec = n.decayedEpsilonRec + n.z_i_t
|
|
n.epsilonRecA = (n.phi * n.epsilonRec) +
|
|
((n.rho - (n.phi * n.beta)) * n.epsilonRecA)
|
|
n.eRec_v = n.phi * n.epsilonRec
|
|
n.eRec_a = -n.phi * n.beta * n.epsilonRecA
|
|
n.eRec = n.eRec_v + n.eRec_a
|
|
|
|
# a piece of knowledgeFn error that belongs to this neuron
|
|
n.error = isnothing(kfn.error) ? nothing : kfn.error * n.Bn
|
|
n.learningStage = kfn.learningStage
|
|
|
|
|
|
|
|
if n.learningStage == "doing_inference"
|
|
# no learning
|
|
elseif n.learningStage == "start_learning" ||
|
|
n.learningStage == "start_learning_no_wchange_reset"
|
|
|
|
# if error signal available then accumulates Δw
|
|
if n.error !== nothing
|
|
Snn_utils.firing_rate!(n)
|
|
Snn_utils.firing_diff!(n)
|
|
n.wRecChange = n.wRecChange +
|
|
-apply!(n.optimiser, n.w_rec,
|
|
(n.error + Snn_utils.voltage_error!(n) + n.firingRateError) * n.eRec) +
|
|
-Snn_utils.firing_rate_regulator!(n) +
|
|
-Snn_utils.voltage_regulator!(n)
|
|
end
|
|
elseif n.learningStage == "during_learning"
|
|
# if error signal available then accumulates Δw
|
|
if n.error !== nothing
|
|
Snn_utils.firing_rate!(n)
|
|
Snn_utils.firing_diff!(n)
|
|
n.wRecChange = n.wRecChange +
|
|
-apply!(n.optimiser, n.w_rec,
|
|
(n.error + Snn_utils.voltage_error!(n) + n.firingRateError) * n.eRec) +
|
|
-Snn_utils.firing_rate_regulator!(n) +
|
|
-Snn_utils.voltage_regulator!(n)
|
|
end
|
|
elseif n.learningStage == "end_learning"
|
|
|
|
# if error signal available then accumulates Δw
|
|
if n.error !== nothing
|
|
Snn_utils.firing_rate!(n)
|
|
Snn_utils.firing_diff!(n)
|
|
n.wRecChange = n.wRecChange +
|
|
-apply!(n.optimiser, n.w_rec,
|
|
(n.error + Snn_utils.voltage_error!(n) + n.firingRateError) * n.eRec) +
|
|
-Snn_utils.firing_rate_regulator!(n) +
|
|
-Snn_utils.voltage_regulator!(n)
|
|
end
|
|
|
|
not_zero = (!iszero).(n.w_rec)
|
|
# set 0 in wRecChange update according to 0 in w_rec for hard constrain connection
|
|
n.w_rec = n.w_rec + (not_zero .* n.wRecChange)
|
|
replace!(x -> x < 0 ? 0 : x, n.w_rec) # no negative weight
|
|
|
|
Snn_utils.neuroplasticity!(n, kfn.firedNeurons)
|
|
end
|
|
end
|
|
|
|
""" linear_neuron learn()
|
|
"""
|
|
function learn!(n::linear_neuron, kfn::knowledgeFn)
|
|
n.error = kfn.outputError[n.id]
|
|
n.learningStage = kfn.learningStage
|
|
|
|
if n.learningStage == "doing_inference"
|
|
# no learning
|
|
elseif n.learningStage == "start_learning"
|
|
# if error signal available then accumulates Δw
|
|
if n.error !== nothing && n.id == 1 # NOT working w/ multithreading training
|
|
Δw = -apply!(n.optimiser, n.w_out, (n.error * n.epsilon_j))
|
|
n.w_out_change = n.w_out_change + Δw
|
|
n.eta = n.optimiser.eta
|
|
Δb = -n.eta * n.error
|
|
n.b_change = n.b_change + Δb
|
|
elseif n.error !== nothing && n.id !== 1
|
|
n.eta = kfn.outputNeuronsArray[1].eta
|
|
Δw = -n.eta * n.error * n.epsilon_j
|
|
n.w_out_change = n.w_out_change + Δw
|
|
Δb = -n.eta * n.error
|
|
n.b_change = n.b_change + Δb
|
|
end
|
|
elseif n.learningStage == "during_learning"
|
|
# if error signal available then accumulates Δw
|
|
if n.error !== nothing && n.id == 1 # NOT working w/ multithreading training
|
|
Δw = -apply!(n.optimiser, n.w_out, (n.error * n.epsilon_j))
|
|
n.w_out_change = n.w_out_change + Δw
|
|
n.eta = n.optimiser.eta
|
|
Δb = -n.eta * n.error
|
|
n.b_change = n.b_change + Δb
|
|
elseif n.error !== nothing && n.id !== 1
|
|
n.eta = kfn.outputNeuronsArray[1].eta
|
|
Δw = -n.eta * n.error * n.epsilon_j
|
|
n.w_out_change = n.w_out_change + Δw
|
|
Δb = -n.eta * n.error
|
|
n.b_change = n.b_change + Δb
|
|
end
|
|
elseif n.learningStage == "end_learning"
|
|
# if error signal available then accumulates Δw
|
|
if n.error !== nothing && n.id == 1 # NOT working w/ multithreading training
|
|
Δw = -apply!(n.optimiser, n.w_out, (n.error * n.epsilon_j))
|
|
n.w_out_change = n.w_out_change + Δw
|
|
n.eta = n.optimiser.eta
|
|
Δb = -n.eta * n.error
|
|
n.b_change = n.b_change + Δb
|
|
elseif n.error !== nothing && n.id !== 1
|
|
n.eta = kfn.outputNeuronsArray[1].eta
|
|
Δw = -n.eta * n.error * n.epsilon_j
|
|
n.w_out_change = n.w_out_change + Δw
|
|
Δb = -n.eta * n.error
|
|
n.b_change = n.b_change + Δb
|
|
end
|
|
|
|
n.w_out = n.w_out + n.w_out_change
|
|
n.b = n.b + n.b_change
|
|
end
|
|
end
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
end # module end |