refractoring

This commit is contained in:
2023-05-10 20:38:23 +07:00
commit 7c4a0dfa6f
15 changed files with 3195 additions and 0 deletions

319
src/snn_utils.jl Normal file
View File

@@ -0,0 +1,319 @@
module snn_utils
using Flux.Optimise: apply!
export calculate_α, calculate_ρ, calculate_k, timestep_forward!, init_neuron, no_negative,
precision, calculate_w_change!, store_knowledgefn_error!, interneurons_adjustment!,
reset_z_t!, reset_learning_params!, reset_learning_history_params!,
cal_v_reg!, calculate_w_change_end!,
firing_rate_error!, firing_rate_regulator!, update_Bn!, cal_firing_reg!,
neuroplasticity!, shakeup!, reset_learning_no_wchange!, adjust_internal_learning_rate!,
gradient_withloss
using Statistics, Random, LinearAlgebra, Distributions, Zygote
using ..types
#------------------------------------------------------------------------------------------------100
function timestep_forward!(x::passthrough_neuron)
x.z_t = x.z_t1
end
function timestep_forward!(x::compute_neuron)
x.z_t = x.z_t1
x.v_t = x.v_t1
end
function timestep_forward!(x::linear_neuron)
x.out_t = x.out_t1
end
no_negative(n) = n < 0.0 ? 0.0 : x
precision(x::Array{<:Array}) = ( std(mean.(x)) / mean(mean.(x)) ) * 100
# reset functions for LIF/ALIF neuron
reset_last_firing_time!(n::compute_neuron) = n.last_firing_time = 0.0
reset_refractory_state_active!(n::compute_neuron) = n.refractory_state_active = false
reset_v_t!(n::compute_neuron) = n.v_t = n.v_t_default
reset_z_t!(n::compute_neuron) = n.z_t = false
reset_epsilon_rec!(n::compute_neuron) = n.epsilon_rec = n.epsilon_rec * 0.0
reset_epsilon_rec_a!(n::alif_neuron) = n.epsilon_rec_a = n.epsilon_rec_a * 0.0
reset_epsilon_in!(n::compute_neuron) = n.epsilon_in = isnothing(n.epsilon_in) ? nothing : n.epsilon_in * 0.0
reset_error!(n::Union{compute_neuron, linear_neuron}) = n.error = nothing
reset_w_in_change!(n::compute_neuron) = n.w_in_change = isnothing(n.w_in_change) ? nothing : n.w_in_change * 0.0
reset_w_rec_change!(n::compute_neuron) = n.w_rec_change = n.w_rec_change * 0.0
reset_a!(n::alif_neuron) = n.a = n.a * 0.0
reset_reg_voltage_a!(n::compute_neuron) = n.reg_voltage_a = n.reg_voltage_a * 0.0
reset_reg_voltage_b!(n::compute_neuron) = n.reg_voltage_b = n.reg_voltage_b * 0.0
reset_reg_voltage_error!(n::compute_neuron) = n.reg_voltage_error = n.reg_voltage_error * 0.0
reset_firing_counter!(n::compute_neuron) = n.firing_counter = n.firing_counter * 0.0
reset_firing_diff!(n::Union{compute_neuron, linear_neuron}) = n.firing_diff = n.firing_diff * 0.0
reset_previous_error!(n::Union{compute_neuron}) =
n.previous_error = n.previous_error * 0.0
# reset function for output neuron
reset_epsilon_j!(n::linear_neuron) = n.epsilon_j = n.epsilon_j * 0.0
reset_out_t!(n::linear_neuron) = n.out_t = n.out_t * 0.0
reset_w_out_change!(n::linear_neuron) = n.w_out_change = n.w_out_change * 0.0
reset_b_change!(n::linear_neuron) = n.b_change = n.b_change * 0.0
""" Reset a part of learning-related params that used to collect learning history during learning
session
"""
# function reset_learning_no_wchange!(n::lif_neuron)
# reset_epsilon_rec!(n)
# # reset_v_t!(n)
# # reset_z_t!(n)
# # reset_reg_voltage_a!(n)
# # reset_reg_voltage_b!(n)
# # reset_reg_voltage_error!(n)
# reset_firing_counter!(n)
# reset_firing_diff!(n)
# reset_previous_error!(n)
# reset_error!(n)
# # # reset refractory state at the end of episode. Otherwise once neuron goes into refractory state,
# # # it will stay in refractory state forever
# # reset_refractory_state_active!(n)
# end
# function reset_learning_no_wchange!(n::Union{alif_neuron, elif_neuron})
# reset_epsilon_rec!(n)
# reset_epsilon_rec_a!(n)
# reset_v_t!(n)
# reset_z_t!(n)
# # reset_a!(n)
# reset_reg_voltage_a!(n)
# reset_reg_voltage_b!(n)
# reset_reg_voltage_error!(n)
# reset_firing_counter!(n)
# reset_firing_diff!(n)
# reset_previous_error!(n)
# reset_error!(n)
# # reset refractory state at the end of episode. Otherwise once neuron goes into refractory state,
# # it will stay in refractory state forever
# reset_refractory_state_active!(n)
# end
# function reset_learning_no_wchange!(n::linear_neuron)
# reset_epsilon_j!(n)
# reset_out_t!(n)
# reset_error!(n)
# end
""" Reset all learning-related params at the END of learning session
"""
function reset_learning_params!(n::lif_neuron)
reset_epsilon_rec!(n)
reset_w_rec_change!(n)
# reset_v_t!(n)
# reset_z_t!(n)
# reset_reg_voltage_a!(n)
# reset_reg_voltage_b!(n)
# reset_reg_voltage_error!(n)
reset_firing_counter!(n)
reset_firing_diff!(n)
reset_previous_error!(n)
reset_error!(n)
# # reset refractory state at the end of episode. Otherwise once neuron goes into refractory state,
# # it will stay in refractory state forever
# reset_refractory_state_active!(n)
end
function reset_learning_params!(n::alif_neuron)
reset_epsilon_rec!(n)
reset_epsilon_rec_a!(n)
reset_w_rec_change!(n)
# reset_v_t!(n)
# reset_z_t!(n)
# reset_a!(n)
# reset_reg_voltage_a!(n)
# reset_reg_voltage_b!(n)
# reset_reg_voltage_error!(n)
reset_firing_counter!(n)
reset_firing_diff!(n)
reset_previous_error!(n)
reset_error!(n)
# # reset refractory state at the end of episode. Otherwise once neuron goes into refractory state,
# # it will stay in refractory state forever
# reset_refractory_state_active!(n)
end
# function reset_learning_no_wchange!(n::passthrough_neuron)
# end
function reset_learning_params!(n::passthrough_neuron)
# skip
end
#------------------------------------------------------------------------------------------------100
function store_knowledgefn_error!(kfn::knowledgeFn)
# condition to adjust nueron in KFN plane in addition to weight adjustment inside each neuron
if kfn.learning_stage == "start_learning"
if kfn.recent_knowledgeFn_error === nothing && kfn.knowledgeFn_error === nothing
kfn.recent_knowledgeFn_error = [[]]
elseif kfn.recent_knowledgeFn_error === nothing
kfn.recent_knowledgeFn_error = [[kfn.knowledgeFn_error]]
elseif kfn.recent_knowledgeFn_error !== nothing && kfn.knowledgeFn_error === nothing
push!(kfn.recent_knowledgeFn_error, [])
else
push!(kfn.recent_knowledgeFn_error, [kfn.knowledgeFn_error])
end
elseif kfn.learning_stage == "during_learning"
if kfn.knowledgeFn_error === nothing
#skip
else
push!(kfn.recent_knowledgeFn_error[end], kfn.knowledgeFn_error)
end
elseif kfn.learning_stage == "end_learning"
if kfn.recent_knowledgeFn_error === nothing
#skip
else
push!(kfn.recent_knowledgeFn_error[end], kfn.knowledgeFn_error)
end
else
error("case does not defined yet")
end
if length(kfn.recent_knowledgeFn_error) > 3
deleteat!(kfn.recent_knowledgeFn_error, 1)
end
end
function update_Bn!(kfn::knowledgeFn)
Δw = nothing
for n in kfn.output_neurons_array
Δw = Δw === nothing ? n.w_out_change : Δw + n.w_out_change
n.w_out = n.w_out - (n.Bn_wout_decay * n.w_out) # w_out decay
end
# Δw = Δw / kfn.kfn_params[:linear_neuron_number] # average
input_neuron_number = kfn.kfn_params[:input_neuron_number] # skip input neuron
for i = 1:kfn.kfn_params[:compute_neuron_number]
n = kfn.neurons_array[input_neuron_number+i]
n.Bn = n.Bn + Δw[i]
n.Bn = n.Bn - (n.Bn_wout_decay * n.Bn) # w_out decay
end
end
""" Regulates membrane potential to stay under v_th, output is weight change
"""
function cal_v_reg!(n::lif_neuron)
# retified linear function
component_a1 = n.v_t1 - n.v_th < 0 ? 0 : (n.v_t1 - n.v_th)^2
component_a2 = -n.v_t1 - n.v_th < 0 ? 0 : (-n.v_t1 - n.v_th)^2
n.reg_voltage_a = n.reg_voltage_a + component_a1 + component_a2
component_b = n.v_t1 - n.v_th < 0 ? 0 : n.v_t1 - n.v_th
#FIXME: not sure the following line is correct
n.reg_voltage_b = n.reg_voltage_b + (component_b * n.epsilon_rec)
end
function cal_v_reg!(n::alif_neuron)
# retified linear function
component_a1 = n.v_t1 - n.av_th < 0 ? 0 : (n.v_t1 - n.av_th)^2
component_a2 = -n.v_t1 - n.av_th < 0 ? 0 : (-n.v_t1 - n.av_th)^2
n.reg_voltage_a = n.reg_voltage_a + component_a1 + component_a2
component_b = n.v_t1 - n.av_th < 0 ? 0 : n.v_t1 - n.av_th
#FIXME: not sure the following line is correct
n.reg_voltage_b = n.reg_voltage_b + (component_b * (n.epsilon_rec - n.epsilon_rec_a))
end
function voltage_error!(n::compute_neuron)
n.reg_voltage_error = 0.5 * n.reg_voltage_a
return n.reg_voltage_error
end
function voltage_regulator!(n::compute_neuron) # running average
Δw = n.optimiser.eta * n.c_reg_v * n.reg_voltage_b
return Δw
end
function firing_rate_error(kfn::knowledgeFn)
start_id = kfn.kfn_params[:input_neuron_number] + 1
return 0.5 * sum([(n.firing_diff)^2 for n in kfn.neurons_array[start_id:end]])
end
function firing_rate_regulator!(n::compute_neuron)
# n.firing_rate NOT running average (average over learning batch)
Δw = n.optimiser.eta * n.c_reg *
(n.firing_rate - n.firing_rate_target) * n.e_rec
Δw = n.firing_rate > n.firing_rate_target ? Δw : Δw * 0.0
return Δw
end
firing_rate!(n::compute_neuron) = n.firing_rate = (n.firing_counter / n.time_stamp) * 1000
firing_diff!(n::compute_neuron) = n.firing_diff = n.firing_rate - n.firing_rate_target
function neuroplasticity!(n::compute_neuron, firing_neurons_list::Vector)
# if there is 0-weight then replace it with new connection
zero_weight_index = findall(iszero.(n.w_rec))
if length(zero_weight_index) != 0
""" sampling new connection from list of neurons that fires instead of ramdom choose from
all compute neuron because there is no point to connect to neuron that not fires i.e.
not fire = no information
"""
subscribe_options = filter(x -> x [n.id], firing_neurons_list) # exclude this neuron id from the list
filter!(x -> x n.subscription_list, subscribe_options) # exclude this neuron's subscription_list from the list
shuffle!(subscribe_options)
end
new_connection_percent = 10 - ((n.optimiser.eta / 0.0001) / 10) # percent is in range 0.1 to 10
percentage = [new_connection_percent, 100.0 - new_connection_percent] / 100.0
for i in zero_weight_index
if Utils.random_choices([true, false], percentage)
n.subscription_list[i] = pop!(subscribe_options)
n.w_rec[i] = 0.01 # new connection should not send large signal otherwise it would throw
# RSNN off path. Let weight grow by an optimiser
end
end
end
function adjust_internal_learning_rate!(n::compute_neuron)
n.internal_learning_rate = n.error_diff[end] < 0.0 ? n.internal_learning_rate * 0.99 :
n.internal_learning_rate * 1.005
end
function push_epsilon_rec_a!(n::lif_neuron)
# skip
end
function push_epsilon_rec_a!(n::alif_neuron)
push!(n.epsilon_rec_a, 0)
end
end # end module