minor fix
This commit is contained in:
@@ -34,8 +34,10 @@ using .learn
|
||||
|
||||
""" version 0.0.6
|
||||
Todo:
|
||||
[*1] if neuron not fire for a long time, reduce it conn strength
|
||||
[DONE] use abs(wRec) during neuron init
|
||||
[] use partial error update for computeNeuron
|
||||
[] use integrate_neuron_params synapticConnectionPercent = 20%
|
||||
[] add liquid time constant
|
||||
[DONE] if neuron not fire for a long time, reduce it conn strength
|
||||
[2] implement dormant connection and pruning machanism. the longer the training the longer
|
||||
0 weight stay 0.
|
||||
[] using RL to control learning signal
|
||||
|
||||
@@ -60,9 +60,11 @@ function (kfn::kfn_1)(m::model, input_data::AbstractVector)
|
||||
end
|
||||
|
||||
# generate noise
|
||||
noise = [GeneralUtils.randomChoiceWithProb([true, false],[0.2, 0.8])
|
||||
noise = [GeneralUtils.randomChoiceWithProb([true, false],[0.01, 0.99])
|
||||
for i in 1:length(input_data)]
|
||||
# noise = [rand(rng, Distributions.Binomial(1, 0.5)) for i in 1:10] # another option
|
||||
# noise = [kfn.timeStep % 50 == 0
|
||||
# for i in 1:length(input_data)]
|
||||
|
||||
input_data = [noise; input_data] # noise must start from neuron id 1
|
||||
|
||||
@@ -95,8 +97,8 @@ function (kfn::kfn_1)(m::model, input_data::AbstractVector)
|
||||
|
||||
return sum(kfn.firedNeurons_t1[kfn.kfnParams[:totalInputPort]+1:end])::Int,
|
||||
logit::Array{Float64},
|
||||
[i for i in kfn.neuronsArray[end].wRec[1:10]],
|
||||
[sum(i.wRec) for i in kfn.outputNeuronsArray],
|
||||
[i for i in kfn.neuronsArray[101].wRec[1:10]],
|
||||
[i.v_t1 for i in kfn.neuronsArray[101:110]],
|
||||
[sum(i.epsilonRec) for i in kfn.outputNeuronsArray],
|
||||
[sum(i.wRecChange) for i in kfn.outputNeuronsArray]
|
||||
end
|
||||
@@ -136,6 +138,8 @@ function (n::lifNeuron)(kfn::knowledgeFn)
|
||||
n.epsilonRec = n.decayedEpsilonRec
|
||||
else
|
||||
n.recSignal = sum(n.wRec .* n.z_i_t) # signal from other neuron that this neuron subscribed
|
||||
|
||||
# computeAlpha!(n)
|
||||
n.alpha_v_t = n.alpha * n.v_t
|
||||
n.v_t1 = n.alpha_v_t + n.recSignal
|
||||
# n.v_t1 = no_negative!(n.v_t1)
|
||||
@@ -152,7 +156,7 @@ function (n::lifNeuron)(kfn::knowledgeFn)
|
||||
# there is a difference from alif formula
|
||||
n.phi = (n.gammaPd / n.v_th) * max(0, 1 - (n.v_t1 - n.v_th) / n.v_th)
|
||||
n.decayedEpsilonRec = n.alpha * n.epsilonRec
|
||||
n.epsilonRec = n.decayedEpsilonRec + n.z_i_t
|
||||
n.epsilonRec = n.decayedEpsilonRec + n.z_i_t
|
||||
end
|
||||
end
|
||||
|
||||
@@ -183,6 +187,7 @@ function (n::alifNeuron)(kfn::knowledgeFn)
|
||||
else
|
||||
n.av_th = n.v_th + (n.beta * n.a)
|
||||
n.recSignal = sum(n.wRec .* n.z_i_t) # signal from other neuron that this neuron subscribed
|
||||
# computeAlpha!(n)
|
||||
n.alpha_v_t = n.alpha * n.v_t
|
||||
n.v_t1 = n.alpha_v_t + n.recSignal
|
||||
# n.v_t1 = no_negative!(n.v_t1)
|
||||
@@ -215,7 +220,7 @@ function (n::linearNeuron)(kfn::T) where T<:knowledgeFn
|
||||
n.timeStep = kfn.timeStep
|
||||
|
||||
# pulling other neuron's firing status at time t
|
||||
n.z_i_t = getindex(kfn.firedNeurons_t1, n.subscriptionList)
|
||||
n.z_i_t = getindex(kfn.firedNeurons_t0, n.subscriptionList)
|
||||
n.z_i_t_commulative += n.z_i_t
|
||||
|
||||
if n.refractoryCounter != 0
|
||||
@@ -228,18 +233,18 @@ function (n::linearNeuron)(kfn::T) where T<:knowledgeFn
|
||||
|
||||
# decay of v_t1
|
||||
n.v_t1 = n.alpha * n.v_t
|
||||
n.vError = n.v_t1 # store voltage that will be used to calculate error later
|
||||
|
||||
n.phi = 0.0
|
||||
n.decayedEpsilonRec = n.alpha * n.epsilonRec
|
||||
n.epsilonRec = n.decayedEpsilonRec
|
||||
n.epsilonRec = n.decayedEpsilonRec
|
||||
else
|
||||
recSignal = n.wRec .* n.z_i_t
|
||||
n.recSignal = sum(recSignal) # signal from other neuron that this neuron subscribed
|
||||
n.recSignal = sum(n.wRec .* n.z_i_t) # signal from other neuron that this neuron subscribed
|
||||
|
||||
# computeAlpha!(n)
|
||||
n.alpha_v_t = n.alpha * n.v_t
|
||||
n.v_t1 = n.alpha_v_t + n.recSignal
|
||||
# n.v_t1 = no_negative!(n.v_t1)
|
||||
n.vError = n.v_t1 # store voltage that will be used to calculate error later
|
||||
|
||||
if n.v_t1 > n.v_th
|
||||
n.z_t1 = true
|
||||
n.refractoryCounter = n.refractoryDuration
|
||||
@@ -250,10 +255,10 @@ function (n::linearNeuron)(kfn::T) where T<:knowledgeFn
|
||||
end
|
||||
|
||||
# there is a difference from alif formula
|
||||
n.phi = (n.gammaPd / n.v_th) * max(0, 1 - (n.v_t1 - n.v_th) / n.v_th)
|
||||
n.phi = (n.gammaPd / n.v_th) * max(0, 1 - (n.v_t1 - n.v_th) / n.v_th)
|
||||
n.decayedEpsilonRec = n.alpha * n.epsilonRec
|
||||
n.epsilonRec = n.decayedEpsilonRec + n.z_i_t
|
||||
end
|
||||
n.epsilonRec = n.decayedEpsilonRec + n.z_i_t
|
||||
end
|
||||
end
|
||||
|
||||
#------------------------------------------------------------------------------------------------100
|
||||
@@ -267,7 +272,8 @@ function (n::integrateNeuron)(kfn::knowledgeFn)
|
||||
n.z_i_t = getindex(kfn.firedNeurons_t0, n.subscriptionList)
|
||||
n.z_i_t_commulative += n.z_i_t
|
||||
|
||||
n.recSignal = sum(n.wRec .* n.z_i_t) # signal from other neuron that this neuron subscribed
|
||||
n.recSignal = sum(n.wRec .* n.z_i_t) # signal from other neuron that this neuron
|
||||
# computeAlpha!(n)
|
||||
n.alpha_v_t = n.alpha * n.v_t
|
||||
if n.recSignal <= 0
|
||||
n.v_t1 = n.alpha_v_t
|
||||
|
||||
214
src/learn.jl
214
src/learn.jl
@@ -4,13 +4,13 @@ using Statistics, Random, LinearAlgebra, JSON3, Flux
|
||||
using GeneralUtils
|
||||
using ..types, ..snn_utils
|
||||
|
||||
export learn!, compute_wRecChange!, computeModelError
|
||||
export learn!, compute_paramsChange!, computeModelError
|
||||
|
||||
#------------------------------------------------------------------------------------------------100
|
||||
|
||||
function compute_wRecChange!(m::model, modelError::Float64, outputError::Vector{Float64})
|
||||
function compute_paramsChange!(m::model, modelError::Float64, outputError::Vector{Float64})
|
||||
# normalize!(modelError)
|
||||
compute_wRecChange!(m.knowledgeFn[:I], modelError, outputError)
|
||||
compute_paramsChange!(m.knowledgeFn[:I], modelError, outputError)
|
||||
end
|
||||
|
||||
# function compute_wRecChange!(kfn::kfn_1, errors::Vector{Float64}, correctAnswer::AbstractVector)
|
||||
@@ -40,20 +40,54 @@ end
|
||||
# end
|
||||
# end
|
||||
|
||||
# function compute_paramsChange!(kfn::kfn_1, modelError::Float64, outputError::Vector{Float64})
|
||||
|
||||
# Threads.@threads for n in kfn.neuronsArray
|
||||
# # for n in kfn.neuronsArray
|
||||
# if typeof(n) <: computeNeuron
|
||||
# # wIndex = findall(isequal.(oN.subscriptionList, n.id)) # use for error projection
|
||||
# wOut = [oN.wRec[findall(isequal.(oN.subscriptionList, n.id))[1]]
|
||||
# for oN in kfn.outputNeuronsArray]
|
||||
|
||||
# compute_wRecChange!(n, wOut, modelError)
|
||||
# # compute_alphaChange!(n, modelError)
|
||||
# compute_firingRateError!(n, kfn.kfnParams[:neuronFiringRateTarget],
|
||||
# kfn.kfnParams[:totalComputeNeuron])
|
||||
# end
|
||||
# end
|
||||
|
||||
# for oN in kfn.outputNeuronsArray
|
||||
# compute_wRecChange!(oN, outputError[oN.id])
|
||||
# # compute_alphaChaZnge!(oN, outputError[oN.id])
|
||||
# end
|
||||
# end
|
||||
|
||||
function compute_paramsChange!(kfn::kfn_1, modelError::Float64, outputError::Vector{Float64})
|
||||
|
||||
function compute_wRecChange!(kfn::kfn_1, modelError::Float64, outputError::Vector{Float64})
|
||||
Threads.@threads for n in kfn.neuronsArray
|
||||
# for n in kfn.neuronsArray
|
||||
if typeof(n)<: computeNeuron
|
||||
# wIndex = findall(isequal.(oN.subscriptionList, n.id))
|
||||
wOut = abs.([oN.wRec[findall(isequal.(oN.subscriptionList, n.id))[1]]
|
||||
for oN in kfn.outputNeuronsArray])
|
||||
compute_wRecChange!(n, wOut, modelError)
|
||||
if typeof(n) <: computeNeuron
|
||||
#WORKING
|
||||
wOut = Int64[]
|
||||
for oN in kfn.outputNeuronsArray
|
||||
wIndex = findall(isequal.(oN.subscriptionList, n.id))
|
||||
if length(wIndex) != 0
|
||||
push!(wOut, wIndex[1])
|
||||
end
|
||||
end
|
||||
|
||||
if length(wOut) != 0
|
||||
compute_wRecChange!(n, wOut, modelError)
|
||||
# compute_alphaChange!(n, modelError)
|
||||
compute_firingRateError!(n, kfn.kfnParams[:neuronFiringRateTarget],
|
||||
kfn.kfnParams[:totalComputeNeuron])
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
for oN in kfn.outputNeuronsArray
|
||||
compute_wRecChange!(oN, outputError[oN.id])
|
||||
# compute_alphaChaZnge!(oN, outputError[oN.id])
|
||||
end
|
||||
end
|
||||
|
||||
@@ -61,7 +95,7 @@ function compute_wRecChange!(n::passthroughNeuron, wOut::AbstractVector, modelEr
|
||||
# skip
|
||||
end
|
||||
|
||||
function compute_wRecChange!(n::lifNeuron, wOut::AbstractVector, modelError::Float64)
|
||||
function compute_wRecChange!(n::lifNeuron, wOut::AbstractVector, modelError::Float64, )
|
||||
# how much error of this neuron 1-spike causing each output neuron's error
|
||||
nError = sum(wOut * modelError)
|
||||
|
||||
@@ -71,7 +105,8 @@ function compute_wRecChange!(n::lifNeuron, wOut::AbstractVector, modelError::Flo
|
||||
# ΔwRecChange .+= (0.2*(abs(sum(n.wRec)) / length(n.wRec)))
|
||||
# end
|
||||
n.wRecChange .+= ΔwRecChange
|
||||
reset_epsilonRec!(n)
|
||||
|
||||
# n.alphaChange += compute_alphaChange(n.eta, nError)
|
||||
end
|
||||
|
||||
function compute_wRecChange!(n::alifNeuron, wOut::AbstractVector, modelError::Float64)
|
||||
@@ -88,11 +123,18 @@ function compute_wRecChange!(n::alifNeuron, wOut::AbstractVector, modelError::Fl
|
||||
# end
|
||||
n.wRecChange .+= ΔwRecChange
|
||||
|
||||
reset_epsilonRec!(n)
|
||||
reset_epsilonRecA!(n)
|
||||
# n.alphaChange += compute_alphaChange(n.eta, nError)
|
||||
end
|
||||
|
||||
function compute_wRecChange!(n::linearNeuron, error::Float64)
|
||||
n.eRec = n.phi * n.epsilonRec
|
||||
ΔwRecChange = -n.eta * error * n.eRec
|
||||
# if sum(n.wRec) < 0 # prevent -sum(wRec) that causing neuron NOT fire at all
|
||||
# ΔwRecChange .+= (abs(sum(n.wRec)) / length(n.wRec))
|
||||
# end
|
||||
n.wRecChange .+= ΔwRecChange
|
||||
end
|
||||
|
||||
function compute_wRecChange!(n::integrateNeuron, error::Float64)
|
||||
ΔwRecChange = -n.eta * error * n.epsilonRec
|
||||
ΔbChange = -n.eta * error
|
||||
@@ -101,22 +143,19 @@ function compute_wRecChange!(n::integrateNeuron, error::Float64)
|
||||
# end
|
||||
n.wRecChange .+= ΔwRecChange
|
||||
n.bChange += ΔbChange
|
||||
reset_epsilonRec!(n)
|
||||
|
||||
# n.alphaChange += compute_alphaChange(n.eta, error)
|
||||
end
|
||||
|
||||
# function compute_wRecChange!(n::linearNeuron, error::Float64)
|
||||
# n.eRec = n.phi * n.epsilonRec
|
||||
# ΔwRecChange = -n.eta * error * n.eRec
|
||||
# # if sum(n.wRec) < 0 # prevent -sum(wRec) that causing neuron NOT fire at all
|
||||
# # ΔwRecChange .+= (abs(sum(n.wRec)) / length(n.wRec))
|
||||
# # end
|
||||
# n.wRecChange .+= ΔwRecChange
|
||||
# # reset_epsilonRec!(n)
|
||||
# end
|
||||
|
||||
# add compute_alphaChange
|
||||
compute_alphaChange(learningRate::Float64, total_wRecChange) = -learningRate * total_wRecChange
|
||||
function compute_firingRateError!(n::computeNeuron, firingRateTarget, totalComputeNeuron)
|
||||
# compute frequency error --> 1-timeStep of kfn runs fires X neurons
|
||||
# (frequency from kfn perspective)
|
||||
n.firingRateTarget = n.timeStep * firingRateTarget / totalComputeNeuron
|
||||
n.firingRate = n.firingCounter / n.timeStep
|
||||
error = n.firingRate - n.firingRateTarget
|
||||
ΔwRecChange = -n.eta * 0.1 * sign(error) * error^2
|
||||
n.wRecChange .+= ΔwRecChange
|
||||
end
|
||||
|
||||
function learn!(m::model)
|
||||
learn!(m.knowledgeFn[:I])
|
||||
@@ -125,10 +164,13 @@ end
|
||||
""" knowledgeFn learn()
|
||||
"""
|
||||
function learn!(kfn::kfn_1)
|
||||
|
||||
# compute kfn error for each neuron
|
||||
Threads.@threads for n in kfn.neuronsArray # multithread is not atomic and causing error
|
||||
# for n in kfn.neuronsArray
|
||||
learn!(n, kfn.firedNeurons, kfn.nExInType)
|
||||
# for n in kfn.neuronsArray totalNeuronFired
|
||||
if typeof(n) <: computeNeuron
|
||||
learn!(n, kfn.firedNeurons, kfn.nExInType)
|
||||
end
|
||||
end
|
||||
for n in kfn.outputNeuronsArray
|
||||
learn!(n, kfn.firedNeurons, kfn.nExInType, kfn.kfnParams[:totalInputPort])
|
||||
@@ -145,15 +187,11 @@ function learn!(n::T, firedNeurons, nExInType) where T<:inputNeuron
|
||||
end
|
||||
|
||||
function learn!(n::T, firedNeurons, nExInType) where T<:computeNeuron
|
||||
|
||||
wSign_0 = sign.(n.wRec) # original sign
|
||||
# n.wRecChange .*= (connStrengthAdjust.(n.synapticStrength))
|
||||
|
||||
wRecChange_reduceCoeff = 1.0
|
||||
# wRecChange_max = 0.2 * abs(sum(n.wRec)) # max change 20%
|
||||
# y = abs(sum(n.wRecChange))
|
||||
# if y > wRecChange_max # capping weight update
|
||||
# wRecChange_reduceCoeff = wRecChange_max / y
|
||||
# end
|
||||
n.wRec += (wRecChange_reduceCoeff * n.wRecChange)
|
||||
# n.alpha += n.alphaChange
|
||||
|
||||
@@ -163,62 +201,112 @@ function learn!(n::T, firedNeurons, nExInType) where T<:computeNeuron
|
||||
# if sum(n.wRecChange) != 0
|
||||
# normalizePeak!(n.wRec, n.wRecChange, 2)
|
||||
# end
|
||||
|
||||
# set weight that fliped sign to 0 for random new connection
|
||||
n.wRec .*= nonFlipedSign
|
||||
# n.wRec = wRecMaxWeight!(n, max=1.0) # cap maximum weight
|
||||
n.wRec = wRecMaxWeight!(n, max=1.0) # cap maximum weight
|
||||
|
||||
# learn alpha
|
||||
# n.alpha_wSignal += n.alpha_wSignalChange
|
||||
# n.alpha_wPotential += n.alpha_wPotentialChange
|
||||
# n.alpha_b += n.alpha_bChange
|
||||
# n.alpha_wSignalChange *= 0.0
|
||||
# n.alpha_wPotentialChange *= 0.0
|
||||
# n.alpha_bChange *= 0.0
|
||||
# computeAlpha!(n)
|
||||
|
||||
# check for non firing. if neuron not fire for too long, reduce all connection strength
|
||||
if n.id ∈ firedNeurons
|
||||
n.notFireCounter = n.notFireTimeOut
|
||||
synapticConnStrength!(n, "updown")
|
||||
n.notFireTimeOut = 0
|
||||
elseif n.id ∉ firedNeurons && n.notFireCounter != n.notFireTimeOut
|
||||
n.notFireTimeOut += 1
|
||||
synapticConnStrength!(n, "updown")
|
||||
elseif n.id ∉ firedNeurons && n.notFireCounter == n.notFireCounter
|
||||
elseif n.id ∉ firedNeurons && n.notFireCounter == n.notFireTimeOut
|
||||
synapticConnStrength!(n, "down")
|
||||
else
|
||||
error("undefined condition line $(@__LINE__)")
|
||||
end
|
||||
|
||||
synapticConnStrength!(n, "updown")
|
||||
neuroplasticity!(n, firedNeurons, nExInType)
|
||||
end
|
||||
|
||||
function learn!(n::integrateNeuron, firedNeurons, nExInType, totalInputPort)
|
||||
function learn!(n::linearNeuron, firedNeurons, nExInType, totalInputPort)
|
||||
wSign_0 = sign.(n.wRec) # original sign
|
||||
# n.wRecChange .*= (connStrengthAdjust.(n.synapticStrength))
|
||||
wRecChange_max = 0.1 * abs(sum(n.wRec)) # max change 20%
|
||||
y = abs(sum(n.wRecChange))
|
||||
wRecChange_reduceCoeff = 1.0
|
||||
# wRecChange_max = 0.2 * abs(sum(n.wRec)) # max change 20%
|
||||
# y = abs(sum(n.wRecChange))
|
||||
# if y > wRecChange_max # capping weight update
|
||||
# wRecChange_reduceCoeff = wRecChange_max / y
|
||||
# end
|
||||
n.wRec += (wRecChange_reduceCoeff * n.wRecChange)
|
||||
n.b += (wRecChange_reduceCoeff * n.bChange)
|
||||
# n.alpha += n.alphaChange
|
||||
n.alpha += n.alphaChange
|
||||
|
||||
wSign_1 = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign
|
||||
nonFlipedSign = isequal.(wSign_0, wSign_1) # 1 not fliped, 0 fliped
|
||||
# # normalize wRec peak to prevent input signal overwhelming neuron
|
||||
# if sum(n.wRecChange) != 0
|
||||
# # normalizePeak!(n.wRec, n.wRecChange, 2)
|
||||
# end
|
||||
|
||||
# set weight that fliped sign to 0 for random new connection
|
||||
# n.wRec .*= nonFlipedSign
|
||||
|
||||
# check for non firing. if neuron not fire for too long, reduce all connection strength
|
||||
if n.id ∈ firedNeurons
|
||||
n.notFireCounter = n.notFireTimeOut
|
||||
synapticConnStrength!(n, "updown")
|
||||
n.notFireTimeOut = 0
|
||||
elseif n.id ∉ firedNeurons && n.notFireCounter != n.notFireTimeOut
|
||||
n.notFireTimeOut += 1
|
||||
synapticConnStrength!(n, "updown")
|
||||
elseif n.id ∉ firedNeurons && n.notFireCounter == n.notFireTimeOut
|
||||
synapticConnStrength!(n, "down")
|
||||
else
|
||||
error("undefined condition line $(@__LINE__)")
|
||||
end
|
||||
|
||||
synapticConnStrength!(n, "updown")
|
||||
neuroplasticity!(n,firedNeurons, nExInType, totalInputPort)
|
||||
end
|
||||
|
||||
function learn!(n::integrateNeuron, firedNeurons, nExInType, totalInputPort)
|
||||
wRecChange_reduceCoeff = 1.0
|
||||
n.wRec += (wRecChange_reduceCoeff * n.wRecChange)
|
||||
n.b += (wRecChange_reduceCoeff * n.bChange)
|
||||
# n.alpha += n.alphaChange
|
||||
|
||||
# learn alpha
|
||||
# n.alpha_wSignal += n.alpha_wSignalChange
|
||||
# n.alpha_wPotential += n.alpha_wPotentialChange
|
||||
# n.alpha_b += n.alpha_bChange
|
||||
# n.alpha_wSignalChange *= 0.0
|
||||
# n.alpha_wPotentialChange *= 0.0
|
||||
# n.alpha_bChange *= 0.0
|
||||
# computeAlpha!(n)
|
||||
|
||||
# # check for non firing. if neuron not fire for too long, reduce all connection strength
|
||||
# if n.id ∈ firedNeurons
|
||||
# n.notFireCounter = n.notFireTimeOut
|
||||
# synapticConnStrength!(n, "updown")
|
||||
# n.notFireTimeOut = 0
|
||||
# elseif n.id ∉ firedNeurons && n.notFireCounter != n.notFireTimeOut
|
||||
# n.notFireTimeOut += 1
|
||||
# synapticConnStrength!(n, "updown")
|
||||
# elseif n.id ∉ firedNeurons && n.notFireCounter == n.notFireTimeOut
|
||||
# synapticConnStrength!(n, "down")
|
||||
# else
|
||||
# error("undefined condition line $(@__LINE__)")
|
||||
# end
|
||||
|
||||
# synapticConnStrength!(n, "updown")
|
||||
# neuroplasticity!(n,firedNeurons, nExInType, totalInputPort)
|
||||
end
|
||||
|
||||
# function learn!(n::linearNeuron, firedNeurons, nExInType, totalInputPort)
|
||||
# wSign_0 = sign.(n.wRec) # original sign
|
||||
# # n.wRecChange .*= (connStrengthAdjust.(n.synapticStrength))
|
||||
# wRecChange_max = 0.1 * abs(sum(n.wRec)) # max change 20%
|
||||
# y = abs(sum(n.wRecChange))
|
||||
# wRecChange_reduceCoeff = 1.0
|
||||
# # if y > wRecChange_max # capping weight update
|
||||
# # wRecChange_reduceCoeff = wRecChange_max / y
|
||||
# # end
|
||||
# n.wRec += (wRecChange_reduceCoeff * n.wRecChange)
|
||||
# n.alpha += n.alphaChange
|
||||
|
||||
# wSign_1 = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign
|
||||
# nonFlipedSign = isequal.(wSign_0, wSign_1) # 1 not fliped, 0 fliped
|
||||
# # normalize wRec peak to prevent input signal overwhelming neuron
|
||||
# if sum(n.wRecChange) != 0
|
||||
# # normalizePeak!(n.wRec, n.wRecChange, 2)
|
||||
# end
|
||||
# # set weight that fliped sign to 0 for random new connection
|
||||
# # n.wRec .*= nonFlipedSign
|
||||
# # capMaxWeight!(n.wRec) # cap maximum weight
|
||||
# # synapticConnStrength!(n, "updown")
|
||||
# # neuroplasticity!(n,firedNeurons, nExInType, totalInputPort)
|
||||
# end
|
||||
|
||||
|
||||
|
||||
|
||||
145
src/snn_utils.jl
145
src/snn_utils.jl
@@ -6,7 +6,8 @@ export calculate_α, calculate_ρ, calculate_k, timestep_forward!, init_neuron,
|
||||
reset_epsilonRecA!, synapticConnStrength!, normalizePeak!, reset_wRecChange!,
|
||||
firing_rate_error!, firing_rate_regulator!, update_Bn!, cal_firing_reg!,
|
||||
neuroplasticity!, shakeup!, reset_learning_no_wchange!, adjust_internal_learning_rate!,
|
||||
gradient_withloss, capMaxWeight, connStrengthAdjust, wRecMaxWeight!
|
||||
gradient_withloss, capMaxWeight, connStrengthAdjust, wRecMaxWeight!,
|
||||
computeAlpha!, compute_alphaChange!
|
||||
|
||||
using Statistics, Random, LinearAlgebra, Distributions, Zygote, Flux
|
||||
using GeneralUtils
|
||||
@@ -48,7 +49,6 @@ reset_firing_counter!(n::Union{computeNeuron, outputNeuron}) = n.firingCounter =
|
||||
reset_firing_diff!(n::Union{computeNeuron, outputNeuron}) = n.firingDiff = n.firingDiff * 0.0
|
||||
reset_refractoryCounter!(n::Union{computeNeuron, outputNeuron}) = n.refractoryCounter = n.refractoryCounter * 0.0
|
||||
reset_z_i_t_commulative!(n::Union{computeNeuron, outputNeuron}) = n.z_i_t_commulative = n.z_i_t_commulative * 0.0
|
||||
reset_alphaChange!(n::Union{computeNeuron, outputNeuron}) = n.alphaChange = n.alphaChange * 0.0
|
||||
|
||||
# reset function for output neuron
|
||||
reset_epsilon_j!(n::linearNeuron) = n.epsilon_j = n.epsilon_j * 0.0
|
||||
@@ -63,8 +63,7 @@ function resetLearningParams!(n::lifNeuron)
|
||||
reset_v_t!(n)
|
||||
reset_z_t!(n)
|
||||
reset_firing_counter!(n)
|
||||
reset_firing_diff!(n)
|
||||
reset_alphaChange!(n)
|
||||
|
||||
|
||||
# reset refractory state at the start/end of episode. Otherwise once neuron goes into
|
||||
# refractory state, it will stay in refractory state forever
|
||||
@@ -79,8 +78,7 @@ function resetLearningParams!(n::alifNeuron)
|
||||
reset_z_t!(n)
|
||||
reset_a!(n)
|
||||
reset_firing_counter!(n)
|
||||
reset_firing_diff!(n)
|
||||
reset_alphaChange!(n)
|
||||
|
||||
|
||||
# reset refractory state at the start/end of episode. Otherwise once neuron goes into
|
||||
# refractory state, it will stay in refractory state forever
|
||||
@@ -113,7 +111,7 @@ function resetLearningParams!(n::integrateNeuron)
|
||||
reset_bChange!(n)
|
||||
reset_v_t!(n)
|
||||
reset_firing_counter!(n)
|
||||
reset_alphaChange!(n)
|
||||
|
||||
end
|
||||
|
||||
#------------------------------------------------------------------------------------------------100
|
||||
@@ -334,11 +332,10 @@ function neuroplasticity!(n::computeNeuron, firedNeurons::Vector,
|
||||
filter!(x -> x ∉ n.subscriptionList, nFiredPool) # exclude this neuron's subscriptionList from the list
|
||||
|
||||
nNonFiredPool = setdiff!([1:length(nExInTypeList)...], nFiredPool)
|
||||
|
||||
filter!(x -> x ∉ [n.id], nNonFiredPool) # exclude this neuron id from the id list
|
||||
filter!(x -> x ∉ n.subscriptionList, nNonFiredPool) # exclude this neuron's subscriptionList from the list
|
||||
|
||||
w = randn(length(zeroWeightConnIndex)) / 100
|
||||
w = randn(length(zeroWeightConnIndex)) / 10
|
||||
synapticStrength = rand(-4.5:0.1:-3.5, length(zeroWeightConnIndex))
|
||||
|
||||
shuffle!(nFiredPool)
|
||||
@@ -356,51 +353,90 @@ function neuroplasticity!(n::computeNeuron, firedNeurons::Vector,
|
||||
newConn = popfirst!(nNonFiredPool)
|
||||
end
|
||||
n.subscriptionList[connIndex] = newConn
|
||||
n.wRec[connIndex] = abs(w[i]) * nExInTypeList[newConn]
|
||||
n.wRec[connIndex] = w[i] #* nExInTypeList[newConn]
|
||||
n.synapticStrength[connIndex] = synapticStrength[i]
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
# function neuroplasticity!(n::outputNeuron, firedNeurons::Vector,
|
||||
# nExInTypeList::Vector, totalInputNeuron::Integer)
|
||||
# # if there is 0-weight then replace it with new connection
|
||||
# zeroWeightConnIndex = findall(iszero.(n.wRec)) # connection that has 0 weight
|
||||
# if length(zeroWeightConnIndex) != 0
|
||||
# # new synaptic connection must sample fron neuron that fires
|
||||
# nFiredPool = filter(x -> x ∉ [n.id], firedNeurons) # exclude this neuron id from the id list
|
||||
# filter!(x -> x ∉ n.subscriptionList, nFiredPool) # exclude this neuron's subscriptionList from the list
|
||||
# filter!(x -> x ∉ [1:totalInputNeuron...], nFiredPool) # exclude input neuron
|
||||
function neuroplasticity!(n::linearNeuron, firedNeurons::Vector,
|
||||
nExInTypeList::Vector, totalInputNeuron::Integer)
|
||||
# if there is 0-weight then replace it with new connection
|
||||
zeroWeightConnIndex = findall(iszero.(n.wRec)) # connection that has 0 weight
|
||||
if length(zeroWeightConnIndex) != 0
|
||||
# new synaptic connection must sample fron neuron that fires
|
||||
nFiredPool = filter(x -> x ∉ [n.id], firedNeurons) # exclude this neuron id from the id list
|
||||
filter!(x -> x ∉ n.subscriptionList, nFiredPool) # exclude this neuron's subscriptionList from the list
|
||||
filter!(x -> x ∉ [1:totalInputNeuron...], nFiredPool) # exclude input neuron
|
||||
|
||||
# nNonFiredPool = setdiff!([1:length(nExInTypeList)...], nFiredPool)
|
||||
# unique!(append!(nNonFiredPool, zeroWeightConnIndex))
|
||||
# filter!(x -> x ∉ [n.id], nNonFiredPool) # exclude this neuron id from the id list
|
||||
# filter!(x -> x ∉ n.subscriptionList, nNonFiredPool) # exclude this neuron's subscriptionList from the list
|
||||
# filter!(x -> x ∉ [1:totalInputNeuron...], nNonFiredPool) # exclude input neuron
|
||||
nNonFiredPool = setdiff!([1:length(nExInTypeList)...], nFiredPool)
|
||||
filter!(x -> x ∉ [n.id], nNonFiredPool) # exclude this neuron id from the id list
|
||||
filter!(x -> x ∉ n.subscriptionList, nNonFiredPool) # exclude this neuron's subscriptionList from the list
|
||||
filter!(x -> x ∉ [1:totalInputNeuron...], nNonFiredPool) # exclude input neuron
|
||||
|
||||
# w = randn(length(zeroWeightConnIndex)) / 100
|
||||
# synapticStrength = rand(-4.5:0.1:-3.5, length(zeroWeightConnIndex))
|
||||
w = randn(length(zeroWeightConnIndex)) / 10
|
||||
synapticStrength = rand(-4.5:0.1:-3.5, length(zeroWeightConnIndex))
|
||||
|
||||
# shuffle!(nFiredPool)
|
||||
# shuffle!(nNonFiredPool)
|
||||
shuffle!(nFiredPool)
|
||||
shuffle!(nNonFiredPool)
|
||||
|
||||
# # add new synaptic connection to neuron
|
||||
# for (i, connIndex) in enumerate(zeroWeightConnIndex)
|
||||
# """ conn that is being replaced has to go into nNonFiredPool so
|
||||
# nNonFiredPool isn't empty """
|
||||
# push!(nNonFiredPool, n.subscriptionList[connIndex])
|
||||
# add new synaptic connection to neuron
|
||||
for (i, connIndex) in enumerate(zeroWeightConnIndex)
|
||||
""" conn that is being replaced has to go into nNonFiredPool so
|
||||
nNonFiredPool isn't empty """
|
||||
push!(nNonFiredPool, n.subscriptionList[connIndex])
|
||||
|
||||
# if length(nFiredPool) != 0
|
||||
# newConn = popfirst!(nFiredPool)
|
||||
# else
|
||||
# newConn = popfirst!(nNonFiredPool)
|
||||
# end
|
||||
# n.subscriptionList[connIndex] = newConn
|
||||
# n.wRec[connIndex] = w[i] * nExInTypeList[newConn]
|
||||
# n.synapticStrength[connIndex] = synapticStrength[i]
|
||||
# end
|
||||
# end
|
||||
# end
|
||||
if length(nFiredPool) != 0
|
||||
newConn = popfirst!(nFiredPool)
|
||||
else
|
||||
newConn = popfirst!(nNonFiredPool)
|
||||
end
|
||||
n.subscriptionList[connIndex] = newConn
|
||||
n.wRec[connIndex] = w[i] * nExInTypeList[newConn]
|
||||
n.synapticStrength[connIndex] = synapticStrength[i]
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
function neuroplasticity!(n::integrateNeuron, firedNeurons::Vector,
|
||||
nExInTypeList::Vector, totalInputNeuron::Integer)
|
||||
# if there is 0-weight then replace it with new connection
|
||||
zeroWeightConnIndex = findall(iszero.(n.wRec)) # connection that has 0 weight
|
||||
if length(zeroWeightConnIndex) != 0
|
||||
# new synaptic connection must sample fron neuron that fires
|
||||
nFiredPool = filter(x -> x ∉ [n.id], firedNeurons) # exclude this neuron id from the id list
|
||||
filter!(x -> x ∉ n.subscriptionList, nFiredPool) # exclude this neuron's subscriptionList from the list
|
||||
filter!(x -> x ∉ [1:totalInputNeuron...], nFiredPool) # exclude input neuron
|
||||
|
||||
nNonFiredPool = setdiff!([1:length(nExInTypeList)...], nFiredPool)
|
||||
# unique!(append!(nNonFiredPool, zeroWeightConnIndex))
|
||||
filter!(x -> x ∉ [n.id], nNonFiredPool) # exclude this neuron id from the id list
|
||||
filter!(x -> x ∉ n.subscriptionList, nNonFiredPool) # exclude this neuron's subscriptionList from the list
|
||||
filter!(x -> x ∉ [1:totalInputNeuron...], nNonFiredPool) # exclude input neuron
|
||||
|
||||
w = randn(length(zeroWeightConnIndex)) / 10
|
||||
synapticStrength = rand(-4.5:0.1:-3.5, length(zeroWeightConnIndex))
|
||||
|
||||
shuffle!(nFiredPool)
|
||||
shuffle!(nNonFiredPool)
|
||||
|
||||
# add new synaptic connection to neuron
|
||||
for (i, connIndex) in enumerate(zeroWeightConnIndex)
|
||||
""" conn that is being replaced has to go into nNonFiredPool so
|
||||
nNonFiredPool isn't empty """
|
||||
push!(nNonFiredPool, n.subscriptionList[connIndex])
|
||||
|
||||
if length(nFiredPool) != 0
|
||||
newConn = popfirst!(nFiredPool)
|
||||
else
|
||||
newConn = popfirst!(nNonFiredPool)
|
||||
end
|
||||
n.subscriptionList[connIndex] = newConn
|
||||
n.wRec[connIndex] = w[i] * nExInTypeList[newConn]
|
||||
n.synapticStrength[connIndex] = synapticStrength[i]
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
""" Cap maximum weight of each neuron connection
|
||||
"""
|
||||
@@ -415,6 +451,27 @@ function wRecMaxWeight!(n::computeNeuron; max=1.0)
|
||||
end
|
||||
|
||||
|
||||
function compute_alphaChange!(n::passthroughNeuron, error::Float64) end
|
||||
|
||||
function compute_alphaChange!(n::Union{computeNeuron, outputNeuron}, error::Float64)
|
||||
if error != 0
|
||||
n.alpha_wSignalChange += -n.eta * sum(n.epsilonRec) * error
|
||||
n.alpha_wPotentialChange += -n.eta * error
|
||||
n.alpha_bChange += -n.eta * error
|
||||
else
|
||||
n.alpha_wSignalChange += n.eta
|
||||
n.alpha_wPotentialChange += n.eta
|
||||
n.alpha_bChange += n.eta
|
||||
end
|
||||
end
|
||||
|
||||
function computeAlpha!(n::Union{computeNeuron, outputNeuron})
|
||||
if sum(n.recSignal) != 0
|
||||
alphaSignal = n.alpha_wSignal * sum(n.recSignal)
|
||||
alphaV = n.alpha_wPotential * n.v_t
|
||||
n.alpha = Flux.sigmoid(alphaSignal + alphaV + n.alpha_b)
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
|
||||
|
||||
191
src/types.jl
191
src/types.jl
@@ -10,7 +10,7 @@ export
|
||||
instantiate_custom_types, init_neuron, populate_neuron,
|
||||
add_neuron!
|
||||
|
||||
using Random, LinearAlgebra
|
||||
using Random, LinearAlgebra, Flux
|
||||
|
||||
#------------------------------------------------------------------------------------------------100
|
||||
|
||||
@@ -117,7 +117,7 @@ Base.@kwdef mutable struct kfn_1 <: knowledgeFn
|
||||
nExcitatory::Array{Int64} =Int64[] # list of excitatory neuron id
|
||||
nInhabitory::Array{Int64} = Int64[] # list of inhabitory neuron id
|
||||
nExInType::Array{Int64} = Int64[] # list all neuron EX or IN
|
||||
excitatoryPercent::Int64 = 60 # percentage of excitatory neuron, inhabitory percent will be 100-ExcitatoryPercent
|
||||
excitatoryPercent::Int64 = 70 # percentage of excitatory neuron, inhabitory percent will be 100-ExcitatoryPercent
|
||||
end
|
||||
|
||||
#------------------------------------------------------------------------------------------------100
|
||||
@@ -193,13 +193,6 @@ function kfn_1(kfnParams::Dict)
|
||||
throw(error("number of compute neuron must be greater than input neuron"))
|
||||
end
|
||||
|
||||
# # Bn
|
||||
# if kfn.kfnParams[:Bn] == "random"
|
||||
# kfn.Bn = [Random.rand(0:0.001:1) for i in 1:kfn.kfnParams[:computeNeuronNumber]]
|
||||
# else # in case I want to specify manually
|
||||
# kfn.Bn = [kfn.kfnParams[:Bn] for i in 1:kfn.kfnParams[:computeNeuronNumber]]
|
||||
# end
|
||||
|
||||
# assign neurons ID by their position in kfn.neurons array because I think it is
|
||||
# straight forward way
|
||||
|
||||
@@ -229,12 +222,6 @@ function kfn_1(kfnParams::Dict)
|
||||
push!(kfn.outputNeuronsArray, neuron)
|
||||
end
|
||||
|
||||
for n in kfn.neuronsArray
|
||||
if typeof(n) <: computeNeuron
|
||||
n.firingRateTarget = kfn.kfnParams[:neuronFiringRateTarget]
|
||||
end
|
||||
end
|
||||
|
||||
# excitatory neuron to inhabitory neuron = 60:40 % of computeNeuron
|
||||
ex_number = Int(floor((kfn.excitatoryPercent/100.0) * kfn.kfnParams[:computeNeuronNumber]))
|
||||
ex_n = [1 for i in 1:ex_number]
|
||||
@@ -265,21 +252,23 @@ function kfn_1(kfnParams::Dict)
|
||||
end
|
||||
end
|
||||
|
||||
# # add ExInType into each output neuron subExInType
|
||||
# for n in kfn.outputNeuronsArray
|
||||
# try # input neuron doest have n.subscriptionList
|
||||
# for (i, sub_id) in enumerate(n.subscriptionList)
|
||||
# n_ExInType = kfn.neuronsArray[sub_id].ExInType
|
||||
# n.wRec[i] *= n_ExInType
|
||||
# end
|
||||
# catch
|
||||
# end
|
||||
# end
|
||||
# add ExInType into each output neuron subExInType
|
||||
for n in kfn.outputNeuronsArray
|
||||
try # input neuron doest have n.subscriptionList
|
||||
for (i, sub_id) in enumerate(n.subscriptionList)
|
||||
n_ExInType = kfn.neuronsArray[sub_id].ExInType
|
||||
n.wRec[i] *= n_ExInType
|
||||
end
|
||||
catch
|
||||
end
|
||||
end
|
||||
|
||||
for n in kfn.neuronsArray
|
||||
push!(kfn.nExInType, n.ExInType)
|
||||
end
|
||||
|
||||
|
||||
|
||||
return kfn
|
||||
end
|
||||
|
||||
@@ -341,8 +330,15 @@ Base.@kwdef mutable struct lifNeuron <: computeNeuron
|
||||
synapticStrengthLimit::NamedTuple = (lowerlimit=(-5=>-5), upperlimit=(5=>5))
|
||||
|
||||
gammaPd::Float64 = 0.3 # γ_pd, discount factor, value from paper
|
||||
alpha::Float64 = 0.0 # α, neuron membrane potential decay factor
|
||||
alphaChange::Float64 = 0.0
|
||||
|
||||
alpha::Float64 = 0.99
|
||||
alpha_wSignal::Float64 = 2.0
|
||||
alpha_wPotential::Float64 = 2.0
|
||||
alpha_b::Float64 = 2.0
|
||||
alpha_wSignalChange::Float64 = 0.0
|
||||
alpha_wPotentialChange::Float64 = 0.0
|
||||
alpha_bChange::Float64 = 0.0
|
||||
|
||||
phi::Float64 = 0.0 # ϕ, psuedo derivative
|
||||
epsilonRec::Array{Float64} = Float64[] # ϵ_rec, eligibility vector for neuron spike
|
||||
decayedEpsilonRec::Array{Float64} = Float64[] # α * epsilonRec
|
||||
@@ -364,7 +360,7 @@ Base.@kwdef mutable struct lifNeuron <: computeNeuron
|
||||
firingRateError::Float64 = 0.0 # local neuron error w.r.t. firing regularization
|
||||
firingRate::Float64 = 0.0 # running average of firing rate in Hz
|
||||
|
||||
notFireTimeOut::Int64 = 100 # consecutive count of not firing. Should be the same as batch size
|
||||
notFireTimeOut::Int64 = 10 # consecutive count of not firing. Should be the same as batch size
|
||||
notFireCounter::Int64 = 0
|
||||
|
||||
""" "inference" = no learning params will be collected.
|
||||
@@ -434,8 +430,22 @@ Base.@kwdef mutable struct alifNeuron <: computeNeuron
|
||||
synapticStrength::Array{Float64} = Float64[]
|
||||
synapticStrengthLimit::NamedTuple = (lowerlimit=(-5=>0), upperlimit=(5=>5))
|
||||
|
||||
alpha::Float64 = 0.0 # α, neuron membrane potential decay factor
|
||||
alphaChange::Float64 = 0.0
|
||||
alpha::Float64 = 0.99
|
||||
alpha_wSignal::Float64 = 2.0
|
||||
alpha_wPotential::Float64 = 2.0
|
||||
alpha_b::Float64 = 2.0
|
||||
alpha_wSignalChange::Float64 = 0.0
|
||||
alpha_wPotentialChange::Float64 = 0.0
|
||||
alpha_bChange::Float64 = 0.0
|
||||
|
||||
# alpha::Vector{Float64} = Float64[]
|
||||
# alpha_wSignal::Vector{Float64} = Float64[]
|
||||
# alpha_wPotential::Float64 = randn() / 100
|
||||
# alpha_b::Vector{Float64} = Float64[]
|
||||
# alpha_wSignalChange::Vector{Float64} = Float64[]
|
||||
# alpha_wPotentialChange::Float64 = 0.0
|
||||
# alpha_bChange::Vector{Float64} = Float64[]
|
||||
|
||||
delta::Float64 = 1.0 # δ, discreate timestep size in millisecond
|
||||
epsilonRec::Array{Float64} = Float64[] # ϵ_rec(v), eligibility vector for neuron i spike
|
||||
epsilonRecA::Array{Float64} = Float64[] # ϵ_rec(a)
|
||||
@@ -461,7 +471,7 @@ Base.@kwdef mutable struct alifNeuron <: computeNeuron
|
||||
firingRateError::Float64 = 0.0 # local neuron error w.r.t. firing regularization
|
||||
firingRate::Float64 = 0.0 # running average of firing rate, Hz
|
||||
|
||||
notFireTimeOut::Int64 = 100 # consecutive count of not firing. Should be the same as batch size
|
||||
notFireTimeOut::Int64 = 10 # consecutive count of not firing. Should be the same as batch size
|
||||
notFireCounter::Int64 = 0
|
||||
|
||||
tau_a::Float64 = 100.0 # τ_a, adaption time constant in millisecond
|
||||
@@ -546,8 +556,16 @@ Base.@kwdef mutable struct linearNeuron <: outputNeuron
|
||||
synapticStrength::Array{Float64} = Float64[]
|
||||
synapticStrengthLimit::NamedTuple = (lowerlimit=(-5=>-5), upperlimit=(5=>5))
|
||||
|
||||
gammaPd::Float64 = 0.3 # γ_pd, discount factor, value from paper
|
||||
alpha::Float64 = 0.0 # α, neuron membrane potential decay factor
|
||||
gammaPd::Float64 = 0.3 # γ_pd, discount factor, value from
|
||||
|
||||
alpha::Float64 = 0.99
|
||||
alpha_wSignal::Float64 = 2.0
|
||||
alpha_wPotential::Float64 = 2.0
|
||||
alpha_b::Float64 = 2.0
|
||||
alpha_wSignalChange::Float64 = 0.0
|
||||
alpha_wPotentialChange::Float64 = 0.0
|
||||
alpha_bChange::Float64 = 0.0
|
||||
|
||||
phi::Float64 = 0.0 # ϕ, psuedo derivative
|
||||
epsilonRec::Array{Float64} = Float64[] # ϵ_rec, eligibility vector for neuron spike
|
||||
decayedEpsilonRec::Array{Float64} = Float64[] # α * epsilonRec
|
||||
@@ -562,6 +580,14 @@ Base.@kwdef mutable struct linearNeuron <: outputNeuron
|
||||
alpha_v_t::Float64 = 0.0 # alpha * v_t
|
||||
|
||||
firingCounter::Int64 = 0 # store how many times neuron fires
|
||||
firingRateTarget::Float64 = 20.0 # neuron's target firing rate in Hz
|
||||
firingDiff::Float64 = 0.0 # e-prop supplement paper equation 5
|
||||
firingRateError::Float64 = 0.0 # local neuron error w.r.t. firing regularization
|
||||
firingRate::Float64 = 0.0 # running average of firing rate in Hz
|
||||
|
||||
notFireTimeOut::Int64 = 10 # consecutive count of not firing. Should be the same as batch size
|
||||
notFireCounter::Int64 = 0
|
||||
|
||||
ExInSignalSum::Float64 = 0.0
|
||||
end
|
||||
|
||||
@@ -627,8 +653,15 @@ Base.@kwdef mutable struct integrateNeuron <: outputNeuron
|
||||
synapticStrengthLimit::NamedTuple = (lowerlimit=(-5=>-5), upperlimit=(5=>5))
|
||||
|
||||
gammaPd::Float64 = 0.3 # γ_pd, discount factor, value from paper
|
||||
alpha::Float64 = 0.0 # α, neuron membrane potential decay factor
|
||||
alphaChange::Float64 = 0.0
|
||||
|
||||
alpha::Float64 = 0.99
|
||||
alpha_wSignal::Float64 = 2.0
|
||||
alpha_wPotential::Float64 = 2.0
|
||||
alpha_b::Float64 = 2.0
|
||||
alpha_wSignalChange::Float64 = 0.0
|
||||
alpha_wPotentialChange::Float64 = 0.0
|
||||
alpha_bChange::Float64 = 0.0
|
||||
|
||||
phi::Float64 = 0.0 # ϕ, psuedo derivative
|
||||
epsilonRec::Array{Float64} = Float64[] # ϵ_rec, eligibility vector for neuron spike
|
||||
decayedEpsilonRec::Array{Float64} = Float64[] # α * epsilonRec
|
||||
@@ -643,6 +676,14 @@ Base.@kwdef mutable struct integrateNeuron <: outputNeuron
|
||||
alpha_v_t::Float64 = 0.0 # alpha * v_t
|
||||
|
||||
firingCounter::Int64 = 0 # store how many times neuron fires
|
||||
firingRateTarget::Float64 = 20.0 # neuron's target firing rate in Hz
|
||||
firingDiff::Float64 = 0.0 # e-prop supplement paper equation 5
|
||||
firingRateError::Float64 = 0.0 # local neuron error w.r.t. firing regularization
|
||||
firingRate::Float64 = 0.0 # running average of firing rate in Hz
|
||||
|
||||
notFireTimeOut::Int64 = 10 # consecutive count of not firing. Should be the same as batch size
|
||||
notFireCounter::Int64 = 0
|
||||
|
||||
ExInSignalSum::Float64 = 0.0
|
||||
end
|
||||
|
||||
@@ -699,23 +740,6 @@ function init_neuron!(id::Int64, n::passthroughNeuron, n_params::Dict, kfnParams
|
||||
n.knowledgeFnName = kfnParams[:knowledgeFnName]
|
||||
end
|
||||
|
||||
# function init_neuron!(id::Int64, n::lifNeuron, kfnParams::Dict)
|
||||
# n.id = id
|
||||
# n.knowledgeFnName = kfnParams[:knowledgeFnName]
|
||||
# subscription_options = shuffle!([1:(kfnParams[:input_neuron_number]+kfnParams[:computeNeuronNumber])...])
|
||||
# if typeof(kfnParams[:synapticConnectionPercent]) == String
|
||||
# percent = parse(Int, kfnParams[:synapticConnectionPercent][1:end-1]) / 100
|
||||
# synapticConnectionPercent = floor(length(subscription_options) * percent)
|
||||
# n.subscriptionList = [pop!(subscription_options) for i = 1:synapticConnectionPercent]
|
||||
# end
|
||||
# filter!(x -> x != n.id, n.subscriptionList)
|
||||
# n.epsilonRec = zeros(length(n.subscriptionList))
|
||||
# n.wRec = Random.rand(length(n.subscriptionList))
|
||||
# n.wRecChange = zeros(length(n.subscriptionList))
|
||||
# n.reg_voltage_b = zeros(length(n.subscriptionList))
|
||||
# n.alpha = calculate_α(n)
|
||||
# end
|
||||
|
||||
function init_neuron!(id::Int64, n::lifNeuron, n_params::Dict, kfnParams::Dict)
|
||||
n.id = id
|
||||
n.knowledgeFnName = kfnParams[:knowledgeFnName]
|
||||
@@ -728,11 +752,13 @@ function init_neuron!(id::Int64, n::lifNeuron, n_params::Dict, kfnParams::Dict)
|
||||
filter!(x -> x != n.id, n.subscriptionList)
|
||||
n.synapticStrength = rand(-4.5:0.01:-4, length(n.subscriptionList))
|
||||
|
||||
n.epsilonRec = zeros(length(n.subscriptionList))
|
||||
# n.wRec = randn(length(n.subscriptionList))
|
||||
n.wRec = randn(rng, length(n.subscriptionList)) / 100
|
||||
n.wRecChange = zeros(length(n.subscriptionList))
|
||||
n.alpha = calculate_α(n)
|
||||
|
||||
n.epsilonRec = zeros(length(n.subscriptionList))
|
||||
# start w/ small weight Otherwise neuron's weight will be explode in the long run
|
||||
n.wRec = randn(rng, length(n.subscriptionList)) / 10
|
||||
|
||||
n.wRecChange = zeros(length(n.subscriptionList))
|
||||
n.z_i_t_commulative = zeros(length(n.subscriptionList))
|
||||
end
|
||||
|
||||
@@ -750,7 +776,9 @@ function init_neuron!(id::Int64, n::alifNeuron, n_params::Dict,
|
||||
n.synapticStrength = rand(-4.5:0.01:-4, length(n.subscriptionList))
|
||||
|
||||
n.epsilonRec = zeros(length(n.subscriptionList))
|
||||
n.wRec = randn(rng, length(n.subscriptionList)) / 100
|
||||
# start w/ small weight Otherwise neuron's weight will be explode in the long run
|
||||
n.wRec = randn(rng, length(n.subscriptionList)) / 10
|
||||
|
||||
n.wRecChange = zeros(length(n.subscriptionList))
|
||||
|
||||
# the more time has passed from the last time neuron was activated, the more
|
||||
@@ -761,6 +789,23 @@ function init_neuron!(id::Int64, n::alifNeuron, n_params::Dict,
|
||||
n.z_i_t_commulative = zeros(length(n.subscriptionList))
|
||||
end
|
||||
|
||||
function init_neuron!(id::Int64, n::linearNeuron, n_params::Dict, kfnParams::Dict)
|
||||
n.id = id
|
||||
n.knowledgeFnName = kfnParams[:knowledgeFnName]
|
||||
|
||||
subscription_options = shuffle!([kfnParams[:totalInputPort]+1 : kfnParams[:totalNeurons]...])
|
||||
subscription_numbers = Int(floor((n_params[:synapticConnectionPercent] / 100.0) *
|
||||
kfnParams[:totalNeurons] - kfnParams[:totalInputPort]))
|
||||
n.subscriptionList = [pop!(subscription_options) for i = 1:subscription_numbers]
|
||||
n.synapticStrength = rand(-4.5:0.01:-4, length(n.subscriptionList))
|
||||
|
||||
n.epsilonRec = zeros(length(n.subscriptionList))
|
||||
n.wRec = randn(rng, length(n.subscriptionList)) / 10
|
||||
n.wRecChange = zeros(length(n.subscriptionList))
|
||||
n.alpha = calculate_k(n)
|
||||
n.z_i_t_commulative = zeros(length(n.subscriptionList))
|
||||
end
|
||||
|
||||
function init_neuron!(id::Int64, n::integrateNeuron, n_params::Dict, kfnParams::Dict)
|
||||
n.id = id
|
||||
n.knowledgeFnName = kfnParams[:knowledgeFnName]
|
||||
@@ -771,31 +816,17 @@ function init_neuron!(id::Int64, n::integrateNeuron, n_params::Dict, kfnParams::
|
||||
n.subscriptionList = [pop!(subscription_options) for i = 1:subscription_numbers]
|
||||
n.synapticStrength = rand(-4.5:0.01:-4, length(n.subscriptionList))
|
||||
|
||||
n.epsilonRec = zeros(length(n.subscriptionList))
|
||||
n.wRec = randn(rng, length(n.subscriptionList)) / 100
|
||||
n.wRecChange = zeros(length(n.subscriptionList))
|
||||
n.alpha = calculate_k(n)
|
||||
|
||||
n.epsilonRec = zeros(length(n.subscriptionList))
|
||||
# start w/ small weight Otherwise neuron's weight will be explode in the long run
|
||||
n.wRec = randn(rng, length(n.subscriptionList)) / 10
|
||||
n.wRecChange = zeros(length(n.subscriptionList))
|
||||
n.z_i_t_commulative = zeros(length(n.subscriptionList))
|
||||
n.b = randn(rng) / 100
|
||||
# start w/ small weight Otherwise neuron's weight will be explode in the long run
|
||||
n.b = randn(rng) / 10
|
||||
end
|
||||
|
||||
# function init_neuron!(id::Int64, n::linearNeuron, n_params::Dict, kfnParams::Dict)
|
||||
# n.id = id
|
||||
# n.knowledgeFnName = kfnParams[:knowledgeFnName]
|
||||
|
||||
# subscription_options = shuffle!([kfnParams[:totalInputPort]+1 : kfnParams[:totalNeurons]...])
|
||||
# subscription_numbers = Int(floor((n_params[:synapticConnectionPercent] / 100.0) *
|
||||
# kfnParams[:totalNeurons] - kfnParams[:totalInputPort]))
|
||||
# n.subscriptionList = [pop!(subscription_options) for i = 1:subscription_numbers]
|
||||
# n.synapticStrength = rand(-4.5:0.01:-4, length(n.subscriptionList))
|
||||
|
||||
# n.epsilonRec = zeros(length(n.subscriptionList))
|
||||
# n.wRec = randn(rng, length(n.subscriptionList)) / 100
|
||||
# n.wRecChange = zeros(length(n.subscriptionList))
|
||||
# n.alpha = calculate_k(n)
|
||||
# n.z_i_t_commulative = zeros(length(n.subscriptionList))
|
||||
# end
|
||||
|
||||
""" Make a neuron intended for use with knowledgeFn
|
||||
"""
|
||||
function init_neuron(id::Int64, n_params::Dict, kfnParams::Dict)
|
||||
@@ -854,6 +885,10 @@ calculate_ρ(neuron::alifNeuron) = exp(-neuron.delta / neuron.tau_a)
|
||||
calculate_k(neuron::linearNeuron) = exp(-neuron.delta / neuron.tau_out)
|
||||
calculate_k(neuron::integrateNeuron) = exp(-neuron.delta / neuron.tau_out)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
#------------------------------------------------------------------------------------------------100
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user