refractoring

This commit is contained in:
2023-05-17 14:28:43 +07:00
parent df26a01929
commit 214466d9e9
5 changed files with 36 additions and 53 deletions

View File

@@ -31,6 +31,29 @@ function (kfn::kfn_1)(m::model, input_data::AbstractVector)
kfn.learningStage = m.learningStage
if kfn.learningStage == "start_learning"
# reset params here instead of at the end_learning so that neuron's parameter data
# don't gets wiped and can be logged for visualization later
for n in kfn.neuronsArray
# epsilonRec need to be reset because it counting how many each synaptic fires and
# use this info to calculate how much synaptic weight should be adjust
resetLearningParams!(n)
end
for n in kfn.output_neuron
# epsilonRec need to be reset because it counting how many each synaptic fires and
# use this info to calculate how much synaptic weight should be adjust
resetLearningParams!(n)
end
# clear variables
kfn.firedNeurons = Vector{Int64}()
kfn.firedNeurons_t0 = Vector{Bool}()
kfn.firedNeurons_t1 = Vector{Bool}()
kfn.learningStage = "learning"
end
# generate noise
noise = [GeneralUtils.randomChoiceWithProb([true, false],[0.5,0.5])
for i in 1:length(input_data)]
@@ -41,6 +64,9 @@ function (kfn::kfn_1)(m::model, input_data::AbstractVector)
for n in kfn.neuronsArray
timestep_forward!(n)
end
for n in kfn.outputNeuronsArray
timestep_forward!(n)
end
# pass input_data into input neuron.
# number of data point equals to number of input neuron starting from id 1
@@ -89,7 +115,6 @@ function (n::lif_neuron)(kfn::knowledgeFn)
# pulling other neuron's firing status at time t
n.z_i_t = getindex(kfn.firedNeurons_t0, n.subscriptionList)
n.z_i_t .*= n.subExInType
if n.refractoryCounter != 0
n.refractoryCounter -= 1
@@ -102,7 +127,7 @@ function (n::lif_neuron)(kfn::knowledgeFn)
# decay of v_t1
n.v_t1 = n.alpha * n.v_t
else
n.recSignal = sum(n.wRec .* n.z_i_t) # signal from other neuron that this neuron subscribed
n.recSignal = sum(n.wRec .* n.z_i_t .* n.subExInType) # signal from other neuron that this neuron subscribed
n.alpha_v_t = n.alpha * n.v_t
n.v_t1 = n.alpha_v_t + n.recSignal
@@ -132,7 +157,6 @@ function (n::alif_neuron)(kfn::knowledgeFn)
n.timeStep = kfn.timeStep
n.z_i_t = getindex(kfn.firedNeurons_t0, n.subscriptionList)
n.z_i_t .*= n.subExInType
if n.refractoryCounter != 0
n.refractoryCounter -= 1
@@ -149,7 +173,7 @@ function (n::alif_neuron)(kfn::knowledgeFn)
n.z_t = isnothing(n.z_t) ? false : n.z_t
n.a = (n.rho * n.a) + ((1 - n.rho) * n.z_t)
n.av_th = n.v_th + (n.beta * n.a)
n.recSignal = sum(n.wRec .* n.z_i_t) # signal from other neuron that this neuron subscribed
n.recSignal = sum(n.wRec .* n.z_i_t .* n.subExInType) # signal from other neuron that this neuron subscribed
n.alpha_v_t = n.alpha * n.v_t
n.v_t1 = n.alpha_v_t + n.recSignal
n.v_t1 = no_negative!.(n.v_t1)
@@ -181,7 +205,6 @@ function (n::linear_neuron)(kfn::T) where T<:knowledgeFn
# pulling other neuron's firing status at time t
n.z_i_t = getindex(kfn.firedNeurons_t0, n.subscriptionList)
n.z_i_t .*= n.subExInType
if n.refractoryCounter != 0
n.refractoryCounter -= 1
@@ -194,7 +217,7 @@ function (n::linear_neuron)(kfn::T) where T<:knowledgeFn
# decay of v_t1
n.v_t1 = n.alpha * n.v_t
else
n.recSignal = sum(n.wRec .* n.z_i_t) # signal from other neuron that this neuron subscribed
n.recSignal = sum(n.wRec .* n.z_i_t .* n.subExInType) # signal from other neuron that this neuron subscribed
n.alpha_v_t = n.alpha * n.v_t
n.v_t1 = n.alpha_v_t + n.recSignal