use logitcrossentropy to train
This commit is contained in:
@@ -34,17 +34,20 @@ using .learn
|
||||
|
||||
""" version 0.0.3
|
||||
Todo:
|
||||
[*2] implement connection strength based on right or wrong answer
|
||||
[*1] how to manage how much constrength increase and decrease
|
||||
[4] implement dormant connection
|
||||
[3] Δweight * connection strength
|
||||
[] using RL to control learning signal
|
||||
[] consider using Dates.now() instead of timestamp because time_stamp may overflow
|
||||
[5] training should include adjusting α, neuron membrane potential decay factor
|
||||
which defined by neuron.tau_m formula in type.jl
|
||||
|
||||
Change from version: 0.0.2
|
||||
-
|
||||
[DONE] new learning method
|
||||
- use Flux.logitcrossentropy for overall error
|
||||
- remove ΔwRecChange that apply immediately during online learning
|
||||
- collect ΔwRecChange during online learning (0-784th) and merge with wRec at
|
||||
the end learning (1000th).
|
||||
- compute model error at the end learning. Model error times with 5 constant for
|
||||
higher learning impact than the error during online
|
||||
|
||||
All features
|
||||
- multidispatch + for loop as main compute method
|
||||
@@ -86,6 +89,9 @@ using .learn
|
||||
on the correct answer -> strengthen the right neural pathway (connections) ->
|
||||
this correct neural pathway resist to change.
|
||||
Not used connection should dissapear (forgetting).
|
||||
|
||||
Removed features
|
||||
- Δweight * connection strength
|
||||
"""
|
||||
|
||||
|
||||
|
||||
94
src/learn.jl
94
src/learn.jl
@@ -1,67 +1,27 @@
|
||||
module learn
|
||||
|
||||
using Statistics, Random, LinearAlgebra, JSON3
|
||||
using Statistics, Random, LinearAlgebra, JSON3, Flux
|
||||
using GeneralUtils
|
||||
using ..types, ..snn_utils
|
||||
|
||||
export learn!
|
||||
export learn!, compute_wRecChange!, computeModelError
|
||||
|
||||
#------------------------------------------------------------------------------------------------100
|
||||
|
||||
function learn!(m::model, modelRespond::AbstractVector, correctAnswer::Union{AbstractVector, Nothing})
|
||||
if correctAnswer === nothing
|
||||
correctAnswer_I = BitArray(zeros(length(modelRespond)))
|
||||
else
|
||||
correctAnswer_I = Bool.(correctAnswer) # correct answer for kfn I
|
||||
end
|
||||
|
||||
learn!(m.knowledgeFn[:I], correctAnswer_I)
|
||||
function learn!(m::model)
|
||||
learn!(m.knowledgeFn[:I])
|
||||
end
|
||||
|
||||
""" knowledgeFn learn()
|
||||
"""
|
||||
function learn!(kfn::kfn_1, correctAnswer::BitVector)
|
||||
function learn!(kfn::kfn_1)
|
||||
# compute kfn error for each neuron
|
||||
# outs = [n.z_t1 for n in kfn.outputNeuronsArray]
|
||||
# for (i, out) in enumerate(outs)
|
||||
# if out != correctAnswer[i] # need to adjust weight
|
||||
# kfnError = ( (kfn.outputNeuronsArray[i].v_th - kfn.outputNeuronsArray[i].vError) *
|
||||
# 100 / kfn.outputNeuronsArray[i].v_th )
|
||||
|
||||
# Threads.@threads for n in kfn.neuronsArray
|
||||
# # for n in kfn.neuronsArray
|
||||
# learn!(n, kfnError)
|
||||
# end
|
||||
|
||||
# learn!(kfn.outputNeuronsArray[i], kfnError)
|
||||
# end
|
||||
# end
|
||||
|
||||
# compute kfn error for each neuron
|
||||
outs = [n.z_t1 for n in kfn.outputNeuronsArray]
|
||||
for (i, out) in enumerate(outs)
|
||||
if out == correctAnswer # output correct
|
||||
kfnError = 0.0
|
||||
Threads.@threads for n in kfn.neuronsArray # multithread is not atomic and causing error
|
||||
# for n in kfn.neuronsArray
|
||||
compute_wRecChange!(n, kfnError)
|
||||
learn!(n, kfn.firedNeurons, kfn.nExInType, true)
|
||||
end
|
||||
compute_wRecChange!(kfn.outputNeuronsArray[i], kfnError)
|
||||
learn!(kfn.outputNeuronsArray[i], kfn.firedNeurons, kfn.nExInType,
|
||||
kfn.kfnParams[:totalInputPort], true)
|
||||
else
|
||||
kfnError = ( (kfn.outputNeuronsArray[i].v_th - kfn.outputNeuronsArray[i].vError) *
|
||||
100.0 / kfn.outputNeuronsArray[i].v_th )^2
|
||||
Threads.@threads for n in kfn.neuronsArray # multithread is not atomic and causing error
|
||||
# for n in kfn.neuronsArray
|
||||
compute_wRecChange!(n, kfnError)
|
||||
learn!(n, kfn.firedNeurons, kfn.nExInType, false)
|
||||
end
|
||||
compute_wRecChange!(kfn.outputNeuronsArray[i], kfnError)
|
||||
learn!(kfn.outputNeuronsArray[i], kfn.firedNeurons, kfn.nExInType,
|
||||
kfn.kfnParams[:totalInputPort], false)
|
||||
learn!(n, kfn.firedNeurons, kfn.nExInType)
|
||||
end
|
||||
for n in kfn.outputNeuronsArray
|
||||
learn!(n, kfn.firedNeurons, kfn.nExInType, kfn.kfnParams[:totalInputPort])
|
||||
end
|
||||
|
||||
# wrap up learning session
|
||||
@@ -70,6 +30,30 @@ function learn!(kfn::kfn_1, correctAnswer::BitVector)
|
||||
end
|
||||
end
|
||||
|
||||
function computeModelError(modelRespond, correctAnswer; magnitude::Float64=1.0)
|
||||
if correctAnswer === nothing
|
||||
correctAnswer = BitArray(zeros(length(modelRespond)))
|
||||
else
|
||||
correctAnswer = Bool.(correctAnswer) # correct answer for kfn I
|
||||
end
|
||||
return Flux.logitcrossentropy(modelRespond, correctAnswer) .* magnitude
|
||||
end
|
||||
|
||||
function compute_wRecChange!(m::model, error::Float64)
|
||||
compute_wRecChange!(m.knowledgeFn[:I], error)
|
||||
end
|
||||
|
||||
function compute_wRecChange!(kfn::kfn_1, error::Float64)
|
||||
# compute kfn error for each neuron
|
||||
Threads.@threads for n in kfn.neuronsArray
|
||||
# for n in kfn.neuronsArray
|
||||
compute_wRecChange!(n, error)
|
||||
end
|
||||
for n in kfn.outputNeuronsArray
|
||||
compute_wRecChange!(n, error)
|
||||
end
|
||||
end
|
||||
|
||||
function compute_wRecChange!(n::passthroughNeuron, error::Float64)
|
||||
# skip
|
||||
end
|
||||
@@ -98,14 +82,12 @@ function compute_wRecChange!(n::linearNeuron, error::Float64)
|
||||
reset_epsilonRec!(n)
|
||||
end
|
||||
|
||||
function learn!(n::T, firedNeurons, nExInType, correctAnswer) where T<:inputNeuron
|
||||
function learn!(n::T, firedNeurons, nExInType) where T<:inputNeuron
|
||||
# skip
|
||||
end
|
||||
|
||||
function learn!(n::T, firedNeurons, nExInType, correctAnswer) where T<:computeNeuron
|
||||
function learn!(n::T, firedNeurons, nExInType) where T<:computeNeuron
|
||||
wSign_0 = sign.(n.wRec) # original sign
|
||||
#TESTING strong connection gets less weight change, weak connection gets more weight change
|
||||
n.wRecChange .*= (connStrengthAdjust.(n.synapticStrength))
|
||||
n.wRec += n.wRecChange # merge wRecChange into wRec
|
||||
reset_wRecChange!(n)
|
||||
wSign_1 = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign
|
||||
@@ -116,14 +98,12 @@ function learn!(n::T, firedNeurons, nExInType, correctAnswer) where T<:computeNe
|
||||
n.wRec .*= nonFlipedSign
|
||||
capMaxWeight!(n.wRec) # cap maximum weight
|
||||
|
||||
synapticConnStrength!(n, correctAnswer)
|
||||
synapticConnStrength!(n)
|
||||
neuroplasticity!(n, firedNeurons, nExInType)
|
||||
end
|
||||
|
||||
function learn!(n::T, firedNeurons, nExInType, totalInputPort, correctAnswer) where T<:outputNeuron
|
||||
function learn!(n::T, firedNeurons, nExInType, totalInputPort) where T<:outputNeuron
|
||||
wSign_0 = sign.(n.wRec) # original sign
|
||||
#TESTING strong connection gets less weight change, weak connection gets more weight change
|
||||
n.wRecChange .*= (connStrengthAdjust.(n.synapticStrength))
|
||||
n.wRec += n.wRecChange
|
||||
reset_wRecChange!(n)
|
||||
wSign_1 = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign
|
||||
@@ -134,7 +114,7 @@ function learn!(n::T, firedNeurons, nExInType, totalInputPort, correctAnswer) wh
|
||||
n.wRec .*= nonFlipedSign
|
||||
capMaxWeight!(n.wRec) # cap maximum weight
|
||||
|
||||
synapticConnStrength!(n, correctAnswer)
|
||||
synapticConnStrength!(n)
|
||||
neuroplasticity!(n,firedNeurons, nExInType, totalInputPort)
|
||||
end
|
||||
|
||||
|
||||
@@ -279,13 +279,13 @@ function synapticConnStrength(currentStrength::Float64, updown::String)
|
||||
|
||||
if updown == "up"
|
||||
if currentStrength > 4 # strong connection
|
||||
updatedStrength = currentStrength + (Δstrength * 0.2)
|
||||
updatedStrength = currentStrength + (Δstrength * 1.0)
|
||||
else
|
||||
updatedStrength = currentStrength + (Δstrength * 0.1)
|
||||
updatedStrength = currentStrength + (Δstrength * 1.0)
|
||||
end
|
||||
elseif updown == "down"
|
||||
if currentStrength > 4
|
||||
updatedStrength = currentStrength - (Δstrength * 0.1)
|
||||
updatedStrength = currentStrength - (Δstrength * 1.0)
|
||||
else
|
||||
updatedStrength = currentStrength - (Δstrength * 1.0)
|
||||
end
|
||||
@@ -294,44 +294,11 @@ function synapticConnStrength(currentStrength::Float64, updown::String)
|
||||
end
|
||||
return updatedStrength::Float64
|
||||
end
|
||||
# function synapticConnStrength(currentStrength::Float64, updown::String)
|
||||
# Δstrength = connStrengthAdjust(currentStrength)
|
||||
|
||||
# if updown == "up"
|
||||
# updatedStrength = currentStrength + Δstrength
|
||||
# else
|
||||
# updatedStrength = currentStrength - Δstrength
|
||||
# end
|
||||
# return updatedStrength::Float64
|
||||
# end
|
||||
|
||||
""" Compute all synaptic connection strength of a neuron. Also mark n.wRec to 0 if wRec goes
|
||||
below lowerlimit.
|
||||
"""
|
||||
# function synapticConnStrength!(n::Union{computeNeuron, outputNeuron})
|
||||
# for (i, connStrength) in enumerate(n.synapticStrength)
|
||||
# # check whether connStrength increase or decrease based on usage from n.epsilonRec
|
||||
# """ use n.z_i_t_commulative instead of the best choice, epsilonRec, here because ΔwRecChange
|
||||
# calculation in learn!() will reset epsilonRec to zeroes vector in case where
|
||||
# output neuron fires and trigger learn!() just before this synapticConnStrength
|
||||
# calculation.
|
||||
# Since n.z_i_t_commulative indicates whether a synaptic connection were used or not, it is
|
||||
# ok to use. n.z_i_t_commulative also span across a training sample without resetting.
|
||||
# """
|
||||
# updown = n.z_i_t_commulative[i] == 0 ? "down" : "up" #
|
||||
# updatedConnStrength = synapticConnStrength(connStrength, updown)
|
||||
# updatedConnStrength = GeneralUtils.limitvalue(updatedConnStrength,
|
||||
# n.synapticStrengthLimit.lowerlimit, n.synapticStrengthLimit.upperlimit)
|
||||
# # at lowerlimit, mark wRec at this position to 0. for new random synaptic conn
|
||||
# if updatedConnStrength == n.synapticStrengthLimit.lowerlimit[1]
|
||||
# n.wRec[i] = 0.0
|
||||
# end
|
||||
# n.synapticStrength[i] = updatedConnStrength
|
||||
# end
|
||||
# end
|
||||
|
||||
function synapticConnStrength!(n::Union{computeNeuron, outputNeuron}, correctAnswer::Bool)
|
||||
if correctAnswer == true
|
||||
function synapticConnStrength!(n::Union{computeNeuron, outputNeuron})
|
||||
for (i, connStrength) in enumerate(n.synapticStrength)
|
||||
# check whether connStrength increase or decrease based on usage from n.epsilonRec
|
||||
""" use n.z_i_t_commulative instead of the best choice, epsilonRec, here because ΔwRecChange
|
||||
@@ -351,18 +318,6 @@ function synapticConnStrength!(n::Union{computeNeuron, outputNeuron}, correctAns
|
||||
end
|
||||
n.synapticStrength[i] = updatedConnStrength
|
||||
end
|
||||
else
|
||||
for (i, connStrength) in enumerate(n.synapticStrength)
|
||||
updatedConnStrength = synapticConnStrength(connStrength, "down")
|
||||
updatedConnStrength = GeneralUtils.limitvalue(updatedConnStrength,
|
||||
n.synapticStrengthLimit.lowerlimit, n.synapticStrengthLimit.upperlimit)
|
||||
# at lowerlimit, mark wRec at this position to 0. for new random synaptic conn
|
||||
if updatedConnStrength == n.synapticStrengthLimit.lowerlimit[1]
|
||||
n.wRec[i] = 0.0
|
||||
end
|
||||
n.synapticStrength[i] = updatedConnStrength
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
function synapticConnStrength!(n::inputNeuron) end
|
||||
|
||||
Reference in New Issue
Block a user