Files
Ironpen/src/learn.jl
2023-05-16 23:24:56 +07:00

179 lines
5.8 KiB
Julia

module learn
using Flux.Optimise: apply!
using Statistics, Flux, Random, LinearAlgebra
using GeneralUtils
using ..types, ..snn_utils
export learn!
#------------------------------------------------------------------------------------------------100
function learn!(m::model, modelRespond, correctAnswer=nothing)
m.knowledgeFn[:I].learningStage = m.learningStage
# # how many matched respond and correct answer
# matched = sum(isequal(modelRespond, correctAnswer))
if correctAnswer === nothing
correctAnswer_I = zeros(length(modelRespond))
else
correctAnswer_I = correctAnswer # correct answer for kfn I
end
learn!(m.knowledgeFn[:I], correctAnswer_I)
end
""" knowledgeFn learn()
"""
function learn!(kfn::kfn_1, correctAnswer::AbstractVector)
# ΔWeight Conn. Strength
# case 1 no no during input signal, no correct answer available, no answer
# case 2 no - during input signal, no correct answer available, wrong answer
# case 3 + - during input signal, correct answer available, no answer
# case 4 no - during input signal, correct answer available, wrong answer
# case 5 no ++ during input signal, correct answer
# case 6 no ++ after input signal, at correct timing, correct answer
# case 6 + - after input signal, at correct timing, no answer
# case 9 no -- after input signal, at correct timing, wrong answer
# case 7 adjust + after input signal, after correct timing (late), correct answer
# case 8 after input signal, after correct timing (late), no answer
# case 8 no - after input signal, after correct timing (late), wrong answer
# success
if kfn.learningStage == "start_learning"
# reset params here instead of at the end_learning so that neuron's parameter data
# don't gets wiped and can be logged for visualization later
for n in kfn.neuronsArray
# epsilonRec need to be reset because it counting how many each synaptic fires and
# use this info to calculate how much synaptic weight should be adjust
resetLearningParams!(n)
end
# clear variables
kfn.firedNeurons = Vector{Int64}()
kfn.firedNeurons_t0 = Vector{Bool}()
kfn.firedNeurons_t1 = Vector{Bool}()
kfn.learningStage = "learning"
end
# compute kfn error
outs = [n.z_t1 for n in kfn.outputNeuronsArray]
for (i, out) in enumerate(outs)
if out != correctAnswer[i] # need to adjust weight
kfnError = (kfn.outputNeuronsArray[i].v_th - kfn.outputNeuronsArray[i].v_t) *
100 / kfn.outputNeuronsArray[i].v_th
# Threads.@threads for n in kfn.neuronsArray
for n in kfn.neuronsArray
learn!(n, kfnError)
end
learn!(kfn.outputNeuronsArray[i], kfn)
end
end
# wrap up learning session
if kfn.learningStage == "end_learning"
# Threads.@threads for n in kfn.neuronsArray
for n in kfn.neuronsArray
n.wRec += n.wRecChange # merge wRecChange into wRec
wSign = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign
nonFlipedSign = isequal.(n.subExInType, wSign) # 1 not fliped, 0 fliped
LinearAlgebra.normalize!(n.wRec, 1)
n.wRec .*= nonFlipedSign # set weight that fliped sign to 0 for random new connection
# Threads.@threads for n in kfn.neuronsArray
for n in kfn.neuronsArray
#WORKING synapticConnStrength
#TODO neuroplasticity
end
end
for n in kfn.outputNeuronsArray # merge wRecChange into wRec
n.wRec += n.wRecChange
wSign = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign
nonFlipedSign = isequal.(n.subExInType, wSign) # 1 not fliped, 0 fliped
LinearAlgebra.normalize!(n.wRec, 1)
n.wRec .*= nonFlipedSign # set weight that fliped sign to 0 for random new connection
#TODO synapticConnStrength
#TODO neuroplasticity
end
resetLearningParams!(n)
# clear variables
kfn.firedNeurons = Vector{Int64}()
kfn.firedNeurons_t0 = Vector{Bool}()
kfn.firedNeurons_t1 = Vector{Bool}()
kfn.learningStage = "inference"
end
end
""" passthrough_neuron learn()
"""
function learn!(n::passthrough_neuron, kfn::knowledgeFn)
# skip
end
""" lif learn()
"""
function learn!(n::lif_neuron, error::Number)
n.eRec = n.phi * n.epsilonRec
ΔwRecChange = n.eta * error
n.wRecChange = (n.subExInType * n.wRecChange) + ΔwRecChange
reset_epsilonRec!(n)
end
""" alif_neuron learn()
"""
function learn!(n::alif_neuron, error::Number)
n.eRec_v = n.phi * n.epsilonRec
n.eRec_a = -n.phi * n.beta * n.epsilonRecA
n.eRec = n.eRec_v + n.eRec_a
ΔwRecChange = n.eta * error
n.wRecChange = (n.subExInType * n.wRecChange) + ΔwRecChange
reset_epsilonRec!(n)
end
""" linear_neuron learn()
"""
function learn!(n::linear_neuron, error::Number)
n.eRec = n.phi * n.epsilonRec
ΔwRecChange = n.eta * error
n.wRecChange = (n.subExInType * n.wRecChange) + ΔwRecChange
reset_epsilonRec!(n)
end
end # module end