module learn using Flux.Optimise: apply! using Statistics, Flux, Random, LinearAlgebra using GeneralUtils using ..types, ..snn_utils export learn! #------------------------------------------------------------------------------------------------100 function learn!(m::model, modelRespond, correctAnswer=nothing) if correctAnswer === nothing correctAnswer_I = zeros(length(modelRespond)) else correctAnswer_I = correctAnswer # correct answer for kfn I end learn!(m.knowledgeFn[:I], correctAnswer_I) end """ knowledgeFn learn() """ function learn!(kfn::kfn_1, correctAnswer::AbstractVector) # compute kfn error outs = [n.z_t1 for n in kfn.outputNeuronsArray] for (i, out) in enumerate(outs) if out != correctAnswer[i] # need to adjust weight kfnError = (kfn.outputNeuronsArray[i].v_th - kfn.outputNeuronsArray[i].vError) * 100 / kfn.outputNeuronsArray[i].v_th # Threads.@threads for n in kfn.neuronsArray for n in kfn.neuronsArray learn!(n, kfnError) end learn!(kfn.outputNeuronsArray[i], kfnError) end end # wrap up learning session if kfn.learningStage == "end_learning" # Threads.@threads for n in kfn.neuronsArray for n in kfn.neuronsArray if typeof(n) <: computeNeuron wSign_0 = sign.(n.wRec) # original sign n.wRec += n.wRecChange # merge wRecChange into wRec wSign_1 = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign nonFlipedSign = isequal.(wSign_0, wSign_1) # 1 not fliped, 0 fliped # normalize wRec peak to prevent input signal overwhelming neuron normalizePeak!(n.wRec, n.wRecChange, 2) # set weight that fliped sign to 0 for random new connection n.wRec .*= nonFlipedSign synapticConnStrength!(n) neuroplasticity!(n, kfn.firedNeurons, kfn.nExInType) end end for n in kfn.outputNeuronsArray # merge wRecChange into wRec wSign_0 = sign.(n.wRec) # original sign n.wRec += n.wRecChange wSign_1 = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign nonFlipedSign = isequal.(wSign_0, wSign_1) # 1 not fliped, 0 fliped normalizePeak!(n.wRec, n.wRecChange, 2) n.wRec .*= nonFlipedSign # set weight that fliped sign to 0 for random new connection synapticConnStrength!(n) neuroplasticity!(n, kfn.firedNeurons, kfn.nExInType) end kfn.learningStage = "inference" end end """ passthroughNeuron learn() """ function learn!(n::passthroughNeuron, error::Number) # skip end """ lif learn() """ function learn!(n::lifNeuron, error::Number) n.eRec = n.phi * n.epsilonRec ΔwRecChange = n.eta * error * n.eRec n.wRecChange .+= ΔwRecChange reset_epsilonRec!(n) end """ alifNeuron learn() """ function learn!(n::alifNeuron, error::Number) n.eRec_v = n.phi * n.epsilonRec n.eRec_a = -n.phi * n.beta * n.epsilonRecA n.eRec = n.eRec_v + n.eRec_a ΔwRecChange = n.eta * error * n.eRec n.wRecChange .+= ΔwRecChange reset_epsilonRec!(n) reset_epsilonRecA!(n) end """ linearNeuron learn() """ function learn!(n::linearNeuron, error::Number) n.eRec = n.phi * n.epsilonRec ΔwRecChange = n.eta * error * n.eRec n.wRecChange .+= ΔwRecChange reset_epsilonRec!(n) end end # module end