module snn_utils using Flux.Optimise: apply! export calculate_α, calculate_ρ, calculate_k, timestep_forward!, init_neuron, no_negative!, precision, calculate_w_change!, store_knowledgefn_error!, interneurons_adjustment!, reset_z_t!, reset_learning_params!, reset_learning_history_params!, cal_v_reg!, calculate_w_change_end!, firing_rate_error!, firing_rate_regulator!, update_Bn!, cal_firing_reg!, neuroplasticity!, shakeup!, reset_learning_no_wchange!, adjust_internal_learning_rate!, gradient_withloss using Statistics, Random, LinearAlgebra, Distributions, Zygote using ..types #------------------------------------------------------------------------------------------------100 function timestep_forward!(x::passthrough_neuron) x.z_t = x.z_t1 end function timestep_forward!(x::compute_neuron) x.z_t = x.z_t1 x.v_t = x.v_t1 end function timestep_forward!(x::linear_neuron) x.out_t = x.out_t1 end no_negative!(x) = x < 0.0 ? 0.0 : x precision(x::Array{<:Array}) = ( std(mean.(x)) / mean(mean.(x)) ) * 100 # reset functions for LIF/ALIF neuron reset_last_firing_time!(n::compute_neuron) = n.last_firing_time = 0.0 reset_refractory_state_active!(n::compute_neuron) = n.refractory_state_active = false reset_v_t!(n::compute_neuron) = n.v_t = n.v_t_default reset_z_t!(n::compute_neuron) = n.z_t = false reset_epsilon_rec!(n::compute_neuron) = n.epsilon_rec = n.epsilon_rec * 0.0 reset_epsilon_rec_a!(n::alif_neuron) = n.epsilon_rec_a = n.epsilon_rec_a * 0.0 reset_epsilon_in!(n::compute_neuron) = n.epsilon_in = isnothing(n.epsilon_in) ? nothing : n.epsilon_in * 0.0 reset_error!(n::Union{compute_neuron, linear_neuron}) = n.error = nothing reset_w_in_change!(n::compute_neuron) = n.w_in_change = isnothing(n.w_in_change) ? nothing : n.w_in_change * 0.0 reset_w_rec_change!(n::compute_neuron) = n.w_rec_change = n.w_rec_change * 0.0 reset_a!(n::alif_neuron) = n.a = n.a * 0.0 reset_reg_voltage_a!(n::compute_neuron) = n.reg_voltage_a = n.reg_voltage_a * 0.0 reset_reg_voltage_b!(n::compute_neuron) = n.reg_voltage_b = n.reg_voltage_b * 0.0 reset_reg_voltage_error!(n::compute_neuron) = n.reg_voltage_error = n.reg_voltage_error * 0.0 reset_firing_counter!(n::compute_neuron) = n.firing_counter = n.firing_counter * 0.0 reset_firing_diff!(n::Union{compute_neuron, linear_neuron}) = n.firing_diff = n.firing_diff * 0.0 reset_previous_error!(n::Union{compute_neuron}) = n.previous_error = n.previous_error * 0.0 # reset function for output neuron reset_epsilon_j!(n::linear_neuron) = n.epsilon_j = n.epsilon_j * 0.0 reset_out_t!(n::linear_neuron) = n.out_t = n.out_t * 0.0 reset_w_out_change!(n::linear_neuron) = n.w_out_change = n.w_out_change * 0.0 reset_b_change!(n::linear_neuron) = n.b_change = n.b_change * 0.0 """ Reset a part of learning-related params that used to collect learning history during learning session """ # function reset_learning_no_wchange!(n::lif_neuron) # reset_epsilon_rec!(n) # # reset_v_t!(n) # # reset_z_t!(n) # # reset_reg_voltage_a!(n) # # reset_reg_voltage_b!(n) # # reset_reg_voltage_error!(n) # reset_firing_counter!(n) # reset_firing_diff!(n) # reset_previous_error!(n) # reset_error!(n) # # # reset refractory state at the end of episode. Otherwise once neuron goes into refractory state, # # # it will stay in refractory state forever # # reset_refractory_state_active!(n) # end # function reset_learning_no_wchange!(n::Union{alif_neuron, elif_neuron}) # reset_epsilon_rec!(n) # reset_epsilon_rec_a!(n) # reset_v_t!(n) # reset_z_t!(n) # # reset_a!(n) # reset_reg_voltage_a!(n) # reset_reg_voltage_b!(n) # reset_reg_voltage_error!(n) # reset_firing_counter!(n) # reset_firing_diff!(n) # reset_previous_error!(n) # reset_error!(n) # # reset refractory state at the end of episode. Otherwise once neuron goes into refractory state, # # it will stay in refractory state forever # reset_refractory_state_active!(n) # end # function reset_learning_no_wchange!(n::linear_neuron) # reset_epsilon_j!(n) # reset_out_t!(n) # reset_error!(n) # end """ Reset all learning-related params at the END of learning session """ function reset_learning_params!(n::lif_neuron) reset_epsilon_rec!(n) reset_w_rec_change!(n) # reset_v_t!(n) # reset_z_t!(n) # reset_reg_voltage_a!(n) # reset_reg_voltage_b!(n) # reset_reg_voltage_error!(n) reset_firing_counter!(n) reset_firing_diff!(n) reset_previous_error!(n) reset_error!(n) # # reset refractory state at the end of episode. Otherwise once neuron goes into refractory state, # # it will stay in refractory state forever # reset_refractory_state_active!(n) end function reset_learning_params!(n::alif_neuron) reset_epsilon_rec!(n) reset_epsilon_rec_a!(n) reset_w_rec_change!(n) # reset_v_t!(n) # reset_z_t!(n) # reset_a!(n) # reset_reg_voltage_a!(n) # reset_reg_voltage_b!(n) # reset_reg_voltage_error!(n) reset_firing_counter!(n) reset_firing_diff!(n) reset_previous_error!(n) reset_error!(n) # # reset refractory state at the end of episode. Otherwise once neuron goes into refractory state, # # it will stay in refractory state forever # reset_refractory_state_active!(n) end # function reset_learning_no_wchange!(n::passthrough_neuron) # end function reset_learning_params!(n::passthrough_neuron) # skip end #------------------------------------------------------------------------------------------------100 function store_knowledgefn_error!(kfn::knowledgeFn) # condition to adjust nueron in KFN plane in addition to weight adjustment inside each neuron if kfn.learning_stage == "start_learning" if kfn.recent_knowledgeFn_error === nothing && kfn.knowledgeFn_error === nothing kfn.recent_knowledgeFn_error = [[]] elseif kfn.recent_knowledgeFn_error === nothing kfn.recent_knowledgeFn_error = [[kfn.knowledgeFn_error]] elseif kfn.recent_knowledgeFn_error !== nothing && kfn.knowledgeFn_error === nothing push!(kfn.recent_knowledgeFn_error, []) else push!(kfn.recent_knowledgeFn_error, [kfn.knowledgeFn_error]) end elseif kfn.learning_stage == "during_learning" if kfn.knowledgeFn_error === nothing #skip else push!(kfn.recent_knowledgeFn_error[end], kfn.knowledgeFn_error) end elseif kfn.learning_stage == "end_learning" if kfn.recent_knowledgeFn_error === nothing #skip else push!(kfn.recent_knowledgeFn_error[end], kfn.knowledgeFn_error) end else error("case does not defined yet") end if length(kfn.recent_knowledgeFn_error) > 3 deleteat!(kfn.recent_knowledgeFn_error, 1) end end function update_Bn!(kfn::knowledgeFn) Δw = nothing for n in kfn.output_neurons_array Δw = Δw === nothing ? n.w_out_change : Δw + n.w_out_change n.w_out = n.w_out - (n.Bn_wout_decay * n.w_out) # w_out decay end # Δw = Δw / kfn.kfn_params[:linear_neuron_number] # average input_neuron_number = kfn.kfn_params[:input_neuron_number] # skip input neuron for i = 1:kfn.kfn_params[:compute_neuron_number] n = kfn.neurons_array[input_neuron_number+i] n.Bn = n.Bn + Δw[i] n.Bn = n.Bn - (n.Bn_wout_decay * n.Bn) # w_out decay end end """ Regulates membrane potential to stay under v_th, output is weight change """ function cal_v_reg!(n::lif_neuron) # retified linear function component_a1 = n.v_t1 - n.v_th < 0 ? 0 : (n.v_t1 - n.v_th)^2 component_a2 = -n.v_t1 - n.v_th < 0 ? 0 : (-n.v_t1 - n.v_th)^2 n.reg_voltage_a = n.reg_voltage_a + component_a1 + component_a2 component_b = n.v_t1 - n.v_th < 0 ? 0 : n.v_t1 - n.v_th #FIXME: not sure the following line is correct n.reg_voltage_b = n.reg_voltage_b + (component_b * n.epsilon_rec) end function cal_v_reg!(n::alif_neuron) # retified linear function component_a1 = n.v_t1 - n.av_th < 0 ? 0 : (n.v_t1 - n.av_th)^2 component_a2 = -n.v_t1 - n.av_th < 0 ? 0 : (-n.v_t1 - n.av_th)^2 n.reg_voltage_a = n.reg_voltage_a + component_a1 + component_a2 component_b = n.v_t1 - n.av_th < 0 ? 0 : n.v_t1 - n.av_th #FIXME: not sure the following line is correct n.reg_voltage_b = n.reg_voltage_b + (component_b * (n.epsilon_rec - n.epsilon_rec_a)) end function voltage_error!(n::compute_neuron) n.reg_voltage_error = 0.5 * n.reg_voltage_a return n.reg_voltage_error end function voltage_regulator!(n::compute_neuron) # running average Δw = n.optimiser.eta * n.c_reg_v * n.reg_voltage_b return Δw end function firing_rate_error(kfn::knowledgeFn) start_id = kfn.kfn_params[:input_neuron_number] + 1 return 0.5 * sum([(n.firing_diff)^2 for n in kfn.neurons_array[start_id:end]]) end function firing_rate_regulator!(n::compute_neuron) # n.firing_rate NOT running average (average over learning batch) Δw = n.optimiser.eta * n.c_reg * (n.firing_rate - n.firing_rate_target) * n.e_rec Δw = n.firing_rate > n.firing_rate_target ? Δw : Δw * 0.0 return Δw end firing_rate!(n::compute_neuron) = n.firing_rate = (n.firing_counter / n.time_stamp) * 1000 firing_diff!(n::compute_neuron) = n.firing_diff = n.firing_rate - n.firing_rate_target function neuroplasticity!(n::compute_neuron, firing_neurons_list::Vector) # if there is 0-weight then replace it with new connection zero_weight_index = findall(iszero.(n.w_rec)) if length(zero_weight_index) != 0 """ sampling new connection from list of neurons that fires instead of ramdom choose from all compute neuron because there is no point to connect to neuron that not fires i.e. not fire = no information """ subscribe_options = filter(x -> x ∉ [n.id], firing_neurons_list) # exclude this neuron id from the list filter!(x -> x ∉ n.subscription_list, subscribe_options) # exclude this neuron's subscription_list from the list shuffle!(subscribe_options) end new_connection_percent = 10 - ((n.optimiser.eta / 0.0001) / 10) # percent is in range 0.1 to 10 percentage = [new_connection_percent, 100.0 - new_connection_percent] / 100.0 for i in zero_weight_index if Utils.random_choices([true, false], percentage) n.subscription_list[i] = pop!(subscribe_options) n.w_rec[i] = 0.01 # new connection should not send large signal otherwise it would throw # RSNN off path. Let weight grow by an optimiser end end end function adjust_internal_learning_rate!(n::compute_neuron) n.internal_learning_rate = n.error_diff[end] < 0.0 ? n.internal_learning_rate * 0.99 : n.internal_learning_rate * 1.005 end function push_epsilon_rec_a!(n::lif_neuron) # skip end function push_epsilon_rec_a!(n::alif_neuron) push!(n.epsilon_rec_a, 0) end end # end module