use logitcrossentropy to train

This commit is contained in:
2023-05-28 22:07:22 +07:00
parent fe35066a94
commit 8fd0cc0fdd
3 changed files with 70 additions and 129 deletions

View File

@@ -34,17 +34,20 @@ using .learn
""" version 0.0.3 """ version 0.0.3
Todo: Todo:
[*2] implement connection strength based on right or wrong answer
[*1] how to manage how much constrength increase and decrease
[4] implement dormant connection [4] implement dormant connection
[3] Δweight * connection strength
[] using RL to control learning signal [] using RL to control learning signal
[] consider using Dates.now() instead of timestamp because time_stamp may overflow [] consider using Dates.now() instead of timestamp because time_stamp may overflow
[5] training should include adjusting α, neuron membrane potential decay factor [5] training should include adjusting α, neuron membrane potential decay factor
which defined by neuron.tau_m formula in type.jl which defined by neuron.tau_m formula in type.jl
Change from version: 0.0.2 Change from version: 0.0.2
- [DONE] new learning method
- use Flux.logitcrossentropy for overall error
- remove ΔwRecChange that apply immediately during online learning
- collect ΔwRecChange during online learning (0-784th) and merge with wRec at
the end learning (1000th).
- compute model error at the end learning. Model error times with 5 constant for
higher learning impact than the error during online
All features All features
- multidispatch + for loop as main compute method - multidispatch + for loop as main compute method
@@ -86,6 +89,9 @@ using .learn
on the correct answer -> strengthen the right neural pathway (connections) -> on the correct answer -> strengthen the right neural pathway (connections) ->
this correct neural pathway resist to change. this correct neural pathway resist to change.
Not used connection should dissapear (forgetting). Not used connection should dissapear (forgetting).
Removed features
- Δweight * connection strength
""" """

View File

@@ -1,67 +1,27 @@
module learn module learn
using Statistics, Random, LinearAlgebra, JSON3 using Statistics, Random, LinearAlgebra, JSON3, Flux
using GeneralUtils using GeneralUtils
using ..types, ..snn_utils using ..types, ..snn_utils
export learn! export learn!, compute_wRecChange!, computeModelError
#------------------------------------------------------------------------------------------------100 #------------------------------------------------------------------------------------------------100
function learn!(m::model, modelRespond::AbstractVector, correctAnswer::Union{AbstractVector, Nothing}) function learn!(m::model)
if correctAnswer === nothing learn!(m.knowledgeFn[:I])
correctAnswer_I = BitArray(zeros(length(modelRespond)))
else
correctAnswer_I = Bool.(correctAnswer) # correct answer for kfn I
end
learn!(m.knowledgeFn[:I], correctAnswer_I)
end end
""" knowledgeFn learn() """ knowledgeFn learn()
""" """
function learn!(kfn::kfn_1, correctAnswer::BitVector) function learn!(kfn::kfn_1)
# compute kfn error for each neuron
# outs = [n.z_t1 for n in kfn.outputNeuronsArray]
# for (i, out) in enumerate(outs)
# if out != correctAnswer[i] # need to adjust weight
# kfnError = ( (kfn.outputNeuronsArray[i].v_th - kfn.outputNeuronsArray[i].vError) *
# 100 / kfn.outputNeuronsArray[i].v_th )
# Threads.@threads for n in kfn.neuronsArray
# # for n in kfn.neuronsArray
# learn!(n, kfnError)
# end
# learn!(kfn.outputNeuronsArray[i], kfnError)
# end
# end
# compute kfn error for each neuron # compute kfn error for each neuron
outs = [n.z_t1 for n in kfn.outputNeuronsArray] Threads.@threads for n in kfn.neuronsArray # multithread is not atomic and causing error
for (i, out) in enumerate(outs) # for n in kfn.neuronsArray
if out == correctAnswer # output correct learn!(n, kfn.firedNeurons, kfn.nExInType)
kfnError = 0.0
Threads.@threads for n in kfn.neuronsArray # multithread is not atomic and causing error
# for n in kfn.neuronsArray
compute_wRecChange!(n, kfnError)
learn!(n, kfn.firedNeurons, kfn.nExInType, true)
end
compute_wRecChange!(kfn.outputNeuronsArray[i], kfnError)
learn!(kfn.outputNeuronsArray[i], kfn.firedNeurons, kfn.nExInType,
kfn.kfnParams[:totalInputPort], true)
else
kfnError = ( (kfn.outputNeuronsArray[i].v_th - kfn.outputNeuronsArray[i].vError) *
100.0 / kfn.outputNeuronsArray[i].v_th )^2
Threads.@threads for n in kfn.neuronsArray # multithread is not atomic and causing error
# for n in kfn.neuronsArray
compute_wRecChange!(n, kfnError)
learn!(n, kfn.firedNeurons, kfn.nExInType, false)
end
compute_wRecChange!(kfn.outputNeuronsArray[i], kfnError)
learn!(kfn.outputNeuronsArray[i], kfn.firedNeurons, kfn.nExInType,
kfn.kfnParams[:totalInputPort], false)
end end
for n in kfn.outputNeuronsArray
learn!(n, kfn.firedNeurons, kfn.nExInType, kfn.kfnParams[:totalInputPort])
end end
# wrap up learning session # wrap up learning session
@@ -70,6 +30,30 @@ function learn!(kfn::kfn_1, correctAnswer::BitVector)
end end
end end
function computeModelError(modelRespond, correctAnswer; magnitude::Float64=1.0)
if correctAnswer === nothing
correctAnswer = BitArray(zeros(length(modelRespond)))
else
correctAnswer = Bool.(correctAnswer) # correct answer for kfn I
end
return Flux.logitcrossentropy(modelRespond, correctAnswer) .* magnitude
end
function compute_wRecChange!(m::model, error::Float64)
compute_wRecChange!(m.knowledgeFn[:I], error)
end
function compute_wRecChange!(kfn::kfn_1, error::Float64)
# compute kfn error for each neuron
Threads.@threads for n in kfn.neuronsArray
# for n in kfn.neuronsArray
compute_wRecChange!(n, error)
end
for n in kfn.outputNeuronsArray
compute_wRecChange!(n, error)
end
end
function compute_wRecChange!(n::passthroughNeuron, error::Float64) function compute_wRecChange!(n::passthroughNeuron, error::Float64)
# skip # skip
end end
@@ -98,14 +82,12 @@ function compute_wRecChange!(n::linearNeuron, error::Float64)
reset_epsilonRec!(n) reset_epsilonRec!(n)
end end
function learn!(n::T, firedNeurons, nExInType, correctAnswer) where T<:inputNeuron function learn!(n::T, firedNeurons, nExInType) where T<:inputNeuron
# skip # skip
end end
function learn!(n::T, firedNeurons, nExInType, correctAnswer) where T<:computeNeuron function learn!(n::T, firedNeurons, nExInType) where T<:computeNeuron
wSign_0 = sign.(n.wRec) # original sign wSign_0 = sign.(n.wRec) # original sign
#TESTING strong connection gets less weight change, weak connection gets more weight change
n.wRecChange .*= (connStrengthAdjust.(n.synapticStrength))
n.wRec += n.wRecChange # merge wRecChange into wRec n.wRec += n.wRecChange # merge wRecChange into wRec
reset_wRecChange!(n) reset_wRecChange!(n)
wSign_1 = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign wSign_1 = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign
@@ -116,14 +98,12 @@ function learn!(n::T, firedNeurons, nExInType, correctAnswer) where T<:computeNe
n.wRec .*= nonFlipedSign n.wRec .*= nonFlipedSign
capMaxWeight!(n.wRec) # cap maximum weight capMaxWeight!(n.wRec) # cap maximum weight
synapticConnStrength!(n, correctAnswer) synapticConnStrength!(n)
neuroplasticity!(n, firedNeurons, nExInType) neuroplasticity!(n, firedNeurons, nExInType)
end end
function learn!(n::T, firedNeurons, nExInType, totalInputPort, correctAnswer) where T<:outputNeuron function learn!(n::T, firedNeurons, nExInType, totalInputPort) where T<:outputNeuron
wSign_0 = sign.(n.wRec) # original sign wSign_0 = sign.(n.wRec) # original sign
#TESTING strong connection gets less weight change, weak connection gets more weight change
n.wRecChange .*= (connStrengthAdjust.(n.synapticStrength))
n.wRec += n.wRecChange n.wRec += n.wRecChange
reset_wRecChange!(n) reset_wRecChange!(n)
wSign_1 = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign wSign_1 = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign
@@ -134,7 +114,7 @@ function learn!(n::T, firedNeurons, nExInType, totalInputPort, correctAnswer) wh
n.wRec .*= nonFlipedSign n.wRec .*= nonFlipedSign
capMaxWeight!(n.wRec) # cap maximum weight capMaxWeight!(n.wRec) # cap maximum weight
synapticConnStrength!(n, correctAnswer) synapticConnStrength!(n)
neuroplasticity!(n,firedNeurons, nExInType, totalInputPort) neuroplasticity!(n,firedNeurons, nExInType, totalInputPort)
end end

View File

@@ -279,13 +279,13 @@ function synapticConnStrength(currentStrength::Float64, updown::String)
if updown == "up" if updown == "up"
if currentStrength > 4 # strong connection if currentStrength > 4 # strong connection
updatedStrength = currentStrength + (Δstrength * 0.2) updatedStrength = currentStrength + (Δstrength * 1.0)
else else
updatedStrength = currentStrength + (Δstrength * 0.1) updatedStrength = currentStrength + (Δstrength * 1.0)
end end
elseif updown == "down" elseif updown == "down"
if currentStrength > 4 if currentStrength > 4
updatedStrength = currentStrength - (Δstrength * 0.1) updatedStrength = currentStrength - (Δstrength * 1.0)
else else
updatedStrength = currentStrength - (Δstrength * 1.0) updatedStrength = currentStrength - (Δstrength * 1.0)
end end
@@ -294,74 +294,29 @@ function synapticConnStrength(currentStrength::Float64, updown::String)
end end
return updatedStrength::Float64 return updatedStrength::Float64
end end
# function synapticConnStrength(currentStrength::Float64, updown::String)
# Δstrength = connStrengthAdjust(currentStrength)
# if updown == "up"
# updatedStrength = currentStrength + Δstrength
# else
# updatedStrength = currentStrength - Δstrength
# end
# return updatedStrength::Float64
# end
""" Compute all synaptic connection strength of a neuron. Also mark n.wRec to 0 if wRec goes """ Compute all synaptic connection strength of a neuron. Also mark n.wRec to 0 if wRec goes
below lowerlimit. below lowerlimit.
""" """
# function synapticConnStrength!(n::Union{computeNeuron, outputNeuron}) function synapticConnStrength!(n::Union{computeNeuron, outputNeuron})
# for (i, connStrength) in enumerate(n.synapticStrength) for (i, connStrength) in enumerate(n.synapticStrength)
# # check whether connStrength increase or decrease based on usage from n.epsilonRec # check whether connStrength increase or decrease based on usage from n.epsilonRec
# """ use n.z_i_t_commulative instead of the best choice, epsilonRec, here because ΔwRecChange """ use n.z_i_t_commulative instead of the best choice, epsilonRec, here because ΔwRecChange
# calculation in learn!() will reset epsilonRec to zeroes vector in case where calculation in learn!() will reset epsilonRec to zeroes vector in case where
# output neuron fires and trigger learn!() just before this synapticConnStrength output neuron fires and trigger learn!() just before this synapticConnStrength
# calculation. calculation.
# Since n.z_i_t_commulative indicates whether a synaptic connection were used or not, it is Since n.z_i_t_commulative indicates whether a synaptic connection were used or not, it is
# ok to use. n.z_i_t_commulative also span across a training sample without resetting. ok to use. n.z_i_t_commulative also span across a training sample without resetting.
# """ """
# updown = n.z_i_t_commulative[i] == 0 ? "down" : "up" # updown = n.z_i_t_commulative[i] == 0 ? "down" : "up"
# updatedConnStrength = synapticConnStrength(connStrength, updown) updatedConnStrength = synapticConnStrength(connStrength, updown)
# updatedConnStrength = GeneralUtils.limitvalue(updatedConnStrength, updatedConnStrength = GeneralUtils.limitvalue(updatedConnStrength,
# n.synapticStrengthLimit.lowerlimit, n.synapticStrengthLimit.upperlimit) n.synapticStrengthLimit.lowerlimit, n.synapticStrengthLimit.upperlimit)
# # at lowerlimit, mark wRec at this position to 0. for new random synaptic conn # at lowerlimit, mark wRec at this position to 0. for new random synaptic conn
# if updatedConnStrength == n.synapticStrengthLimit.lowerlimit[1] if updatedConnStrength == n.synapticStrengthLimit.lowerlimit[1]
# n.wRec[i] = 0.0 n.wRec[i] = 0.0
# end
# n.synapticStrength[i] = updatedConnStrength
# end
# end
function synapticConnStrength!(n::Union{computeNeuron, outputNeuron}, correctAnswer::Bool)
if correctAnswer == true
for (i, connStrength) in enumerate(n.synapticStrength)
# check whether connStrength increase or decrease based on usage from n.epsilonRec
""" use n.z_i_t_commulative instead of the best choice, epsilonRec, here because ΔwRecChange
calculation in learn!() will reset epsilonRec to zeroes vector in case where
output neuron fires and trigger learn!() just before this synapticConnStrength
calculation.
Since n.z_i_t_commulative indicates whether a synaptic connection were used or not, it is
ok to use. n.z_i_t_commulative also span across a training sample without resetting.
"""
updown = n.z_i_t_commulative[i] == 0 ? "down" : "up"
updatedConnStrength = synapticConnStrength(connStrength, updown)
updatedConnStrength = GeneralUtils.limitvalue(updatedConnStrength,
n.synapticStrengthLimit.lowerlimit, n.synapticStrengthLimit.upperlimit)
# at lowerlimit, mark wRec at this position to 0. for new random synaptic conn
if updatedConnStrength == n.synapticStrengthLimit.lowerlimit[1]
n.wRec[i] = 0.0
end
n.synapticStrength[i] = updatedConnStrength
end
else
for (i, connStrength) in enumerate(n.synapticStrength)
updatedConnStrength = synapticConnStrength(connStrength, "down")
updatedConnStrength = GeneralUtils.limitvalue(updatedConnStrength,
n.synapticStrengthLimit.lowerlimit, n.synapticStrengthLimit.upperlimit)
# at lowerlimit, mark wRec at this position to 0. for new random synaptic conn
if updatedConnStrength == n.synapticStrengthLimit.lowerlimit[1]
n.wRec[i] = 0.0
end
n.synapticStrength[i] = updatedConnStrength
end end
n.synapticStrength[i] = updatedConnStrength
end end
end end