Files
Ironpen/src/forward.jl

292 lines
8.6 KiB
Julia

module forward
using Statistics, Random, LinearAlgebra, JSON3
using GeneralUtils
using ..types, ..snn_utils
#------------------------------------------------------------------------------------------------100
""" Model forward()
"""
function (m::model)(input_data::AbstractVector)
m.timeStep += 1
# process all corresponding KFN
# raw_model_respond, outputNeuron_v_t1, firedNeurons_t1 = m.knowledgeFn[:I](m, input_data)
# the 2nd return (KFN error) should not be used as model error but I use it because there is
# only one KFN in a model right now
return m.knowledgeFn[:I](m, input_data)
end
#------------------------------------------------------------------------------------------------100
""" knowledgeFn forward()
"""
function (kfn::kfn_1)(m::model, input_data::AbstractVector)
kfn.timeStep = m.timeStep
kfn.learningStage = m.learningStage
if kfn.learningStage == "start_learning"
# reset params here instead of at the end_learning so that neuron's parameter data
# don't gets wiped and can be logged for visualization later
for n in kfn.neuronsArray
# epsilonRec need to be reset because it counting how many each synaptic fires and
# use this info to calculate how much synaptic weight should be adjust
resetLearningParams!(n)
end
for n in kfn.outputNeuronsArray
# epsilonRec need to be reset because it counting how many each synaptic fires and
# use this info to calculate how much synaptic weight should be adjust
resetLearningParams!(n)
end
# clear variables
kfn.firedNeurons = Int64[]
kfn.firedNeurons_t0 = Bool[]
kfn.firedNeurons_t1 = Bool[]
kfn.learningStage = "learning"
m.learningStage = kfn.learningStage
end
# generate noise
noise = [GeneralUtils.randomChoiceWithProb([true, false],[0.5,0.5])
for i in 1:length(input_data)]
# noise = [rand(rng, Distributions.Binomial(1, 0.5)) for i in 1:10] # another option
input_data = [noise; input_data] # noise must start from neuron id 1
for n in kfn.neuronsArray
timestep_forward!(n)
end
for n in kfn.outputNeuronsArray
timestep_forward!(n)
end
# pass input_data into input neuron.
# number of data point equals to number of input neuron starting from id 1
for (i, data) in enumerate(input_data)
kfn.neuronsArray[i].z_t1 = data
end
kfn.firedNeurons_t0 = [n.z_t for n in kfn.neuronsArray] #TODO check if it is used?
# Threads.@threads for n in kfn.neuronsArray
for n in kfn.neuronsArray
n(kfn)
end
kfn.firedNeurons_t1 = [n.z_t1 for n in kfn.neuronsArray]
append!(kfn.firedNeurons, findall(kfn.firedNeurons_t1)) # store id of neuron that fires
kfn.firedNeurons |> unique! # use for random new neuron connection
# Threads.@threads for n in kfn.outputNeuronsArray
for n in kfn.outputNeuronsArray
n(kfn)
end
out = [n.z_t1 for n in kfn.outputNeuronsArray]
outputNeuron_v_t1 = [n.v_t1 for n in kfn.outputNeuronsArray]
return out::Array{Bool}, outputNeuron_v_t1::Array{Float64}, sum(kfn.firedNeurons_t1),
kfn.exSignalSum, kfn.inSignalsum
end
#------------------------------------------------------------------------------------------------100
""" passthroughNeuron forward()
"""
function (n::passthroughNeuron)(kfn::knowledgeFn)
n.timeStep = kfn.timeStep
end
#------------------------------------------------------------------------------------------------100
""" lifNeuron forward()
"""
function (n::lifNeuron)(kfn::knowledgeFn)
n.timeStep = kfn.timeStep
# pulling other neuron's firing status at time t
n.z_i_t = getindex(kfn.firedNeurons_t0, n.subscriptionList)
n.z_i_t_commulative += n.z_i_t
if n.refractoryCounter != 0
n.refractoryCounter -= 1
# neuron is in refractory state, skip all calculation
n.z_t1 = false # used by timestep_forward() in kfn. Set to zero because neuron spike
# last only 1 timestep follow by a period of refractory.
n.recSignal = n.recSignal * 0.0
# decay of v_t1
n.v_t1 = n.alpha * n.v_t
else
n.recSignal = sum(n.wRec .* n.z_i_t) # signal from other neuron that this neuron subscribed
n.alpha_v_t = n.alpha * n.v_t
n.v_t1 = n.alpha_v_t + n.recSignal
n.v_t1 = no_negative!(n.v_t1)
if n.v_t1 > n.v_th
n.z_t1 = true
n.refractoryCounter = n.refractoryDuration
n.firingCounter += 1
n.v_t1 = n.vRest
else
n.z_t1 = false
end
# there is a difference from alif formula
n.phi = (n.gammaPd / n.v_th) * max(0, 1 - (n.v_t1 - n.v_th) / n.v_th)
n.decayedEpsilonRec = n.alpha * n.epsilonRec
n.epsilonRec = n.decayedEpsilonRec + n.z_i_t
end
end
#------------------------------------------------------------------------------------------------100
""" alifNeuron forward()
"""
function (n::alifNeuron)(kfn::knowledgeFn)
n.timeStep = kfn.timeStep
n.z_i_t = getindex(kfn.firedNeurons_t0, n.subscriptionList)
n.z_i_t_commulative += n.z_i_t
if n.refractoryCounter != 0
n.refractoryCounter -= 1
# neuron is in refractory state, skip all calculation
n.z_t1 = false # used by timestep_forward() in kfn. Set to zero because neuron spike last only 1 timestep follow by a period of refractory.
n.a = (n.rho * n.a) + ((1 - n.rho) * n.z_t)
n.recSignal = n.recSignal * 0.0
# decay of v_t1
n.v_t1 = n.alpha * n.v_t
n.phi = 0
else
n.a = (n.rho * n.a) + ((1 - n.rho) * n.z_t)
n.av_th = n.v_th + (n.beta * n.a)
n.recSignal = sum(n.wRec .* n.z_i_t) # signal from other neuron that this neuron subscribed
n.alpha_v_t = n.alpha * n.v_t
n.v_t1 = n.alpha_v_t + n.recSignal
n.v_t1 = no_negative!(n.v_t1)
if n.v_t1 > n.av_th
n.z_t1 = true
n.refractoryCounter = n.refractoryDuration
n.firingCounter += 1
n.v_t1 = n.vRest
else
n.z_t1 = false
end
# there is a difference from lif formula
n.phi = (n.gammaPd / n.v_th) * max(0, 1 - (n.v_t1 - n.av_th) / n.v_th)
n.decayedEpsilonRec = n.alpha * n.epsilonRec
n.epsilonRec = n.decayedEpsilonRec + n.z_i_t
n.epsilonRecA = (n.phi * n.epsilonRec) +
((n.rho - (n.phi * n.beta)) * n.epsilonRecA)
end
end
#------------------------------------------------------------------------------------------------100
""" linearNeuron forward()
In this implementation, each output neuron is fully connected to every lif and alif neuron.
"""
function (n::linearNeuron)(kfn::T) where T<:knowledgeFn
n.timeStep = kfn.timeStep
# pulling other neuron's firing status at time t
n.z_i_t = getindex(kfn.firedNeurons_t1, n.subscriptionList)
n.z_i_t_commulative += n.z_i_t
if n.refractoryCounter != 0
n.refractoryCounter -= 1
# neuron is in refractory state, skip all calculation
n.z_t1 = false # used by timestep_forward() in kfn. Set to zero because neuron spike
# last only 1 timestep follow by a period of refractory.
n.recSignal = n.recSignal * 0.0
# decay of v_t1
n.v_t1 = n.alpha * n.v_t
n.vError = n.v_t1 # store voltage that will be used to calculate error later
else
recSignal = n.wRec .* n.z_i_t
if n.id == 1 #FIXME debugging output neuron dead
for i in recSignal
if i > 0
kfn.exSignalSum += i
elseif i < 0
kfn.inSignalsum += i
else
end
end
end
n.recSignal = sum(recSignal) # signal from other neuron that this neuron subscribed
n.alpha_v_t = n.alpha * n.v_t
n.v_t1 = n.alpha_v_t + n.recSignal
n.v_t1 = no_negative!(n.v_t1)
n.vError = n.v_t1 # store voltage that will be used to calculate error later
if n.v_t1 > n.v_th
n.z_t1 = true
n.refractoryCounter = n.refractoryDuration
n.firingCounter += 1
n.v_t1 = n.vRest
else
n.z_t1 = false
end
# there is a difference from alif formula
n.phi = (n.gammaPd / n.v_th) * max(0, 1 - (n.v_t1 - n.v_th) / n.v_th)
n.decayedEpsilonRec = n.alpha * n.epsilonRec
n.epsilonRec = n.decayedEpsilonRec + n.z_i_t
end
end
end # end module