130 lines
3.6 KiB
Julia
130 lines
3.6 KiB
Julia
module learn
|
|
|
|
using Flux.Optimise: apply!
|
|
|
|
using Statistics, Flux, Random, LinearAlgebra
|
|
using GeneralUtils
|
|
using ..types, ..snn_utils
|
|
|
|
export learn!
|
|
|
|
#------------------------------------------------------------------------------------------------100
|
|
|
|
function learn!(m::model, modelRespond, correctAnswer=nothing)
|
|
if correctAnswer === nothing
|
|
correctAnswer_I = zeros(length(modelRespond))
|
|
else
|
|
correctAnswer_I = correctAnswer # correct answer for kfn I
|
|
end
|
|
|
|
learn!(m.knowledgeFn[:I], correctAnswer_I)
|
|
end
|
|
|
|
""" knowledgeFn learn()
|
|
"""
|
|
function learn!(kfn::kfn_1, correctAnswer::AbstractVector)
|
|
# compute kfn error
|
|
outs = [n.z_t1 for n in kfn.outputNeuronsArray]
|
|
for (i, out) in enumerate(outs)
|
|
if out != correctAnswer[i] # need to adjust weight
|
|
kfnError = (kfn.outputNeuronsArray[i].v_th - kfn.outputNeuronsArray[i].vError) *
|
|
100 / kfn.outputNeuronsArray[i].v_th
|
|
|
|
# Threads.@threads for n in kfn.neuronsArray
|
|
for n in kfn.neuronsArray
|
|
learn!(n, kfnError)
|
|
end
|
|
|
|
learn!(kfn.outputNeuronsArray[i], kfnError)
|
|
end
|
|
end
|
|
|
|
# wrap up learning session
|
|
if kfn.learningStage == "end_learning"
|
|
# Threads.@threads for n in kfn.neuronsArray
|
|
for n in kfn.neuronsArray
|
|
if typeof(n) <: computeNeuron
|
|
wSign_0 = sign.(n.wRec) # original sign
|
|
n.wRec += n.wRecChange # merge wRecChange into wRec
|
|
wSign_1 = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign
|
|
nonFlipedSign = isequal.(wSign_0, wSign_1) # 1 not fliped, 0 fliped
|
|
# normalize wRec peak to prevent input signal overwhelming neuron
|
|
normalizePeak!(n.wRec, n.wRecChange, 2)
|
|
# set weight that fliped sign to 0 for random new connection
|
|
n.wRec .*= nonFlipedSign
|
|
|
|
synapticConnStrength!(n)
|
|
neuroplasticity!(n, kfn.firedNeurons, kfn.nExInType)
|
|
end
|
|
end
|
|
|
|
for n in kfn.outputNeuronsArray # merge wRecChange into wRec
|
|
wSign_0 = sign.(n.wRec) # original sign
|
|
n.wRec += n.wRecChange
|
|
wSign_1 = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign
|
|
nonFlipedSign = isequal.(wSign_0, wSign_1) # 1 not fliped, 0 fliped
|
|
normalizePeak!(n.wRec, n.wRecChange, 2)
|
|
n.wRec .*= nonFlipedSign # set weight that fliped sign to 0 for random new connection
|
|
|
|
synapticConnStrength!(n)
|
|
neuroplasticity!(n, kfn.firedNeurons, kfn.nExInType)
|
|
end
|
|
|
|
kfn.learningStage = "inference"
|
|
end
|
|
end
|
|
|
|
""" passthroughNeuron learn()
|
|
"""
|
|
function learn!(n::passthroughNeuron, error::Number)
|
|
# skip
|
|
end
|
|
|
|
""" lif learn()
|
|
"""
|
|
function learn!(n::lifNeuron, error::Number)
|
|
n.eRec = n.phi * n.epsilonRec
|
|
ΔwRecChange = n.eta * error * n.eRec
|
|
n.wRecChange .+= ΔwRecChange
|
|
reset_epsilonRec!(n)
|
|
end
|
|
|
|
""" alifNeuron learn()
|
|
"""
|
|
function learn!(n::alifNeuron, error::Number)
|
|
n.eRec_v = n.phi * n.epsilonRec
|
|
n.eRec_a = -n.phi * n.beta * n.epsilonRecA
|
|
n.eRec = n.eRec_v + n.eRec_a
|
|
ΔwRecChange = n.eta * error * n.eRec
|
|
n.wRecChange .+= ΔwRecChange
|
|
reset_epsilonRec!(n)
|
|
reset_epsilonRecA!(n)
|
|
end
|
|
|
|
""" linearNeuron learn()
|
|
"""
|
|
function learn!(n::linearNeuron, error::Number)
|
|
n.eRec = n.phi * n.epsilonRec
|
|
ΔwRecChange = n.eta * error * n.eRec
|
|
n.wRecChange .+= ΔwRecChange
|
|
reset_epsilonRec!(n)
|
|
end
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
end # module end |