From 9539b1b39bdb2cbb5cf6bb283a201be4a730bfc3 Mon Sep 17 00:00:00 2001 From: tonaerospace Date: Wed, 17 May 2023 22:41:09 +0700 Subject: [PATCH] bug fix --- src/forward.jl | 8 +++----- src/learn.jl | 26 ++++++++++++++------------ src/snn_utils.jl | 11 ++++++----- src/types.jl | 8 ++++---- 4 files changed, 27 insertions(+), 26 deletions(-) diff --git a/src/forward.jl b/src/forward.jl index ca42340..02b8eb8 100644 --- a/src/forward.jl +++ b/src/forward.jl @@ -129,7 +129,6 @@ function (n::lifNeuron)(kfn::knowledgeFn) n.v_t1 = n.alpha * n.v_t else n.recSignal = sum(n.wRec .* n.z_i_t) # signal from other neuron that this neuron subscribed - n.alpha_v_t = n.alpha * n.v_t n.v_t1 = n.alpha_v_t + n.recSignal n.v_t1 = no_negative!(n.v_t1) @@ -171,7 +170,6 @@ function (n::alifNeuron)(kfn::knowledgeFn) n.v_t1 = n.alpha * n.v_t n.phi = 0 else - n.z_t = isnothing(n.z_t) ? false : n.z_t n.a = (n.rho * n.a) + ((1 - n.rho) * n.z_t) n.av_th = n.v_th + (n.beta * n.a) n.recSignal = sum(n.wRec .* n.z_i_t) # signal from other neuron that this neuron subscribed @@ -205,7 +203,7 @@ function (n::linearNeuron)(kfn::T) where T<:knowledgeFn n.timeStep = kfn.timeStep # pulling other neuron's firing status at time t - n.z_i_t = getindex(kfn.firedNeurons_t0, n.subscriptionList) + n.z_i_t = getindex(kfn.firedNeurons_t1, n.subscriptionList) if n.refractoryCounter != 0 n.refractoryCounter -= 1 @@ -217,13 +215,13 @@ function (n::linearNeuron)(kfn::T) where T<:knowledgeFn # decay of v_t1 n.v_t1 = n.alpha * n.v_t + n.vError = n.v_t1 else n.recSignal = sum(n.wRec .* n.z_i_t) # signal from other neuron that this neuron subscribed - n.alpha_v_t = n.alpha * n.v_t n.v_t1 = n.alpha_v_t + n.recSignal n.v_t1 = no_negative!(n.v_t1) - + n.vError = n.v_t1 if n.v_t1 > n.v_th n.z_t1 = true n.refractoryCounter = n.refractoryDuration diff --git a/src/learn.jl b/src/learn.jl index 83a7fe4..c15c919 100644 --- a/src/learn.jl +++ b/src/learn.jl @@ -27,7 +27,7 @@ function learn!(kfn::kfn_1, correctAnswer::AbstractVector) outs = [n.z_t1 for n in kfn.outputNeuronsArray] for (i, out) in enumerate(outs) if out != correctAnswer[i] # need to adjust weight - kfnError = (kfn.outputNeuronsArray[i].v_th - kfn.outputNeuronsArray[i].v_t1) * + kfnError = (kfn.outputNeuronsArray[i].v_th - kfn.outputNeuronsArray[i].vError) * 100 / kfn.outputNeuronsArray[i].v_th # Threads.@threads for n in kfn.neuronsArray @@ -35,7 +35,7 @@ function learn!(kfn::kfn_1, correctAnswer::AbstractVector) learn!(n, kfnError) end - learn!(kfn.outputNeuronsArray[i], kfn) + learn!(kfn.outputNeuronsArray[i], kfnError) end end @@ -43,18 +43,20 @@ function learn!(kfn::kfn_1, correctAnswer::AbstractVector) if kfn.learningStage == "end_learning" # Threads.@threads for n in kfn.neuronsArray for n in kfn.neuronsArray - wSign_0 = sign.(n.wRec) # original sign - n.wRec += n.wRecChange # merge wRecChange into wRec - wSign_1 = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign - nonFlipedSign = isequal.(wSign_0, wSign_1) # 1 not fliped, 0 fliped - # normalize wRec peak to prevent input signal overwhelming neuron - normalizePeak!(n.wRec, 2) - n.wRec .*= nonFlipedSign # set weight that fliped sign to 0 for random new connection + if typeof(n) <: computeNeuron + wSign_0 = sign.(n.wRec) # original sign + n.wRec += n.wRecChange # merge wRecChange into wRec + wSign_1 = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign + nonFlipedSign = isequal.(wSign_0, wSign_1) # 1 not fliped, 0 fliped + # normalize wRec peak to prevent input signal overwhelming neuron + normalizePeak!(n.wRec, 2) + n.wRec .*= nonFlipedSign # set weight that fliped sign to 0 for random new connection - synapticConnStrength!(n) - #TODO neuroplasticity - println("") + synapticConnStrength!(n) + #TODO neuroplasticity + println("") + end end for n in kfn.outputNeuronsArray # merge wRecChange into wRec diff --git a/src/snn_utils.jl b/src/snn_utils.jl index 1d3dc98..ee1eab5 100644 --- a/src/snn_utils.jl +++ b/src/snn_utils.jl @@ -123,7 +123,7 @@ function resetLearningParams!(n::alifNeuron) # reset refractory state at the start/end of episode. Otherwise once neuron goes into # refractory state, it will stay in refractory state forever - reset_refractoryCounter!(n) + # reset_refractoryCounter!(n) end # function reset_learning_no_wchange!(n::passthroughNeuron) @@ -136,12 +136,12 @@ end function resetLearningParams!(n::linearNeuron) reset_epsilonRec!(n) reset_wRecChange!(n) - reset_v_t!(n) + # reset_v_t!(n) reset_firing_counter!(n) # reset refractory state at the start/end of episode. Otherwise once neuron goes into # refractory state, it will stay in refractory state forever - reset_refractoryCounter!(n) + # reset_refractoryCounter!(n) end #------------------------------------------------------------------------------------------------100 @@ -279,7 +279,8 @@ end function synapticConnStrength!(n::Union{computeNeuron, outputNeuron}) for (i, connStrength) in enumerate(n.synapticStrength) # check whether connStrength increase or decrease based on usage from n.epsilonRec - updown = n.epsilonRec[i] == 0.0 ? "down" : "up" + #WORKING n.epsilonRec is all 0.0 why? may b it was reset? + updown = n.epsilonRec[i] == 0.0 ? "down" : "up" updatedConnStrength = synapticConnStrength(connStrength, updown) updatedConnStrength = GeneralUtils.limitvalue(updatedConnStrength, n.synapticStrengthLimit.lowerlimit, n.synapticStrengthLimit.upperlimit) @@ -321,7 +322,7 @@ end within its radius. radius must be odd number """ function normalizePeak!(v::Vector, radius::Integer=2) - peak = findall(isequal.(v, maximum(abs.(v))))[1] + peak = findall(isequal.(abs.(v), maximum(abs.(v))))[1] upindex = peak - radius upindex = upindex < 1 ? 1 : upindex downindex = peak + radius diff --git a/src/types.jl b/src/types.jl index 5e7621d..eef8420 100644 --- a/src/types.jl +++ b/src/types.jl @@ -512,6 +512,7 @@ Base.@kwdef mutable struct linearNeuron <: outputNeuron v_t1::Float64 = rand() # vᵗ⁺¹, postsynaptic neuron membrane potential at current timestep v_th::Float64 = 1.0 # vᵗʰ, neuron firing threshold vRest::Float64 = 0.0 # resting potential after neuron fired + vError::Union{Float64,Nothing} = nothing # used to compute model error z_t::Bool = false # zᵗ, neuron postsynaptic firing of previous timestep # zᵗ⁺¹, neuron firing status at time = t+1. I need this because the way I calculate all # neurons forward function at each timestep-by-timestep is to do every neuron @@ -539,7 +540,6 @@ Base.@kwdef mutable struct linearNeuron <: outputNeuron wRecChange::Union{Array{Float64},Nothing} = nothing # Δw_rec, cumulated wRec change recSignal::Union{Float64,Nothing} = nothing # incoming recurrent signal alpha_v_t::Union{Float64,Nothing} = nothing # alpha * v_t - error::Union{Float64,Nothing} = nothing # local neuron error firingCounter::Integer = 0 # store how many times neuron fires end @@ -627,7 +627,7 @@ function init_neuron!(id::Int64, n::lifNeuron, n_params::Dict, kfnParams::Dict) n.synapticStrength = rand(-5:0.1:-3, length(n.subscriptionList)) n.epsilonRec = zeros(length(n.subscriptionList)) - n.wRec = rand(length(n.subscriptionList)) + n.wRec = rand(-0.2:0.01:0.2, length(n.subscriptionList)) n.wRecChange = zeros(length(n.subscriptionList)) n.alpha = calculate_α(n) end @@ -646,7 +646,7 @@ function init_neuron!(id::Int64, n::alifNeuron, n_params::Dict, n.synapticStrength = rand(-5:0.1:-3, length(n.subscriptionList)) n.epsilonRec = zeros(length(n.subscriptionList)) - n.wRec = rand(length(n.subscriptionList)) + n.wRec = rand(-0.2:0.01:0.2, length(n.subscriptionList)) n.wRecChange = zeros(length(n.subscriptionList)) # the more time has passed from the last time neuron was activated, the more @@ -668,7 +668,7 @@ function init_neuron!(id::Int64, n::linearNeuron, n_params::Dict, kfnParams::Dic n.synapticStrength = rand(-5:0.1:-3, length(n.subscriptionList)) n.epsilonRec = zeros(length(n.subscriptionList)) - n.wRec = rand(length(n.subscriptionList)) + n.wRec = rand(-0.2:0.01:0.2, length(n.subscriptionList)) n.wRecChange = zeros(length(n.subscriptionList)) n.alpha = calculate_k(n) end