Files
Ironpen/src/learn.jl
2023-05-19 14:54:12 +07:00

130 lines
3.6 KiB
Julia

module learn
using Flux.Optimise: apply!
using Statistics, Flux, Random, LinearAlgebra
using GeneralUtils
using ..types, ..snn_utils
export learn!
#------------------------------------------------------------------------------------------------100
function learn!(m::model, modelRespond, correctAnswer=nothing)
if correctAnswer === nothing
correctAnswer_I = zeros(length(modelRespond))
else
correctAnswer_I = correctAnswer # correct answer for kfn I
end
learn!(m.knowledgeFn[:I], correctAnswer_I)
end
""" knowledgeFn learn()
"""
function learn!(kfn::kfn_1, correctAnswer::AbstractVector)
# compute kfn error
outs = [n.z_t1 for n in kfn.outputNeuronsArray]
for (i, out) in enumerate(outs)
if out != correctAnswer[i] # need to adjust weight
kfnError = (kfn.outputNeuronsArray[i].v_th - kfn.outputNeuronsArray[i].vError) *
100 / kfn.outputNeuronsArray[i].v_th
# Threads.@threads for n in kfn.neuronsArray
for n in kfn.neuronsArray
learn!(n, kfnError)
end
learn!(kfn.outputNeuronsArray[i], kfnError)
end
end
# wrap up learning session
if kfn.learningStage == "end_learning"
# Threads.@threads for n in kfn.neuronsArray
for n in kfn.neuronsArray
if typeof(n) <: computeNeuron
wSign_0 = sign.(n.wRec) # original sign
n.wRec += n.wRecChange # merge wRecChange into wRec
wSign_1 = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign
nonFlipedSign = isequal.(wSign_0, wSign_1) # 1 not fliped, 0 fliped
# normalize wRec peak to prevent input signal overwhelming neuron
normalizePeak!(n.wRec, n.wRecChange, 2)
# set weight that fliped sign to 0 for random new connection
n.wRec .*= nonFlipedSign
synapticConnStrength!(n)
neuroplasticity!(n, kfn.firedNeurons, kfn.nExInType)
end
end
for n in kfn.outputNeuronsArray # merge wRecChange into wRec
wSign_0 = sign.(n.wRec) # original sign
n.wRec += n.wRecChange
wSign_1 = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign
nonFlipedSign = isequal.(wSign_0, wSign_1) # 1 not fliped, 0 fliped
normalizePeak!(n.wRec, n.wRecChange, 2)
n.wRec .*= nonFlipedSign # set weight that fliped sign to 0 for random new connection
synapticConnStrength!(n)
neuroplasticity!(n, kfn.firedNeurons, kfn.nExInType)
end
kfn.learningStage = "inference"
end
end
""" passthroughNeuron learn()
"""
function learn!(n::passthroughNeuron, error::Number)
# skip
end
""" lif learn()
"""
function learn!(n::lifNeuron, error::Number)
n.eRec = n.phi * n.epsilonRec
ΔwRecChange = n.eta * error * n.eRec
n.wRecChange .+= ΔwRecChange
reset_epsilonRec!(n)
end
""" alifNeuron learn()
"""
function learn!(n::alifNeuron, error::Number)
n.eRec_v = n.phi * n.epsilonRec
n.eRec_a = -n.phi * n.beta * n.epsilonRecA
n.eRec = n.eRec_v + n.eRec_a
ΔwRecChange = n.eta * error * n.eRec
n.wRecChange .+= ΔwRecChange
reset_epsilonRec!(n)
reset_epsilonRecA!(n)
end
""" linearNeuron learn()
"""
function learn!(n::linearNeuron, error::Number)
n.eRec = n.phi * n.epsilonRec
ΔwRecChange = n.eta * error * n.eRec
n.wRecChange .+= ΔwRecChange
reset_epsilonRec!(n)
end
end # module end