change from end of sample learning to online learning
This commit is contained in:
@@ -16,9 +16,9 @@ weakdeps = ["ChainRulesCore"]
|
|||||||
|
|
||||||
[[deps.Accessors]]
|
[[deps.Accessors]]
|
||||||
deps = ["Compat", "CompositionsBase", "ConstructionBase", "Dates", "InverseFunctions", "LinearAlgebra", "MacroTools", "Requires", "Test"]
|
deps = ["Compat", "CompositionsBase", "ConstructionBase", "Dates", "InverseFunctions", "LinearAlgebra", "MacroTools", "Requires", "Test"]
|
||||||
git-tree-sha1 = "a4f8669e46c8cdf68661fe6bb0f7b89f51dd23cf"
|
git-tree-sha1 = "2b301c2388067d655fe5e4ca6d4aa53b61f895b4"
|
||||||
uuid = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
|
uuid = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
|
||||||
version = "0.1.30"
|
version = "0.1.31"
|
||||||
|
|
||||||
[deps.Accessors.extensions]
|
[deps.Accessors.extensions]
|
||||||
AccessorsAxisKeysExt = "AxisKeys"
|
AccessorsAxisKeysExt = "AxisKeys"
|
||||||
@@ -67,10 +67,24 @@ uuid = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
|
|||||||
version = "0.4.2"
|
version = "0.4.2"
|
||||||
|
|
||||||
[[deps.BangBang]]
|
[[deps.BangBang]]
|
||||||
deps = ["Compat", "ConstructionBase", "Future", "InitialValues", "LinearAlgebra", "Requires", "Setfield", "Tables", "ZygoteRules"]
|
deps = ["Compat", "ConstructionBase", "InitialValues", "LinearAlgebra", "Requires", "Setfield", "Tables"]
|
||||||
git-tree-sha1 = "7fe6d92c4f281cf4ca6f2fba0ce7b299742da7ca"
|
git-tree-sha1 = "54b00d1b93791f8e19e31584bd30f2cb6004614b"
|
||||||
uuid = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
|
uuid = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
|
||||||
version = "0.3.37"
|
version = "0.3.38"
|
||||||
|
|
||||||
|
[deps.BangBang.extensions]
|
||||||
|
BangBangChainRulesCoreExt = "ChainRulesCore"
|
||||||
|
BangBangDataFramesExt = "DataFrames"
|
||||||
|
BangBangStaticArraysExt = "StaticArrays"
|
||||||
|
BangBangStructArraysExt = "StructArrays"
|
||||||
|
BangBangTypedTablesExt = "TypedTables"
|
||||||
|
|
||||||
|
[deps.BangBang.weakdeps]
|
||||||
|
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
|
||||||
|
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
|
||||||
|
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
|
||||||
|
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
|
||||||
|
TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9"
|
||||||
|
|
||||||
[[deps.Base64]]
|
[[deps.Base64]]
|
||||||
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
|
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
|
||||||
@@ -155,9 +169,13 @@ uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
|
|||||||
version = "1.0.2+0"
|
version = "1.0.2+0"
|
||||||
|
|
||||||
[[deps.CompositionsBase]]
|
[[deps.CompositionsBase]]
|
||||||
git-tree-sha1 = "455419f7e328a1a2493cabc6428d79e951349769"
|
git-tree-sha1 = "802bb88cd69dfd1509f6670416bd4434015693ad"
|
||||||
uuid = "a33af91c-f02d-484b-be07-31d278c5ca2b"
|
uuid = "a33af91c-f02d-484b-be07-31d278c5ca2b"
|
||||||
version = "0.1.1"
|
version = "0.1.2"
|
||||||
|
weakdeps = ["InverseFunctions"]
|
||||||
|
|
||||||
|
[deps.CompositionsBase.extensions]
|
||||||
|
CompositionsBaseInverseFunctionsExt = "InverseFunctions"
|
||||||
|
|
||||||
[[deps.ConstructionBase]]
|
[[deps.ConstructionBase]]
|
||||||
deps = ["LinearAlgebra"]
|
deps = ["LinearAlgebra"]
|
||||||
@@ -340,13 +358,13 @@ version = "0.1.4"
|
|||||||
|
|
||||||
[[deps.GPUCompiler]]
|
[[deps.GPUCompiler]]
|
||||||
deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "Scratch", "TimerOutputs", "UUIDs"]
|
deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "Scratch", "TimerOutputs", "UUIDs"]
|
||||||
git-tree-sha1 = "e9a9173cd77e16509cdf9c1663fda19b22a518b7"
|
git-tree-sha1 = "5737dc242dadd392d934ee330c69ceff47f0259c"
|
||||||
uuid = "61eb1bfa-7361-4325-ad38-22787b887f55"
|
uuid = "61eb1bfa-7361-4325-ad38-22787b887f55"
|
||||||
version = "0.19.3"
|
version = "0.19.4"
|
||||||
|
|
||||||
[[deps.GeneralUtils]]
|
[[deps.GeneralUtils]]
|
||||||
deps = ["DataStructures", "Distributions", "JSON3"]
|
deps = ["DataStructures", "Distributions", "JSON3"]
|
||||||
path = "/home/ton/.julia/dev/GeneralUtils"
|
path = "C:\\Users\\naraw\\.julia\\dev\\GeneralUtils"
|
||||||
uuid = "c6c72f09-b708-4ac8-ac7c-2084d70108fe"
|
uuid = "c6c72f09-b708-4ac8-ac7c-2084d70108fe"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
|
|
||||||
@@ -750,9 +768,9 @@ version = "0.1.15"
|
|||||||
|
|
||||||
[[deps.StaticArrays]]
|
[[deps.StaticArrays]]
|
||||||
deps = ["LinearAlgebra", "Random", "StaticArraysCore", "Statistics"]
|
deps = ["LinearAlgebra", "Random", "StaticArraysCore", "Statistics"]
|
||||||
git-tree-sha1 = "c262c8e978048c2b095be1672c9bee55b4619521"
|
git-tree-sha1 = "8982b3607a212b070a5e46eea83eb62b4744ae12"
|
||||||
uuid = "90137ffa-7385-5640-81b9-e52037218182"
|
uuid = "90137ffa-7385-5640-81b9-e52037218182"
|
||||||
version = "1.5.24"
|
version = "1.5.25"
|
||||||
|
|
||||||
[[deps.StaticArraysCore]]
|
[[deps.StaticArraysCore]]
|
||||||
git-tree-sha1 = "6b7ba252635a5eff6a0b0664a41ee140a1c9e72a"
|
git-tree-sha1 = "6b7ba252635a5eff6a0b0664a41ee140a1c9e72a"
|
||||||
|
|||||||
@@ -25,16 +25,16 @@ using .forward
|
|||||||
include("learn.jl")
|
include("learn.jl")
|
||||||
using .learn
|
using .learn
|
||||||
|
|
||||||
include("readout.jl")
|
# include("readout.jl")
|
||||||
using .readout
|
# using .readout
|
||||||
|
|
||||||
include("interface.jl")
|
# include("interface.jl")
|
||||||
using .interface
|
# using .interface
|
||||||
#------------------------------------------------------------------------------------------------100
|
#------------------------------------------------------------------------------------------------100
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Todo:
|
Todo:
|
||||||
[*1] add maximum weight cap of each connection
|
|
||||||
[2] implement connection strength based on right or wrong answer
|
[2] implement connection strength based on right or wrong answer
|
||||||
[4] implement dormant connection
|
[4] implement dormant connection
|
||||||
[3] Δweight * connection strength
|
[3] Δweight * connection strength
|
||||||
@@ -63,6 +63,7 @@ using .interface
|
|||||||
[DONE] neuroplasticity() i.e. change connection
|
[DONE] neuroplasticity() i.e. change connection
|
||||||
[DONE] add multi threads
|
[DONE] add multi threads
|
||||||
[DONE] during 0 training if 1-9 output neuron fires, adjust weight only those neurons
|
[DONE] during 0 training if 1-9 output neuron fires, adjust weight only those neurons
|
||||||
|
[DONE] add maximum weight cap of each connection
|
||||||
|
|
||||||
Change from version: v06_36a
|
Change from version: v06_36a
|
||||||
-
|
-
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
module forward
|
module forward
|
||||||
|
|
||||||
using Flux.Optimise: apply!
|
using Statistics, Random, LinearAlgebra, JSON3
|
||||||
|
|
||||||
using Statistics, Flux, Random, LinearAlgebra, JSON3
|
|
||||||
using GeneralUtils
|
using GeneralUtils
|
||||||
using ..types, ..snn_utils
|
using ..types, ..snn_utils
|
||||||
|
|
||||||
@@ -77,19 +75,17 @@ function (kfn::kfn_1)(m::model, input_data::AbstractVector)
|
|||||||
|
|
||||||
kfn.firedNeurons_t0 = [n.z_t for n in kfn.neuronsArray] #TODO check if it is used?
|
kfn.firedNeurons_t0 = [n.z_t for n in kfn.neuronsArray] #TODO check if it is used?
|
||||||
|
|
||||||
Threads.@threads for n in kfn.neuronsArray
|
# Threads.@threads for n in kfn.neuronsArray
|
||||||
# for n in kfn.neuronsArray
|
for n in kfn.neuronsArray
|
||||||
n(kfn)
|
n(kfn)
|
||||||
end
|
end
|
||||||
|
|
||||||
kfn.firedNeurons_t1 = [n.z_t1 for n in kfn.neuronsArray]
|
kfn.firedNeurons_t1 = [n.z_t1 for n in kfn.neuronsArray]
|
||||||
append!(kfn.firedNeurons, findall(kfn.firedNeurons_t1)) # store id of neuron that fires
|
append!(kfn.firedNeurons, findall(kfn.firedNeurons_t1)) # store id of neuron that fires
|
||||||
if kfn.learningStage == "end_learning"
|
|
||||||
kfn.firedNeurons |> unique! # use for random new neuron connection
|
kfn.firedNeurons |> unique! # use for random new neuron connection
|
||||||
end
|
|
||||||
|
|
||||||
Threads.@threads for n in kfn.outputNeuronsArray
|
# Threads.@threads for n in kfn.outputNeuronsArray
|
||||||
# for n in kfn.outputNeuronsArray
|
for n in kfn.outputNeuronsArray
|
||||||
n(kfn)
|
n(kfn)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|||||||
128
src/learn.jl
128
src/learn.jl
@@ -1,8 +1,6 @@
|
|||||||
module learn
|
module learn
|
||||||
|
|
||||||
using Flux.Optimise: apply!
|
using Statistics, Random, LinearAlgebra, JSON3
|
||||||
|
|
||||||
using Statistics, Flux, Random, LinearAlgebra, JSON3
|
|
||||||
using GeneralUtils
|
using GeneralUtils
|
||||||
using ..types, ..snn_utils
|
using ..types, ..snn_utils
|
||||||
|
|
||||||
@@ -12,7 +10,7 @@ export learn!
|
|||||||
|
|
||||||
function learn!(m::model, modelRespond::Vector{Bool}, correctAnswer::Union{AbstractVector, Nothing})
|
function learn!(m::model, modelRespond::Vector{Bool}, correctAnswer::Union{AbstractVector, Nothing})
|
||||||
if correctAnswer === nothing
|
if correctAnswer === nothing
|
||||||
correctAnswer_I = BitArray(undef, length(modelRespond))
|
correctAnswer_I = BitArray(zeros(length(modelRespond)))
|
||||||
else
|
else
|
||||||
correctAnswer_I = Bool.(correctAnswer) # correct answer for kfn I
|
correctAnswer_I = Bool.(correctAnswer) # correct answer for kfn I
|
||||||
end
|
end
|
||||||
@@ -43,76 +41,43 @@ function learn!(kfn::kfn_1, correctAnswer::BitVector)
|
|||||||
outs = [n.z_t1 for n in kfn.outputNeuronsArray]
|
outs = [n.z_t1 for n in kfn.outputNeuronsArray]
|
||||||
for (i, out) in enumerate(outs)
|
for (i, out) in enumerate(outs)
|
||||||
if out != correctAnswer[i] # need to adjust weight
|
if out != correctAnswer[i] # need to adjust weight
|
||||||
kfnError = ( (kfn.outputNeuronsArray[i].v_th - kfn.outputNeuronsArray[i].vError) /
|
kfnError = ( (kfn.outputNeuronsArray[i].v_th - kfn.outputNeuronsArray[i].vError) * 100.0 /
|
||||||
kfn.outputNeuronsArray[i].v_th )
|
kfn.outputNeuronsArray[i].v_th )
|
||||||
if correctAnswer[i] == 1 # output neuron that associated with correctAnswer
|
if correctAnswer[i] == 1 # output neuron that associated with correctAnswer
|
||||||
Threads.@threads for n in kfn.neuronsArray
|
Threads.@threads for n in kfn.neuronsArray # multithread is not atomic and causing error
|
||||||
# for n in kfn.neuronsArray
|
# for n in kfn.neuronsArray
|
||||||
learn!(n, kfnError)
|
compute_wRecChange!(n, kfnError)
|
||||||
|
learn!(n, kfn.firedNeurons, kfn.nExInType)
|
||||||
end
|
end
|
||||||
|
compute_wRecChange!(kfn.outputNeuronsArray[i], kfnError)
|
||||||
learn!(kfn.outputNeuronsArray[i], kfnError)
|
learn!(kfn.outputNeuronsArray[i], kfn.firedNeurons, kfn.nExInType,
|
||||||
|
kfn.kfnParams[:totalInputPort])
|
||||||
else # output neuron that is NOT associated with correctAnswer
|
else # output neuron that is NOT associated with correctAnswer
|
||||||
learn!(kfn.outputNeuronsArray[i], kfnError)
|
compute_wRecChange!(kfn.outputNeuronsArray[i], kfnError)
|
||||||
|
learn!(kfn.outputNeuronsArray[i], kfn.firedNeurons, kfn.nExInType,
|
||||||
|
kfn.kfnParams[:totalInputPort])
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
# wrap up learning session
|
# wrap up learning session
|
||||||
if kfn.learningStage == "end_learning"
|
if kfn.learningStage == "end_learning"
|
||||||
Threads.@threads for n in kfn.neuronsArray
|
|
||||||
# for n in kfn.neuronsArray
|
|
||||||
if typeof(n) <: computeNeuron
|
|
||||||
wSign_0 = sign.(n.wRec) # original sign
|
|
||||||
n.wRec += n.wRecChange # merge wRecChange into wRec
|
|
||||||
wSign_1 = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign
|
|
||||||
nonFlipedSign = isequal.(wSign_0, wSign_1) # 1 not fliped, 0 fliped
|
|
||||||
# normalize wRec peak to prevent input signal overwhelming neuron
|
|
||||||
normalizePeak!(n.wRec, n.wRecChange, 2)
|
|
||||||
# set weight that fliped sign to 0 for random new connection
|
|
||||||
n.wRec .*= nonFlipedSign
|
|
||||||
capMaxWeight!(n.wRec) # cap maximum weight
|
|
||||||
|
|
||||||
synapticConnStrength!(n)
|
|
||||||
neuroplasticity!(n, kfn.firedNeurons, kfn.nExInType)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
for n in kfn.outputNeuronsArray # merge wRecChange into wRec
|
|
||||||
wSign_0 = sign.(n.wRec) # original sign
|
|
||||||
n.wRec += n.wRecChange
|
|
||||||
wSign_1 = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign
|
|
||||||
nonFlipedSign = isequal.(wSign_0, wSign_1) # 1 not fliped, 0 fliped
|
|
||||||
normalizePeak!(n.wRec, n.wRecChange, 2)
|
|
||||||
n.wRec .*= nonFlipedSign # set weight that fliped sign to 0 for random new connection
|
|
||||||
capMaxWeight!(n.wRec) # cap maximum weight
|
|
||||||
|
|
||||||
synapticConnStrength!(n)
|
|
||||||
neuroplasticity!(n, kfn.firedNeurons, kfn.nExInType, kfn.kfnParams[:totalInputPort])
|
|
||||||
end
|
|
||||||
|
|
||||||
kfn.learningStage = "inference"
|
kfn.learningStage = "inference"
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
""" passthroughNeuron learn()
|
function compute_wRecChange!(n::passthroughNeuron, error::Float64)
|
||||||
"""
|
|
||||||
function learn!(n::passthroughNeuron, error::Float64)
|
|
||||||
# skip
|
# skip
|
||||||
end
|
end
|
||||||
|
|
||||||
""" lif learn()
|
function compute_wRecChange!(n::lifNeuron, error::Float64)
|
||||||
"""
|
|
||||||
function learn!(n::lifNeuron, error::Float64)
|
|
||||||
n.eRec = n.phi * n.epsilonRec
|
n.eRec = n.phi * n.epsilonRec
|
||||||
ΔwRecChange = n.eta * error * n.eRec
|
ΔwRecChange = n.eta * error * n.eRec
|
||||||
n.wRecChange .+= ΔwRecChange
|
n.wRecChange .+= ΔwRecChange
|
||||||
reset_epsilonRec!(n)
|
reset_epsilonRec!(n)
|
||||||
end
|
end
|
||||||
|
|
||||||
""" alifNeuron learn()
|
function compute_wRecChange!(n::alifNeuron, error::Float64)
|
||||||
"""
|
|
||||||
function learn!(n::alifNeuron, error::Float64)
|
|
||||||
n.eRec_v = n.phi * n.epsilonRec
|
n.eRec_v = n.phi * n.epsilonRec
|
||||||
n.eRec_a = -n.phi * n.beta * n.epsilonRecA
|
n.eRec_a = -n.phi * n.beta * n.epsilonRecA
|
||||||
n.eRec = n.eRec_v + n.eRec_a
|
n.eRec = n.eRec_v + n.eRec_a
|
||||||
@@ -122,14 +87,71 @@ function learn!(n::alifNeuron, error::Float64)
|
|||||||
reset_epsilonRecA!(n)
|
reset_epsilonRecA!(n)
|
||||||
end
|
end
|
||||||
|
|
||||||
""" linearNeuron learn()
|
function compute_wRecChange!(n::linearNeuron, error::Float64)
|
||||||
"""
|
|
||||||
function learn!(n::linearNeuron, error::Float64)
|
|
||||||
n.eRec = n.phi * n.epsilonRec
|
n.eRec = n.phi * n.epsilonRec
|
||||||
ΔwRecChange = n.eta * error * n.eRec
|
ΔwRecChange = n.eta * error * n.eRec
|
||||||
n.wRecChange .+= ΔwRecChange
|
n.wRecChange .+= ΔwRecChange
|
||||||
reset_epsilonRec!(n)
|
reset_epsilonRec!(n)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function learn!(n::T, firedNeurons, nExInType) where T<:inputNeuron
|
||||||
|
# skip
|
||||||
|
end
|
||||||
|
|
||||||
|
function learn!(n::T, firedNeurons, nExInType) where T<:computeNeuron
|
||||||
|
wSign_0 = sign.(n.wRec) # original sign
|
||||||
|
n.wRec += n.wRecChange # merge wRecChange into wRec
|
||||||
|
reset_wRecChange!(n)
|
||||||
|
wSign_1 = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign
|
||||||
|
nonFlipedSign = isequal.(wSign_0, wSign_1) # 1 not fliped, 0 fliped
|
||||||
|
# normalize wRec peak to prevent input signal overwhelming neuron
|
||||||
|
normalizePeak!(n.wRec, n.wRecChange, 2)
|
||||||
|
# set weight that fliped sign to 0 for random new connection
|
||||||
|
n.wRec .*= nonFlipedSign
|
||||||
|
capMaxWeight!(n.wRec) # cap maximum weight
|
||||||
|
|
||||||
|
synapticConnStrength!(n)
|
||||||
|
neuroplasticity!(n, firedNeurons, nExInType)
|
||||||
|
end
|
||||||
|
|
||||||
|
function learn!(n::T, firedNeurons, nExInType, totalInputPort) where T<:outputNeuron
|
||||||
|
wSign_0 = sign.(n.wRec) # original sign
|
||||||
|
n.wRec += n.wRecChange
|
||||||
|
reset_wRecChange!(n)
|
||||||
|
wSign_1 = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign
|
||||||
|
nonFlipedSign = isequal.(wSign_0, wSign_1) # 1 not fliped, 0 fliped
|
||||||
|
# normalize wRec peak to prevent input signal overwhelming neuron
|
||||||
|
normalizePeak!(n.wRec, n.wRecChange, 2)
|
||||||
|
# set weight that fliped sign to 0 for random new connection
|
||||||
|
n.wRec .*= nonFlipedSign
|
||||||
|
capMaxWeight!(n.wRec) # cap maximum weight
|
||||||
|
|
||||||
|
synapticConnStrength!(n)
|
||||||
|
neuroplasticity!(n,firedNeurons, nExInType, totalInputPort)
|
||||||
|
end
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,9 @@
|
|||||||
module snn_utils
|
module snn_utils
|
||||||
|
|
||||||
using Flux.Optimise: apply!
|
|
||||||
export calculate_α, calculate_ρ, calculate_k, timestep_forward!, init_neuron, no_negative!,
|
export calculate_α, calculate_ρ, calculate_k, timestep_forward!, init_neuron, no_negative!,
|
||||||
precision, calculate_w_change!, store_knowledgefn_error!, interneurons_adjustment!,
|
precision, calculate_w_change!, store_knowledgefn_error!, interneurons_adjustment!,
|
||||||
reset_z_t!, resetLearningParams!, reset_learning_history_params!, reset_epsilonRec!,
|
reset_z_t!, resetLearningParams!, reset_learning_history_params!, reset_epsilonRec!,
|
||||||
reset_epsilonRecA!, synapticConnStrength!, normalizePeak!,
|
reset_epsilonRecA!, synapticConnStrength!, normalizePeak!, reset_wRecChange!,
|
||||||
firing_rate_error!, firing_rate_regulator!, update_Bn!, cal_firing_reg!,
|
firing_rate_error!, firing_rate_regulator!, update_Bn!, cal_firing_reg!,
|
||||||
neuroplasticity!, shakeup!, reset_learning_no_wchange!, adjust_internal_learning_rate!,
|
neuroplasticity!, shakeup!, reset_learning_no_wchange!, adjust_internal_learning_rate!,
|
||||||
gradient_withloss, capMaxWeight!
|
gradient_withloss, capMaxWeight!
|
||||||
@@ -253,6 +252,11 @@ function adjust_internal_learning_rate!(n::computeNeuron)
|
|||||||
n.internal_learning_rate * 1.005
|
n.internal_learning_rate * 1.005
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function connStrengthAdjust(currentStrength::Float64)
|
||||||
|
Δstrength = (1.0 - sigmoid(currentStrength))
|
||||||
|
return Δstrength::Float64
|
||||||
|
end
|
||||||
|
|
||||||
""" compute synaptic connection strength. bias will shift currentStrength to fit into
|
""" compute synaptic connection strength. bias will shift currentStrength to fit into
|
||||||
sigmoid operating range which centred at 0 and range is -37 to 37.
|
sigmoid operating range which centred at 0 and range is -37 to 37.
|
||||||
# Example
|
# Example
|
||||||
@@ -260,21 +264,15 @@ end
|
|||||||
one may use bias = -5 to transform synaptic strength into range -5 to 5
|
one may use bias = -5 to transform synaptic strength into range -5 to 5
|
||||||
the return value is shifted back to original scale
|
the return value is shifted back to original scale
|
||||||
"""
|
"""
|
||||||
function synapticConnStrength(currentStrength::AbstractFloat, updown::String, bias::Number=0)::Float64
|
function synapticConnStrength(currentStrength::Float64, updown::String)
|
||||||
currentStrength += bias
|
Δstrength = connStrengthAdjust(currentStrength)
|
||||||
if currentStrength > 0
|
|
||||||
Δstrength = (1.0 - sigmoid(currentStrength))
|
|
||||||
else
|
|
||||||
Δstrength = sigmoid(currentStrength)
|
|
||||||
end
|
|
||||||
|
|
||||||
if updown == "up"
|
if updown == "up"
|
||||||
updatedStrength = currentStrength + Δstrength
|
updatedStrength = currentStrength + Δstrength
|
||||||
else
|
else
|
||||||
updatedStrength = currentStrength - Δstrength
|
updatedStrength = currentStrength - Δstrength
|
||||||
end
|
end
|
||||||
updatedStrength -= bias
|
return updatedStrength::Float64
|
||||||
return updatedStrength
|
|
||||||
end
|
end
|
||||||
|
|
||||||
""" Compute all synaptic connection strength of a neuron. Also mark n.wRec to 0 if wRec goes
|
""" Compute all synaptic connection strength of a neuron. Also mark n.wRec to 0 if wRec goes
|
||||||
@@ -318,47 +316,6 @@ function normalizePeak!(v1::Vector, v2::Vector, radius::Integer=2)
|
|||||||
normalize!(subvector, 1)
|
normalize!(subvector, 1)
|
||||||
end
|
end
|
||||||
|
|
||||||
""" rewire of neuron synaptic connection that has 0 weight. With connection's excitatory and
|
|
||||||
inhabitory ratio constraint.
|
|
||||||
"""
|
|
||||||
# function neuroplasticity!(n::Union{computeNeuron, outputNeuron}, firedNeurons::Vector,
|
|
||||||
# nExcitatory::Vector, nInhabitory::Vector, excitatoryPercent::Integer)
|
|
||||||
# # if there is 0-weight then replace it with new connection
|
|
||||||
# zeroWeightConnIndex = findall(iszero.(n.wRec)) # connection that has 0 weight
|
|
||||||
# desiredEx = Int(floor((excitatoryPercent / 100) * length(n.subscriptionList)))
|
|
||||||
# desiredIn = length(n.subscriptionList) - desiredEx
|
|
||||||
# wRecSign = sign.(n.wRec)
|
|
||||||
# inConn = sum(isequal.(wRecSign, -1))
|
|
||||||
|
|
||||||
# # random new synaptic connection
|
|
||||||
# inConnToAdd = desiredIn - inConn
|
|
||||||
# if inConnToAdd <= 0
|
|
||||||
# # skip all new Conn will be excitatory type
|
|
||||||
# else
|
|
||||||
# newConnVecSign = ones(length(zeroWeightConnIndex))
|
|
||||||
# newConnVecSign = view(newConnVecSign, 1:inConnToAdd) * -1
|
|
||||||
# end
|
|
||||||
|
|
||||||
# # new synaptic connection must sample fron neuron that fires
|
|
||||||
# inPool = nInhabitory ∩ firedNeurons
|
|
||||||
# filter!(x -> x ∉ [n.id], inPool) # exclude this neuron id from the list
|
|
||||||
# filter!(x -> x ∉ n.subscriptionList, inPool) # exclude this neuron's subscriptionList from the list
|
|
||||||
# exPool = nExcitatory ∩ firedNeurons
|
|
||||||
# filter!(x -> x ∉ [n.id], exPool) # exclude this neuron id from the list
|
|
||||||
# filter!(x -> x ∉ n.subscriptionList, exPool) # exclude this neuron's subscriptionList from the list
|
|
||||||
|
|
||||||
# w = [rand(0.01:0.01:0.2, length(zeroWeightConnIndex))] .* newConnVecSign
|
|
||||||
# synapticStrength = [rand(-5:0.01:-4, length(zeroWeightConnIndex))]
|
|
||||||
|
|
||||||
# # add new synaptic connection to neuron
|
|
||||||
# for (i, connIndex) in enumerate(zeroWeightConnIndex)
|
|
||||||
# n.subscriptionList[connIndex] = newConnVecSign[i] < 0 ? pop!(inPool) : pop!(exPool)
|
|
||||||
# n.wRec[connIndex] = w[i]
|
|
||||||
# n.synapticStrength[connIndex] = synapticStrength[i]
|
|
||||||
# end
|
|
||||||
# end
|
|
||||||
|
|
||||||
|
|
||||||
""" rewire of neuron synaptic connection that has 0 weight. Without connection's excitatory and
|
""" rewire of neuron synaptic connection that has 0 weight. Without connection's excitatory and
|
||||||
inhabitory ratio constraint.
|
inhabitory ratio constraint.
|
||||||
"""
|
"""
|
||||||
|
|||||||
40
src/types.jl
40
src/types.jl
@@ -9,7 +9,7 @@ export
|
|||||||
instantiate_custom_types, init_neuron, populate_neuron,
|
instantiate_custom_types, init_neuron, populate_neuron,
|
||||||
add_neuron!
|
add_neuron!
|
||||||
|
|
||||||
using Random, Flux, LinearAlgebra
|
using Random, LinearAlgebra
|
||||||
|
|
||||||
#------------------------------------------------------------------------------------------------100
|
#------------------------------------------------------------------------------------------------100
|
||||||
|
|
||||||
@@ -350,7 +350,7 @@ Base.@kwdef mutable struct lifNeuron <: computeNeuron
|
|||||||
refractoryDuration::Int64 = 3 # neuron's refratory period in millisecond
|
refractoryDuration::Int64 = 3 # neuron's refratory period in millisecond
|
||||||
refractoryCounter::Int64 = 0
|
refractoryCounter::Int64 = 0
|
||||||
tau_m::Float64 = 0.0 # τ_m, membrane time constant in millisecond
|
tau_m::Float64 = 0.0 # τ_m, membrane time constant in millisecond
|
||||||
eta::Float64 = 0.0001 # η, learning rate
|
eta::Float64 = 0.01 # η, learning rate
|
||||||
wRecChange::Array{Float64} = Float64[] # Δw_rec, cumulated wRec change
|
wRecChange::Array{Float64} = Float64[] # Δw_rec, cumulated wRec change
|
||||||
recSignal::Float64 = 0.0 # incoming recurrent signal
|
recSignal::Float64 = 0.0 # incoming recurrent signal
|
||||||
alpha_v_t::Float64 = 0.0 # alpha * v_t
|
alpha_v_t::Float64 = 0.0 # alpha * v_t
|
||||||
@@ -438,7 +438,7 @@ Base.@kwdef mutable struct alifNeuron <: computeNeuron
|
|||||||
eRec_v::Array{Float64} = Float64[] # a component of neuron's eligibility trace resulted from v_t
|
eRec_v::Array{Float64} = Float64[] # a component of neuron's eligibility trace resulted from v_t
|
||||||
eRec_a::Array{Float64} = Float64[] # a component of neuron's eligibility trace resulted from av_th
|
eRec_a::Array{Float64} = Float64[] # a component of neuron's eligibility trace resulted from av_th
|
||||||
eRec::Array{Float64} = Float64[] # neuron's eligibility trace
|
eRec::Array{Float64} = Float64[] # neuron's eligibility trace
|
||||||
eta::Float64 = 0.0001 # eta, learning rate
|
eta::Float64 = 0.01 # eta, learning rate
|
||||||
gammaPd::Float64 = 0.3 # γ_pd, discount factor, value from paper
|
gammaPd::Float64 = 0.3 # γ_pd, discount factor, value from paper
|
||||||
phi::Float64 = 0.0 # ϕ, psuedo derivative
|
phi::Float64 = 0.0 # ϕ, psuedo derivative
|
||||||
refractoryDuration::Int64 = 3 # neuron's refractory period in millisecond
|
refractoryDuration::Int64 = 3 # neuron's refractory period in millisecond
|
||||||
@@ -448,7 +448,7 @@ Base.@kwdef mutable struct alifNeuron <: computeNeuron
|
|||||||
recSignal::Float64 = 0.0 # incoming recurrent signal
|
recSignal::Float64 = 0.0 # incoming recurrent signal
|
||||||
alpha_v_t::Float64 = 0.0 # alpha * v_t
|
alpha_v_t::Float64 = 0.0 # alpha * v_t
|
||||||
error::Float64 = 0.0 # local neuron error
|
error::Float64 = 0.0 # local neuron error
|
||||||
optimiser::Union{Any,Nothing} = load_optimiser("AdaBelief") # Flux optimizer
|
# optimiser::Union{Any,Nothing} = load_optimiser("AdaBelief") # Flux optimizer
|
||||||
|
|
||||||
firingCounter::Int64 = 0 # store how many times neuron fires
|
firingCounter::Int64 = 0 # store how many times neuron fires
|
||||||
firingRateTarget::Float64 = 20.0 # neuron's target firing rate in Hz
|
firingRateTarget::Float64 = 20.0 # neuron's target firing rate in Hz
|
||||||
@@ -548,7 +548,7 @@ Base.@kwdef mutable struct linearNeuron <: outputNeuron
|
|||||||
refractoryDuration::Int64 = 3 # neuron's refratory period in millisecond
|
refractoryDuration::Int64 = 3 # neuron's refratory period in millisecond
|
||||||
refractoryCounter::Int64 = 0
|
refractoryCounter::Int64 = 0
|
||||||
tau_out::Float64 = 0.0 # τ_out, membrane time constant in millisecond
|
tau_out::Float64 = 0.0 # τ_out, membrane time constant in millisecond
|
||||||
eta::Float64 = 0.0001 # η, learning rate
|
eta::Float64 = 0.01 # η, learning rate
|
||||||
wRecChange::Array{Float64} = Float64[] # Δw_rec, cumulated wRec change
|
wRecChange::Array{Float64} = Float64[] # Δw_rec, cumulated wRec change
|
||||||
recSignal::Float64 = 0.0 # incoming recurrent signal
|
recSignal::Float64 = 0.0 # incoming recurrent signal
|
||||||
alpha_v_t::Float64 = 0.0 # alpha * v_t
|
alpha_v_t::Float64 = 0.0 # alpha * v_t
|
||||||
@@ -588,21 +588,21 @@ end
|
|||||||
|
|
||||||
#------------------------------------------------------------------------------------------------100
|
#------------------------------------------------------------------------------------------------100
|
||||||
|
|
||||||
function load_optimiser(optimiser_name::String; params::Union{Dict,Nothing} = nothing)
|
# function load_optimiser(optimiser_name::String; params::Union{Dict,Nothing} = nothing)
|
||||||
if optimiser_name == "AdaBelief"
|
# if optimiser_name == "AdaBelief"
|
||||||
params = (0.01, (0.9, 0.8))
|
# params = (0.01, (0.9, 0.8))
|
||||||
return Flux.Optimise.AdaBelief(params...)
|
# return Flux.Optimise.AdaBelief(params...)
|
||||||
elseif optimiser_name == "AdaBelief2"
|
# elseif optimiser_name == "AdaBelief2"
|
||||||
# output neuron requires slower change pace so η is lower than compute neuron at 0.007
|
# # output neuron requires slower change pace so η is lower than compute neuron at 0.007
|
||||||
# because if w_out change too fast, compute neuron will not able to
|
# # because if w_out change too fast, compute neuron will not able to
|
||||||
# grapse output neuron moving direction i.e. both compute neuron's direction and
|
# # grapse output neuron moving direction i.e. both compute neuron's direction and
|
||||||
# output neuron direction are out of sync.
|
# # output neuron direction are out of sync.
|
||||||
params = (0.007, (0.9, 0.8))
|
# params = (0.007, (0.9, 0.8))
|
||||||
return Flux.Optimise.AdaBelief(params...)
|
# return Flux.Optimise.AdaBelief(params...)
|
||||||
else
|
# else
|
||||||
error("optimiser is not defined yet in load_optimiser()")
|
# error("optimiser is not defined yet in load_optimiser()")
|
||||||
end
|
# end
|
||||||
end
|
# end
|
||||||
|
|
||||||
function init_neuron!(id::Int64, n::passthroughNeuron, n_params::Dict, kfnParams::Dict)
|
function init_neuron!(id::Int64, n::passthroughNeuron, n_params::Dict, kfnParams::Dict)
|
||||||
n.id = id
|
n.id = id
|
||||||
|
|||||||
Reference in New Issue
Block a user