Files
Ironpen/previousVersion/0.0.3/src/learn.jl
2023-06-20 10:15:47 +07:00

161 lines
3.9 KiB
Julia

module learn
using Statistics, Random, LinearAlgebra, JSON3, Flux
using GeneralUtils
using ..types, ..snn_utils
export learn!, compute_wRecChange!, computeModelError
#------------------------------------------------------------------------------------------------100
function learn!(m::model)
learn!(m.knowledgeFn[:I])
end
""" knowledgeFn learn()
"""
function learn!(kfn::kfn_1)
# compute kfn error for each neuron
Threads.@threads for n in kfn.neuronsArray # multithread is not atomic and causing error
# for n in kfn.neuronsArray
learn!(n, kfn.firedNeurons, kfn.nExInType)
end
for n in kfn.outputNeuronsArray
learn!(n, kfn.firedNeurons, kfn.nExInType, kfn.kfnParams[:totalInputPort])
end
# wrap up learning session
if kfn.learningStage == "end_learning"
kfn.learningStage = "inference"
end
end
function computeModelError(modelRespond, correctAnswer; magnitude::Float64=1.0)
if correctAnswer === nothing
correctAnswer = BitArray(zeros(length(modelRespond)))
else
correctAnswer = Bool.(correctAnswer) # correct answer for kfn I
end
return Flux.logitcrossentropy(modelRespond, correctAnswer) .* magnitude
end
function compute_wRecChange!(m::model, error::Float64)
compute_wRecChange!(m.knowledgeFn[:I], error)
end
function compute_wRecChange!(kfn::kfn_1, error::Float64)
# compute kfn error for each neuron
Threads.@threads for n in kfn.neuronsArray
# for n in kfn.neuronsArray
compute_wRecChange!(n, error)
end
for n in kfn.outputNeuronsArray
compute_wRecChange!(n, error)
end
end
function compute_wRecChange!(n::passthroughNeuron, error::Float64)
# skip
end
function compute_wRecChange!(n::lifNeuron, error::Float64)
n.eRec = n.phi * n.epsilonRec
ΔwRecChange = -n.eta * error * n.eRec
n.wRecChange .+= ΔwRecChange
reset_epsilonRec!(n)
end
function compute_wRecChange!(n::alifNeuron, error::Float64)
n.eRec_v = n.phi * n.epsilonRec
n.eRec_a = n.phi * n.beta * n.epsilonRecA
n.eRec = n.eRec_v + n.eRec_a
ΔwRecChange = -n.eta * error * n.eRec
n.wRecChange .+= ΔwRecChange
reset_epsilonRec!(n)
reset_epsilonRecA!(n)
end
function compute_wRecChange!(n::linearNeuron, error::Float64)
n.eRec = n.phi * n.epsilonRec
ΔwRecChange = -n.eta * error * n.eRec
n.wRecChange .+= ΔwRecChange
reset_epsilonRec!(n)
end
function learn!(n::T, firedNeurons, nExInType) where T<:inputNeuron
# skip
end
function learn!(n::T, firedNeurons, nExInType) where T<:computeNeuron
wSign_0 = sign.(n.wRec) # original sign
n.wRec += n.wRecChange # merge wRecChange into wRec
reset_wRecChange!(n)
wSign_1 = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign
nonFlipedSign = isequal.(wSign_0, wSign_1) # 1 not fliped, 0 fliped
# normalize wRec peak to prevent input signal overwhelming neuron
normalizePeak!(n.wRec, n.wRecChange, 2)
# set weight that fliped sign to 0 for random new connection
# n.wRec .*= nonFlipedSign
capMaxWeight!(n.wRec) # cap maximum weight
synapticConnStrength!(n)
neuroplasticity!(n, firedNeurons, nExInType)
end
function learn!(n::T, firedNeurons, nExInType, totalInputPort) where T<:outputNeuron
wSign_0 = sign.(n.wRec) # original sign
n.wRec += n.wRecChange
reset_wRecChange!(n)
wSign_1 = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign
nonFlipedSign = isequal.(wSign_0, wSign_1) # 1 not fliped, 0 fliped
# normalize wRec peak to prevent input signal overwhelming neuron
normalizePeak!(n.wRec, n.wRecChange, 2)
# set weight that fliped sign to 0 for random new connection
# n.wRec .*= nonFlipedSign
capMaxWeight!(n.wRec) # cap maximum weight
# synapticConnStrength!(n) #CHANGE
neuroplasticity!(n,firedNeurons, nExInType, totalInputPort)
end
end # module end