179 lines
5.8 KiB
Julia
179 lines
5.8 KiB
Julia
module learn
|
|
|
|
using Flux.Optimise: apply!
|
|
|
|
using Statistics, Flux, Random, LinearAlgebra
|
|
using GeneralUtils
|
|
using ..types, ..snn_utils
|
|
|
|
export learn!
|
|
|
|
#------------------------------------------------------------------------------------------------100
|
|
|
|
function learn!(m::model, modelRespond, correctAnswer=nothing)
|
|
m.knowledgeFn[:I].learningStage = m.learningStage
|
|
|
|
# # how many matched respond and correct answer
|
|
# matched = sum(isequal(modelRespond, correctAnswer))
|
|
|
|
if correctAnswer === nothing
|
|
correctAnswer_I = zeros(length(modelRespond))
|
|
else
|
|
correctAnswer_I = correctAnswer # correct answer for kfn I
|
|
end
|
|
|
|
learn!(m.knowledgeFn[:I], correctAnswer_I)
|
|
end
|
|
|
|
""" knowledgeFn learn()
|
|
"""
|
|
function learn!(kfn::kfn_1, correctAnswer::AbstractVector)
|
|
|
|
# ΔWeight Conn. Strength
|
|
# case 1 no no during input signal, no correct answer available, no answer
|
|
# case 2 no - during input signal, no correct answer available, wrong answer
|
|
# case 3 + - during input signal, correct answer available, no answer
|
|
# case 4 no - during input signal, correct answer available, wrong answer
|
|
# case 5 no ++ during input signal, correct answer
|
|
# case 6 no ++ after input signal, at correct timing, correct answer
|
|
# case 6 + - after input signal, at correct timing, no answer
|
|
# case 9 no -- after input signal, at correct timing, wrong answer
|
|
# case 7 adjust + after input signal, after correct timing (late), correct answer
|
|
# case 8 after input signal, after correct timing (late), no answer
|
|
# case 8 no - after input signal, after correct timing (late), wrong answer
|
|
|
|
# success
|
|
|
|
if kfn.learningStage == "start_learning"
|
|
# reset params here instead of at the end_learning so that neuron's parameter data
|
|
# don't gets wiped and can be logged for visualization later
|
|
for n in kfn.neuronsArray
|
|
# epsilonRec need to be reset because it counting how many each synaptic fires and
|
|
# use this info to calculate how much synaptic weight should be adjust
|
|
resetLearningParams!(n)
|
|
end
|
|
|
|
# clear variables
|
|
kfn.firedNeurons = Vector{Int64}()
|
|
kfn.firedNeurons_t0 = Vector{Bool}()
|
|
kfn.firedNeurons_t1 = Vector{Bool}()
|
|
|
|
kfn.learningStage = "learning"
|
|
end
|
|
|
|
# compute kfn error
|
|
outs = [n.z_t1 for n in kfn.outputNeuronsArray]
|
|
for (i, out) in enumerate(outs)
|
|
if out != correctAnswer[i] # need to adjust weight
|
|
kfnError = (kfn.outputNeuronsArray[i].v_th - kfn.outputNeuronsArray[i].v_t) *
|
|
100 / kfn.outputNeuronsArray[i].v_th
|
|
|
|
# Threads.@threads for n in kfn.neuronsArray
|
|
for n in kfn.neuronsArray
|
|
learn!(n, kfnError)
|
|
end
|
|
|
|
learn!(kfn.outputNeuronsArray[i], kfn)
|
|
end
|
|
end
|
|
|
|
# wrap up learning session
|
|
if kfn.learningStage == "end_learning"
|
|
# Threads.@threads for n in kfn.neuronsArray
|
|
for n in kfn.neuronsArray
|
|
n.wRec += n.wRecChange # merge wRecChange into wRec
|
|
wSign = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign
|
|
nonFlipedSign = isequal.(n.subExInType, wSign) # 1 not fliped, 0 fliped
|
|
LinearAlgebra.normalize!(n.wRec, 1)
|
|
n.wRec .*= nonFlipedSign # set weight that fliped sign to 0 for random new connection
|
|
|
|
# Threads.@threads for n in kfn.neuronsArray
|
|
for n in kfn.neuronsArray
|
|
#WORKING synapticConnStrength
|
|
|
|
|
|
|
|
#TODO neuroplasticity
|
|
end
|
|
|
|
end
|
|
|
|
for n in kfn.outputNeuronsArray # merge wRecChange into wRec
|
|
n.wRec += n.wRecChange
|
|
wSign = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign
|
|
nonFlipedSign = isequal.(n.subExInType, wSign) # 1 not fliped, 0 fliped
|
|
LinearAlgebra.normalize!(n.wRec, 1)
|
|
n.wRec .*= nonFlipedSign # set weight that fliped sign to 0 for random new connection
|
|
|
|
#TODO synapticConnStrength
|
|
#TODO neuroplasticity
|
|
end
|
|
|
|
|
|
resetLearningParams!(n)
|
|
|
|
# clear variables
|
|
kfn.firedNeurons = Vector{Int64}()
|
|
kfn.firedNeurons_t0 = Vector{Bool}()
|
|
kfn.firedNeurons_t1 = Vector{Bool}()
|
|
|
|
kfn.learningStage = "inference"
|
|
end
|
|
end
|
|
|
|
""" passthrough_neuron learn()
|
|
"""
|
|
function learn!(n::passthrough_neuron, kfn::knowledgeFn)
|
|
# skip
|
|
end
|
|
|
|
""" lif learn()
|
|
"""
|
|
function learn!(n::lif_neuron, error::Number)
|
|
n.eRec = n.phi * n.epsilonRec
|
|
|
|
ΔwRecChange = n.eta * error
|
|
n.wRecChange = (n.subExInType * n.wRecChange) + ΔwRecChange
|
|
reset_epsilonRec!(n)
|
|
end
|
|
|
|
""" alif_neuron learn()
|
|
"""
|
|
function learn!(n::alif_neuron, error::Number)
|
|
n.eRec_v = n.phi * n.epsilonRec
|
|
n.eRec_a = -n.phi * n.beta * n.epsilonRecA
|
|
n.eRec = n.eRec_v + n.eRec_a
|
|
|
|
ΔwRecChange = n.eta * error
|
|
n.wRecChange = (n.subExInType * n.wRecChange) + ΔwRecChange
|
|
reset_epsilonRec!(n)
|
|
end
|
|
|
|
""" linear_neuron learn()
|
|
"""
|
|
function learn!(n::linear_neuron, error::Number)
|
|
n.eRec = n.phi * n.epsilonRec
|
|
|
|
ΔwRecChange = n.eta * error
|
|
n.wRecChange = (n.subExInType * n.wRecChange) + ΔwRecChange
|
|
reset_epsilonRec!(n)
|
|
end
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
end # module end |