module learn using Flux.Optimise: apply! using Statistics, Flux, Random, LinearAlgebra using GeneralUtils using ..types, ..snn_utils export learn! #------------------------------------------------------------------------------------------------100 function learn!(m::model, modelRespond, correctAnswer=nothing) m.knowledgeFn[:I].learningStage = m.learningStage # # how many matched respond and correct answer # matched = sum(isequal(modelRespond, correctAnswer)) if correctAnswer === nothing correctAnswer_I = zeros(length(modelRespond)) else correctAnswer_I = correctAnswer # correct answer for kfn I end learn!(m.knowledgeFn[:I], correctAnswer_I) end """ knowledgeFn learn() """ function learn!(kfn::kfn_1, correctAnswer::AbstractVector) # ΔWeight Conn. Strength # case 1 no no during input signal, no correct answer available, no answer # case 2 no - during input signal, no correct answer available, wrong answer # case 3 + - during input signal, correct answer available, no answer # case 4 no - during input signal, correct answer available, wrong answer # case 5 no ++ during input signal, correct answer # case 6 no ++ after input signal, at correct timing, correct answer # case 6 + - after input signal, at correct timing, no answer # case 9 no -- after input signal, at correct timing, wrong answer # case 7 adjust + after input signal, after correct timing (late), correct answer # case 8 after input signal, after correct timing (late), no answer # case 8 no - after input signal, after correct timing (late), wrong answer # success if kfn.learningStage == "start_learning" # reset params here instead of at the end_learning so that neuron's parameter data # don't gets wiped and can be logged for visualization later for n in kfn.neuronsArray # epsilonRec need to be reset because it counting how many each synaptic fires and # use this info to calculate how much synaptic weight should be adjust resetLearningParams!(n) end # clear variables kfn.firedNeurons = Vector{Int64}() kfn.firedNeurons_t0 = Vector{Bool}() kfn.firedNeurons_t1 = Vector{Bool}() kfn.learningStage = "learning" end # compute kfn error outs = [n.z_t1 for n in kfn.outputNeuronsArray] for (i, out) in enumerate(outs) if out != correctAnswer[i] # need to adjust weight kfnError = (kfn.outputNeuronsArray[i].v_th - kfn.outputNeuronsArray[i].v_t) * 100 / kfn.outputNeuronsArray[i].v_th # Threads.@threads for n in kfn.neuronsArray for n in kfn.neuronsArray learn!(n, kfnError) end learn!(kfn.outputNeuronsArray[i], kfn) end end # wrap up learning session if kfn.learningStage == "end_learning" # Threads.@threads for n in kfn.neuronsArray for n in kfn.neuronsArray n.wRec += n.wRecChange # merge wRecChange into wRec wSign = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign nonFlipedSign = isequal.(n.subExInType, wSign) # 1 not fliped, 0 fliped LinearAlgebra.normalize!(n.wRec, 1) n.wRec .*= nonFlipedSign # set weight that fliped sign to 0 for random new connection # Threads.@threads for n in kfn.neuronsArray for n in kfn.neuronsArray #WORKING synapticConnStrength #TODO neuroplasticity end end for n in kfn.outputNeuronsArray # merge wRecChange into wRec n.wRec += n.wRecChange wSign = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign nonFlipedSign = isequal.(n.subExInType, wSign) # 1 not fliped, 0 fliped LinearAlgebra.normalize!(n.wRec, 1) n.wRec .*= nonFlipedSign # set weight that fliped sign to 0 for random new connection #TODO synapticConnStrength #TODO neuroplasticity end resetLearningParams!(n) # clear variables kfn.firedNeurons = Vector{Int64}() kfn.firedNeurons_t0 = Vector{Bool}() kfn.firedNeurons_t1 = Vector{Bool}() kfn.learningStage = "inference" end end """ passthrough_neuron learn() """ function learn!(n::passthrough_neuron, kfn::knowledgeFn) # skip end """ lif learn() """ function learn!(n::lif_neuron, error::Number) n.eRec = n.phi * n.epsilonRec ΔwRecChange = n.eta * error n.wRecChange = (n.subExInType * n.wRecChange) + ΔwRecChange reset_epsilonRec!(n) end """ alif_neuron learn() """ function learn!(n::alif_neuron, error::Number) n.eRec_v = n.phi * n.epsilonRec n.eRec_a = -n.phi * n.beta * n.epsilonRecA n.eRec = n.eRec_v + n.eRec_a ΔwRecChange = n.eta * error n.wRecChange = (n.subExInType * n.wRecChange) + ΔwRecChange reset_epsilonRec!(n) end """ linear_neuron learn() """ function learn!(n::linear_neuron, error::Number) n.eRec = n.phi * n.epsilonRec ΔwRecChange = n.eta * error n.wRecChange = (n.subExInType * n.wRecChange) + ΔwRecChange reset_epsilonRec!(n) end end # module end