use logitcrossentropy to train

This commit is contained in:
2023-05-28 22:07:22 +07:00
parent fe35066a94
commit 8fd0cc0fdd
3 changed files with 70 additions and 129 deletions

View File

@@ -34,17 +34,20 @@ using .learn
""" version 0.0.3
Todo:
[*2] implement connection strength based on right or wrong answer
[*1] how to manage how much constrength increase and decrease
[4] implement dormant connection
[3] Δweight * connection strength
[] using RL to control learning signal
[] consider using Dates.now() instead of timestamp because time_stamp may overflow
[5] training should include adjusting α, neuron membrane potential decay factor
which defined by neuron.tau_m formula in type.jl
Change from version: 0.0.2
-
[DONE] new learning method
- use Flux.logitcrossentropy for overall error
- remove ΔwRecChange that apply immediately during online learning
- collect ΔwRecChange during online learning (0-784th) and merge with wRec at
the end learning (1000th).
- compute model error at the end learning. Model error times with 5 constant for
higher learning impact than the error during online
All features
- multidispatch + for loop as main compute method
@@ -86,6 +89,9 @@ using .learn
on the correct answer -> strengthen the right neural pathway (connections) ->
this correct neural pathway resist to change.
Not used connection should dissapear (forgetting).
Removed features
- Δweight * connection strength
"""

View File

@@ -1,67 +1,27 @@
module learn
using Statistics, Random, LinearAlgebra, JSON3
using Statistics, Random, LinearAlgebra, JSON3, Flux
using GeneralUtils
using ..types, ..snn_utils
export learn!
export learn!, compute_wRecChange!, computeModelError
#------------------------------------------------------------------------------------------------100
function learn!(m::model, modelRespond::AbstractVector, correctAnswer::Union{AbstractVector, Nothing})
if correctAnswer === nothing
correctAnswer_I = BitArray(zeros(length(modelRespond)))
else
correctAnswer_I = Bool.(correctAnswer) # correct answer for kfn I
end
learn!(m.knowledgeFn[:I], correctAnswer_I)
function learn!(m::model)
learn!(m.knowledgeFn[:I])
end
""" knowledgeFn learn()
"""
function learn!(kfn::kfn_1, correctAnswer::BitVector)
# compute kfn error for each neuron
# outs = [n.z_t1 for n in kfn.outputNeuronsArray]
# for (i, out) in enumerate(outs)
# if out != correctAnswer[i] # need to adjust weight
# kfnError = ( (kfn.outputNeuronsArray[i].v_th - kfn.outputNeuronsArray[i].vError) *
# 100 / kfn.outputNeuronsArray[i].v_th )
# Threads.@threads for n in kfn.neuronsArray
# # for n in kfn.neuronsArray
# learn!(n, kfnError)
# end
# learn!(kfn.outputNeuronsArray[i], kfnError)
# end
# end
function learn!(kfn::kfn_1)
# compute kfn error for each neuron
outs = [n.z_t1 for n in kfn.outputNeuronsArray]
for (i, out) in enumerate(outs)
if out == correctAnswer # output correct
kfnError = 0.0
Threads.@threads for n in kfn.neuronsArray # multithread is not atomic and causing error
# for n in kfn.neuronsArray
compute_wRecChange!(n, kfnError)
learn!(n, kfn.firedNeurons, kfn.nExInType, true)
end
compute_wRecChange!(kfn.outputNeuronsArray[i], kfnError)
learn!(kfn.outputNeuronsArray[i], kfn.firedNeurons, kfn.nExInType,
kfn.kfnParams[:totalInputPort], true)
else
kfnError = ( (kfn.outputNeuronsArray[i].v_th - kfn.outputNeuronsArray[i].vError) *
100.0 / kfn.outputNeuronsArray[i].v_th )^2
Threads.@threads for n in kfn.neuronsArray # multithread is not atomic and causing error
# for n in kfn.neuronsArray
compute_wRecChange!(n, kfnError)
learn!(n, kfn.firedNeurons, kfn.nExInType, false)
end
compute_wRecChange!(kfn.outputNeuronsArray[i], kfnError)
learn!(kfn.outputNeuronsArray[i], kfn.firedNeurons, kfn.nExInType,
kfn.kfnParams[:totalInputPort], false)
Threads.@threads for n in kfn.neuronsArray # multithread is not atomic and causing error
# for n in kfn.neuronsArray
learn!(n, kfn.firedNeurons, kfn.nExInType)
end
for n in kfn.outputNeuronsArray
learn!(n, kfn.firedNeurons, kfn.nExInType, kfn.kfnParams[:totalInputPort])
end
# wrap up learning session
@@ -70,6 +30,30 @@ function learn!(kfn::kfn_1, correctAnswer::BitVector)
end
end
function computeModelError(modelRespond, correctAnswer; magnitude::Float64=1.0)
if correctAnswer === nothing
correctAnswer = BitArray(zeros(length(modelRespond)))
else
correctAnswer = Bool.(correctAnswer) # correct answer for kfn I
end
return Flux.logitcrossentropy(modelRespond, correctAnswer) .* magnitude
end
function compute_wRecChange!(m::model, error::Float64)
compute_wRecChange!(m.knowledgeFn[:I], error)
end
function compute_wRecChange!(kfn::kfn_1, error::Float64)
# compute kfn error for each neuron
Threads.@threads for n in kfn.neuronsArray
# for n in kfn.neuronsArray
compute_wRecChange!(n, error)
end
for n in kfn.outputNeuronsArray
compute_wRecChange!(n, error)
end
end
function compute_wRecChange!(n::passthroughNeuron, error::Float64)
# skip
end
@@ -98,14 +82,12 @@ function compute_wRecChange!(n::linearNeuron, error::Float64)
reset_epsilonRec!(n)
end
function learn!(n::T, firedNeurons, nExInType, correctAnswer) where T<:inputNeuron
function learn!(n::T, firedNeurons, nExInType) where T<:inputNeuron
# skip
end
function learn!(n::T, firedNeurons, nExInType, correctAnswer) where T<:computeNeuron
function learn!(n::T, firedNeurons, nExInType) where T<:computeNeuron
wSign_0 = sign.(n.wRec) # original sign
#TESTING strong connection gets less weight change, weak connection gets more weight change
n.wRecChange .*= (connStrengthAdjust.(n.synapticStrength))
n.wRec += n.wRecChange # merge wRecChange into wRec
reset_wRecChange!(n)
wSign_1 = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign
@@ -116,14 +98,12 @@ function learn!(n::T, firedNeurons, nExInType, correctAnswer) where T<:computeNe
n.wRec .*= nonFlipedSign
capMaxWeight!(n.wRec) # cap maximum weight
synapticConnStrength!(n, correctAnswer)
synapticConnStrength!(n)
neuroplasticity!(n, firedNeurons, nExInType)
end
function learn!(n::T, firedNeurons, nExInType, totalInputPort, correctAnswer) where T<:outputNeuron
function learn!(n::T, firedNeurons, nExInType, totalInputPort) where T<:outputNeuron
wSign_0 = sign.(n.wRec) # original sign
#TESTING strong connection gets less weight change, weak connection gets more weight change
n.wRecChange .*= (connStrengthAdjust.(n.synapticStrength))
n.wRec += n.wRecChange
reset_wRecChange!(n)
wSign_1 = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign
@@ -134,7 +114,7 @@ function learn!(n::T, firedNeurons, nExInType, totalInputPort, correctAnswer) wh
n.wRec .*= nonFlipedSign
capMaxWeight!(n.wRec) # cap maximum weight
synapticConnStrength!(n, correctAnswer)
synapticConnStrength!(n)
neuroplasticity!(n,firedNeurons, nExInType, totalInputPort)
end

View File

@@ -279,13 +279,13 @@ function synapticConnStrength(currentStrength::Float64, updown::String)
if updown == "up"
if currentStrength > 4 # strong connection
updatedStrength = currentStrength + (Δstrength * 0.2)
updatedStrength = currentStrength + (Δstrength * 1.0)
else
updatedStrength = currentStrength + (Δstrength * 0.1)
updatedStrength = currentStrength + (Δstrength * 1.0)
end
elseif updown == "down"
if currentStrength > 4
updatedStrength = currentStrength - (Δstrength * 0.1)
updatedStrength = currentStrength - (Δstrength * 1.0)
else
updatedStrength = currentStrength - (Δstrength * 1.0)
end
@@ -294,74 +294,29 @@ function synapticConnStrength(currentStrength::Float64, updown::String)
end
return updatedStrength::Float64
end
# function synapticConnStrength(currentStrength::Float64, updown::String)
# Δstrength = connStrengthAdjust(currentStrength)
# if updown == "up"
# updatedStrength = currentStrength + Δstrength
# else
# updatedStrength = currentStrength - Δstrength
# end
# return updatedStrength::Float64
# end
""" Compute all synaptic connection strength of a neuron. Also mark n.wRec to 0 if wRec goes
below lowerlimit.
"""
# function synapticConnStrength!(n::Union{computeNeuron, outputNeuron})
# for (i, connStrength) in enumerate(n.synapticStrength)
# # check whether connStrength increase or decrease based on usage from n.epsilonRec
# """ use n.z_i_t_commulative instead of the best choice, epsilonRec, here because ΔwRecChange
# calculation in learn!() will reset epsilonRec to zeroes vector in case where
# output neuron fires and trigger learn!() just before this synapticConnStrength
# calculation.
# Since n.z_i_t_commulative indicates whether a synaptic connection were used or not, it is
# ok to use. n.z_i_t_commulative also span across a training sample without resetting.
# """
# updown = n.z_i_t_commulative[i] == 0 ? "down" : "up" #
# updatedConnStrength = synapticConnStrength(connStrength, updown)
# updatedConnStrength = GeneralUtils.limitvalue(updatedConnStrength,
# n.synapticStrengthLimit.lowerlimit, n.synapticStrengthLimit.upperlimit)
# # at lowerlimit, mark wRec at this position to 0. for new random synaptic conn
# if updatedConnStrength == n.synapticStrengthLimit.lowerlimit[1]
# n.wRec[i] = 0.0
# end
# n.synapticStrength[i] = updatedConnStrength
# end
# end
function synapticConnStrength!(n::Union{computeNeuron, outputNeuron}, correctAnswer::Bool)
if correctAnswer == true
for (i, connStrength) in enumerate(n.synapticStrength)
# check whether connStrength increase or decrease based on usage from n.epsilonRec
""" use n.z_i_t_commulative instead of the best choice, epsilonRec, here because ΔwRecChange
calculation in learn!() will reset epsilonRec to zeroes vector in case where
output neuron fires and trigger learn!() just before this synapticConnStrength
calculation.
Since n.z_i_t_commulative indicates whether a synaptic connection were used or not, it is
ok to use. n.z_i_t_commulative also span across a training sample without resetting.
"""
updown = n.z_i_t_commulative[i] == 0 ? "down" : "up"
updatedConnStrength = synapticConnStrength(connStrength, updown)
updatedConnStrength = GeneralUtils.limitvalue(updatedConnStrength,
n.synapticStrengthLimit.lowerlimit, n.synapticStrengthLimit.upperlimit)
# at lowerlimit, mark wRec at this position to 0. for new random synaptic conn
if updatedConnStrength == n.synapticStrengthLimit.lowerlimit[1]
n.wRec[i] = 0.0
end
n.synapticStrength[i] = updatedConnStrength
end
else
for (i, connStrength) in enumerate(n.synapticStrength)
updatedConnStrength = synapticConnStrength(connStrength, "down")
updatedConnStrength = GeneralUtils.limitvalue(updatedConnStrength,
n.synapticStrengthLimit.lowerlimit, n.synapticStrengthLimit.upperlimit)
# at lowerlimit, mark wRec at this position to 0. for new random synaptic conn
if updatedConnStrength == n.synapticStrengthLimit.lowerlimit[1]
n.wRec[i] = 0.0
end
n.synapticStrength[i] = updatedConnStrength
function synapticConnStrength!(n::Union{computeNeuron, outputNeuron})
for (i, connStrength) in enumerate(n.synapticStrength)
# check whether connStrength increase or decrease based on usage from n.epsilonRec
""" use n.z_i_t_commulative instead of the best choice, epsilonRec, here because ΔwRecChange
calculation in learn!() will reset epsilonRec to zeroes vector in case where
output neuron fires and trigger learn!() just before this synapticConnStrength
calculation.
Since n.z_i_t_commulative indicates whether a synaptic connection were used or not, it is
ok to use. n.z_i_t_commulative also span across a training sample without resetting.
"""
updown = n.z_i_t_commulative[i] == 0 ? "down" : "up"
updatedConnStrength = synapticConnStrength(connStrength, updown)
updatedConnStrength = GeneralUtils.limitvalue(updatedConnStrength,
n.synapticStrengthLimit.lowerlimit, n.synapticStrengthLimit.upperlimit)
# at lowerlimit, mark wRec at this position to 0. for new random synaptic conn
if updatedConnStrength == n.synapticStrengthLimit.lowerlimit[1]
n.wRec[i] = 0.0
end
n.synapticStrength[i] = updatedConnStrength
end
end