use logitcrossentropy to train
This commit is contained in:
@@ -34,17 +34,20 @@ using .learn
|
|||||||
|
|
||||||
""" version 0.0.3
|
""" version 0.0.3
|
||||||
Todo:
|
Todo:
|
||||||
[*2] implement connection strength based on right or wrong answer
|
|
||||||
[*1] how to manage how much constrength increase and decrease
|
|
||||||
[4] implement dormant connection
|
[4] implement dormant connection
|
||||||
[3] Δweight * connection strength
|
|
||||||
[] using RL to control learning signal
|
[] using RL to control learning signal
|
||||||
[] consider using Dates.now() instead of timestamp because time_stamp may overflow
|
[] consider using Dates.now() instead of timestamp because time_stamp may overflow
|
||||||
[5] training should include adjusting α, neuron membrane potential decay factor
|
[5] training should include adjusting α, neuron membrane potential decay factor
|
||||||
which defined by neuron.tau_m formula in type.jl
|
which defined by neuron.tau_m formula in type.jl
|
||||||
|
|
||||||
Change from version: 0.0.2
|
Change from version: 0.0.2
|
||||||
-
|
[DONE] new learning method
|
||||||
|
- use Flux.logitcrossentropy for overall error
|
||||||
|
- remove ΔwRecChange that apply immediately during online learning
|
||||||
|
- collect ΔwRecChange during online learning (0-784th) and merge with wRec at
|
||||||
|
the end learning (1000th).
|
||||||
|
- compute model error at the end learning. Model error times with 5 constant for
|
||||||
|
higher learning impact than the error during online
|
||||||
|
|
||||||
All features
|
All features
|
||||||
- multidispatch + for loop as main compute method
|
- multidispatch + for loop as main compute method
|
||||||
@@ -86,6 +89,9 @@ using .learn
|
|||||||
on the correct answer -> strengthen the right neural pathway (connections) ->
|
on the correct answer -> strengthen the right neural pathway (connections) ->
|
||||||
this correct neural pathway resist to change.
|
this correct neural pathway resist to change.
|
||||||
Not used connection should dissapear (forgetting).
|
Not used connection should dissapear (forgetting).
|
||||||
|
|
||||||
|
Removed features
|
||||||
|
- Δweight * connection strength
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
98
src/learn.jl
98
src/learn.jl
@@ -1,67 +1,27 @@
|
|||||||
module learn
|
module learn
|
||||||
|
|
||||||
using Statistics, Random, LinearAlgebra, JSON3
|
using Statistics, Random, LinearAlgebra, JSON3, Flux
|
||||||
using GeneralUtils
|
using GeneralUtils
|
||||||
using ..types, ..snn_utils
|
using ..types, ..snn_utils
|
||||||
|
|
||||||
export learn!
|
export learn!, compute_wRecChange!, computeModelError
|
||||||
|
|
||||||
#------------------------------------------------------------------------------------------------100
|
#------------------------------------------------------------------------------------------------100
|
||||||
|
|
||||||
function learn!(m::model, modelRespond::AbstractVector, correctAnswer::Union{AbstractVector, Nothing})
|
function learn!(m::model)
|
||||||
if correctAnswer === nothing
|
learn!(m.knowledgeFn[:I])
|
||||||
correctAnswer_I = BitArray(zeros(length(modelRespond)))
|
|
||||||
else
|
|
||||||
correctAnswer_I = Bool.(correctAnswer) # correct answer for kfn I
|
|
||||||
end
|
|
||||||
|
|
||||||
learn!(m.knowledgeFn[:I], correctAnswer_I)
|
|
||||||
end
|
end
|
||||||
|
|
||||||
""" knowledgeFn learn()
|
""" knowledgeFn learn()
|
||||||
"""
|
"""
|
||||||
function learn!(kfn::kfn_1, correctAnswer::BitVector)
|
function learn!(kfn::kfn_1)
|
||||||
# compute kfn error for each neuron
|
|
||||||
# outs = [n.z_t1 for n in kfn.outputNeuronsArray]
|
|
||||||
# for (i, out) in enumerate(outs)
|
|
||||||
# if out != correctAnswer[i] # need to adjust weight
|
|
||||||
# kfnError = ( (kfn.outputNeuronsArray[i].v_th - kfn.outputNeuronsArray[i].vError) *
|
|
||||||
# 100 / kfn.outputNeuronsArray[i].v_th )
|
|
||||||
|
|
||||||
# Threads.@threads for n in kfn.neuronsArray
|
|
||||||
# # for n in kfn.neuronsArray
|
|
||||||
# learn!(n, kfnError)
|
|
||||||
# end
|
|
||||||
|
|
||||||
# learn!(kfn.outputNeuronsArray[i], kfnError)
|
|
||||||
# end
|
|
||||||
# end
|
|
||||||
|
|
||||||
# compute kfn error for each neuron
|
# compute kfn error for each neuron
|
||||||
outs = [n.z_t1 for n in kfn.outputNeuronsArray]
|
Threads.@threads for n in kfn.neuronsArray # multithread is not atomic and causing error
|
||||||
for (i, out) in enumerate(outs)
|
# for n in kfn.neuronsArray
|
||||||
if out == correctAnswer # output correct
|
learn!(n, kfn.firedNeurons, kfn.nExInType)
|
||||||
kfnError = 0.0
|
|
||||||
Threads.@threads for n in kfn.neuronsArray # multithread is not atomic and causing error
|
|
||||||
# for n in kfn.neuronsArray
|
|
||||||
compute_wRecChange!(n, kfnError)
|
|
||||||
learn!(n, kfn.firedNeurons, kfn.nExInType, true)
|
|
||||||
end
|
|
||||||
compute_wRecChange!(kfn.outputNeuronsArray[i], kfnError)
|
|
||||||
learn!(kfn.outputNeuronsArray[i], kfn.firedNeurons, kfn.nExInType,
|
|
||||||
kfn.kfnParams[:totalInputPort], true)
|
|
||||||
else
|
|
||||||
kfnError = ( (kfn.outputNeuronsArray[i].v_th - kfn.outputNeuronsArray[i].vError) *
|
|
||||||
100.0 / kfn.outputNeuronsArray[i].v_th )^2
|
|
||||||
Threads.@threads for n in kfn.neuronsArray # multithread is not atomic and causing error
|
|
||||||
# for n in kfn.neuronsArray
|
|
||||||
compute_wRecChange!(n, kfnError)
|
|
||||||
learn!(n, kfn.firedNeurons, kfn.nExInType, false)
|
|
||||||
end
|
|
||||||
compute_wRecChange!(kfn.outputNeuronsArray[i], kfnError)
|
|
||||||
learn!(kfn.outputNeuronsArray[i], kfn.firedNeurons, kfn.nExInType,
|
|
||||||
kfn.kfnParams[:totalInputPort], false)
|
|
||||||
end
|
end
|
||||||
|
for n in kfn.outputNeuronsArray
|
||||||
|
learn!(n, kfn.firedNeurons, kfn.nExInType, kfn.kfnParams[:totalInputPort])
|
||||||
end
|
end
|
||||||
|
|
||||||
# wrap up learning session
|
# wrap up learning session
|
||||||
@@ -70,6 +30,30 @@ function learn!(kfn::kfn_1, correctAnswer::BitVector)
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function computeModelError(modelRespond, correctAnswer; magnitude::Float64=1.0)
|
||||||
|
if correctAnswer === nothing
|
||||||
|
correctAnswer = BitArray(zeros(length(modelRespond)))
|
||||||
|
else
|
||||||
|
correctAnswer = Bool.(correctAnswer) # correct answer for kfn I
|
||||||
|
end
|
||||||
|
return Flux.logitcrossentropy(modelRespond, correctAnswer) .* magnitude
|
||||||
|
end
|
||||||
|
|
||||||
|
function compute_wRecChange!(m::model, error::Float64)
|
||||||
|
compute_wRecChange!(m.knowledgeFn[:I], error)
|
||||||
|
end
|
||||||
|
|
||||||
|
function compute_wRecChange!(kfn::kfn_1, error::Float64)
|
||||||
|
# compute kfn error for each neuron
|
||||||
|
Threads.@threads for n in kfn.neuronsArray
|
||||||
|
# for n in kfn.neuronsArray
|
||||||
|
compute_wRecChange!(n, error)
|
||||||
|
end
|
||||||
|
for n in kfn.outputNeuronsArray
|
||||||
|
compute_wRecChange!(n, error)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
function compute_wRecChange!(n::passthroughNeuron, error::Float64)
|
function compute_wRecChange!(n::passthroughNeuron, error::Float64)
|
||||||
# skip
|
# skip
|
||||||
end
|
end
|
||||||
@@ -98,14 +82,12 @@ function compute_wRecChange!(n::linearNeuron, error::Float64)
|
|||||||
reset_epsilonRec!(n)
|
reset_epsilonRec!(n)
|
||||||
end
|
end
|
||||||
|
|
||||||
function learn!(n::T, firedNeurons, nExInType, correctAnswer) where T<:inputNeuron
|
function learn!(n::T, firedNeurons, nExInType) where T<:inputNeuron
|
||||||
# skip
|
# skip
|
||||||
end
|
end
|
||||||
|
|
||||||
function learn!(n::T, firedNeurons, nExInType, correctAnswer) where T<:computeNeuron
|
function learn!(n::T, firedNeurons, nExInType) where T<:computeNeuron
|
||||||
wSign_0 = sign.(n.wRec) # original sign
|
wSign_0 = sign.(n.wRec) # original sign
|
||||||
#TESTING strong connection gets less weight change, weak connection gets more weight change
|
|
||||||
n.wRecChange .*= (connStrengthAdjust.(n.synapticStrength))
|
|
||||||
n.wRec += n.wRecChange # merge wRecChange into wRec
|
n.wRec += n.wRecChange # merge wRecChange into wRec
|
||||||
reset_wRecChange!(n)
|
reset_wRecChange!(n)
|
||||||
wSign_1 = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign
|
wSign_1 = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign
|
||||||
@@ -116,14 +98,12 @@ function learn!(n::T, firedNeurons, nExInType, correctAnswer) where T<:computeNe
|
|||||||
n.wRec .*= nonFlipedSign
|
n.wRec .*= nonFlipedSign
|
||||||
capMaxWeight!(n.wRec) # cap maximum weight
|
capMaxWeight!(n.wRec) # cap maximum weight
|
||||||
|
|
||||||
synapticConnStrength!(n, correctAnswer)
|
synapticConnStrength!(n)
|
||||||
neuroplasticity!(n, firedNeurons, nExInType)
|
neuroplasticity!(n, firedNeurons, nExInType)
|
||||||
end
|
end
|
||||||
|
|
||||||
function learn!(n::T, firedNeurons, nExInType, totalInputPort, correctAnswer) where T<:outputNeuron
|
function learn!(n::T, firedNeurons, nExInType, totalInputPort) where T<:outputNeuron
|
||||||
wSign_0 = sign.(n.wRec) # original sign
|
wSign_0 = sign.(n.wRec) # original sign
|
||||||
#TESTING strong connection gets less weight change, weak connection gets more weight change
|
|
||||||
n.wRecChange .*= (connStrengthAdjust.(n.synapticStrength))
|
|
||||||
n.wRec += n.wRecChange
|
n.wRec += n.wRecChange
|
||||||
reset_wRecChange!(n)
|
reset_wRecChange!(n)
|
||||||
wSign_1 = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign
|
wSign_1 = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign
|
||||||
@@ -134,7 +114,7 @@ function learn!(n::T, firedNeurons, nExInType, totalInputPort, correctAnswer) wh
|
|||||||
n.wRec .*= nonFlipedSign
|
n.wRec .*= nonFlipedSign
|
||||||
capMaxWeight!(n.wRec) # cap maximum weight
|
capMaxWeight!(n.wRec) # cap maximum weight
|
||||||
|
|
||||||
synapticConnStrength!(n, correctAnswer)
|
synapticConnStrength!(n)
|
||||||
neuroplasticity!(n,firedNeurons, nExInType, totalInputPort)
|
neuroplasticity!(n,firedNeurons, nExInType, totalInputPort)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|||||||
@@ -279,13 +279,13 @@ function synapticConnStrength(currentStrength::Float64, updown::String)
|
|||||||
|
|
||||||
if updown == "up"
|
if updown == "up"
|
||||||
if currentStrength > 4 # strong connection
|
if currentStrength > 4 # strong connection
|
||||||
updatedStrength = currentStrength + (Δstrength * 0.2)
|
updatedStrength = currentStrength + (Δstrength * 1.0)
|
||||||
else
|
else
|
||||||
updatedStrength = currentStrength + (Δstrength * 0.1)
|
updatedStrength = currentStrength + (Δstrength * 1.0)
|
||||||
end
|
end
|
||||||
elseif updown == "down"
|
elseif updown == "down"
|
||||||
if currentStrength > 4
|
if currentStrength > 4
|
||||||
updatedStrength = currentStrength - (Δstrength * 0.1)
|
updatedStrength = currentStrength - (Δstrength * 1.0)
|
||||||
else
|
else
|
||||||
updatedStrength = currentStrength - (Δstrength * 1.0)
|
updatedStrength = currentStrength - (Δstrength * 1.0)
|
||||||
end
|
end
|
||||||
@@ -294,74 +294,29 @@ function synapticConnStrength(currentStrength::Float64, updown::String)
|
|||||||
end
|
end
|
||||||
return updatedStrength::Float64
|
return updatedStrength::Float64
|
||||||
end
|
end
|
||||||
# function synapticConnStrength(currentStrength::Float64, updown::String)
|
|
||||||
# Δstrength = connStrengthAdjust(currentStrength)
|
|
||||||
|
|
||||||
# if updown == "up"
|
|
||||||
# updatedStrength = currentStrength + Δstrength
|
|
||||||
# else
|
|
||||||
# updatedStrength = currentStrength - Δstrength
|
|
||||||
# end
|
|
||||||
# return updatedStrength::Float64
|
|
||||||
# end
|
|
||||||
|
|
||||||
""" Compute all synaptic connection strength of a neuron. Also mark n.wRec to 0 if wRec goes
|
""" Compute all synaptic connection strength of a neuron. Also mark n.wRec to 0 if wRec goes
|
||||||
below lowerlimit.
|
below lowerlimit.
|
||||||
"""
|
"""
|
||||||
# function synapticConnStrength!(n::Union{computeNeuron, outputNeuron})
|
function synapticConnStrength!(n::Union{computeNeuron, outputNeuron})
|
||||||
# for (i, connStrength) in enumerate(n.synapticStrength)
|
for (i, connStrength) in enumerate(n.synapticStrength)
|
||||||
# # check whether connStrength increase or decrease based on usage from n.epsilonRec
|
# check whether connStrength increase or decrease based on usage from n.epsilonRec
|
||||||
# """ use n.z_i_t_commulative instead of the best choice, epsilonRec, here because ΔwRecChange
|
""" use n.z_i_t_commulative instead of the best choice, epsilonRec, here because ΔwRecChange
|
||||||
# calculation in learn!() will reset epsilonRec to zeroes vector in case where
|
calculation in learn!() will reset epsilonRec to zeroes vector in case where
|
||||||
# output neuron fires and trigger learn!() just before this synapticConnStrength
|
output neuron fires and trigger learn!() just before this synapticConnStrength
|
||||||
# calculation.
|
calculation.
|
||||||
# Since n.z_i_t_commulative indicates whether a synaptic connection were used or not, it is
|
Since n.z_i_t_commulative indicates whether a synaptic connection were used or not, it is
|
||||||
# ok to use. n.z_i_t_commulative also span across a training sample without resetting.
|
ok to use. n.z_i_t_commulative also span across a training sample without resetting.
|
||||||
# """
|
"""
|
||||||
# updown = n.z_i_t_commulative[i] == 0 ? "down" : "up" #
|
updown = n.z_i_t_commulative[i] == 0 ? "down" : "up"
|
||||||
# updatedConnStrength = synapticConnStrength(connStrength, updown)
|
updatedConnStrength = synapticConnStrength(connStrength, updown)
|
||||||
# updatedConnStrength = GeneralUtils.limitvalue(updatedConnStrength,
|
updatedConnStrength = GeneralUtils.limitvalue(updatedConnStrength,
|
||||||
# n.synapticStrengthLimit.lowerlimit, n.synapticStrengthLimit.upperlimit)
|
n.synapticStrengthLimit.lowerlimit, n.synapticStrengthLimit.upperlimit)
|
||||||
# # at lowerlimit, mark wRec at this position to 0. for new random synaptic conn
|
# at lowerlimit, mark wRec at this position to 0. for new random synaptic conn
|
||||||
# if updatedConnStrength == n.synapticStrengthLimit.lowerlimit[1]
|
if updatedConnStrength == n.synapticStrengthLimit.lowerlimit[1]
|
||||||
# n.wRec[i] = 0.0
|
n.wRec[i] = 0.0
|
||||||
# end
|
|
||||||
# n.synapticStrength[i] = updatedConnStrength
|
|
||||||
# end
|
|
||||||
# end
|
|
||||||
|
|
||||||
function synapticConnStrength!(n::Union{computeNeuron, outputNeuron}, correctAnswer::Bool)
|
|
||||||
if correctAnswer == true
|
|
||||||
for (i, connStrength) in enumerate(n.synapticStrength)
|
|
||||||
# check whether connStrength increase or decrease based on usage from n.epsilonRec
|
|
||||||
""" use n.z_i_t_commulative instead of the best choice, epsilonRec, here because ΔwRecChange
|
|
||||||
calculation in learn!() will reset epsilonRec to zeroes vector in case where
|
|
||||||
output neuron fires and trigger learn!() just before this synapticConnStrength
|
|
||||||
calculation.
|
|
||||||
Since n.z_i_t_commulative indicates whether a synaptic connection were used or not, it is
|
|
||||||
ok to use. n.z_i_t_commulative also span across a training sample without resetting.
|
|
||||||
"""
|
|
||||||
updown = n.z_i_t_commulative[i] == 0 ? "down" : "up"
|
|
||||||
updatedConnStrength = synapticConnStrength(connStrength, updown)
|
|
||||||
updatedConnStrength = GeneralUtils.limitvalue(updatedConnStrength,
|
|
||||||
n.synapticStrengthLimit.lowerlimit, n.synapticStrengthLimit.upperlimit)
|
|
||||||
# at lowerlimit, mark wRec at this position to 0. for new random synaptic conn
|
|
||||||
if updatedConnStrength == n.synapticStrengthLimit.lowerlimit[1]
|
|
||||||
n.wRec[i] = 0.0
|
|
||||||
end
|
|
||||||
n.synapticStrength[i] = updatedConnStrength
|
|
||||||
end
|
|
||||||
else
|
|
||||||
for (i, connStrength) in enumerate(n.synapticStrength)
|
|
||||||
updatedConnStrength = synapticConnStrength(connStrength, "down")
|
|
||||||
updatedConnStrength = GeneralUtils.limitvalue(updatedConnStrength,
|
|
||||||
n.synapticStrengthLimit.lowerlimit, n.synapticStrengthLimit.upperlimit)
|
|
||||||
# at lowerlimit, mark wRec at this position to 0. for new random synaptic conn
|
|
||||||
if updatedConnStrength == n.synapticStrengthLimit.lowerlimit[1]
|
|
||||||
n.wRec[i] = 0.0
|
|
||||||
end
|
|
||||||
n.synapticStrength[i] = updatedConnStrength
|
|
||||||
end
|
end
|
||||||
|
n.synapticStrength[i] = updatedConnStrength
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user