diff --git a/src/Ironpen.jl b/src/Ironpen.jl index f229842..a4e532e 100644 --- a/src/Ironpen.jl +++ b/src/Ironpen.jl @@ -34,7 +34,8 @@ using .learn """ version 0.0.6 Todo: - [1] use abs(wRec) suring neuron init + [*1] if neuron not fire for a long time, reduce it conn strength + [DONE] use abs(wRec) during neuron init [2] implement dormant connection and pruning machanism. the longer the training the longer 0 weight stay 0. [] using RL to control learning signal diff --git a/src/forward.jl b/src/forward.jl index b7bf914..631767d 100644 --- a/src/forward.jl +++ b/src/forward.jl @@ -60,7 +60,7 @@ function (kfn::kfn_1)(m::model, input_data::AbstractVector) end # generate noise - noise = [GeneralUtils.randomChoiceWithProb([true, false],[0.0, 1.0]) + noise = [GeneralUtils.randomChoiceWithProb([true, false],[0.2, 0.8]) for i in 1:length(input_data)] # noise = [rand(rng, Distributions.Binomial(1, 0.5)) for i in 1:10] # another option @@ -95,7 +95,7 @@ function (kfn::kfn_1)(m::model, input_data::AbstractVector) return sum(kfn.firedNeurons_t1[kfn.kfnParams[:totalInputPort]+1:end])::Int, logit::Array{Float64}, - [n.v_t1 for n in kfn.outputNeuronsArray], + [i for i in kfn.neuronsArray[end].wRec[1:10]], [sum(i.wRec) for i in kfn.outputNeuronsArray], [sum(i.epsilonRec) for i in kfn.outputNeuronsArray], [sum(i.wRecChange) for i in kfn.outputNeuronsArray] diff --git a/src/learn.jl b/src/learn.jl index 022415f..5971189 100644 --- a/src/learn.jl +++ b/src/learn.jl @@ -71,7 +71,7 @@ function compute_wRecChange!(n::lifNeuron, wOut::AbstractVector, modelError::Flo # ΔwRecChange .+= (0.2*(abs(sum(n.wRec)) / length(n.wRec))) # end n.wRecChange .+= ΔwRecChange - # reset_epsilonRec!(n) + reset_epsilonRec!(n) end function compute_wRecChange!(n::alifNeuron, wOut::AbstractVector, modelError::Float64) @@ -88,8 +88,8 @@ function compute_wRecChange!(n::alifNeuron, wOut::AbstractVector, modelError::Fl # end n.wRecChange .+= ΔwRecChange - # reset_epsilonRec!(n) - # reset_epsilonRecA!(n) + reset_epsilonRec!(n) + reset_epsilonRecA!(n) # n.alphaChange += compute_alphaChange(n.eta, nError) end @@ -101,7 +101,7 @@ function compute_wRecChange!(n::integrateNeuron, error::Float64) # end n.wRecChange .+= ΔwRecChange n.bChange += ΔbChange - # reset_epsilonRec!(n) + reset_epsilonRec!(n) # n.alphaChange += compute_alphaChange(n.eta, error) end @@ -165,8 +165,21 @@ function learn!(n::T, firedNeurons, nExInType) where T<:computeNeuron # end # set weight that fliped sign to 0 for random new connection n.wRec .*= nonFlipedSign - # capMaxWeight!(n.wRec) # cap maximum weight - synapticConnStrength!(n, "updown") + # n.wRec = wRecMaxWeight!(n, max=1.0) # cap maximum weight + + # check for non firing. if neuron not fire for too long, reduce all connection strength + if n.id ∈ firedNeurons + n.notFireCounter = n.notFireTimeOut + synapticConnStrength!(n, "updown") + elseif n.id ∉ firedNeurons && n.notFireCounter != n.notFireTimeOut + n.notFireTimeOut += 1 + synapticConnStrength!(n, "updown") + elseif n.id ∉ firedNeurons && n.notFireCounter == n.notFireCounter + synapticConnStrength!(n, "down") + else + error("undefined condition line $(@__LINE__)") + end + neuroplasticity!(n, firedNeurons, nExInType) end diff --git a/src/snn_utils.jl b/src/snn_utils.jl index 24c1377..c7e8898 100644 --- a/src/snn_utils.jl +++ b/src/snn_utils.jl @@ -6,7 +6,7 @@ export calculate_α, calculate_ρ, calculate_k, timestep_forward!, init_neuron, reset_epsilonRecA!, synapticConnStrength!, normalizePeak!, reset_wRecChange!, firing_rate_error!, firing_rate_regulator!, update_Bn!, cal_firing_reg!, neuroplasticity!, shakeup!, reset_learning_no_wchange!, adjust_internal_learning_rate!, - gradient_withloss, capMaxWeight!, connStrengthAdjust + gradient_withloss, capMaxWeight, connStrengthAdjust, wRecMaxWeight! using Statistics, Random, LinearAlgebra, Distributions, Zygote, Flux using GeneralUtils @@ -404,9 +404,14 @@ end """ Cap maximum weight of each neuron connection """ -function capMaxWeight!(v::Vector{Float64}, max=1.0) +function capMaxWeight(v::Vector{Float64}, max=1.0) originalSign = sign.(v) - v = originalSign .* GeneralUtils.replaceMoreThan.(abs.(v), max) + return originalSign .* GeneralUtils.replaceMoreThan.(abs.(v), max) +end + +function wRecMaxWeight!(n::computeNeuron; max=1.0) + originalSign = sign.(n.wRec) + n.wRec = originalSign .* GeneralUtils.replaceMoreThan.(abs.(n.wRec), max) end @@ -425,11 +430,6 @@ end - - - - - diff --git a/src/types.jl b/src/types.jl index 51a25cc..364252c 100644 --- a/src/types.jl +++ b/src/types.jl @@ -253,7 +253,7 @@ function kfn_1(kfnParams::Dict) try # input neuron doest have n.subscriptionList for (i, sub_id) in enumerate(n.subscriptionList) n_ExInType = kfn.neuronsArray[sub_id].ExInType - n.wRec[i] *= n_ExInType + n.wRec[i] = abs(n.wRec[i]) * n_ExInType # add id exin type to kfn if n_ExInType < 0 push!(kfn.nInhabitory, sub_id) @@ -364,6 +364,9 @@ Base.@kwdef mutable struct lifNeuron <: computeNeuron firingRateError::Float64 = 0.0 # local neuron error w.r.t. firing regularization firingRate::Float64 = 0.0 # running average of firing rate in Hz + notFireTimeOut::Int64 = 100 # consecutive count of not firing. Should be the same as batch size + notFireCounter::Int64 = 0 + """ "inference" = no learning params will be collected. "learning" = neuron will accumulate epsilon_j, compute Δw_rec_change each time correct answer is available then merge Δw_rec_change into wRecChange then @@ -458,6 +461,9 @@ Base.@kwdef mutable struct alifNeuron <: computeNeuron firingRateError::Float64 = 0.0 # local neuron error w.r.t. firing regularization firingRate::Float64 = 0.0 # running average of firing rate, Hz + notFireTimeOut::Int64 = 100 # consecutive count of not firing. Should be the same as batch size + notFireCounter::Int64 = 0 + tau_a::Float64 = 100.0 # τ_a, adaption time constant in millisecond beta::Float64 = 0.15 # β, constant, value from paper rho::Float64 = 0.0 # ρ, threshold adaptation decay factor @@ -744,7 +750,7 @@ function init_neuron!(id::Int64, n::alifNeuron, n_params::Dict, n.synapticStrength = rand(-4.5:0.01:-4, length(n.subscriptionList)) n.epsilonRec = zeros(length(n.subscriptionList)) - n.wRec = randn(rng, length(n.subscriptionList)) / 100 # TODO use abs() + n.wRec = randn(rng, length(n.subscriptionList)) / 100 n.wRecChange = zeros(length(n.subscriptionList)) # the more time has passed from the last time neuron was activated, the more