change from end of sample learning to online learning

This commit is contained in:
2023-05-26 21:03:03 +07:00
parent 3556167591
commit b0cede75c1
6 changed files with 146 additions and 152 deletions

View File

@@ -16,9 +16,9 @@ weakdeps = ["ChainRulesCore"]
[[deps.Accessors]] [[deps.Accessors]]
deps = ["Compat", "CompositionsBase", "ConstructionBase", "Dates", "InverseFunctions", "LinearAlgebra", "MacroTools", "Requires", "Test"] deps = ["Compat", "CompositionsBase", "ConstructionBase", "Dates", "InverseFunctions", "LinearAlgebra", "MacroTools", "Requires", "Test"]
git-tree-sha1 = "a4f8669e46c8cdf68661fe6bb0f7b89f51dd23cf" git-tree-sha1 = "2b301c2388067d655fe5e4ca6d4aa53b61f895b4"
uuid = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" uuid = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
version = "0.1.30" version = "0.1.31"
[deps.Accessors.extensions] [deps.Accessors.extensions]
AccessorsAxisKeysExt = "AxisKeys" AccessorsAxisKeysExt = "AxisKeys"
@@ -67,10 +67,24 @@ uuid = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
version = "0.4.2" version = "0.4.2"
[[deps.BangBang]] [[deps.BangBang]]
deps = ["Compat", "ConstructionBase", "Future", "InitialValues", "LinearAlgebra", "Requires", "Setfield", "Tables", "ZygoteRules"] deps = ["Compat", "ConstructionBase", "InitialValues", "LinearAlgebra", "Requires", "Setfield", "Tables"]
git-tree-sha1 = "7fe6d92c4f281cf4ca6f2fba0ce7b299742da7ca" git-tree-sha1 = "54b00d1b93791f8e19e31584bd30f2cb6004614b"
uuid = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" uuid = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
version = "0.3.37" version = "0.3.38"
[deps.BangBang.extensions]
BangBangChainRulesCoreExt = "ChainRulesCore"
BangBangDataFramesExt = "DataFrames"
BangBangStaticArraysExt = "StaticArrays"
BangBangStructArraysExt = "StructArrays"
BangBangTypedTablesExt = "TypedTables"
[deps.BangBang.weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9"
[[deps.Base64]] [[deps.Base64]]
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
@@ -155,9 +169,13 @@ uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
version = "1.0.2+0" version = "1.0.2+0"
[[deps.CompositionsBase]] [[deps.CompositionsBase]]
git-tree-sha1 = "455419f7e328a1a2493cabc6428d79e951349769" git-tree-sha1 = "802bb88cd69dfd1509f6670416bd4434015693ad"
uuid = "a33af91c-f02d-484b-be07-31d278c5ca2b" uuid = "a33af91c-f02d-484b-be07-31d278c5ca2b"
version = "0.1.1" version = "0.1.2"
weakdeps = ["InverseFunctions"]
[deps.CompositionsBase.extensions]
CompositionsBaseInverseFunctionsExt = "InverseFunctions"
[[deps.ConstructionBase]] [[deps.ConstructionBase]]
deps = ["LinearAlgebra"] deps = ["LinearAlgebra"]
@@ -340,13 +358,13 @@ version = "0.1.4"
[[deps.GPUCompiler]] [[deps.GPUCompiler]]
deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "Scratch", "TimerOutputs", "UUIDs"] deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "Scratch", "TimerOutputs", "UUIDs"]
git-tree-sha1 = "e9a9173cd77e16509cdf9c1663fda19b22a518b7" git-tree-sha1 = "5737dc242dadd392d934ee330c69ceff47f0259c"
uuid = "61eb1bfa-7361-4325-ad38-22787b887f55" uuid = "61eb1bfa-7361-4325-ad38-22787b887f55"
version = "0.19.3" version = "0.19.4"
[[deps.GeneralUtils]] [[deps.GeneralUtils]]
deps = ["DataStructures", "Distributions", "JSON3"] deps = ["DataStructures", "Distributions", "JSON3"]
path = "/home/ton/.julia/dev/GeneralUtils" path = "C:\\Users\\naraw\\.julia\\dev\\GeneralUtils"
uuid = "c6c72f09-b708-4ac8-ac7c-2084d70108fe" uuid = "c6c72f09-b708-4ac8-ac7c-2084d70108fe"
version = "0.1.0" version = "0.1.0"
@@ -750,9 +768,9 @@ version = "0.1.15"
[[deps.StaticArrays]] [[deps.StaticArrays]]
deps = ["LinearAlgebra", "Random", "StaticArraysCore", "Statistics"] deps = ["LinearAlgebra", "Random", "StaticArraysCore", "Statistics"]
git-tree-sha1 = "c262c8e978048c2b095be1672c9bee55b4619521" git-tree-sha1 = "8982b3607a212b070a5e46eea83eb62b4744ae12"
uuid = "90137ffa-7385-5640-81b9-e52037218182" uuid = "90137ffa-7385-5640-81b9-e52037218182"
version = "1.5.24" version = "1.5.25"
[[deps.StaticArraysCore]] [[deps.StaticArraysCore]]
git-tree-sha1 = "6b7ba252635a5eff6a0b0664a41ee140a1c9e72a" git-tree-sha1 = "6b7ba252635a5eff6a0b0664a41ee140a1c9e72a"

View File

@@ -25,16 +25,16 @@ using .forward
include("learn.jl") include("learn.jl")
using .learn using .learn
include("readout.jl") # include("readout.jl")
using .readout # using .readout
include("interface.jl") # include("interface.jl")
using .interface # using .interface
#------------------------------------------------------------------------------------------------100 #------------------------------------------------------------------------------------------------100
""" """
Todo: Todo:
[*1] add maximum weight cap of each connection
[2] implement connection strength based on right or wrong answer [2] implement connection strength based on right or wrong answer
[4] implement dormant connection [4] implement dormant connection
[3] Δweight * connection strength [3] Δweight * connection strength
@@ -63,6 +63,7 @@ using .interface
[DONE] neuroplasticity() i.e. change connection [DONE] neuroplasticity() i.e. change connection
[DONE] add multi threads [DONE] add multi threads
[DONE] during 0 training if 1-9 output neuron fires, adjust weight only those neurons [DONE] during 0 training if 1-9 output neuron fires, adjust weight only those neurons
[DONE] add maximum weight cap of each connection
Change from version: v06_36a Change from version: v06_36a
- -

View File

@@ -1,8 +1,6 @@
module forward module forward
using Flux.Optimise: apply! using Statistics, Random, LinearAlgebra, JSON3
using Statistics, Flux, Random, LinearAlgebra, JSON3
using GeneralUtils using GeneralUtils
using ..types, ..snn_utils using ..types, ..snn_utils
@@ -77,19 +75,17 @@ function (kfn::kfn_1)(m::model, input_data::AbstractVector)
kfn.firedNeurons_t0 = [n.z_t for n in kfn.neuronsArray] #TODO check if it is used? kfn.firedNeurons_t0 = [n.z_t for n in kfn.neuronsArray] #TODO check if it is used?
Threads.@threads for n in kfn.neuronsArray # Threads.@threads for n in kfn.neuronsArray
# for n in kfn.neuronsArray for n in kfn.neuronsArray
n(kfn) n(kfn)
end end
kfn.firedNeurons_t1 = [n.z_t1 for n in kfn.neuronsArray] kfn.firedNeurons_t1 = [n.z_t1 for n in kfn.neuronsArray]
append!(kfn.firedNeurons, findall(kfn.firedNeurons_t1)) # store id of neuron that fires append!(kfn.firedNeurons, findall(kfn.firedNeurons_t1)) # store id of neuron that fires
if kfn.learningStage == "end_learning" kfn.firedNeurons |> unique! # use for random new neuron connection
kfn.firedNeurons |> unique! # use for random new neuron connection
end
Threads.@threads for n in kfn.outputNeuronsArray # Threads.@threads for n in kfn.outputNeuronsArray
# for n in kfn.outputNeuronsArray for n in kfn.outputNeuronsArray
n(kfn) n(kfn)
end end

View File

@@ -1,8 +1,6 @@
module learn module learn
using Flux.Optimise: apply! using Statistics, Random, LinearAlgebra, JSON3
using Statistics, Flux, Random, LinearAlgebra, JSON3
using GeneralUtils using GeneralUtils
using ..types, ..snn_utils using ..types, ..snn_utils
@@ -12,7 +10,7 @@ export learn!
function learn!(m::model, modelRespond::Vector{Bool}, correctAnswer::Union{AbstractVector, Nothing}) function learn!(m::model, modelRespond::Vector{Bool}, correctAnswer::Union{AbstractVector, Nothing})
if correctAnswer === nothing if correctAnswer === nothing
correctAnswer_I = BitArray(undef, length(modelRespond)) correctAnswer_I = BitArray(zeros(length(modelRespond)))
else else
correctAnswer_I = Bool.(correctAnswer) # correct answer for kfn I correctAnswer_I = Bool.(correctAnswer) # correct answer for kfn I
end end
@@ -43,76 +41,43 @@ function learn!(kfn::kfn_1, correctAnswer::BitVector)
outs = [n.z_t1 for n in kfn.outputNeuronsArray] outs = [n.z_t1 for n in kfn.outputNeuronsArray]
for (i, out) in enumerate(outs) for (i, out) in enumerate(outs)
if out != correctAnswer[i] # need to adjust weight if out != correctAnswer[i] # need to adjust weight
kfnError = ( (kfn.outputNeuronsArray[i].v_th - kfn.outputNeuronsArray[i].vError) / kfnError = ( (kfn.outputNeuronsArray[i].v_th - kfn.outputNeuronsArray[i].vError) * 100.0 /
kfn.outputNeuronsArray[i].v_th ) kfn.outputNeuronsArray[i].v_th )
if correctAnswer[i] == 1 # output neuron that associated with correctAnswer if correctAnswer[i] == 1 # output neuron that associated with correctAnswer
Threads.@threads for n in kfn.neuronsArray Threads.@threads for n in kfn.neuronsArray # multithread is not atomic and causing error
# for n in kfn.neuronsArray # for n in kfn.neuronsArray
learn!(n, kfnError) compute_wRecChange!(n, kfnError)
learn!(n, kfn.firedNeurons, kfn.nExInType)
end end
compute_wRecChange!(kfn.outputNeuronsArray[i], kfnError)
learn!(kfn.outputNeuronsArray[i], kfnError) learn!(kfn.outputNeuronsArray[i], kfn.firedNeurons, kfn.nExInType,
kfn.kfnParams[:totalInputPort])
else # output neuron that is NOT associated with correctAnswer else # output neuron that is NOT associated with correctAnswer
learn!(kfn.outputNeuronsArray[i], kfnError) compute_wRecChange!(kfn.outputNeuronsArray[i], kfnError)
learn!(kfn.outputNeuronsArray[i], kfn.firedNeurons, kfn.nExInType,
kfn.kfnParams[:totalInputPort])
end end
end end
end end
# wrap up learning session # wrap up learning session
if kfn.learningStage == "end_learning" if kfn.learningStage == "end_learning"
Threads.@threads for n in kfn.neuronsArray
# for n in kfn.neuronsArray
if typeof(n) <: computeNeuron
wSign_0 = sign.(n.wRec) # original sign
n.wRec += n.wRecChange # merge wRecChange into wRec
wSign_1 = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign
nonFlipedSign = isequal.(wSign_0, wSign_1) # 1 not fliped, 0 fliped
# normalize wRec peak to prevent input signal overwhelming neuron
normalizePeak!(n.wRec, n.wRecChange, 2)
# set weight that fliped sign to 0 for random new connection
n.wRec .*= nonFlipedSign
capMaxWeight!(n.wRec) # cap maximum weight
synapticConnStrength!(n)
neuroplasticity!(n, kfn.firedNeurons, kfn.nExInType)
end
end
for n in kfn.outputNeuronsArray # merge wRecChange into wRec
wSign_0 = sign.(n.wRec) # original sign
n.wRec += n.wRecChange
wSign_1 = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign
nonFlipedSign = isequal.(wSign_0, wSign_1) # 1 not fliped, 0 fliped
normalizePeak!(n.wRec, n.wRecChange, 2)
n.wRec .*= nonFlipedSign # set weight that fliped sign to 0 for random new connection
capMaxWeight!(n.wRec) # cap maximum weight
synapticConnStrength!(n)
neuroplasticity!(n, kfn.firedNeurons, kfn.nExInType, kfn.kfnParams[:totalInputPort])
end
kfn.learningStage = "inference" kfn.learningStage = "inference"
end end
end end
""" passthroughNeuron learn() function compute_wRecChange!(n::passthroughNeuron, error::Float64)
"""
function learn!(n::passthroughNeuron, error::Float64)
# skip # skip
end end
""" lif learn() function compute_wRecChange!(n::lifNeuron, error::Float64)
"""
function learn!(n::lifNeuron, error::Float64)
n.eRec = n.phi * n.epsilonRec n.eRec = n.phi * n.epsilonRec
ΔwRecChange = n.eta * error * n.eRec ΔwRecChange = n.eta * error * n.eRec
n.wRecChange .+= ΔwRecChange n.wRecChange .+= ΔwRecChange
reset_epsilonRec!(n) reset_epsilonRec!(n)
end end
""" alifNeuron learn() function compute_wRecChange!(n::alifNeuron, error::Float64)
"""
function learn!(n::alifNeuron, error::Float64)
n.eRec_v = n.phi * n.epsilonRec n.eRec_v = n.phi * n.epsilonRec
n.eRec_a = -n.phi * n.beta * n.epsilonRecA n.eRec_a = -n.phi * n.beta * n.epsilonRecA
n.eRec = n.eRec_v + n.eRec_a n.eRec = n.eRec_v + n.eRec_a
@@ -122,14 +87,71 @@ function learn!(n::alifNeuron, error::Float64)
reset_epsilonRecA!(n) reset_epsilonRecA!(n)
end end
""" linearNeuron learn() function compute_wRecChange!(n::linearNeuron, error::Float64)
"""
function learn!(n::linearNeuron, error::Float64)
n.eRec = n.phi * n.epsilonRec n.eRec = n.phi * n.epsilonRec
ΔwRecChange = n.eta * error * n.eRec ΔwRecChange = n.eta * error * n.eRec
n.wRecChange .+= ΔwRecChange n.wRecChange .+= ΔwRecChange
reset_epsilonRec!(n) reset_epsilonRec!(n)
end end
function learn!(n::T, firedNeurons, nExInType) where T<:inputNeuron
# skip
end
function learn!(n::T, firedNeurons, nExInType) where T<:computeNeuron
wSign_0 = sign.(n.wRec) # original sign
n.wRec += n.wRecChange # merge wRecChange into wRec
reset_wRecChange!(n)
wSign_1 = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign
nonFlipedSign = isequal.(wSign_0, wSign_1) # 1 not fliped, 0 fliped
# normalize wRec peak to prevent input signal overwhelming neuron
normalizePeak!(n.wRec, n.wRecChange, 2)
# set weight that fliped sign to 0 for random new connection
n.wRec .*= nonFlipedSign
capMaxWeight!(n.wRec) # cap maximum weight
synapticConnStrength!(n)
neuroplasticity!(n, firedNeurons, nExInType)
end
function learn!(n::T, firedNeurons, nExInType, totalInputPort) where T<:outputNeuron
wSign_0 = sign.(n.wRec) # original sign
n.wRec += n.wRecChange
reset_wRecChange!(n)
wSign_1 = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign
nonFlipedSign = isequal.(wSign_0, wSign_1) # 1 not fliped, 0 fliped
# normalize wRec peak to prevent input signal overwhelming neuron
normalizePeak!(n.wRec, n.wRecChange, 2)
# set weight that fliped sign to 0 for random new connection
n.wRec .*= nonFlipedSign
capMaxWeight!(n.wRec) # cap maximum weight
synapticConnStrength!(n)
neuroplasticity!(n,firedNeurons, nExInType, totalInputPort)
end

View File

@@ -1,10 +1,9 @@
module snn_utils module snn_utils
using Flux.Optimise: apply!
export calculate_α, calculate_ρ, calculate_k, timestep_forward!, init_neuron, no_negative!, export calculate_α, calculate_ρ, calculate_k, timestep_forward!, init_neuron, no_negative!,
precision, calculate_w_change!, store_knowledgefn_error!, interneurons_adjustment!, precision, calculate_w_change!, store_knowledgefn_error!, interneurons_adjustment!,
reset_z_t!, resetLearningParams!, reset_learning_history_params!, reset_epsilonRec!, reset_z_t!, resetLearningParams!, reset_learning_history_params!, reset_epsilonRec!,
reset_epsilonRecA!, synapticConnStrength!, normalizePeak!, reset_epsilonRecA!, synapticConnStrength!, normalizePeak!, reset_wRecChange!,
firing_rate_error!, firing_rate_regulator!, update_Bn!, cal_firing_reg!, firing_rate_error!, firing_rate_regulator!, update_Bn!, cal_firing_reg!,
neuroplasticity!, shakeup!, reset_learning_no_wchange!, adjust_internal_learning_rate!, neuroplasticity!, shakeup!, reset_learning_no_wchange!, adjust_internal_learning_rate!,
gradient_withloss, capMaxWeight! gradient_withloss, capMaxWeight!
@@ -253,6 +252,11 @@ function adjust_internal_learning_rate!(n::computeNeuron)
n.internal_learning_rate * 1.005 n.internal_learning_rate * 1.005
end end
function connStrengthAdjust(currentStrength::Float64)
Δstrength = (1.0 - sigmoid(currentStrength))
return Δstrength::Float64
end
""" compute synaptic connection strength. bias will shift currentStrength to fit into """ compute synaptic connection strength. bias will shift currentStrength to fit into
sigmoid operating range which centred at 0 and range is -37 to 37. sigmoid operating range which centred at 0 and range is -37 to 37.
# Example # Example
@@ -260,21 +264,15 @@ end
one may use bias = -5 to transform synaptic strength into range -5 to 5 one may use bias = -5 to transform synaptic strength into range -5 to 5
the return value is shifted back to original scale the return value is shifted back to original scale
""" """
function synapticConnStrength(currentStrength::AbstractFloat, updown::String, bias::Number=0)::Float64 function synapticConnStrength(currentStrength::Float64, updown::String)
currentStrength += bias Δstrength = connStrengthAdjust(currentStrength)
if currentStrength > 0
Δstrength = (1.0 - sigmoid(currentStrength))
else
Δstrength = sigmoid(currentStrength)
end
if updown == "up" if updown == "up"
updatedStrength = currentStrength + Δstrength updatedStrength = currentStrength + Δstrength
else else
updatedStrength = currentStrength - Δstrength updatedStrength = currentStrength - Δstrength
end end
updatedStrength -= bias return updatedStrength::Float64
return updatedStrength
end end
""" Compute all synaptic connection strength of a neuron. Also mark n.wRec to 0 if wRec goes """ Compute all synaptic connection strength of a neuron. Also mark n.wRec to 0 if wRec goes
@@ -318,47 +316,6 @@ function normalizePeak!(v1::Vector, v2::Vector, radius::Integer=2)
normalize!(subvector, 1) normalize!(subvector, 1)
end end
""" rewire of neuron synaptic connection that has 0 weight. With connection's excitatory and
inhabitory ratio constraint.
"""
# function neuroplasticity!(n::Union{computeNeuron, outputNeuron}, firedNeurons::Vector,
# nExcitatory::Vector, nInhabitory::Vector, excitatoryPercent::Integer)
# # if there is 0-weight then replace it with new connection
# zeroWeightConnIndex = findall(iszero.(n.wRec)) # connection that has 0 weight
# desiredEx = Int(floor((excitatoryPercent / 100) * length(n.subscriptionList)))
# desiredIn = length(n.subscriptionList) - desiredEx
# wRecSign = sign.(n.wRec)
# inConn = sum(isequal.(wRecSign, -1))
# # random new synaptic connection
# inConnToAdd = desiredIn - inConn
# if inConnToAdd <= 0
# # skip all new Conn will be excitatory type
# else
# newConnVecSign = ones(length(zeroWeightConnIndex))
# newConnVecSign = view(newConnVecSign, 1:inConnToAdd) * -1
# end
# # new synaptic connection must sample fron neuron that fires
# inPool = nInhabitory ∩ firedNeurons
# filter!(x -> x ∉ [n.id], inPool) # exclude this neuron id from the list
# filter!(x -> x ∉ n.subscriptionList, inPool) # exclude this neuron's subscriptionList from the list
# exPool = nExcitatory ∩ firedNeurons
# filter!(x -> x ∉ [n.id], exPool) # exclude this neuron id from the list
# filter!(x -> x ∉ n.subscriptionList, exPool) # exclude this neuron's subscriptionList from the list
# w = [rand(0.01:0.01:0.2, length(zeroWeightConnIndex))] .* newConnVecSign
# synapticStrength = [rand(-5:0.01:-4, length(zeroWeightConnIndex))]
# # add new synaptic connection to neuron
# for (i, connIndex) in enumerate(zeroWeightConnIndex)
# n.subscriptionList[connIndex] = newConnVecSign[i] < 0 ? pop!(inPool) : pop!(exPool)
# n.wRec[connIndex] = w[i]
# n.synapticStrength[connIndex] = synapticStrength[i]
# end
# end
""" rewire of neuron synaptic connection that has 0 weight. Without connection's excitatory and """ rewire of neuron synaptic connection that has 0 weight. Without connection's excitatory and
inhabitory ratio constraint. inhabitory ratio constraint.
""" """

View File

@@ -9,7 +9,7 @@ export
instantiate_custom_types, init_neuron, populate_neuron, instantiate_custom_types, init_neuron, populate_neuron,
add_neuron! add_neuron!
using Random, Flux, LinearAlgebra using Random, LinearAlgebra
#------------------------------------------------------------------------------------------------100 #------------------------------------------------------------------------------------------------100
@@ -350,7 +350,7 @@ Base.@kwdef mutable struct lifNeuron <: computeNeuron
refractoryDuration::Int64 = 3 # neuron's refratory period in millisecond refractoryDuration::Int64 = 3 # neuron's refratory period in millisecond
refractoryCounter::Int64 = 0 refractoryCounter::Int64 = 0
tau_m::Float64 = 0.0 # τ_m, membrane time constant in millisecond tau_m::Float64 = 0.0 # τ_m, membrane time constant in millisecond
eta::Float64 = 0.0001 # η, learning rate eta::Float64 = 0.01 # η, learning rate
wRecChange::Array{Float64} = Float64[] # Δw_rec, cumulated wRec change wRecChange::Array{Float64} = Float64[] # Δw_rec, cumulated wRec change
recSignal::Float64 = 0.0 # incoming recurrent signal recSignal::Float64 = 0.0 # incoming recurrent signal
alpha_v_t::Float64 = 0.0 # alpha * v_t alpha_v_t::Float64 = 0.0 # alpha * v_t
@@ -438,7 +438,7 @@ Base.@kwdef mutable struct alifNeuron <: computeNeuron
eRec_v::Array{Float64} = Float64[] # a component of neuron's eligibility trace resulted from v_t eRec_v::Array{Float64} = Float64[] # a component of neuron's eligibility trace resulted from v_t
eRec_a::Array{Float64} = Float64[] # a component of neuron's eligibility trace resulted from av_th eRec_a::Array{Float64} = Float64[] # a component of neuron's eligibility trace resulted from av_th
eRec::Array{Float64} = Float64[] # neuron's eligibility trace eRec::Array{Float64} = Float64[] # neuron's eligibility trace
eta::Float64 = 0.0001 # eta, learning rate eta::Float64 = 0.01 # eta, learning rate
gammaPd::Float64 = 0.3 # γ_pd, discount factor, value from paper gammaPd::Float64 = 0.3 # γ_pd, discount factor, value from paper
phi::Float64 = 0.0 # ϕ, psuedo derivative phi::Float64 = 0.0 # ϕ, psuedo derivative
refractoryDuration::Int64 = 3 # neuron's refractory period in millisecond refractoryDuration::Int64 = 3 # neuron's refractory period in millisecond
@@ -448,7 +448,7 @@ Base.@kwdef mutable struct alifNeuron <: computeNeuron
recSignal::Float64 = 0.0 # incoming recurrent signal recSignal::Float64 = 0.0 # incoming recurrent signal
alpha_v_t::Float64 = 0.0 # alpha * v_t alpha_v_t::Float64 = 0.0 # alpha * v_t
error::Float64 = 0.0 # local neuron error error::Float64 = 0.0 # local neuron error
optimiser::Union{Any,Nothing} = load_optimiser("AdaBelief") # Flux optimizer # optimiser::Union{Any,Nothing} = load_optimiser("AdaBelief") # Flux optimizer
firingCounter::Int64 = 0 # store how many times neuron fires firingCounter::Int64 = 0 # store how many times neuron fires
firingRateTarget::Float64 = 20.0 # neuron's target firing rate in Hz firingRateTarget::Float64 = 20.0 # neuron's target firing rate in Hz
@@ -548,7 +548,7 @@ Base.@kwdef mutable struct linearNeuron <: outputNeuron
refractoryDuration::Int64 = 3 # neuron's refratory period in millisecond refractoryDuration::Int64 = 3 # neuron's refratory period in millisecond
refractoryCounter::Int64 = 0 refractoryCounter::Int64 = 0
tau_out::Float64 = 0.0 # τ_out, membrane time constant in millisecond tau_out::Float64 = 0.0 # τ_out, membrane time constant in millisecond
eta::Float64 = 0.0001 # η, learning rate eta::Float64 = 0.01 # η, learning rate
wRecChange::Array{Float64} = Float64[] # Δw_rec, cumulated wRec change wRecChange::Array{Float64} = Float64[] # Δw_rec, cumulated wRec change
recSignal::Float64 = 0.0 # incoming recurrent signal recSignal::Float64 = 0.0 # incoming recurrent signal
alpha_v_t::Float64 = 0.0 # alpha * v_t alpha_v_t::Float64 = 0.0 # alpha * v_t
@@ -588,21 +588,21 @@ end
#------------------------------------------------------------------------------------------------100 #------------------------------------------------------------------------------------------------100
function load_optimiser(optimiser_name::String; params::Union{Dict,Nothing} = nothing) # function load_optimiser(optimiser_name::String; params::Union{Dict,Nothing} = nothing)
if optimiser_name == "AdaBelief" # if optimiser_name == "AdaBelief"
params = (0.01, (0.9, 0.8)) # params = (0.01, (0.9, 0.8))
return Flux.Optimise.AdaBelief(params...) # return Flux.Optimise.AdaBelief(params...)
elseif optimiser_name == "AdaBelief2" # elseif optimiser_name == "AdaBelief2"
# output neuron requires slower change pace so η is lower than compute neuron at 0.007 # # output neuron requires slower change pace so η is lower than compute neuron at 0.007
# because if w_out change too fast, compute neuron will not able to # # because if w_out change too fast, compute neuron will not able to
# grapse output neuron moving direction i.e. both compute neuron's direction and # # grapse output neuron moving direction i.e. both compute neuron's direction and
# output neuron direction are out of sync. # # output neuron direction are out of sync.
params = (0.007, (0.9, 0.8)) # params = (0.007, (0.9, 0.8))
return Flux.Optimise.AdaBelief(params...) # return Flux.Optimise.AdaBelief(params...)
else # else
error("optimiser is not defined yet in load_optimiser()") # error("optimiser is not defined yet in load_optimiser()")
end # end
end # end
function init_neuron!(id::Int64, n::passthroughNeuron, n_params::Dict, kfnParams::Dict) function init_neuron!(id::Int64, n::passthroughNeuron, n_params::Dict, kfnParams::Dict)
n.id = id n.id = id