357 lines
12 KiB
Julia
357 lines
12 KiB
Julia
module snn_utils
|
||
|
||
using Flux.Optimise: apply!
|
||
export calculate_α, calculate_ρ, calculate_k, timestep_forward!, init_neuron, no_negative!,
|
||
precision, calculate_w_change!, store_knowledgefn_error!, interneurons_adjustment!,
|
||
reset_z_t!, resetLearningParams!, reset_learning_history_params!, reset_epsilonRec!,
|
||
reset_epsilonRecA!, synapticConnStrength!,
|
||
firing_rate_error!, firing_rate_regulator!, update_Bn!, cal_firing_reg!,
|
||
neuroplasticity!, shakeup!, reset_learning_no_wchange!, adjust_internal_learning_rate!,
|
||
gradient_withloss
|
||
|
||
using Statistics, Random, LinearAlgebra, Distributions, Zygote, Flux
|
||
using GeneralUtils
|
||
using ..types
|
||
|
||
#------------------------------------------------------------------------------------------------100
|
||
|
||
function timestep_forward!(x::passthrough_neuron)
|
||
x.z_t = x.z_t1
|
||
end
|
||
|
||
function timestep_forward!(x::compute_neuron)
|
||
x.z_t = x.z_t1
|
||
x.v_t = x.v_t1
|
||
end
|
||
|
||
no_negative!(x) = x < 0.0 ? 0.0 : x
|
||
precision(x::Array{<:Array}) = ( std(mean.(x)) / mean(mean.(x)) ) * 100
|
||
|
||
# reset functions for LIF/ALIF neuron
|
||
reset_last_firing_time!(n::compute_neuron) = n.lastFiringTime = 0.0
|
||
reset_refractory_state_active!(n::compute_neuron) = n.refractory_state_active = false
|
||
reset_v_t!(n::neuron) = n.v_t = n.vRest
|
||
reset_z_t!(n::compute_neuron) = n.z_t = false
|
||
reset_epsilonRec!(n::compute_neuron) = n.epsilonRec = n.epsilonRec * 0.0
|
||
reset_epsilonRec!(n::output_neuron) = n.epsilonRec = n.epsilonRec * 0.0
|
||
reset_epsilonRecA!(n::alif_neuron) = n.epsilonRecA = n.epsilonRecA * 0.0
|
||
reset_epsilon_in!(n::compute_neuron) = n.epsilon_in = isnothing(n.epsilon_in) ? nothing : n.epsilon_in * 0.0
|
||
reset_error!(n::Union{compute_neuron, linear_neuron}) = n.error = nothing
|
||
reset_w_in_change!(n::compute_neuron) = n.w_in_change = isnothing(n.w_in_change) ? nothing : n.w_in_change * 0.0
|
||
reset_wRecChange!(n::compute_neuron) = n.wRecChange = n.wRecChange * 0.0
|
||
reset_a!(n::alif_neuron) = n.a = n.a * 0.0
|
||
reset_reg_voltage_a!(n::compute_neuron) = n.reg_voltage_a = n.reg_voltage_a * 0.0
|
||
reset_reg_voltage_b!(n::compute_neuron) = n.reg_voltage_b = n.reg_voltage_b * 0.0
|
||
reset_reg_voltage_error!(n::compute_neuron) = n.reg_voltage_error = n.reg_voltage_error * 0.0
|
||
reset_firing_counter!(n::compute_neuron) = n.firingCounter = n.firingCounter * 0.0
|
||
reset_firing_diff!(n::Union{compute_neuron, linear_neuron}) = n.firingDiff = n.firingDiff * 0.0
|
||
reset_refractoryCounter!(n::Union{compute_neuron, linear_neuron}) = n.refractoryCounter = n.refractoryCounter * 0.0
|
||
|
||
# reset function for output neuron
|
||
reset_epsilon_j!(n::linear_neuron) = n.epsilon_j = n.epsilon_j * 0.0
|
||
reset_out_t!(n::linear_neuron) = n.out_t = n.out_t * 0.0
|
||
reset_w_out_change!(n::linear_neuron) = n.w_out_change = n.w_out_change * 0.0
|
||
reset_b_change!(n::linear_neuron) = n.b_change = n.b_change * 0.0
|
||
|
||
|
||
""" Reset a part of learning-related params that used to collect learning history during learning
|
||
session
|
||
"""
|
||
# function reset_learning_no_wchange!(n::lif_neuron)
|
||
# reset_epsilonRec!(n)
|
||
# # reset_v_t!(n)
|
||
# # reset_z_t!(n)
|
||
# # reset_reg_voltage_a!(n)
|
||
# # reset_reg_voltage_b!(n)
|
||
# # reset_reg_voltage_error!(n)
|
||
# reset_firing_counter!(n)
|
||
# reset_firing_diff!(n)
|
||
# reset_previous_error!(n)
|
||
# reset_error!(n)
|
||
|
||
# # # reset refractory state at the end of episode. Otherwise once neuron goes into refractory state,
|
||
# # # it will stay in refractory state forever
|
||
# # reset_refractory_state_active!(n)
|
||
# end
|
||
# function reset_learning_no_wchange!(n::Union{alif_neuron, elif_neuron})
|
||
# reset_epsilonRec!(n)
|
||
# reset_epsilonRecA!(n)
|
||
# reset_v_t!(n)
|
||
# reset_z_t!(n)
|
||
# # reset_a!(n)
|
||
# reset_reg_voltage_a!(n)
|
||
# reset_reg_voltage_b!(n)
|
||
# reset_reg_voltage_error!(n)
|
||
# reset_firing_counter!(n)
|
||
# reset_firing_diff!(n)
|
||
# reset_previous_error!(n)
|
||
# reset_error!(n)
|
||
|
||
# # reset refractory state at the end of episode. Otherwise once neuron goes into refractory state,
|
||
# # it will stay in refractory state forever
|
||
# reset_refractory_state_active!(n)
|
||
# end
|
||
# function reset_learning_no_wchange!(n::linear_neuron)
|
||
# reset_epsilon_j!(n)
|
||
# reset_out_t!(n)
|
||
# reset_error!(n)
|
||
# end
|
||
|
||
""" Reset all learning-related params at the END of learning session
|
||
"""
|
||
function resetLearningParams!(n::lif_neuron)
|
||
reset_epsilonRec!(n)
|
||
reset_wRecChange!(n)
|
||
# reset_v_t!(n)
|
||
# reset_z_t!(n)
|
||
reset_firing_counter!(n)
|
||
reset_firing_diff!(n)
|
||
|
||
# reset refractory state at the start/end of episode. Otherwise once neuron goes into
|
||
# refractory state, it will stay in refractory state forever
|
||
reset_refractoryCounter!(n)
|
||
end
|
||
function resetLearningParams!(n::alif_neuron)
|
||
reset_epsilonRec!(n)
|
||
reset_epsilonRecA!(n)
|
||
reset_wRecChange!(n)
|
||
# reset_v_t!(n)
|
||
# reset_z_t!(n)
|
||
# reset_a!(n)
|
||
reset_firing_counter!(n)
|
||
reset_firing_diff!(n)
|
||
|
||
# reset refractory state at the start/end of episode. Otherwise once neuron goes into
|
||
# refractory state, it will stay in refractory state forever
|
||
reset_refractoryCounter!(n)
|
||
end
|
||
|
||
# function reset_learning_no_wchange!(n::passthrough_neuron)
|
||
# end
|
||
|
||
function resetLearningParams!(n::passthrough_neuron)
|
||
# skip
|
||
end
|
||
|
||
function resetLearningParams!(n::linear_neuron)
|
||
reset_epsilonRec!(n)
|
||
reset_wRecChange!(n)
|
||
reset_v_t!(n)
|
||
reset_firing_counter!(n)
|
||
|
||
# reset refractory state at the start/end of episode. Otherwise once neuron goes into
|
||
# refractory state, it will stay in refractory state forever
|
||
reset_refractoryCounter!(n)
|
||
end
|
||
#------------------------------------------------------------------------------------------------100
|
||
|
||
function store_knowledgefn_error!(kfn::knowledgeFn)
|
||
# condition to adjust nueron in KFN plane in addition to weight adjustment inside each neuron
|
||
if kfn.learningStage == "start_learning"
|
||
if kfn.recent_knowledgeFn_error === nothing && kfn.knowledgeFn_error === nothing
|
||
kfn.recent_knowledgeFn_error = [[]]
|
||
elseif kfn.recent_knowledgeFn_error === nothing
|
||
kfn.recent_knowledgeFn_error = [[kfn.knowledgeFn_error]]
|
||
elseif kfn.recent_knowledgeFn_error !== nothing && kfn.knowledgeFn_error === nothing
|
||
push!(kfn.recent_knowledgeFn_error, [])
|
||
else
|
||
push!(kfn.recent_knowledgeFn_error, [kfn.knowledgeFn_error])
|
||
end
|
||
elseif kfn.learningStage == "during_learning"
|
||
if kfn.knowledgeFn_error === nothing
|
||
#skip
|
||
else
|
||
push!(kfn.recent_knowledgeFn_error[end], kfn.knowledgeFn_error)
|
||
end
|
||
elseif kfn.learningStage == "end_learning"
|
||
if kfn.recent_knowledgeFn_error === nothing
|
||
#skip
|
||
else
|
||
push!(kfn.recent_knowledgeFn_error[end], kfn.knowledgeFn_error)
|
||
end
|
||
else
|
||
error("case does not defined yet")
|
||
end
|
||
|
||
if length(kfn.recent_knowledgeFn_error) > 3
|
||
deleteat!(kfn.recent_knowledgeFn_error, 1)
|
||
end
|
||
end
|
||
|
||
function update_Bn!(kfn::knowledgeFn)
|
||
Δw = nothing
|
||
for n in kfn.outputNeuronsArray
|
||
Δw = Δw === nothing ? n.w_out_change : Δw + n.w_out_change
|
||
n.w_out = n.w_out - (n.Bn_wout_decay * n.w_out) # w_out decay
|
||
end
|
||
# Δw = Δw / kfn.kfnParams[:linear_neuron_number] # average
|
||
|
||
input_neuron_number = kfn.kfnParams[:input_neuron_number] # skip input neuron
|
||
for i = 1:kfn.kfnParams[:compute_neuron_number]
|
||
n = kfn.neuronsArray[input_neuron_number+i]
|
||
n.Bn = n.Bn + Δw[i]
|
||
n.Bn = n.Bn - (n.Bn_wout_decay * n.Bn) # w_out decay
|
||
end
|
||
end
|
||
|
||
""" Regulates membrane potential to stay under v_th, output is weight change
|
||
"""
|
||
function cal_v_reg!(n::lif_neuron)
|
||
# retified linear function
|
||
component_a1 = n.v_t1 - n.v_th < 0 ? 0 : (n.v_t1 - n.v_th)^2
|
||
component_a2 = -n.v_t1 - n.v_th < 0 ? 0 : (-n.v_t1 - n.v_th)^2
|
||
n.reg_voltage_a = n.reg_voltage_a + component_a1 + component_a2
|
||
|
||
component_b = n.v_t1 - n.v_th < 0 ? 0 : n.v_t1 - n.v_th
|
||
#FIXME: not sure the following line is correct
|
||
n.reg_voltage_b = n.reg_voltage_b + (component_b * n.epsilonRec)
|
||
end
|
||
|
||
function cal_v_reg!(n::alif_neuron)
|
||
# retified linear function
|
||
component_a1 = n.v_t1 - n.av_th < 0 ? 0 : (n.v_t1 - n.av_th)^2
|
||
component_a2 = -n.v_t1 - n.av_th < 0 ? 0 : (-n.v_t1 - n.av_th)^2
|
||
n.reg_voltage_a = n.reg_voltage_a + component_a1 + component_a2
|
||
|
||
component_b = n.v_t1 - n.av_th < 0 ? 0 : n.v_t1 - n.av_th
|
||
#FIXME: not sure the following line is correct
|
||
n.reg_voltage_b = n.reg_voltage_b + (component_b * (n.epsilonRec - n.epsilonRecA))
|
||
end
|
||
|
||
function voltage_error!(n::compute_neuron)
|
||
n.reg_voltage_error = 0.5 * n.reg_voltage_a
|
||
return n.reg_voltage_error
|
||
end
|
||
|
||
function voltage_regulator!(n::compute_neuron) # running average
|
||
Δw = n.optimiser.eta * n.c_reg_v * n.reg_voltage_b
|
||
return Δw
|
||
end
|
||
|
||
function firingRateError(kfn::knowledgeFn)
|
||
start_id = kfn.kfnParams[:input_neuron_number] + 1
|
||
return 0.5 * sum([(n.firingDiff)^2 for n in kfn.neuronsArray[start_id:end]])
|
||
end
|
||
|
||
function firing_rate_regulator!(n::compute_neuron)
|
||
# n.firingRate NOT running average (average over learning batch)
|
||
Δw = n.optimiser.eta * n.c_reg *
|
||
(n.firingRate - n.firingRateTarget) * n.eRec
|
||
Δw = n.firingRate > n.firingRateTarget ? Δw : Δw * 0.0
|
||
return Δw
|
||
end
|
||
|
||
firing_rate!(n::compute_neuron) = n.firingRate = (n.firingCounter / n.timeStep) * 1000
|
||
firing_diff!(n::compute_neuron) = n.firingDiff = n.firingRate - n.firingRateTarget
|
||
|
||
function adjust_internal_learning_rate!(n::compute_neuron)
|
||
n.internal_learning_rate = n.error_diff[end] < 0.0 ? n.internal_learning_rate * 0.99 :
|
||
n.internal_learning_rate * 1.005
|
||
end
|
||
|
||
""" compute synaptic connection strength. bias will shift currentStrength to fit into
|
||
sigmoid operating range which centred at 0 and range is -37 to 37.
|
||
# Example
|
||
synaptic strength range is 0 to 10
|
||
one may use bias = -5 to transform synaptic strength into range -5 to 5
|
||
the return value is shifted back to original scale
|
||
"""
|
||
function synapticConnStrength(currentStrength::AbstractFloat, updown::String, bias::Number=0)
|
||
currentStrength += bias
|
||
if currentStrength > 0
|
||
Δstrength = (1.0 - sigmoid(currentStrength))
|
||
else
|
||
Δstrength = sigmoid(currentStrength)
|
||
end
|
||
|
||
if updown == "up"
|
||
updatedStrength = currentStrength + Δstrength
|
||
else
|
||
updatedStrength = currentStrength - Δstrength
|
||
end
|
||
updatedStrength -= bias
|
||
return updatedStrength
|
||
end
|
||
|
||
""" Compute all synaptic connection strength of a neuron. Also mark n.wRec to 0 if wRec goes
|
||
below lowerlimit.
|
||
"""
|
||
function synapticConnStrength!(n::Union{compute_neuron, output_neuron})
|
||
for (i, connStrength) in enumerate(n.synapticStrength)
|
||
# check whether connStrength increase or decrease based on usage from n.epsilonRec
|
||
updown = n.epsilonRec[i] == 0.0 ? "down" : "up"
|
||
updatedConnStrength = synapticConnStrength(connStrength, updown)
|
||
updatedConnStrength = GeneralUtils.limitvalue(updatedConnStrength,
|
||
n.synapticStrengthLimit.lowerlimit, n.synapticStrengthLimit.upperlimit)
|
||
# at lowerlimit, mark wRec at this position to 0. for new random synaptic conn
|
||
if updatedConnStrength == n.synapticStrengthLimit.lowerlimit[1]
|
||
n.wRec[i] = 0.0
|
||
end
|
||
end
|
||
end
|
||
|
||
function synapticConnStrength!(n::input_neuron) end
|
||
|
||
function neuroplasticity!(n::compute_neuron, firedNeurons::Vector)
|
||
# if there is 0-weight then replace it with new connection
|
||
zero_weight_index = findall(iszero.(n.wRec))
|
||
if length(zero_weight_index) != 0
|
||
""" sampling new connection from list of neurons that fires instead of ramdom choose from
|
||
all compute neuron because there is no point to connect to neuron that not fires i.e.
|
||
not fire = no information
|
||
"""
|
||
|
||
subscribe_options = filter(x -> x ∉ [n.id], firedNeurons) # exclude this neuron id from the list
|
||
filter!(x -> x ∉ n.subscriptionList, subscribe_options) # exclude this neuron's subscriptionList from the list
|
||
shuffle!(subscribe_options)
|
||
end
|
||
|
||
new_connection_percent = 10 - ((n.optimiser.eta / 0.0001) / 10) # percent is in range 0.1 to 10
|
||
percentage = [new_connection_percent, 100.0 - new_connection_percent] / 100.0
|
||
for i in zero_weight_index
|
||
if Utils.random_choices([true, false], percentage)
|
||
n.subscriptionList[i] = pop!(subscribe_options)
|
||
n.wRec[i] = 0.01 # new connection should not send large signal otherwise it would throw
|
||
# RSNN off path. Let weight grow by an optimiser
|
||
end
|
||
end
|
||
end
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
end # end module |