change from end of sample learning to online learning

This commit is contained in:
2023-05-26 21:03:03 +07:00
parent 3556167591
commit b0cede75c1
6 changed files with 146 additions and 152 deletions

View File

@@ -16,9 +16,9 @@ weakdeps = ["ChainRulesCore"]
[[deps.Accessors]]
deps = ["Compat", "CompositionsBase", "ConstructionBase", "Dates", "InverseFunctions", "LinearAlgebra", "MacroTools", "Requires", "Test"]
git-tree-sha1 = "a4f8669e46c8cdf68661fe6bb0f7b89f51dd23cf"
git-tree-sha1 = "2b301c2388067d655fe5e4ca6d4aa53b61f895b4"
uuid = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
version = "0.1.30"
version = "0.1.31"
[deps.Accessors.extensions]
AccessorsAxisKeysExt = "AxisKeys"
@@ -67,10 +67,24 @@ uuid = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
version = "0.4.2"
[[deps.BangBang]]
deps = ["Compat", "ConstructionBase", "Future", "InitialValues", "LinearAlgebra", "Requires", "Setfield", "Tables", "ZygoteRules"]
git-tree-sha1 = "7fe6d92c4f281cf4ca6f2fba0ce7b299742da7ca"
deps = ["Compat", "ConstructionBase", "InitialValues", "LinearAlgebra", "Requires", "Setfield", "Tables"]
git-tree-sha1 = "54b00d1b93791f8e19e31584bd30f2cb6004614b"
uuid = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
version = "0.3.37"
version = "0.3.38"
[deps.BangBang.extensions]
BangBangChainRulesCoreExt = "ChainRulesCore"
BangBangDataFramesExt = "DataFrames"
BangBangStaticArraysExt = "StaticArrays"
BangBangStructArraysExt = "StructArrays"
BangBangTypedTablesExt = "TypedTables"
[deps.BangBang.weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9"
[[deps.Base64]]
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
@@ -155,9 +169,13 @@ uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
version = "1.0.2+0"
[[deps.CompositionsBase]]
git-tree-sha1 = "455419f7e328a1a2493cabc6428d79e951349769"
git-tree-sha1 = "802bb88cd69dfd1509f6670416bd4434015693ad"
uuid = "a33af91c-f02d-484b-be07-31d278c5ca2b"
version = "0.1.1"
version = "0.1.2"
weakdeps = ["InverseFunctions"]
[deps.CompositionsBase.extensions]
CompositionsBaseInverseFunctionsExt = "InverseFunctions"
[[deps.ConstructionBase]]
deps = ["LinearAlgebra"]
@@ -340,13 +358,13 @@ version = "0.1.4"
[[deps.GPUCompiler]]
deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "Scratch", "TimerOutputs", "UUIDs"]
git-tree-sha1 = "e9a9173cd77e16509cdf9c1663fda19b22a518b7"
git-tree-sha1 = "5737dc242dadd392d934ee330c69ceff47f0259c"
uuid = "61eb1bfa-7361-4325-ad38-22787b887f55"
version = "0.19.3"
version = "0.19.4"
[[deps.GeneralUtils]]
deps = ["DataStructures", "Distributions", "JSON3"]
path = "/home/ton/.julia/dev/GeneralUtils"
path = "C:\\Users\\naraw\\.julia\\dev\\GeneralUtils"
uuid = "c6c72f09-b708-4ac8-ac7c-2084d70108fe"
version = "0.1.0"
@@ -750,9 +768,9 @@ version = "0.1.15"
[[deps.StaticArrays]]
deps = ["LinearAlgebra", "Random", "StaticArraysCore", "Statistics"]
git-tree-sha1 = "c262c8e978048c2b095be1672c9bee55b4619521"
git-tree-sha1 = "8982b3607a212b070a5e46eea83eb62b4744ae12"
uuid = "90137ffa-7385-5640-81b9-e52037218182"
version = "1.5.24"
version = "1.5.25"
[[deps.StaticArraysCore]]
git-tree-sha1 = "6b7ba252635a5eff6a0b0664a41ee140a1c9e72a"

View File

@@ -25,16 +25,16 @@ using .forward
include("learn.jl")
using .learn
include("readout.jl")
using .readout
# include("readout.jl")
# using .readout
include("interface.jl")
using .interface
# include("interface.jl")
# using .interface
#------------------------------------------------------------------------------------------------100
"""
Todo:
[*1] add maximum weight cap of each connection
[2] implement connection strength based on right or wrong answer
[4] implement dormant connection
[3] Δweight * connection strength
@@ -63,6 +63,7 @@ using .interface
[DONE] neuroplasticity() i.e. change connection
[DONE] add multi threads
[DONE] during 0 training if 1-9 output neuron fires, adjust weight only those neurons
[DONE] add maximum weight cap of each connection
Change from version: v06_36a
-

View File

@@ -1,8 +1,6 @@
module forward
using Flux.Optimise: apply!
using Statistics, Flux, Random, LinearAlgebra, JSON3
using Statistics, Random, LinearAlgebra, JSON3
using GeneralUtils
using ..types, ..snn_utils
@@ -77,19 +75,17 @@ function (kfn::kfn_1)(m::model, input_data::AbstractVector)
kfn.firedNeurons_t0 = [n.z_t for n in kfn.neuronsArray] #TODO check if it is used?
Threads.@threads for n in kfn.neuronsArray
# for n in kfn.neuronsArray
# Threads.@threads for n in kfn.neuronsArray
for n in kfn.neuronsArray
n(kfn)
end
kfn.firedNeurons_t1 = [n.z_t1 for n in kfn.neuronsArray]
append!(kfn.firedNeurons, findall(kfn.firedNeurons_t1)) # store id of neuron that fires
if kfn.learningStage == "end_learning"
kfn.firedNeurons |> unique! # use for random new neuron connection
end
kfn.firedNeurons |> unique! # use for random new neuron connection
Threads.@threads for n in kfn.outputNeuronsArray
# for n in kfn.outputNeuronsArray
# Threads.@threads for n in kfn.outputNeuronsArray
for n in kfn.outputNeuronsArray
n(kfn)
end

View File

@@ -1,8 +1,6 @@
module learn
using Flux.Optimise: apply!
using Statistics, Flux, Random, LinearAlgebra, JSON3
using Statistics, Random, LinearAlgebra, JSON3
using GeneralUtils
using ..types, ..snn_utils
@@ -12,7 +10,7 @@ export learn!
function learn!(m::model, modelRespond::Vector{Bool}, correctAnswer::Union{AbstractVector, Nothing})
if correctAnswer === nothing
correctAnswer_I = BitArray(undef, length(modelRespond))
correctAnswer_I = BitArray(zeros(length(modelRespond)))
else
correctAnswer_I = Bool.(correctAnswer) # correct answer for kfn I
end
@@ -43,76 +41,43 @@ function learn!(kfn::kfn_1, correctAnswer::BitVector)
outs = [n.z_t1 for n in kfn.outputNeuronsArray]
for (i, out) in enumerate(outs)
if out != correctAnswer[i] # need to adjust weight
kfnError = ( (kfn.outputNeuronsArray[i].v_th - kfn.outputNeuronsArray[i].vError) /
kfnError = ( (kfn.outputNeuronsArray[i].v_th - kfn.outputNeuronsArray[i].vError) * 100.0 /
kfn.outputNeuronsArray[i].v_th )
if correctAnswer[i] == 1 # output neuron that associated with correctAnswer
Threads.@threads for n in kfn.neuronsArray
Threads.@threads for n in kfn.neuronsArray # multithread is not atomic and causing error
# for n in kfn.neuronsArray
learn!(n, kfnError)
compute_wRecChange!(n, kfnError)
learn!(n, kfn.firedNeurons, kfn.nExInType)
end
learn!(kfn.outputNeuronsArray[i], kfnError)
compute_wRecChange!(kfn.outputNeuronsArray[i], kfnError)
learn!(kfn.outputNeuronsArray[i], kfn.firedNeurons, kfn.nExInType,
kfn.kfnParams[:totalInputPort])
else # output neuron that is NOT associated with correctAnswer
learn!(kfn.outputNeuronsArray[i], kfnError)
compute_wRecChange!(kfn.outputNeuronsArray[i], kfnError)
learn!(kfn.outputNeuronsArray[i], kfn.firedNeurons, kfn.nExInType,
kfn.kfnParams[:totalInputPort])
end
end
end
# wrap up learning session
if kfn.learningStage == "end_learning"
Threads.@threads for n in kfn.neuronsArray
# for n in kfn.neuronsArray
if typeof(n) <: computeNeuron
wSign_0 = sign.(n.wRec) # original sign
n.wRec += n.wRecChange # merge wRecChange into wRec
wSign_1 = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign
nonFlipedSign = isequal.(wSign_0, wSign_1) # 1 not fliped, 0 fliped
# normalize wRec peak to prevent input signal overwhelming neuron
normalizePeak!(n.wRec, n.wRecChange, 2)
# set weight that fliped sign to 0 for random new connection
n.wRec .*= nonFlipedSign
capMaxWeight!(n.wRec) # cap maximum weight
synapticConnStrength!(n)
neuroplasticity!(n, kfn.firedNeurons, kfn.nExInType)
end
end
for n in kfn.outputNeuronsArray # merge wRecChange into wRec
wSign_0 = sign.(n.wRec) # original sign
n.wRec += n.wRecChange
wSign_1 = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign
nonFlipedSign = isequal.(wSign_0, wSign_1) # 1 not fliped, 0 fliped
normalizePeak!(n.wRec, n.wRecChange, 2)
n.wRec .*= nonFlipedSign # set weight that fliped sign to 0 for random new connection
capMaxWeight!(n.wRec) # cap maximum weight
synapticConnStrength!(n)
neuroplasticity!(n, kfn.firedNeurons, kfn.nExInType, kfn.kfnParams[:totalInputPort])
end
kfn.learningStage = "inference"
end
end
""" passthroughNeuron learn()
"""
function learn!(n::passthroughNeuron, error::Float64)
function compute_wRecChange!(n::passthroughNeuron, error::Float64)
# skip
end
""" lif learn()
"""
function learn!(n::lifNeuron, error::Float64)
function compute_wRecChange!(n::lifNeuron, error::Float64)
n.eRec = n.phi * n.epsilonRec
ΔwRecChange = n.eta * error * n.eRec
n.wRecChange .+= ΔwRecChange
reset_epsilonRec!(n)
end
""" alifNeuron learn()
"""
function learn!(n::alifNeuron, error::Float64)
function compute_wRecChange!(n::alifNeuron, error::Float64)
n.eRec_v = n.phi * n.epsilonRec
n.eRec_a = -n.phi * n.beta * n.epsilonRecA
n.eRec = n.eRec_v + n.eRec_a
@@ -122,14 +87,71 @@ function learn!(n::alifNeuron, error::Float64)
reset_epsilonRecA!(n)
end
""" linearNeuron learn()
"""
function learn!(n::linearNeuron, error::Float64)
function compute_wRecChange!(n::linearNeuron, error::Float64)
n.eRec = n.phi * n.epsilonRec
ΔwRecChange = n.eta * error * n.eRec
n.wRecChange .+= ΔwRecChange
reset_epsilonRec!(n)
end
function learn!(n::T, firedNeurons, nExInType) where T<:inputNeuron
# skip
end
function learn!(n::T, firedNeurons, nExInType) where T<:computeNeuron
wSign_0 = sign.(n.wRec) # original sign
n.wRec += n.wRecChange # merge wRecChange into wRec
reset_wRecChange!(n)
wSign_1 = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign
nonFlipedSign = isequal.(wSign_0, wSign_1) # 1 not fliped, 0 fliped
# normalize wRec peak to prevent input signal overwhelming neuron
normalizePeak!(n.wRec, n.wRecChange, 2)
# set weight that fliped sign to 0 for random new connection
n.wRec .*= nonFlipedSign
capMaxWeight!(n.wRec) # cap maximum weight
synapticConnStrength!(n)
neuroplasticity!(n, firedNeurons, nExInType)
end
function learn!(n::T, firedNeurons, nExInType, totalInputPort) where T<:outputNeuron
wSign_0 = sign.(n.wRec) # original sign
n.wRec += n.wRecChange
reset_wRecChange!(n)
wSign_1 = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign
nonFlipedSign = isequal.(wSign_0, wSign_1) # 1 not fliped, 0 fliped
# normalize wRec peak to prevent input signal overwhelming neuron
normalizePeak!(n.wRec, n.wRecChange, 2)
# set weight that fliped sign to 0 for random new connection
n.wRec .*= nonFlipedSign
capMaxWeight!(n.wRec) # cap maximum weight
synapticConnStrength!(n)
neuroplasticity!(n,firedNeurons, nExInType, totalInputPort)
end

View File

@@ -1,10 +1,9 @@
module snn_utils
using Flux.Optimise: apply!
export calculate_α, calculate_ρ, calculate_k, timestep_forward!, init_neuron, no_negative!,
precision, calculate_w_change!, store_knowledgefn_error!, interneurons_adjustment!,
reset_z_t!, resetLearningParams!, reset_learning_history_params!, reset_epsilonRec!,
reset_epsilonRecA!, synapticConnStrength!, normalizePeak!,
reset_epsilonRecA!, synapticConnStrength!, normalizePeak!, reset_wRecChange!,
firing_rate_error!, firing_rate_regulator!, update_Bn!, cal_firing_reg!,
neuroplasticity!, shakeup!, reset_learning_no_wchange!, adjust_internal_learning_rate!,
gradient_withloss, capMaxWeight!
@@ -253,6 +252,11 @@ function adjust_internal_learning_rate!(n::computeNeuron)
n.internal_learning_rate * 1.005
end
function connStrengthAdjust(currentStrength::Float64)
Δstrength = (1.0 - sigmoid(currentStrength))
return Δstrength::Float64
end
""" compute synaptic connection strength. bias will shift currentStrength to fit into
sigmoid operating range which centred at 0 and range is -37 to 37.
# Example
@@ -260,21 +264,15 @@ end
one may use bias = -5 to transform synaptic strength into range -5 to 5
the return value is shifted back to original scale
"""
function synapticConnStrength(currentStrength::AbstractFloat, updown::String, bias::Number=0)::Float64
currentStrength += bias
if currentStrength > 0
Δstrength = (1.0 - sigmoid(currentStrength))
else
Δstrength = sigmoid(currentStrength)
end
function synapticConnStrength(currentStrength::Float64, updown::String)
Δstrength = connStrengthAdjust(currentStrength)
if updown == "up"
updatedStrength = currentStrength + Δstrength
else
updatedStrength = currentStrength - Δstrength
end
updatedStrength -= bias
return updatedStrength
return updatedStrength::Float64
end
""" Compute all synaptic connection strength of a neuron. Also mark n.wRec to 0 if wRec goes
@@ -318,47 +316,6 @@ function normalizePeak!(v1::Vector, v2::Vector, radius::Integer=2)
normalize!(subvector, 1)
end
""" rewire of neuron synaptic connection that has 0 weight. With connection's excitatory and
inhabitory ratio constraint.
"""
# function neuroplasticity!(n::Union{computeNeuron, outputNeuron}, firedNeurons::Vector,
# nExcitatory::Vector, nInhabitory::Vector, excitatoryPercent::Integer)
# # if there is 0-weight then replace it with new connection
# zeroWeightConnIndex = findall(iszero.(n.wRec)) # connection that has 0 weight
# desiredEx = Int(floor((excitatoryPercent / 100) * length(n.subscriptionList)))
# desiredIn = length(n.subscriptionList) - desiredEx
# wRecSign = sign.(n.wRec)
# inConn = sum(isequal.(wRecSign, -1))
# # random new synaptic connection
# inConnToAdd = desiredIn - inConn
# if inConnToAdd <= 0
# # skip all new Conn will be excitatory type
# else
# newConnVecSign = ones(length(zeroWeightConnIndex))
# newConnVecSign = view(newConnVecSign, 1:inConnToAdd) * -1
# end
# # new synaptic connection must sample fron neuron that fires
# inPool = nInhabitory ∩ firedNeurons
# filter!(x -> x ∉ [n.id], inPool) # exclude this neuron id from the list
# filter!(x -> x ∉ n.subscriptionList, inPool) # exclude this neuron's subscriptionList from the list
# exPool = nExcitatory ∩ firedNeurons
# filter!(x -> x ∉ [n.id], exPool) # exclude this neuron id from the list
# filter!(x -> x ∉ n.subscriptionList, exPool) # exclude this neuron's subscriptionList from the list
# w = [rand(0.01:0.01:0.2, length(zeroWeightConnIndex))] .* newConnVecSign
# synapticStrength = [rand(-5:0.01:-4, length(zeroWeightConnIndex))]
# # add new synaptic connection to neuron
# for (i, connIndex) in enumerate(zeroWeightConnIndex)
# n.subscriptionList[connIndex] = newConnVecSign[i] < 0 ? pop!(inPool) : pop!(exPool)
# n.wRec[connIndex] = w[i]
# n.synapticStrength[connIndex] = synapticStrength[i]
# end
# end
""" rewire of neuron synaptic connection that has 0 weight. Without connection's excitatory and
inhabitory ratio constraint.
"""

View File

@@ -9,7 +9,7 @@ export
instantiate_custom_types, init_neuron, populate_neuron,
add_neuron!
using Random, Flux, LinearAlgebra
using Random, LinearAlgebra
#------------------------------------------------------------------------------------------------100
@@ -350,7 +350,7 @@ Base.@kwdef mutable struct lifNeuron <: computeNeuron
refractoryDuration::Int64 = 3 # neuron's refratory period in millisecond
refractoryCounter::Int64 = 0
tau_m::Float64 = 0.0 # τ_m, membrane time constant in millisecond
eta::Float64 = 0.0001 # η, learning rate
eta::Float64 = 0.01 # η, learning rate
wRecChange::Array{Float64} = Float64[] # Δw_rec, cumulated wRec change
recSignal::Float64 = 0.0 # incoming recurrent signal
alpha_v_t::Float64 = 0.0 # alpha * v_t
@@ -438,7 +438,7 @@ Base.@kwdef mutable struct alifNeuron <: computeNeuron
eRec_v::Array{Float64} = Float64[] # a component of neuron's eligibility trace resulted from v_t
eRec_a::Array{Float64} = Float64[] # a component of neuron's eligibility trace resulted from av_th
eRec::Array{Float64} = Float64[] # neuron's eligibility trace
eta::Float64 = 0.0001 # eta, learning rate
eta::Float64 = 0.01 # eta, learning rate
gammaPd::Float64 = 0.3 # γ_pd, discount factor, value from paper
phi::Float64 = 0.0 # ϕ, psuedo derivative
refractoryDuration::Int64 = 3 # neuron's refractory period in millisecond
@@ -448,7 +448,7 @@ Base.@kwdef mutable struct alifNeuron <: computeNeuron
recSignal::Float64 = 0.0 # incoming recurrent signal
alpha_v_t::Float64 = 0.0 # alpha * v_t
error::Float64 = 0.0 # local neuron error
optimiser::Union{Any,Nothing} = load_optimiser("AdaBelief") # Flux optimizer
# optimiser::Union{Any,Nothing} = load_optimiser("AdaBelief") # Flux optimizer
firingCounter::Int64 = 0 # store how many times neuron fires
firingRateTarget::Float64 = 20.0 # neuron's target firing rate in Hz
@@ -548,7 +548,7 @@ Base.@kwdef mutable struct linearNeuron <: outputNeuron
refractoryDuration::Int64 = 3 # neuron's refratory period in millisecond
refractoryCounter::Int64 = 0
tau_out::Float64 = 0.0 # τ_out, membrane time constant in millisecond
eta::Float64 = 0.0001 # η, learning rate
eta::Float64 = 0.01 # η, learning rate
wRecChange::Array{Float64} = Float64[] # Δw_rec, cumulated wRec change
recSignal::Float64 = 0.0 # incoming recurrent signal
alpha_v_t::Float64 = 0.0 # alpha * v_t
@@ -588,21 +588,21 @@ end
#------------------------------------------------------------------------------------------------100
function load_optimiser(optimiser_name::String; params::Union{Dict,Nothing} = nothing)
if optimiser_name == "AdaBelief"
params = (0.01, (0.9, 0.8))
return Flux.Optimise.AdaBelief(params...)
elseif optimiser_name == "AdaBelief2"
# output neuron requires slower change pace so η is lower than compute neuron at 0.007
# because if w_out change too fast, compute neuron will not able to
# grapse output neuron moving direction i.e. both compute neuron's direction and
# output neuron direction are out of sync.
params = (0.007, (0.9, 0.8))
return Flux.Optimise.AdaBelief(params...)
else
error("optimiser is not defined yet in load_optimiser()")
end
end
# function load_optimiser(optimiser_name::String; params::Union{Dict,Nothing} = nothing)
# if optimiser_name == "AdaBelief"
# params = (0.01, (0.9, 0.8))
# return Flux.Optimise.AdaBelief(params...)
# elseif optimiser_name == "AdaBelief2"
# # output neuron requires slower change pace so η is lower than compute neuron at 0.007
# # because if w_out change too fast, compute neuron will not able to
# # grapse output neuron moving direction i.e. both compute neuron's direction and
# # output neuron direction are out of sync.
# params = (0.007, (0.9, 0.8))
# return Flux.Optimise.AdaBelief(params...)
# else
# error("optimiser is not defined yet in load_optimiser()")
# end
# end
function init_neuron!(id::Int64, n::passthroughNeuron, n_params::Dict, kfnParams::Dict)
n.id = id