refractoring
This commit is contained in:
105
src/forward.jl
105
src/forward.jl
@@ -12,7 +12,7 @@ using ..types, ..snn_utils
|
||||
"""
|
||||
function (m::model)(input_data::AbstractVector)
|
||||
# m.global_tick += 1
|
||||
m.time_stamp += 1
|
||||
m.timeStep += 1
|
||||
|
||||
# process all corresponding KFN
|
||||
raw_model_respond = m.knowledgeFn[:I](m, input_data)
|
||||
@@ -28,9 +28,9 @@ end
|
||||
"""
|
||||
|
||||
function (kfn::kfn_1)(m::model, input_data::AbstractVector)
|
||||
kfn.time_stamp = m.time_stamp
|
||||
kfn.timeStep = m.timeStep
|
||||
kfn.softreset = m.softreset
|
||||
kfn.learning_stage = m.learning_stage
|
||||
kfn.learningStage = m.learningStage
|
||||
kfn.error = m.error
|
||||
|
||||
# generate noise
|
||||
@@ -40,53 +40,38 @@ function (kfn::kfn_1)(m::model, input_data::AbstractVector)
|
||||
|
||||
input_data = [noise; input_data] # noise start from neuron id 1
|
||||
|
||||
for n in kfn.neurons_array
|
||||
for n in kfn.neuronsArray
|
||||
timestep_forward!(n)
|
||||
end
|
||||
for n in kfn.output_neurons_array
|
||||
for n in kfn.outputNeuronsArray
|
||||
timestep_forward!(n)
|
||||
end
|
||||
|
||||
kfn.learning_stage = m.learning_stage
|
||||
if kfn.learning_stage == "start_learning"
|
||||
# reset params here instead of at the end_learning so that neuron's parameter data
|
||||
# don't gets wiped and can be logged for visualization later
|
||||
for n in kfn.neurons_array
|
||||
# epsilon_rec need to be reset because it counting how many each synaptic fires and
|
||||
# use this info to calculate how much synaptic weight should be adjust
|
||||
reset_learning_params!(n)
|
||||
end
|
||||
|
||||
# clear variables
|
||||
kfn.firing_neurons_list = Vector{Int64}()
|
||||
kfn.outputs = nothing
|
||||
end
|
||||
|
||||
# pass input_data into input neuron.
|
||||
# number of data point equals to number of input neuron starting from id 1
|
||||
for (i, data) in enumerate(input_data)
|
||||
kfn.neurons_array[i].z_t1 = data
|
||||
kfn.neuronsArray[i].z_t1 = data
|
||||
end
|
||||
|
||||
kfn.snn_firing_state_t0 = [n.z_t for n in kfn.neurons_array] #TODO check if it is used?
|
||||
kfn.firedNeurons_t0 = [n.z_t for n in kfn.neuronsArray] #TODO check if it is used?
|
||||
|
||||
#CHANGE Threads.@threads for n in kfn.neurons_array
|
||||
for n in kfn.neurons_array
|
||||
#CHANGE Threads.@threads for n in kfn.neuronsArray
|
||||
for n in kfn.neuronsArray
|
||||
n(kfn)
|
||||
end
|
||||
|
||||
kfn.snn_firing_state_t1 = [n.z_t1 for n in kfn.neurons_array]
|
||||
append!(kfn.firing_neurons_list, findall(kfn.snn_firing_state_t1)) # store id of neuron that fires
|
||||
if kfn.learning_stage == "end_learning" # use for random new neuron connection
|
||||
kfn.firing_neurons_list |> unique!
|
||||
kfn.firedNeurons_t1 = [n.z_t1 for n in kfn.neuronsArray]
|
||||
append!(kfn.firedNeurons, findall(kfn.firedNeurons_t1)) # store id of neuron that fires
|
||||
if kfn.learningStage == "end_learning"
|
||||
kfn.firedNeurons |> unique! # use for random new neuron connection
|
||||
end
|
||||
|
||||
# Threads.@threads for n in kfn.output_neurons_array
|
||||
for n in kfn.output_neurons_array
|
||||
# Threads.@threads for n in kfn.outputNeuronsArray
|
||||
for n in kfn.outputNeuronsArray
|
||||
n(kfn)
|
||||
end
|
||||
|
||||
out = [n.out_t1 for n in kfn.output_neurons_array]
|
||||
out = [n.out_t1 for n in kfn.outputNeuronsArray]
|
||||
|
||||
return out
|
||||
end
|
||||
@@ -96,7 +81,7 @@ end
|
||||
""" passthrough_neuron forward()
|
||||
"""
|
||||
function (n::passthrough_neuron)(kfn::knowledgeFn)
|
||||
n.time_stamp = kfn.time_stamp
|
||||
n.timeStep = kfn.timeStep
|
||||
# n.global_tick = kfn.global_tick
|
||||
end
|
||||
|
||||
@@ -105,40 +90,40 @@ end
|
||||
""" lif_neuron forward()
|
||||
"""
|
||||
function (n::lif_neuron)(kfn::knowledgeFn)
|
||||
n.time_stamp = kfn.time_stamp
|
||||
n.timeStep = kfn.timeStep
|
||||
|
||||
# pulling other neuron's firing status at time t
|
||||
n.z_i_t = getindex(kfn.snn_firing_state_t0, n.subscription_list)
|
||||
n.z_i_t .*= n.sub_ExIn_type
|
||||
n.z_i_t = getindex(kfn.firedNeurons_t0, n.subscriptionList)
|
||||
n.z_i_t .*= n.subExInType
|
||||
|
||||
if n.refractory_counter != 0
|
||||
n.refractory_counter -= 1
|
||||
if n.refractoryCounter != 0
|
||||
n.refractoryCounter -= 1
|
||||
|
||||
# neuron is in refractory state, skip all calculation
|
||||
n.z_t1 = false # used by timestep_forward() in kfn. Set to zero because neuron spike
|
||||
# last only 1 timestep follow by a period of refractory.
|
||||
n.recurrent_signal = n.recurrent_signal * 0.0
|
||||
n.recSignal = n.recSignal * 0.0
|
||||
|
||||
# Exponantial decay of v_t1
|
||||
n.v_t1 = n.v_t * n.alpha^(n.time_stamp - n.last_firing_time) # or n.v_t1 = n.alpha * n.v_t
|
||||
n.v_t1 = n.v_t * n.alpha^(n.timeStep - n.lastFiringTime) # or n.v_t1 = n.alpha * n.v_t
|
||||
else
|
||||
n.recurrent_signal = sum(n.w_rec .* n.z_i_t) # signal from other neuron that this neuron subscribed
|
||||
n.recSignal = sum(n.w_rec .* n.z_i_t) # signal from other neuron that this neuron subscribed
|
||||
|
||||
n.alpha_v_t = n.alpha * n.v_t
|
||||
n.v_t1 = n.alpha_v_t + n.recurrent_signal
|
||||
n.v_t1 = n.alpha_v_t + n.recSignal
|
||||
n.v_t1 = no_negative!.(n.v_t1)
|
||||
|
||||
if n.v_t1 > n.v_th
|
||||
n.z_t1 = true
|
||||
n.refractory_counter = n.refractory_duration
|
||||
n.firing_counter += 1
|
||||
n.v_t1 = n.v_t1 - n.v_th
|
||||
n.refractoryCounter = n.refractoryDuration
|
||||
n.firingCounter += 1
|
||||
n.v_t1 = n.vRest
|
||||
else
|
||||
n.z_t1 = false
|
||||
end
|
||||
|
||||
# there is a difference from alif formula
|
||||
n.phi = (n.gamma_pd / n.v_th) * max(0, 1 - (n.v_t1 - n.v_th) / n.v_th)
|
||||
n.phi = (n.gammaPd / n.v_th) * max(0, 1 - (n.v_t1 - n.v_th) / n.v_th)
|
||||
end
|
||||
end
|
||||
|
||||
@@ -147,41 +132,41 @@ end
|
||||
""" alif_neuron forward()
|
||||
"""
|
||||
function (n::alif_neuron)(kfn::knowledgeFn)
|
||||
n.time_stamp = kfn.time_stamp
|
||||
n.timeStep = kfn.timeStep
|
||||
|
||||
n.z_i_t = getindex(kfn.snn_firing_state_t0, n.subscription_list)
|
||||
n.z_i_t .*= n.sub_ExIn_type
|
||||
n.z_i_t = getindex(kfn.firedNeurons_t0, n.subscriptionList)
|
||||
n.z_i_t .*= n.subExInType
|
||||
|
||||
if n.refractory_counter != 0
|
||||
n.refractory_counter -= 1
|
||||
if n.refractoryCounter != 0
|
||||
n.refractoryCounter -= 1
|
||||
|
||||
# neuron is in refractory state, skip all calculation
|
||||
n.z_t1 = false # used by timestep_forward() in kfn. Set to zero because neuron spike last only 1 timestep follow by a period of refractory.
|
||||
n.a = (n.rho * n.a) + ((1 - n.rho) * n.z_t)
|
||||
n.recurrent_signal = n.recurrent_signal * 0.0
|
||||
n.recSignal = n.recSignal * 0.0
|
||||
|
||||
# Exponantial decay of v_t1
|
||||
n.v_t1 = n.v_t * n.alpha^(n.time_stamp - n.last_firing_time) # or n.v_t1 = n.alpha * n.v_t
|
||||
n.v_t1 = n.v_t * n.alpha^(n.timeStep - n.lastFiringTime) # or n.v_t1 = n.alpha * n.v_t
|
||||
n.phi = 0
|
||||
else
|
||||
n.z_t = isnothing(n.z_t) ? false : n.z_t
|
||||
n.a = (n.rho * n.a) + ((1 - n.rho) * n.z_t)
|
||||
n.av_th = n.v_th + (n.beta * n.a)
|
||||
n.recurrent_signal = sum(n.w_rec .* n.z_i_t) # signal from other neuron that this neuron subscribed
|
||||
n.recSignal = sum(n.w_rec .* n.z_i_t) # signal from other neuron that this neuron subscribed
|
||||
n.alpha_v_t = n.alpha * n.v_t
|
||||
n.v_t1 = n.alpha_v_t + n.recurrent_signal
|
||||
n.v_t1 = n.alpha_v_t + n.recSignal
|
||||
n.v_t1 = no_negative!.(n.v_t1)
|
||||
if n.v_t1 > n.av_th
|
||||
n.z_t1 = true
|
||||
n.refractory_counter = n.refractory_duration
|
||||
n.firing_counter += 1
|
||||
n.v_t1 = n.v_t1 - n.v_th
|
||||
n.refractoryCounter = n.refractoryDuration
|
||||
n.firingCounter += 1
|
||||
n.v_t1 = n.vRest
|
||||
else
|
||||
n.z_t1 = false
|
||||
end
|
||||
|
||||
# there is a difference from lif formula
|
||||
n.phi = (n.gamma_pd / n.v_th) * max(0, 1 - (n.v_t1 - n.av_th) / n.v_th)
|
||||
n.phi = (n.gammaPd / n.v_th) * max(0, 1 - (n.v_t1 - n.av_th) / n.v_th)
|
||||
end
|
||||
end
|
||||
|
||||
@@ -191,8 +176,8 @@ end
|
||||
In this implementation, each output neuron is fully connected to every lif and alif neuron.
|
||||
"""
|
||||
function (n::linear_neuron)(kfn::T) where T<:knowledgeFn
|
||||
n.time_stamp = kfn.time_stamp
|
||||
n.out_t1 = getindex(kfn.snn_firing_state_t1, n.subscription_list)[1]
|
||||
n.timeStep = kfn.timeStep
|
||||
n.out_t1 = getindex(kfn.firedNeurons_t1, n.subscriptionList)[1]
|
||||
end
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user