use logitcrossentropy to train

This commit is contained in:
2023-05-28 22:07:22 +07:00
parent fe35066a94
commit 8fd0cc0fdd
3 changed files with 70 additions and 129 deletions

View File

@@ -34,17 +34,20 @@ using .learn
""" version 0.0.3 """ version 0.0.3
Todo: Todo:
[*2] implement connection strength based on right or wrong answer
[*1] how to manage how much constrength increase and decrease
[4] implement dormant connection [4] implement dormant connection
[3] Δweight * connection strength
[] using RL to control learning signal [] using RL to control learning signal
[] consider using Dates.now() instead of timestamp because time_stamp may overflow [] consider using Dates.now() instead of timestamp because time_stamp may overflow
[5] training should include adjusting α, neuron membrane potential decay factor [5] training should include adjusting α, neuron membrane potential decay factor
which defined by neuron.tau_m formula in type.jl which defined by neuron.tau_m formula in type.jl
Change from version: 0.0.2 Change from version: 0.0.2
- [DONE] new learning method
- use Flux.logitcrossentropy for overall error
- remove ΔwRecChange that apply immediately during online learning
- collect ΔwRecChange during online learning (0-784th) and merge with wRec at
the end learning (1000th).
- compute model error at the end learning. Model error times with 5 constant for
higher learning impact than the error during online
All features All features
- multidispatch + for loop as main compute method - multidispatch + for loop as main compute method
@@ -86,6 +89,9 @@ using .learn
on the correct answer -> strengthen the right neural pathway (connections) -> on the correct answer -> strengthen the right neural pathway (connections) ->
this correct neural pathway resist to change. this correct neural pathway resist to change.
Not used connection should dissapear (forgetting). Not used connection should dissapear (forgetting).
Removed features
- Δweight * connection strength
""" """

View File

@@ -1,67 +1,27 @@
module learn module learn
using Statistics, Random, LinearAlgebra, JSON3 using Statistics, Random, LinearAlgebra, JSON3, Flux
using GeneralUtils using GeneralUtils
using ..types, ..snn_utils using ..types, ..snn_utils
export learn! export learn!, compute_wRecChange!, computeModelError
#------------------------------------------------------------------------------------------------100 #------------------------------------------------------------------------------------------------100
function learn!(m::model, modelRespond::AbstractVector, correctAnswer::Union{AbstractVector, Nothing}) function learn!(m::model)
if correctAnswer === nothing learn!(m.knowledgeFn[:I])
correctAnswer_I = BitArray(zeros(length(modelRespond)))
else
correctAnswer_I = Bool.(correctAnswer) # correct answer for kfn I
end
learn!(m.knowledgeFn[:I], correctAnswer_I)
end end
""" knowledgeFn learn() """ knowledgeFn learn()
""" """
function learn!(kfn::kfn_1, correctAnswer::BitVector) function learn!(kfn::kfn_1)
# compute kfn error for each neuron # compute kfn error for each neuron
# outs = [n.z_t1 for n in kfn.outputNeuronsArray]
# for (i, out) in enumerate(outs)
# if out != correctAnswer[i] # need to adjust weight
# kfnError = ( (kfn.outputNeuronsArray[i].v_th - kfn.outputNeuronsArray[i].vError) *
# 100 / kfn.outputNeuronsArray[i].v_th )
# Threads.@threads for n in kfn.neuronsArray
# # for n in kfn.neuronsArray
# learn!(n, kfnError)
# end
# learn!(kfn.outputNeuronsArray[i], kfnError)
# end
# end
# compute kfn error for each neuron
outs = [n.z_t1 for n in kfn.outputNeuronsArray]
for (i, out) in enumerate(outs)
if out == correctAnswer # output correct
kfnError = 0.0
Threads.@threads for n in kfn.neuronsArray # multithread is not atomic and causing error Threads.@threads for n in kfn.neuronsArray # multithread is not atomic and causing error
# for n in kfn.neuronsArray # for n in kfn.neuronsArray
compute_wRecChange!(n, kfnError) learn!(n, kfn.firedNeurons, kfn.nExInType)
learn!(n, kfn.firedNeurons, kfn.nExInType, true)
end
compute_wRecChange!(kfn.outputNeuronsArray[i], kfnError)
learn!(kfn.outputNeuronsArray[i], kfn.firedNeurons, kfn.nExInType,
kfn.kfnParams[:totalInputPort], true)
else
kfnError = ( (kfn.outputNeuronsArray[i].v_th - kfn.outputNeuronsArray[i].vError) *
100.0 / kfn.outputNeuronsArray[i].v_th )^2
Threads.@threads for n in kfn.neuronsArray # multithread is not atomic and causing error
# for n in kfn.neuronsArray
compute_wRecChange!(n, kfnError)
learn!(n, kfn.firedNeurons, kfn.nExInType, false)
end
compute_wRecChange!(kfn.outputNeuronsArray[i], kfnError)
learn!(kfn.outputNeuronsArray[i], kfn.firedNeurons, kfn.nExInType,
kfn.kfnParams[:totalInputPort], false)
end end
for n in kfn.outputNeuronsArray
learn!(n, kfn.firedNeurons, kfn.nExInType, kfn.kfnParams[:totalInputPort])
end end
# wrap up learning session # wrap up learning session
@@ -70,6 +30,30 @@ function learn!(kfn::kfn_1, correctAnswer::BitVector)
end end
end end
function computeModelError(modelRespond, correctAnswer; magnitude::Float64=1.0)
if correctAnswer === nothing
correctAnswer = BitArray(zeros(length(modelRespond)))
else
correctAnswer = Bool.(correctAnswer) # correct answer for kfn I
end
return Flux.logitcrossentropy(modelRespond, correctAnswer) .* magnitude
end
function compute_wRecChange!(m::model, error::Float64)
compute_wRecChange!(m.knowledgeFn[:I], error)
end
function compute_wRecChange!(kfn::kfn_1, error::Float64)
# compute kfn error for each neuron
Threads.@threads for n in kfn.neuronsArray
# for n in kfn.neuronsArray
compute_wRecChange!(n, error)
end
for n in kfn.outputNeuronsArray
compute_wRecChange!(n, error)
end
end
function compute_wRecChange!(n::passthroughNeuron, error::Float64) function compute_wRecChange!(n::passthroughNeuron, error::Float64)
# skip # skip
end end
@@ -98,14 +82,12 @@ function compute_wRecChange!(n::linearNeuron, error::Float64)
reset_epsilonRec!(n) reset_epsilonRec!(n)
end end
function learn!(n::T, firedNeurons, nExInType, correctAnswer) where T<:inputNeuron function learn!(n::T, firedNeurons, nExInType) where T<:inputNeuron
# skip # skip
end end
function learn!(n::T, firedNeurons, nExInType, correctAnswer) where T<:computeNeuron function learn!(n::T, firedNeurons, nExInType) where T<:computeNeuron
wSign_0 = sign.(n.wRec) # original sign wSign_0 = sign.(n.wRec) # original sign
#TESTING strong connection gets less weight change, weak connection gets more weight change
n.wRecChange .*= (connStrengthAdjust.(n.synapticStrength))
n.wRec += n.wRecChange # merge wRecChange into wRec n.wRec += n.wRecChange # merge wRecChange into wRec
reset_wRecChange!(n) reset_wRecChange!(n)
wSign_1 = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign wSign_1 = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign
@@ -116,14 +98,12 @@ function learn!(n::T, firedNeurons, nExInType, correctAnswer) where T<:computeNe
n.wRec .*= nonFlipedSign n.wRec .*= nonFlipedSign
capMaxWeight!(n.wRec) # cap maximum weight capMaxWeight!(n.wRec) # cap maximum weight
synapticConnStrength!(n, correctAnswer) synapticConnStrength!(n)
neuroplasticity!(n, firedNeurons, nExInType) neuroplasticity!(n, firedNeurons, nExInType)
end end
function learn!(n::T, firedNeurons, nExInType, totalInputPort, correctAnswer) where T<:outputNeuron function learn!(n::T, firedNeurons, nExInType, totalInputPort) where T<:outputNeuron
wSign_0 = sign.(n.wRec) # original sign wSign_0 = sign.(n.wRec) # original sign
#TESTING strong connection gets less weight change, weak connection gets more weight change
n.wRecChange .*= (connStrengthAdjust.(n.synapticStrength))
n.wRec += n.wRecChange n.wRec += n.wRecChange
reset_wRecChange!(n) reset_wRecChange!(n)
wSign_1 = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign wSign_1 = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign
@@ -134,7 +114,7 @@ function learn!(n::T, firedNeurons, nExInType, totalInputPort, correctAnswer) wh
n.wRec .*= nonFlipedSign n.wRec .*= nonFlipedSign
capMaxWeight!(n.wRec) # cap maximum weight capMaxWeight!(n.wRec) # cap maximum weight
synapticConnStrength!(n, correctAnswer) synapticConnStrength!(n)
neuroplasticity!(n,firedNeurons, nExInType, totalInputPort) neuroplasticity!(n,firedNeurons, nExInType, totalInputPort)
end end

View File

@@ -279,13 +279,13 @@ function synapticConnStrength(currentStrength::Float64, updown::String)
if updown == "up" if updown == "up"
if currentStrength > 4 # strong connection if currentStrength > 4 # strong connection
updatedStrength = currentStrength + (Δstrength * 0.2) updatedStrength = currentStrength + (Δstrength * 1.0)
else else
updatedStrength = currentStrength + (Δstrength * 0.1) updatedStrength = currentStrength + (Δstrength * 1.0)
end end
elseif updown == "down" elseif updown == "down"
if currentStrength > 4 if currentStrength > 4
updatedStrength = currentStrength - (Δstrength * 0.1) updatedStrength = currentStrength - (Δstrength * 1.0)
else else
updatedStrength = currentStrength - (Δstrength * 1.0) updatedStrength = currentStrength - (Δstrength * 1.0)
end end
@@ -294,44 +294,11 @@ function synapticConnStrength(currentStrength::Float64, updown::String)
end end
return updatedStrength::Float64 return updatedStrength::Float64
end end
# function synapticConnStrength(currentStrength::Float64, updown::String)
# Δstrength = connStrengthAdjust(currentStrength)
# if updown == "up"
# updatedStrength = currentStrength + Δstrength
# else
# updatedStrength = currentStrength - Δstrength
# end
# return updatedStrength::Float64
# end
""" Compute all synaptic connection strength of a neuron. Also mark n.wRec to 0 if wRec goes """ Compute all synaptic connection strength of a neuron. Also mark n.wRec to 0 if wRec goes
below lowerlimit. below lowerlimit.
""" """
# function synapticConnStrength!(n::Union{computeNeuron, outputNeuron}) function synapticConnStrength!(n::Union{computeNeuron, outputNeuron})
# for (i, connStrength) in enumerate(n.synapticStrength)
# # check whether connStrength increase or decrease based on usage from n.epsilonRec
# """ use n.z_i_t_commulative instead of the best choice, epsilonRec, here because ΔwRecChange
# calculation in learn!() will reset epsilonRec to zeroes vector in case where
# output neuron fires and trigger learn!() just before this synapticConnStrength
# calculation.
# Since n.z_i_t_commulative indicates whether a synaptic connection were used or not, it is
# ok to use. n.z_i_t_commulative also span across a training sample without resetting.
# """
# updown = n.z_i_t_commulative[i] == 0 ? "down" : "up" #
# updatedConnStrength = synapticConnStrength(connStrength, updown)
# updatedConnStrength = GeneralUtils.limitvalue(updatedConnStrength,
# n.synapticStrengthLimit.lowerlimit, n.synapticStrengthLimit.upperlimit)
# # at lowerlimit, mark wRec at this position to 0. for new random synaptic conn
# if updatedConnStrength == n.synapticStrengthLimit.lowerlimit[1]
# n.wRec[i] = 0.0
# end
# n.synapticStrength[i] = updatedConnStrength
# end
# end
function synapticConnStrength!(n::Union{computeNeuron, outputNeuron}, correctAnswer::Bool)
if correctAnswer == true
for (i, connStrength) in enumerate(n.synapticStrength) for (i, connStrength) in enumerate(n.synapticStrength)
# check whether connStrength increase or decrease based on usage from n.epsilonRec # check whether connStrength increase or decrease based on usage from n.epsilonRec
""" use n.z_i_t_commulative instead of the best choice, epsilonRec, here because ΔwRecChange """ use n.z_i_t_commulative instead of the best choice, epsilonRec, here because ΔwRecChange
@@ -351,18 +318,6 @@ function synapticConnStrength!(n::Union{computeNeuron, outputNeuron}, correctAns
end end
n.synapticStrength[i] = updatedConnStrength n.synapticStrength[i] = updatedConnStrength
end end
else
for (i, connStrength) in enumerate(n.synapticStrength)
updatedConnStrength = synapticConnStrength(connStrength, "down")
updatedConnStrength = GeneralUtils.limitvalue(updatedConnStrength,
n.synapticStrengthLimit.lowerlimit, n.synapticStrengthLimit.upperlimit)
# at lowerlimit, mark wRec at this position to 0. for new random synaptic conn
if updatedConnStrength == n.synapticStrengthLimit.lowerlimit[1]
n.wRec[i] = 0.0
end
n.synapticStrength[i] = updatedConnStrength
end
end
end end
function synapticConnStrength!(n::inputNeuron) end function synapticConnStrength!(n::inputNeuron) end