module learn using Flux.Optimise: apply! using Statistics, Flux, Random, LinearAlgebra using GeneralUtils using ..types export learn! #------------------------------------------------------------------------------------------------100 function learn!(m::model, modelRespond, correctAnswer=nothing, correctTiming=nothing) # set all KFN if m.learningStage == "start_learning" m.knowledgeFn[:I].learningStage = "start_learning" elseif m.learningStage == "end_learning" m.knowledgeFn[:I].learningStage = "end_learning" else end #WORKING compute error # timingError = too_early = m.modelParams[:perfect_timing] - m.timeStep model_error = (model_respond .- correct_answer) * too_early model_error = Flux.logitcrossentropy(model_respond, correct_answer) output_elements_error = model_respond - correct_answer learn!(m.knowledgeFn[:I], model_error, output_elements_error) return model_error end # function learn!(m::model, raw_model_respond, correct_answer=nothing) # if m.learningStage != "doing_inference" # model_error = Flux.logitcrossentropy(raw_model_respond, correct_answer) # output_elements_error = raw_model_respond - correct_answer # learn!(m.knowledgeFn[:I], model_error, output_elements_error) # else # model_error = nothing # end # return model_error # end """ knowledgeFn learn() """ function learn!(kfn::knowledgeFn, error::Union{Float64,Nothing}=nothing, outputError::Union{Vector,Nothing}=nothing) kfn.error = error kfn.outputError = outputError kfn.learningStage = m.learningStage if m.learningStage == "start_learning" # reset params here instead of at the end_learning so that neuron's parameter data # don't gets wiped and can be logged for visualization later for n in kfn.neuronsArray # epsilonRec need to be reset because it counting how many each synaptic fires and # use this info to calculate how much synaptic weight should be adjust reset_learning_params!(n) end # clear variables kfn.firedNeurons = Vector{Int64}() kfn.outputs = nothing end # Threads.@threads for n in kfn.neuronsArray for n in kfn.neuronsArray learn!(n, kfn) # Neurons are always learning, besides error from model output end if kfn.outputError !== nothing # Threads.@threads for n in kfn.outputNeuronsArray for n in kfn.outputNeuronsArray # not use multithreading because 1st output neuron # will set learning rate that will be used by # other output neurons learn!(n, kfn) end #TODO: put other KFN to learn here # for main loop user's display and training's exit condition avgNeuronsFiringRate = 0.0 for n in kfn.neuronsArray if typeof(n) <: compute_neuron avgNeuronsFiringRate += n.firingRate end end kfn.avgNeuronsFiringRate = avgNeuronsFiringRate / kfn.kfnParams[:compute_neuron_number] avgNeurons_v_t1 = 0.0 for n in kfn.neuronsArray if typeof(n) <: compute_neuron avgNeurons_v_t1 += n.v_t1 end end kfn.avgNeurons_v_t1 = avgNeurons_v_t1 / kfn.kfnParams[:compute_neuron_number] end end """ passthrough_neuron learn() """ function learn!(n::passthrough_neuron, kfn::knowledgeFn) # skip end """ lif learn() """ function learn!(n::lif_neuron, kfn::knowledgeFn) if n.learnable_flag == true n.decayedEpsilonRec = n.alpha * n.epsilonRec n.epsilonRec = n.decayedEpsilonRec + n.z_i_t n.eRec = n.phi * n.epsilonRec end # a piece of knowledgeFn error that belongs to this neuron n.error = isnothing(kfn.error) ? nothing : kfn.error * n.Bn n.learningStage = kfn.learningStage # accumulate voltage regularization terms Snn_utils.cal_v_reg!(n) if n.learningStage == "doing_inference" # no learning elseif n.learningStage == "start_learning" || n.learningStage == "start_learning_no_wchange_reset" # if error signal available then accumulates Δw if n.error !== nothing Snn_utils.firing_rate!(n) Snn_utils.firing_diff!(n) n.wRecChange = n.wRecChange + -apply!(n.optimiser, n.w_rec, (n.error + Snn_utils.voltage_error!(n) + n.firingRateError) * n.eRec) + -Snn_utils.firing_rate_regulator!(n) + -Snn_utils.voltage_regulator!(n) end elseif n.learningStage == "during_learning" # if error signal available then accumulates Δw if n.error !== nothing Snn_utils.firing_rate!(n) Snn_utils.firing_diff!(n) n.wRecChange = n.wRecChange + -apply!(n.optimiser, n.w_rec, (n.error + Snn_utils.voltage_error!(n) + n.firingRateError) * n.eRec) + -Snn_utils.firing_rate_regulator!(n) + -Snn_utils.voltage_regulator!(n) end elseif n.learningStage == "end_learning" # if error signal available then accumulates Δw if n.error !== nothing Snn_utils.firing_rate!(n) Snn_utils.firing_diff!(n) n.wRecChange = n.wRecChange + -apply!(n.optimiser, n.w_rec, (n.error + Snn_utils.voltage_error!(n) + n.firingRateError) * n.eRec) + -Snn_utils.firing_rate_regulator!(n) + -Snn_utils.voltage_regulator!(n) end not_zero = (!iszero).(n.w_rec) # set 0 in wRecChange update according to 0 in w_rec for hard constrain connection n.w_rec = n.w_rec + (not_zero .* n.wRecChange) replace!(x -> x < 0 ? 0 : x, n.w_rec) # no negative weight Snn_utils.neuroplasticity!(n, kfn.firedNeurons) end end """ alif_neuron learn() """ function learn!(n::alif_neuron, kfn::knowledgeFn) n.decayedEpsilonRec = n.alpha * n.epsilonRec n.epsilonRec = n.decayedEpsilonRec + n.z_i_t n.epsilonRecA = (n.phi * n.epsilonRec) + ((n.rho - (n.phi * n.beta)) * n.epsilonRecA) n.eRec_v = n.phi * n.epsilonRec n.eRec_a = -n.phi * n.beta * n.epsilonRecA n.eRec = n.eRec_v + n.eRec_a # a piece of knowledgeFn error that belongs to this neuron n.error = isnothing(kfn.error) ? nothing : kfn.error * n.Bn n.learningStage = kfn.learningStage if n.learningStage == "doing_inference" # no learning elseif n.learningStage == "start_learning" || n.learningStage == "start_learning_no_wchange_reset" # if error signal available then accumulates Δw if n.error !== nothing Snn_utils.firing_rate!(n) Snn_utils.firing_diff!(n) n.wRecChange = n.wRecChange + -apply!(n.optimiser, n.w_rec, (n.error + Snn_utils.voltage_error!(n) + n.firingRateError) * n.eRec) + -Snn_utils.firing_rate_regulator!(n) + -Snn_utils.voltage_regulator!(n) end elseif n.learningStage == "during_learning" # if error signal available then accumulates Δw if n.error !== nothing Snn_utils.firing_rate!(n) Snn_utils.firing_diff!(n) n.wRecChange = n.wRecChange + -apply!(n.optimiser, n.w_rec, (n.error + Snn_utils.voltage_error!(n) + n.firingRateError) * n.eRec) + -Snn_utils.firing_rate_regulator!(n) + -Snn_utils.voltage_regulator!(n) end elseif n.learningStage == "end_learning" # if error signal available then accumulates Δw if n.error !== nothing Snn_utils.firing_rate!(n) Snn_utils.firing_diff!(n) n.wRecChange = n.wRecChange + -apply!(n.optimiser, n.w_rec, (n.error + Snn_utils.voltage_error!(n) + n.firingRateError) * n.eRec) + -Snn_utils.firing_rate_regulator!(n) + -Snn_utils.voltage_regulator!(n) end not_zero = (!iszero).(n.w_rec) # set 0 in wRecChange update according to 0 in w_rec for hard constrain connection n.w_rec = n.w_rec + (not_zero .* n.wRecChange) replace!(x -> x < 0 ? 0 : x, n.w_rec) # no negative weight Snn_utils.neuroplasticity!(n, kfn.firedNeurons) end end """ linear_neuron learn() """ function learn!(n::linear_neuron, kfn::knowledgeFn) n.error = kfn.outputError[n.id] n.learningStage = kfn.learningStage if n.learningStage == "doing_inference" # no learning elseif n.learningStage == "start_learning" # if error signal available then accumulates Δw if n.error !== nothing && n.id == 1 # NOT working w/ multithreading training Δw = -apply!(n.optimiser, n.w_out, (n.error * n.epsilon_j)) n.w_out_change = n.w_out_change + Δw n.eta = n.optimiser.eta Δb = -n.eta * n.error n.b_change = n.b_change + Δb elseif n.error !== nothing && n.id !== 1 n.eta = kfn.outputNeuronsArray[1].eta Δw = -n.eta * n.error * n.epsilon_j n.w_out_change = n.w_out_change + Δw Δb = -n.eta * n.error n.b_change = n.b_change + Δb end elseif n.learningStage == "during_learning" # if error signal available then accumulates Δw if n.error !== nothing && n.id == 1 # NOT working w/ multithreading training Δw = -apply!(n.optimiser, n.w_out, (n.error * n.epsilon_j)) n.w_out_change = n.w_out_change + Δw n.eta = n.optimiser.eta Δb = -n.eta * n.error n.b_change = n.b_change + Δb elseif n.error !== nothing && n.id !== 1 n.eta = kfn.outputNeuronsArray[1].eta Δw = -n.eta * n.error * n.epsilon_j n.w_out_change = n.w_out_change + Δw Δb = -n.eta * n.error n.b_change = n.b_change + Δb end elseif n.learningStage == "end_learning" # if error signal available then accumulates Δw if n.error !== nothing && n.id == 1 # NOT working w/ multithreading training Δw = -apply!(n.optimiser, n.w_out, (n.error * n.epsilon_j)) n.w_out_change = n.w_out_change + Δw n.eta = n.optimiser.eta Δb = -n.eta * n.error n.b_change = n.b_change + Δb elseif n.error !== nothing && n.id !== 1 n.eta = kfn.outputNeuronsArray[1].eta Δw = -n.eta * n.error * n.epsilon_j n.w_out_change = n.w_out_change + Δw Δb = -n.eta * n.error n.b_change = n.b_change + Δb end n.w_out = n.w_out + n.w_out_change n.b = n.b + n.b_change end end end # module end