module forward using Statistics, Random, LinearAlgebra, JSON3, Flux using GeneralUtils using ..types, ..snn_utils #------------------------------------------------------------------------------------------------100 """ Model forward() """ function (m::model)(input_data::AbstractVector) m.timeStep += 1 # process all corresponding KFN # raw_model_respond, outputNeuron_v_t1, firedNeurons_t1 = m.knowledgeFn[:I](m, input_data) # the 2nd return (KFN error) should not be used as model error but I use it because there is # only one KFN in a model right now return m.knowledgeFn[:I](m, input_data) end #------------------------------------------------------------------------------------------------100 """ knowledgeFn forward() """ function (kfn::kfn_1)(m::model, input_data::AbstractVector) kfn.timeStep = m.timeStep for n in kfn.neuronsArray timestep_forward!(n) end for n in kfn.outputNeuronsArray timestep_forward!(n) end kfn.learningStage = m.learningStage if kfn.learningStage == "start_learning" # reset params here instead of at the end_learning so that neuron's parameter data # don't gets wiped and can be logged for visualization later for n in kfn.neuronsArray # epsilonRec need to be reset because it counting how many each synaptic fires and # use this info to calculate how much synaptic weight should be adjust resetLearningParams!(n) end for n in kfn.outputNeuronsArray # epsilonRec need to be reset because it counting how many each synaptic fires and # use this info to calculate how much synaptic weight should be adjust resetLearningParams!(n) end # clear variables kfn.firedNeurons = Int64[] kfn.firedNeurons_t0 = Bool[] kfn.firedNeurons_t1 = Bool[] kfn.learningStage = "learning" m.learningStage = kfn.learningStage end # generate noise noise = [GeneralUtils.randomChoiceWithProb([true, false],[0.01, 0.99]) for i in 1:length(input_data)] # noise = [rand(rng, Distributions.Binomial(1, 0.5)) for i in 1:10] # another option # noise = [kfn.timeStep % 50 == 0 # for i in 1:length(input_data)] input_data = [noise; input_data] # noise must start from neuron id 1 # pass input_data into input neuron. # number of data point equals to number of input neuron starting from id 1 for (i, data) in enumerate(input_data) kfn.neuronsArray[i].z_t1 = data end kfn.firedNeurons_t0 = [n.z_t for n in kfn.neuronsArray] Threads.@threads for n in kfn.neuronsArray # for n in kfn.neuronsArray n(kfn) end kfn.firedNeurons_t1 = [n.z_t1 for n in kfn.neuronsArray] append!(kfn.firedNeurons, findall(kfn.firedNeurons_t1)) # store id of neuron that fires kfn.firedNeurons |> unique! # use for random new neuron connection Threads.@threads for n in kfn.outputNeuronsArray # for n in kfn.outputNeuronsArray n(kfn) end logit = [n.v_t1 for n in kfn.outputNeuronsArray] # _predict = Flux.softmax(logit) # predict = findall(isequal.(_predict, maximum(_predict)))[1] return sum(kfn.firedNeurons_t1[kfn.kfnParams[:totalInputPort]+1:end])::Int, logit::Array{Float64}, [i for i in kfn.neuronsArray[101].wRec[1:10]], [i.v_t1 for i in kfn.neuronsArray[101:110]], [sum(i.epsilonRec) for i in kfn.outputNeuronsArray], [sum(i.wRecChange) for i in kfn.outputNeuronsArray] end #------------------------------------------------------------------------------------------------100 """ passthroughNeuron forward() """ function (n::passthroughNeuron)(kfn::knowledgeFn) n.timeStep = kfn.timeStep end #------------------------------------------------------------------------------------------------100 """ lifNeuron forward() """ function (n::lifNeuron)(kfn::knowledgeFn) n.timeStep = kfn.timeStep # pulling other neuron's firing status at time t n.z_i_t = getindex(kfn.firedNeurons_t0, n.subscriptionList) n.z_i_t_commulative += n.z_i_t if n.refractoryCounter != 0 n.refractoryCounter -= 1 # neuron is in refractory state, skip all calculation n.z_t1 = false # used by timestep_forward() in kfn. Set to zero because neuron spike # last only 1 timestep follow by a period of refractory. n.recSignal = n.recSignal * 0.0 # decay of v_t1 n.v_t1 = n.alpha * n.v_t n.phi = 0.0 n.decayedEpsilonRec = n.alpha * n.epsilonRec n.epsilonRec = n.decayedEpsilonRec else n.recSignal = sum(n.wRec .* n.z_i_t) # signal from other neuron that this neuron subscribed # computeAlpha!(n) n.alpha_v_t = n.alpha * n.v_t n.v_t1 = n.alpha_v_t + n.recSignal # n.v_t1 = no_negative!(n.v_t1) if n.v_t1 > n.v_th n.z_t1 = true n.refractoryCounter = n.refractoryDuration n.firingCounter += 1 n.v_t1 = n.vRest else n.z_t1 = false end # there is a difference from alif formula n.phi = (n.gammaPd / n.v_th) * max(0, 1 - (n.v_t1 - n.v_th) / n.v_th) n.decayedEpsilonRec = n.alpha * n.epsilonRec n.epsilonRec = n.decayedEpsilonRec + n.z_i_t end end #------------------------------------------------------------------------------------------------100 """ alifNeuron forward() """ function (n::alifNeuron)(kfn::knowledgeFn) n.timeStep = kfn.timeStep n.z_i_t = getindex(kfn.firedNeurons_t0, n.subscriptionList) n.z_i_t_commulative += n.z_i_t if n.refractoryCounter != 0 n.refractoryCounter -= 1 # neuron is in refractory state, skip all calculation n.z_t1 = false # used by timestep_forward() in kfn. Set to zero because neuron spike last only 1 timestep follow by a period of refractory. n.a = (n.rho * n.a) n.recSignal = n.recSignal * 0.0 # decay of v_t1 n.v_t1 = n.alpha * n.v_t n.phi = 0.0 n.decayedEpsilonRec = n.alpha * n.epsilonRec n.epsilonRec = n.decayedEpsilonRec else n.av_th = n.v_th + (n.beta * n.a) n.recSignal = sum(n.wRec .* n.z_i_t) # signal from other neuron that this neuron subscribed # computeAlpha!(n) n.alpha_v_t = n.alpha * n.v_t n.v_t1 = n.alpha_v_t + n.recSignal # n.v_t1 = no_negative!(n.v_t1) if n.v_t1 > n.av_th n.z_t1 = true n.refractoryCounter = n.refractoryDuration n.firingCounter += 1 n.v_t1 = n.vRest n.a = (n.rho * n.a) + 1.0 else n.z_t1 = false n.a = (n.rho * n.a) end # there is a difference from lif formula n.phi = (n.gammaPd / n.v_th) * max(0, 1 - (n.v_t1 - n.av_th) / n.v_th) n.decayedEpsilonRec = n.alpha * n.epsilonRec n.epsilonRec = n.decayedEpsilonRec + n.z_i_t n.epsilonRecA = (n.phi * n.epsilonRec) + ((n.rho - (n.phi * n.beta)) * n.epsilonRecA) end end #------------------------------------------------------------------------------------------------100 """ linearNeuron forward() In this implementation, each output neuron is fully connected to every lif and alif neuron. """ function (n::linearNeuron)(kfn::T) where T<:knowledgeFn n.timeStep = kfn.timeStep # pulling other neuron's firing status at time t n.z_i_t = getindex(kfn.firedNeurons_t0, n.subscriptionList) n.z_i_t_commulative += n.z_i_t if n.refractoryCounter != 0 n.refractoryCounter -= 1 # neuron is in refractory state, skip all calculation n.z_t1 = false # used by timestep_forward() in kfn. Set to zero because neuron spike # last only 1 timestep follow by a period of refractory. n.recSignal = n.recSignal * 0.0 # decay of v_t1 n.v_t1 = n.alpha * n.v_t n.phi = 0.0 n.decayedEpsilonRec = n.alpha * n.epsilonRec n.epsilonRec = n.decayedEpsilonRec else n.recSignal = sum(n.wRec .* n.z_i_t) # signal from other neuron that this neuron subscribed # computeAlpha!(n) n.alpha_v_t = n.alpha * n.v_t n.v_t1 = n.alpha_v_t + n.recSignal # n.v_t1 = no_negative!(n.v_t1) if n.v_t1 > n.v_th n.z_t1 = true n.refractoryCounter = n.refractoryDuration n.firingCounter += 1 n.v_t1 = n.vRest else n.z_t1 = false end # there is a difference from alif formula n.phi = (n.gammaPd / n.v_th) * max(0, 1 - (n.v_t1 - n.v_th) / n.v_th) n.decayedEpsilonRec = n.alpha * n.epsilonRec n.epsilonRec = n.decayedEpsilonRec + n.z_i_t end end #------------------------------------------------------------------------------------------------100 """ integrateNeuron forward() """ function (n::integrateNeuron)(kfn::knowledgeFn) n.timeStep = kfn.timeStep # pulling other neuron's firing status at time t n.z_i_t = getindex(kfn.firedNeurons_t0, n.subscriptionList) n.z_i_t_commulative += n.z_i_t n.recSignal = sum(n.wRec .* n.z_i_t) # signal from other neuron that this neuron # computeAlpha!(n) n.alpha_v_t = n.alpha * n.v_t if n.recSignal <= 0 n.v_t1 = n.alpha_v_t else n.v_t1 = n.alpha_v_t + n.recSignal + n.b end # there is a difference from alif formula n.decayedEpsilonRec = n.alpha * n.epsilonRec n.epsilonRec = n.decayedEpsilonRec + n.z_i_t end end # end module