change from end of sample learning to online learning
This commit is contained in:
@@ -16,9 +16,9 @@ weakdeps = ["ChainRulesCore"]
|
||||
|
||||
[[deps.Accessors]]
|
||||
deps = ["Compat", "CompositionsBase", "ConstructionBase", "Dates", "InverseFunctions", "LinearAlgebra", "MacroTools", "Requires", "Test"]
|
||||
git-tree-sha1 = "a4f8669e46c8cdf68661fe6bb0f7b89f51dd23cf"
|
||||
git-tree-sha1 = "2b301c2388067d655fe5e4ca6d4aa53b61f895b4"
|
||||
uuid = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
|
||||
version = "0.1.30"
|
||||
version = "0.1.31"
|
||||
|
||||
[deps.Accessors.extensions]
|
||||
AccessorsAxisKeysExt = "AxisKeys"
|
||||
@@ -67,10 +67,24 @@ uuid = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
|
||||
version = "0.4.2"
|
||||
|
||||
[[deps.BangBang]]
|
||||
deps = ["Compat", "ConstructionBase", "Future", "InitialValues", "LinearAlgebra", "Requires", "Setfield", "Tables", "ZygoteRules"]
|
||||
git-tree-sha1 = "7fe6d92c4f281cf4ca6f2fba0ce7b299742da7ca"
|
||||
deps = ["Compat", "ConstructionBase", "InitialValues", "LinearAlgebra", "Requires", "Setfield", "Tables"]
|
||||
git-tree-sha1 = "54b00d1b93791f8e19e31584bd30f2cb6004614b"
|
||||
uuid = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
|
||||
version = "0.3.37"
|
||||
version = "0.3.38"
|
||||
|
||||
[deps.BangBang.extensions]
|
||||
BangBangChainRulesCoreExt = "ChainRulesCore"
|
||||
BangBangDataFramesExt = "DataFrames"
|
||||
BangBangStaticArraysExt = "StaticArrays"
|
||||
BangBangStructArraysExt = "StructArrays"
|
||||
BangBangTypedTablesExt = "TypedTables"
|
||||
|
||||
[deps.BangBang.weakdeps]
|
||||
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
|
||||
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
|
||||
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
|
||||
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
|
||||
TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9"
|
||||
|
||||
[[deps.Base64]]
|
||||
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
|
||||
@@ -155,9 +169,13 @@ uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
|
||||
version = "1.0.2+0"
|
||||
|
||||
[[deps.CompositionsBase]]
|
||||
git-tree-sha1 = "455419f7e328a1a2493cabc6428d79e951349769"
|
||||
git-tree-sha1 = "802bb88cd69dfd1509f6670416bd4434015693ad"
|
||||
uuid = "a33af91c-f02d-484b-be07-31d278c5ca2b"
|
||||
version = "0.1.1"
|
||||
version = "0.1.2"
|
||||
weakdeps = ["InverseFunctions"]
|
||||
|
||||
[deps.CompositionsBase.extensions]
|
||||
CompositionsBaseInverseFunctionsExt = "InverseFunctions"
|
||||
|
||||
[[deps.ConstructionBase]]
|
||||
deps = ["LinearAlgebra"]
|
||||
@@ -340,13 +358,13 @@ version = "0.1.4"
|
||||
|
||||
[[deps.GPUCompiler]]
|
||||
deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "Scratch", "TimerOutputs", "UUIDs"]
|
||||
git-tree-sha1 = "e9a9173cd77e16509cdf9c1663fda19b22a518b7"
|
||||
git-tree-sha1 = "5737dc242dadd392d934ee330c69ceff47f0259c"
|
||||
uuid = "61eb1bfa-7361-4325-ad38-22787b887f55"
|
||||
version = "0.19.3"
|
||||
version = "0.19.4"
|
||||
|
||||
[[deps.GeneralUtils]]
|
||||
deps = ["DataStructures", "Distributions", "JSON3"]
|
||||
path = "/home/ton/.julia/dev/GeneralUtils"
|
||||
path = "C:\\Users\\naraw\\.julia\\dev\\GeneralUtils"
|
||||
uuid = "c6c72f09-b708-4ac8-ac7c-2084d70108fe"
|
||||
version = "0.1.0"
|
||||
|
||||
@@ -750,9 +768,9 @@ version = "0.1.15"
|
||||
|
||||
[[deps.StaticArrays]]
|
||||
deps = ["LinearAlgebra", "Random", "StaticArraysCore", "Statistics"]
|
||||
git-tree-sha1 = "c262c8e978048c2b095be1672c9bee55b4619521"
|
||||
git-tree-sha1 = "8982b3607a212b070a5e46eea83eb62b4744ae12"
|
||||
uuid = "90137ffa-7385-5640-81b9-e52037218182"
|
||||
version = "1.5.24"
|
||||
version = "1.5.25"
|
||||
|
||||
[[deps.StaticArraysCore]]
|
||||
git-tree-sha1 = "6b7ba252635a5eff6a0b0664a41ee140a1c9e72a"
|
||||
|
||||
@@ -25,16 +25,16 @@ using .forward
|
||||
include("learn.jl")
|
||||
using .learn
|
||||
|
||||
include("readout.jl")
|
||||
using .readout
|
||||
# include("readout.jl")
|
||||
# using .readout
|
||||
|
||||
include("interface.jl")
|
||||
using .interface
|
||||
# include("interface.jl")
|
||||
# using .interface
|
||||
#------------------------------------------------------------------------------------------------100
|
||||
|
||||
"""
|
||||
Todo:
|
||||
[*1] add maximum weight cap of each connection
|
||||
|
||||
[2] implement connection strength based on right or wrong answer
|
||||
[4] implement dormant connection
|
||||
[3] Δweight * connection strength
|
||||
@@ -63,6 +63,7 @@ using .interface
|
||||
[DONE] neuroplasticity() i.e. change connection
|
||||
[DONE] add multi threads
|
||||
[DONE] during 0 training if 1-9 output neuron fires, adjust weight only those neurons
|
||||
[DONE] add maximum weight cap of each connection
|
||||
|
||||
Change from version: v06_36a
|
||||
-
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
module forward
|
||||
|
||||
using Flux.Optimise: apply!
|
||||
|
||||
using Statistics, Flux, Random, LinearAlgebra, JSON3
|
||||
using Statistics, Random, LinearAlgebra, JSON3
|
||||
using GeneralUtils
|
||||
using ..types, ..snn_utils
|
||||
|
||||
@@ -77,19 +75,17 @@ function (kfn::kfn_1)(m::model, input_data::AbstractVector)
|
||||
|
||||
kfn.firedNeurons_t0 = [n.z_t for n in kfn.neuronsArray] #TODO check if it is used?
|
||||
|
||||
Threads.@threads for n in kfn.neuronsArray
|
||||
# for n in kfn.neuronsArray
|
||||
# Threads.@threads for n in kfn.neuronsArray
|
||||
for n in kfn.neuronsArray
|
||||
n(kfn)
|
||||
end
|
||||
|
||||
kfn.firedNeurons_t1 = [n.z_t1 for n in kfn.neuronsArray]
|
||||
append!(kfn.firedNeurons, findall(kfn.firedNeurons_t1)) # store id of neuron that fires
|
||||
if kfn.learningStage == "end_learning"
|
||||
kfn.firedNeurons |> unique! # use for random new neuron connection
|
||||
end
|
||||
|
||||
Threads.@threads for n in kfn.outputNeuronsArray
|
||||
# for n in kfn.outputNeuronsArray
|
||||
# Threads.@threads for n in kfn.outputNeuronsArray
|
||||
for n in kfn.outputNeuronsArray
|
||||
n(kfn)
|
||||
end
|
||||
|
||||
|
||||
128
src/learn.jl
128
src/learn.jl
@@ -1,8 +1,6 @@
|
||||
module learn
|
||||
|
||||
using Flux.Optimise: apply!
|
||||
|
||||
using Statistics, Flux, Random, LinearAlgebra, JSON3
|
||||
using Statistics, Random, LinearAlgebra, JSON3
|
||||
using GeneralUtils
|
||||
using ..types, ..snn_utils
|
||||
|
||||
@@ -12,7 +10,7 @@ export learn!
|
||||
|
||||
function learn!(m::model, modelRespond::Vector{Bool}, correctAnswer::Union{AbstractVector, Nothing})
|
||||
if correctAnswer === nothing
|
||||
correctAnswer_I = BitArray(undef, length(modelRespond))
|
||||
correctAnswer_I = BitArray(zeros(length(modelRespond)))
|
||||
else
|
||||
correctAnswer_I = Bool.(correctAnswer) # correct answer for kfn I
|
||||
end
|
||||
@@ -43,76 +41,43 @@ function learn!(kfn::kfn_1, correctAnswer::BitVector)
|
||||
outs = [n.z_t1 for n in kfn.outputNeuronsArray]
|
||||
for (i, out) in enumerate(outs)
|
||||
if out != correctAnswer[i] # need to adjust weight
|
||||
kfnError = ( (kfn.outputNeuronsArray[i].v_th - kfn.outputNeuronsArray[i].vError) /
|
||||
kfnError = ( (kfn.outputNeuronsArray[i].v_th - kfn.outputNeuronsArray[i].vError) * 100.0 /
|
||||
kfn.outputNeuronsArray[i].v_th )
|
||||
if correctAnswer[i] == 1 # output neuron that associated with correctAnswer
|
||||
Threads.@threads for n in kfn.neuronsArray
|
||||
Threads.@threads for n in kfn.neuronsArray # multithread is not atomic and causing error
|
||||
# for n in kfn.neuronsArray
|
||||
learn!(n, kfnError)
|
||||
compute_wRecChange!(n, kfnError)
|
||||
learn!(n, kfn.firedNeurons, kfn.nExInType)
|
||||
end
|
||||
|
||||
learn!(kfn.outputNeuronsArray[i], kfnError)
|
||||
compute_wRecChange!(kfn.outputNeuronsArray[i], kfnError)
|
||||
learn!(kfn.outputNeuronsArray[i], kfn.firedNeurons, kfn.nExInType,
|
||||
kfn.kfnParams[:totalInputPort])
|
||||
else # output neuron that is NOT associated with correctAnswer
|
||||
learn!(kfn.outputNeuronsArray[i], kfnError)
|
||||
compute_wRecChange!(kfn.outputNeuronsArray[i], kfnError)
|
||||
learn!(kfn.outputNeuronsArray[i], kfn.firedNeurons, kfn.nExInType,
|
||||
kfn.kfnParams[:totalInputPort])
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
# wrap up learning session
|
||||
if kfn.learningStage == "end_learning"
|
||||
Threads.@threads for n in kfn.neuronsArray
|
||||
# for n in kfn.neuronsArray
|
||||
if typeof(n) <: computeNeuron
|
||||
wSign_0 = sign.(n.wRec) # original sign
|
||||
n.wRec += n.wRecChange # merge wRecChange into wRec
|
||||
wSign_1 = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign
|
||||
nonFlipedSign = isequal.(wSign_0, wSign_1) # 1 not fliped, 0 fliped
|
||||
# normalize wRec peak to prevent input signal overwhelming neuron
|
||||
normalizePeak!(n.wRec, n.wRecChange, 2)
|
||||
# set weight that fliped sign to 0 for random new connection
|
||||
n.wRec .*= nonFlipedSign
|
||||
capMaxWeight!(n.wRec) # cap maximum weight
|
||||
|
||||
synapticConnStrength!(n)
|
||||
neuroplasticity!(n, kfn.firedNeurons, kfn.nExInType)
|
||||
end
|
||||
end
|
||||
|
||||
for n in kfn.outputNeuronsArray # merge wRecChange into wRec
|
||||
wSign_0 = sign.(n.wRec) # original sign
|
||||
n.wRec += n.wRecChange
|
||||
wSign_1 = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign
|
||||
nonFlipedSign = isequal.(wSign_0, wSign_1) # 1 not fliped, 0 fliped
|
||||
normalizePeak!(n.wRec, n.wRecChange, 2)
|
||||
n.wRec .*= nonFlipedSign # set weight that fliped sign to 0 for random new connection
|
||||
capMaxWeight!(n.wRec) # cap maximum weight
|
||||
|
||||
synapticConnStrength!(n)
|
||||
neuroplasticity!(n, kfn.firedNeurons, kfn.nExInType, kfn.kfnParams[:totalInputPort])
|
||||
end
|
||||
|
||||
kfn.learningStage = "inference"
|
||||
end
|
||||
end
|
||||
|
||||
""" passthroughNeuron learn()
|
||||
"""
|
||||
function learn!(n::passthroughNeuron, error::Float64)
|
||||
function compute_wRecChange!(n::passthroughNeuron, error::Float64)
|
||||
# skip
|
||||
end
|
||||
|
||||
""" lif learn()
|
||||
"""
|
||||
function learn!(n::lifNeuron, error::Float64)
|
||||
function compute_wRecChange!(n::lifNeuron, error::Float64)
|
||||
n.eRec = n.phi * n.epsilonRec
|
||||
ΔwRecChange = n.eta * error * n.eRec
|
||||
n.wRecChange .+= ΔwRecChange
|
||||
reset_epsilonRec!(n)
|
||||
end
|
||||
|
||||
""" alifNeuron learn()
|
||||
"""
|
||||
function learn!(n::alifNeuron, error::Float64)
|
||||
function compute_wRecChange!(n::alifNeuron, error::Float64)
|
||||
n.eRec_v = n.phi * n.epsilonRec
|
||||
n.eRec_a = -n.phi * n.beta * n.epsilonRecA
|
||||
n.eRec = n.eRec_v + n.eRec_a
|
||||
@@ -122,14 +87,71 @@ function learn!(n::alifNeuron, error::Float64)
|
||||
reset_epsilonRecA!(n)
|
||||
end
|
||||
|
||||
""" linearNeuron learn()
|
||||
"""
|
||||
function learn!(n::linearNeuron, error::Float64)
|
||||
function compute_wRecChange!(n::linearNeuron, error::Float64)
|
||||
n.eRec = n.phi * n.epsilonRec
|
||||
ΔwRecChange = n.eta * error * n.eRec
|
||||
n.wRecChange .+= ΔwRecChange
|
||||
reset_epsilonRec!(n)
|
||||
end
|
||||
|
||||
function learn!(n::T, firedNeurons, nExInType) where T<:inputNeuron
|
||||
# skip
|
||||
end
|
||||
|
||||
function learn!(n::T, firedNeurons, nExInType) where T<:computeNeuron
|
||||
wSign_0 = sign.(n.wRec) # original sign
|
||||
n.wRec += n.wRecChange # merge wRecChange into wRec
|
||||
reset_wRecChange!(n)
|
||||
wSign_1 = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign
|
||||
nonFlipedSign = isequal.(wSign_0, wSign_1) # 1 not fliped, 0 fliped
|
||||
# normalize wRec peak to prevent input signal overwhelming neuron
|
||||
normalizePeak!(n.wRec, n.wRecChange, 2)
|
||||
# set weight that fliped sign to 0 for random new connection
|
||||
n.wRec .*= nonFlipedSign
|
||||
capMaxWeight!(n.wRec) # cap maximum weight
|
||||
|
||||
synapticConnStrength!(n)
|
||||
neuroplasticity!(n, firedNeurons, nExInType)
|
||||
end
|
||||
|
||||
function learn!(n::T, firedNeurons, nExInType, totalInputPort) where T<:outputNeuron
|
||||
wSign_0 = sign.(n.wRec) # original sign
|
||||
n.wRec += n.wRecChange
|
||||
reset_wRecChange!(n)
|
||||
wSign_1 = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign
|
||||
nonFlipedSign = isequal.(wSign_0, wSign_1) # 1 not fliped, 0 fliped
|
||||
# normalize wRec peak to prevent input signal overwhelming neuron
|
||||
normalizePeak!(n.wRec, n.wRecChange, 2)
|
||||
# set weight that fliped sign to 0 for random new connection
|
||||
n.wRec .*= nonFlipedSign
|
||||
capMaxWeight!(n.wRec) # cap maximum weight
|
||||
|
||||
synapticConnStrength!(n)
|
||||
neuroplasticity!(n,firedNeurons, nExInType, totalInputPort)
|
||||
end
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
module snn_utils
|
||||
|
||||
using Flux.Optimise: apply!
|
||||
export calculate_α, calculate_ρ, calculate_k, timestep_forward!, init_neuron, no_negative!,
|
||||
precision, calculate_w_change!, store_knowledgefn_error!, interneurons_adjustment!,
|
||||
reset_z_t!, resetLearningParams!, reset_learning_history_params!, reset_epsilonRec!,
|
||||
reset_epsilonRecA!, synapticConnStrength!, normalizePeak!,
|
||||
reset_epsilonRecA!, synapticConnStrength!, normalizePeak!, reset_wRecChange!,
|
||||
firing_rate_error!, firing_rate_regulator!, update_Bn!, cal_firing_reg!,
|
||||
neuroplasticity!, shakeup!, reset_learning_no_wchange!, adjust_internal_learning_rate!,
|
||||
gradient_withloss, capMaxWeight!
|
||||
@@ -253,6 +252,11 @@ function adjust_internal_learning_rate!(n::computeNeuron)
|
||||
n.internal_learning_rate * 1.005
|
||||
end
|
||||
|
||||
function connStrengthAdjust(currentStrength::Float64)
|
||||
Δstrength = (1.0 - sigmoid(currentStrength))
|
||||
return Δstrength::Float64
|
||||
end
|
||||
|
||||
""" compute synaptic connection strength. bias will shift currentStrength to fit into
|
||||
sigmoid operating range which centred at 0 and range is -37 to 37.
|
||||
# Example
|
||||
@@ -260,21 +264,15 @@ end
|
||||
one may use bias = -5 to transform synaptic strength into range -5 to 5
|
||||
the return value is shifted back to original scale
|
||||
"""
|
||||
function synapticConnStrength(currentStrength::AbstractFloat, updown::String, bias::Number=0)::Float64
|
||||
currentStrength += bias
|
||||
if currentStrength > 0
|
||||
Δstrength = (1.0 - sigmoid(currentStrength))
|
||||
else
|
||||
Δstrength = sigmoid(currentStrength)
|
||||
end
|
||||
function synapticConnStrength(currentStrength::Float64, updown::String)
|
||||
Δstrength = connStrengthAdjust(currentStrength)
|
||||
|
||||
if updown == "up"
|
||||
updatedStrength = currentStrength + Δstrength
|
||||
else
|
||||
updatedStrength = currentStrength - Δstrength
|
||||
end
|
||||
updatedStrength -= bias
|
||||
return updatedStrength
|
||||
return updatedStrength::Float64
|
||||
end
|
||||
|
||||
""" Compute all synaptic connection strength of a neuron. Also mark n.wRec to 0 if wRec goes
|
||||
@@ -318,47 +316,6 @@ function normalizePeak!(v1::Vector, v2::Vector, radius::Integer=2)
|
||||
normalize!(subvector, 1)
|
||||
end
|
||||
|
||||
""" rewire of neuron synaptic connection that has 0 weight. With connection's excitatory and
|
||||
inhabitory ratio constraint.
|
||||
"""
|
||||
# function neuroplasticity!(n::Union{computeNeuron, outputNeuron}, firedNeurons::Vector,
|
||||
# nExcitatory::Vector, nInhabitory::Vector, excitatoryPercent::Integer)
|
||||
# # if there is 0-weight then replace it with new connection
|
||||
# zeroWeightConnIndex = findall(iszero.(n.wRec)) # connection that has 0 weight
|
||||
# desiredEx = Int(floor((excitatoryPercent / 100) * length(n.subscriptionList)))
|
||||
# desiredIn = length(n.subscriptionList) - desiredEx
|
||||
# wRecSign = sign.(n.wRec)
|
||||
# inConn = sum(isequal.(wRecSign, -1))
|
||||
|
||||
# # random new synaptic connection
|
||||
# inConnToAdd = desiredIn - inConn
|
||||
# if inConnToAdd <= 0
|
||||
# # skip all new Conn will be excitatory type
|
||||
# else
|
||||
# newConnVecSign = ones(length(zeroWeightConnIndex))
|
||||
# newConnVecSign = view(newConnVecSign, 1:inConnToAdd) * -1
|
||||
# end
|
||||
|
||||
# # new synaptic connection must sample fron neuron that fires
|
||||
# inPool = nInhabitory ∩ firedNeurons
|
||||
# filter!(x -> x ∉ [n.id], inPool) # exclude this neuron id from the list
|
||||
# filter!(x -> x ∉ n.subscriptionList, inPool) # exclude this neuron's subscriptionList from the list
|
||||
# exPool = nExcitatory ∩ firedNeurons
|
||||
# filter!(x -> x ∉ [n.id], exPool) # exclude this neuron id from the list
|
||||
# filter!(x -> x ∉ n.subscriptionList, exPool) # exclude this neuron's subscriptionList from the list
|
||||
|
||||
# w = [rand(0.01:0.01:0.2, length(zeroWeightConnIndex))] .* newConnVecSign
|
||||
# synapticStrength = [rand(-5:0.01:-4, length(zeroWeightConnIndex))]
|
||||
|
||||
# # add new synaptic connection to neuron
|
||||
# for (i, connIndex) in enumerate(zeroWeightConnIndex)
|
||||
# n.subscriptionList[connIndex] = newConnVecSign[i] < 0 ? pop!(inPool) : pop!(exPool)
|
||||
# n.wRec[connIndex] = w[i]
|
||||
# n.synapticStrength[connIndex] = synapticStrength[i]
|
||||
# end
|
||||
# end
|
||||
|
||||
|
||||
""" rewire of neuron synaptic connection that has 0 weight. Without connection's excitatory and
|
||||
inhabitory ratio constraint.
|
||||
"""
|
||||
|
||||
40
src/types.jl
40
src/types.jl
@@ -9,7 +9,7 @@ export
|
||||
instantiate_custom_types, init_neuron, populate_neuron,
|
||||
add_neuron!
|
||||
|
||||
using Random, Flux, LinearAlgebra
|
||||
using Random, LinearAlgebra
|
||||
|
||||
#------------------------------------------------------------------------------------------------100
|
||||
|
||||
@@ -350,7 +350,7 @@ Base.@kwdef mutable struct lifNeuron <: computeNeuron
|
||||
refractoryDuration::Int64 = 3 # neuron's refratory period in millisecond
|
||||
refractoryCounter::Int64 = 0
|
||||
tau_m::Float64 = 0.0 # τ_m, membrane time constant in millisecond
|
||||
eta::Float64 = 0.0001 # η, learning rate
|
||||
eta::Float64 = 0.01 # η, learning rate
|
||||
wRecChange::Array{Float64} = Float64[] # Δw_rec, cumulated wRec change
|
||||
recSignal::Float64 = 0.0 # incoming recurrent signal
|
||||
alpha_v_t::Float64 = 0.0 # alpha * v_t
|
||||
@@ -438,7 +438,7 @@ Base.@kwdef mutable struct alifNeuron <: computeNeuron
|
||||
eRec_v::Array{Float64} = Float64[] # a component of neuron's eligibility trace resulted from v_t
|
||||
eRec_a::Array{Float64} = Float64[] # a component of neuron's eligibility trace resulted from av_th
|
||||
eRec::Array{Float64} = Float64[] # neuron's eligibility trace
|
||||
eta::Float64 = 0.0001 # eta, learning rate
|
||||
eta::Float64 = 0.01 # eta, learning rate
|
||||
gammaPd::Float64 = 0.3 # γ_pd, discount factor, value from paper
|
||||
phi::Float64 = 0.0 # ϕ, psuedo derivative
|
||||
refractoryDuration::Int64 = 3 # neuron's refractory period in millisecond
|
||||
@@ -448,7 +448,7 @@ Base.@kwdef mutable struct alifNeuron <: computeNeuron
|
||||
recSignal::Float64 = 0.0 # incoming recurrent signal
|
||||
alpha_v_t::Float64 = 0.0 # alpha * v_t
|
||||
error::Float64 = 0.0 # local neuron error
|
||||
optimiser::Union{Any,Nothing} = load_optimiser("AdaBelief") # Flux optimizer
|
||||
# optimiser::Union{Any,Nothing} = load_optimiser("AdaBelief") # Flux optimizer
|
||||
|
||||
firingCounter::Int64 = 0 # store how many times neuron fires
|
||||
firingRateTarget::Float64 = 20.0 # neuron's target firing rate in Hz
|
||||
@@ -548,7 +548,7 @@ Base.@kwdef mutable struct linearNeuron <: outputNeuron
|
||||
refractoryDuration::Int64 = 3 # neuron's refratory period in millisecond
|
||||
refractoryCounter::Int64 = 0
|
||||
tau_out::Float64 = 0.0 # τ_out, membrane time constant in millisecond
|
||||
eta::Float64 = 0.0001 # η, learning rate
|
||||
eta::Float64 = 0.01 # η, learning rate
|
||||
wRecChange::Array{Float64} = Float64[] # Δw_rec, cumulated wRec change
|
||||
recSignal::Float64 = 0.0 # incoming recurrent signal
|
||||
alpha_v_t::Float64 = 0.0 # alpha * v_t
|
||||
@@ -588,21 +588,21 @@ end
|
||||
|
||||
#------------------------------------------------------------------------------------------------100
|
||||
|
||||
function load_optimiser(optimiser_name::String; params::Union{Dict,Nothing} = nothing)
|
||||
if optimiser_name == "AdaBelief"
|
||||
params = (0.01, (0.9, 0.8))
|
||||
return Flux.Optimise.AdaBelief(params...)
|
||||
elseif optimiser_name == "AdaBelief2"
|
||||
# output neuron requires slower change pace so η is lower than compute neuron at 0.007
|
||||
# because if w_out change too fast, compute neuron will not able to
|
||||
# grapse output neuron moving direction i.e. both compute neuron's direction and
|
||||
# output neuron direction are out of sync.
|
||||
params = (0.007, (0.9, 0.8))
|
||||
return Flux.Optimise.AdaBelief(params...)
|
||||
else
|
||||
error("optimiser is not defined yet in load_optimiser()")
|
||||
end
|
||||
end
|
||||
# function load_optimiser(optimiser_name::String; params::Union{Dict,Nothing} = nothing)
|
||||
# if optimiser_name == "AdaBelief"
|
||||
# params = (0.01, (0.9, 0.8))
|
||||
# return Flux.Optimise.AdaBelief(params...)
|
||||
# elseif optimiser_name == "AdaBelief2"
|
||||
# # output neuron requires slower change pace so η is lower than compute neuron at 0.007
|
||||
# # because if w_out change too fast, compute neuron will not able to
|
||||
# # grapse output neuron moving direction i.e. both compute neuron's direction and
|
||||
# # output neuron direction are out of sync.
|
||||
# params = (0.007, (0.9, 0.8))
|
||||
# return Flux.Optimise.AdaBelief(params...)
|
||||
# else
|
||||
# error("optimiser is not defined yet in load_optimiser()")
|
||||
# end
|
||||
# end
|
||||
|
||||
function init_neuron!(id::Int64, n::passthroughNeuron, n_params::Dict, kfnParams::Dict)
|
||||
n.id = id
|
||||
|
||||
Reference in New Issue
Block a user