161 lines
3.9 KiB
Julia
161 lines
3.9 KiB
Julia
module learn
|
|
|
|
using Statistics, Random, LinearAlgebra, JSON3, Flux
|
|
using GeneralUtils
|
|
using ..types, ..snn_utils
|
|
|
|
export learn!, compute_wRecChange!, computeModelError
|
|
|
|
#------------------------------------------------------------------------------------------------100
|
|
|
|
function learn!(m::model)
|
|
learn!(m.knowledgeFn[:I])
|
|
end
|
|
|
|
""" knowledgeFn learn()
|
|
"""
|
|
function learn!(kfn::kfn_1)
|
|
# compute kfn error for each neuron
|
|
Threads.@threads for n in kfn.neuronsArray # multithread is not atomic and causing error
|
|
# for n in kfn.neuronsArray
|
|
learn!(n, kfn.firedNeurons, kfn.nExInType)
|
|
end
|
|
for n in kfn.outputNeuronsArray
|
|
learn!(n, kfn.firedNeurons, kfn.nExInType, kfn.kfnParams[:totalInputPort])
|
|
end
|
|
|
|
# wrap up learning session
|
|
if kfn.learningStage == "end_learning"
|
|
kfn.learningStage = "inference"
|
|
end
|
|
end
|
|
|
|
function computeModelError(modelRespond, correctAnswer; magnitude::Float64=1.0)
|
|
if correctAnswer === nothing
|
|
correctAnswer = BitArray(zeros(length(modelRespond)))
|
|
else
|
|
correctAnswer = Bool.(correctAnswer) # correct answer for kfn I
|
|
end
|
|
return Flux.logitcrossentropy(modelRespond, correctAnswer) .* magnitude
|
|
end
|
|
|
|
function compute_wRecChange!(m::model, error::Float64)
|
|
compute_wRecChange!(m.knowledgeFn[:I], error)
|
|
end
|
|
|
|
function compute_wRecChange!(kfn::kfn_1, error::Float64)
|
|
# compute kfn error for each neuron
|
|
Threads.@threads for n in kfn.neuronsArray
|
|
# for n in kfn.neuronsArray
|
|
compute_wRecChange!(n, error)
|
|
end
|
|
for n in kfn.outputNeuronsArray
|
|
compute_wRecChange!(n, error)
|
|
end
|
|
end
|
|
|
|
function compute_wRecChange!(n::passthroughNeuron, error::Float64)
|
|
# skip
|
|
end
|
|
|
|
function compute_wRecChange!(n::lifNeuron, error::Float64)
|
|
n.eRec = n.phi * n.epsilonRec
|
|
ΔwRecChange = -n.eta * error * n.eRec
|
|
n.wRecChange .+= ΔwRecChange
|
|
reset_epsilonRec!(n)
|
|
end
|
|
|
|
function compute_wRecChange!(n::alifNeuron, error::Float64)
|
|
n.eRec_v = n.phi * n.epsilonRec
|
|
n.eRec_a = n.phi * n.beta * n.epsilonRecA
|
|
n.eRec = n.eRec_v + n.eRec_a
|
|
ΔwRecChange = -n.eta * error * n.eRec
|
|
n.wRecChange .+= ΔwRecChange
|
|
reset_epsilonRec!(n)
|
|
reset_epsilonRecA!(n)
|
|
end
|
|
|
|
function compute_wRecChange!(n::linearNeuron, error::Float64)
|
|
n.eRec = n.phi * n.epsilonRec
|
|
ΔwRecChange = -n.eta * error * n.eRec
|
|
n.wRecChange .+= ΔwRecChange
|
|
reset_epsilonRec!(n)
|
|
end
|
|
|
|
function learn!(n::T, firedNeurons, nExInType) where T<:inputNeuron
|
|
# skip
|
|
end
|
|
|
|
function learn!(n::T, firedNeurons, nExInType) where T<:computeNeuron
|
|
wSign_0 = sign.(n.wRec) # original sign
|
|
n.wRec += n.wRecChange # merge wRecChange into wRec
|
|
reset_wRecChange!(n)
|
|
wSign_1 = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign
|
|
nonFlipedSign = isequal.(wSign_0, wSign_1) # 1 not fliped, 0 fliped
|
|
# normalize wRec peak to prevent input signal overwhelming neuron
|
|
normalizePeak!(n.wRec, n.wRecChange, 2)
|
|
# set weight that fliped sign to 0 for random new connection
|
|
# n.wRec .*= nonFlipedSign
|
|
capMaxWeight!(n.wRec) # cap maximum weight
|
|
|
|
synapticConnStrength!(n)
|
|
neuroplasticity!(n, firedNeurons, nExInType)
|
|
end
|
|
|
|
function learn!(n::T, firedNeurons, nExInType, totalInputPort) where T<:outputNeuron
|
|
wSign_0 = sign.(n.wRec) # original sign
|
|
n.wRec += n.wRecChange
|
|
reset_wRecChange!(n)
|
|
wSign_1 = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign
|
|
nonFlipedSign = isequal.(wSign_0, wSign_1) # 1 not fliped, 0 fliped
|
|
# normalize wRec peak to prevent input signal overwhelming neuron
|
|
normalizePeak!(n.wRec, n.wRecChange, 2)
|
|
# set weight that fliped sign to 0 for random new connection
|
|
# n.wRec .*= nonFlipedSign
|
|
capMaxWeight!(n.wRec) # cap maximum weight
|
|
|
|
# synapticConnStrength!(n) #CHANGE
|
|
neuroplasticity!(n,firedNeurons, nExInType, totalInputPort)
|
|
end
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
end # module end |