Files
Ironpen/src/snn_utils.jl
2023-05-22 20:10:10 +07:00

468 lines
18 KiB
Julia
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
module snn_utils
using Flux.Optimise: apply!
export calculate_α, calculate_ρ, calculate_k, timestep_forward!, init_neuron, no_negative!,
precision, calculate_w_change!, store_knowledgefn_error!, interneurons_adjustment!,
reset_z_t!, resetLearningParams!, reset_learning_history_params!, reset_epsilonRec!,
reset_epsilonRecA!, synapticConnStrength!, normalizePeak!,
firing_rate_error!, firing_rate_regulator!, update_Bn!, cal_firing_reg!,
neuroplasticity!, shakeup!, reset_learning_no_wchange!, adjust_internal_learning_rate!,
gradient_withloss
using Statistics, Random, LinearAlgebra, Distributions, Zygote, Flux
using GeneralUtils
using ..types
#------------------------------------------------------------------------------------------------100
function timestep_forward!(x::passthroughNeuron)
x.z_t = x.z_t1
end
function timestep_forward!(x::Union{computeNeuron, outputNeuron})
x.z_t = x.z_t1
x.v_t = x.v_t1
end
no_negative!(x) = x < 0.0 ? 0.0 : x
precision(x::Array{<:Array}) = ( std(mean.(x)) / mean(mean.(x)) ) * 100
# reset functions for LIF/ALIF neuron
reset_last_firing_time!(n::computeNeuron) = n.lastFiringTime = 0.0
reset_refractory_state_active!(n::computeNeuron) = n.refractory_state_active = false
reset_v_t!(n::neuron) = n.v_t = n.vRest
reset_z_t!(n::computeNeuron) = n.z_t = false
reset_epsilonRec!(n::computeNeuron) = n.epsilonRec = n.epsilonRec * 0.0
reset_epsilonRec!(n::outputNeuron) = n.epsilonRec = n.epsilonRec * 0.0
reset_epsilonRecA!(n::alifNeuron) = n.epsilonRecA = n.epsilonRecA * 0.0
reset_epsilon_in!(n::computeNeuron) = n.epsilon_in = isnothing(n.epsilon_in) ? nothing : n.epsilon_in * 0.0
reset_error!(n::Union{computeNeuron, outputNeuron}) = n.error = nothing
reset_w_in_change!(n::computeNeuron) = n.w_in_change = isnothing(n.w_in_change) ? nothing : n.w_in_change * 0.0
reset_wRecChange!(n::Union{computeNeuron, outputNeuron}) = n.wRecChange = n.wRecChange * 0.0
reset_a!(n::alifNeuron) = n.a = n.a * 0.0
reset_reg_voltage_a!(n::computeNeuron) = n.reg_voltage_a = n.reg_voltage_a * 0.0
reset_reg_voltage_b!(n::computeNeuron) = n.reg_voltage_b = n.reg_voltage_b * 0.0
reset_reg_voltage_error!(n::computeNeuron) = n.reg_voltage_error = n.reg_voltage_error * 0.0
reset_firing_counter!(n::Union{computeNeuron, outputNeuron}) = n.firingCounter = n.firingCounter * 0.0
reset_firing_diff!(n::Union{computeNeuron, outputNeuron}) = n.firingDiff = n.firingDiff * 0.0
reset_refractoryCounter!(n::Union{computeNeuron, outputNeuron}) = n.refractoryCounter = n.refractoryCounter * 0.0
reset_z_i_t_commulative!(n::Union{computeNeuron, outputNeuron}) = n.z_i_t_commulative = n.z_i_t_commulative * 0.0
# reset function for output neuron
reset_epsilon_j!(n::linearNeuron) = n.epsilon_j = n.epsilon_j * 0.0
reset_out_t!(n::linearNeuron) = n.out_t = n.out_t * 0.0
reset_w_out_change!(n::linearNeuron) = n.w_out_change = n.w_out_change * 0.0
reset_b_change!(n::linearNeuron) = n.b_change = n.b_change * 0.0
""" Reset a part of learning-related params that used to collect learning history during learning
session
"""
# function reset_learning_no_wchange!(n::lifNeuron)
# reset_epsilonRec!(n)
# # reset_v_t!(n)
# # reset_z_t!(n)
# # reset_reg_voltage_a!(n)
# # reset_reg_voltage_b!(n)
# # reset_reg_voltage_error!(n)
# reset_firing_counter!(n)
# reset_firing_diff!(n)
# reset_previous_error!(n)
# reset_error!(n)
# # # reset refractory state at the end of episode. Otherwise once neuron goes into refractory state,
# # # it will stay in refractory state forever
# # reset_refractory_state_active!(n)
# end
# function reset_learning_no_wchange!(n::Union{alifNeuron, elif_neuron})
# reset_epsilonRec!(n)
# reset_epsilonRecA!(n)
# reset_v_t!(n)
# reset_z_t!(n)
# # reset_a!(n)
# reset_reg_voltage_a!(n)
# reset_reg_voltage_b!(n)
# reset_reg_voltage_error!(n)
# reset_firing_counter!(n)
# reset_firing_diff!(n)
# reset_previous_error!(n)
# reset_error!(n)
# # reset refractory state at the end of episode. Otherwise once neuron goes into refractory state,
# # it will stay in refractory state forever
# reset_refractory_state_active!(n)
# end
# function reset_learning_no_wchange!(n::linearNeuron)
# reset_epsilon_j!(n)
# reset_out_t!(n)
# reset_error!(n)
# end
""" Reset all learning-related params at the END of learning session
"""
function resetLearningParams!(n::lifNeuron)
reset_epsilonRec!(n)
reset_wRecChange!(n)
# reset_v_t!(n)
# reset_z_t!(n)
reset_firing_counter!(n)
reset_firing_diff!(n)
# reset refractory state at the start/end of episode. Otherwise once neuron goes into
# refractory state, it will stay in refractory state forever
# reset_refractoryCounter!(n)
reset_z_i_t_commulative!(n)
end
function resetLearningParams!(n::alifNeuron)
reset_epsilonRec!(n)
reset_epsilonRecA!(n)
reset_wRecChange!(n)
# reset_v_t!(n)
# reset_z_t!(n)
# reset_a!(n)
reset_firing_counter!(n)
reset_firing_diff!(n)
# reset refractory state at the start/end of episode. Otherwise once neuron goes into
# refractory state, it will stay in refractory state forever
# reset_refractoryCounter!(n)
reset_z_i_t_commulative!(n)
end
# function reset_learning_no_wchange!(n::passthroughNeuron)
# end
function resetLearningParams!(n::passthroughNeuron)
# skip
end
function resetLearningParams!(n::linearNeuron)
reset_epsilonRec!(n)
reset_wRecChange!(n)
# reset_v_t!(n)
reset_firing_counter!(n)
# reset refractory state at the start/end of episode. Otherwise once neuron goes into
# refractory state, it will stay in refractory state forever
# reset_refractoryCounter!(n)
reset_z_i_t_commulative!(n)
end
#------------------------------------------------------------------------------------------------100
function store_knowledgefn_error!(kfn::knowledgeFn)
# condition to adjust nueron in KFN plane in addition to weight adjustment inside each neuron
if kfn.learningStage == "start_learning"
if kfn.recent_knowledgeFn_error === nothing && kfn.knowledgeFn_error === nothing
kfn.recent_knowledgeFn_error = [[]]
elseif kfn.recent_knowledgeFn_error === nothing
kfn.recent_knowledgeFn_error = [[kfn.knowledgeFn_error]]
elseif kfn.recent_knowledgeFn_error !== nothing && kfn.knowledgeFn_error === nothing
push!(kfn.recent_knowledgeFn_error, [])
else
push!(kfn.recent_knowledgeFn_error, [kfn.knowledgeFn_error])
end
elseif kfn.learningStage == "during_learning"
if kfn.knowledgeFn_error === nothing
#skip
else
push!(kfn.recent_knowledgeFn_error[end], kfn.knowledgeFn_error)
end
elseif kfn.learningStage == "end_learning"
if kfn.recent_knowledgeFn_error === nothing
#skip
else
push!(kfn.recent_knowledgeFn_error[end], kfn.knowledgeFn_error)
end
else
error("case does not defined yet")
end
if length(kfn.recent_knowledgeFn_error) > 3
deleteat!(kfn.recent_knowledgeFn_error, 1)
end
end
function update_Bn!(kfn::knowledgeFn)
Δw = nothing
for n in kfn.outputNeuronsArray
Δw = Δw === nothing ? n.w_out_change : Δw + n.w_out_change
n.w_out = n.w_out - (n.Bn_wout_decay * n.w_out) # w_out decay
end
# Δw = Δw / kfn.kfnParams[:linear_neuron_number] # average
input_neuron_number = kfn.kfnParams[:input_neuron_number] # skip input neuron
for i = 1:kfn.kfnParams[:compute_neuron_number]
n = kfn.neuronsArray[input_neuron_number+i]
n.Bn = n.Bn + Δw[i]
n.Bn = n.Bn - (n.Bn_wout_decay * n.Bn) # w_out decay
end
end
""" Regulates membrane potential to stay under v_th, output is weight change
"""
function cal_v_reg!(n::lifNeuron)
# retified linear function
component_a1 = n.v_t1 - n.v_th < 0 ? 0 : (n.v_t1 - n.v_th)^2
component_a2 = -n.v_t1 - n.v_th < 0 ? 0 : (-n.v_t1 - n.v_th)^2
n.reg_voltage_a = n.reg_voltage_a + component_a1 + component_a2
component_b = n.v_t1 - n.v_th < 0 ? 0 : n.v_t1 - n.v_th
#FIXME: not sure the following line is correct
n.reg_voltage_b = n.reg_voltage_b + (component_b * n.epsilonRec)
end
function cal_v_reg!(n::alifNeuron)
# retified linear function
component_a1 = n.v_t1 - n.av_th < 0 ? 0 : (n.v_t1 - n.av_th)^2
component_a2 = -n.v_t1 - n.av_th < 0 ? 0 : (-n.v_t1 - n.av_th)^2
n.reg_voltage_a = n.reg_voltage_a + component_a1 + component_a2
component_b = n.v_t1 - n.av_th < 0 ? 0 : n.v_t1 - n.av_th
#FIXME: not sure the following line is correct
n.reg_voltage_b = n.reg_voltage_b + (component_b * (n.epsilonRec - n.epsilonRecA))
end
function voltage_error!(n::computeNeuron)
n.reg_voltage_error = 0.5 * n.reg_voltage_a
return n.reg_voltage_error
end
function voltage_regulator!(n::computeNeuron) # running average
Δw = n.optimiser.eta * n.c_reg_v * n.reg_voltage_b
return Δw
end
function firingRateError(kfn::knowledgeFn)
start_id = kfn.kfnParams[:input_neuron_number] + 1
return 0.5 * sum([(n.firingDiff)^2 for n in kfn.neuronsArray[start_id:end]])
end
function firing_rate_regulator!(n::computeNeuron)
# n.firingRate NOT running average (average over learning batch)
Δw = n.optimiser.eta * n.c_reg *
(n.firingRate - n.firingRateTarget) * n.eRec
Δw = n.firingRate > n.firingRateTarget ? Δw : Δw * 0.0
return Δw
end
firing_rate!(n::computeNeuron) = n.firingRate = (n.firingCounter / n.timeStep) * 1000
firing_diff!(n::computeNeuron) = n.firingDiff = n.firingRate - n.firingRateTarget
function adjust_internal_learning_rate!(n::computeNeuron)
n.internal_learning_rate = n.error_diff[end] < 0.0 ? n.internal_learning_rate * 0.99 :
n.internal_learning_rate * 1.005
end
""" compute synaptic connection strength. bias will shift currentStrength to fit into
sigmoid operating range which centred at 0 and range is -37 to 37.
# Example
synaptic strength range is 0 to 10
one may use bias = -5 to transform synaptic strength into range -5 to 5
the return value is shifted back to original scale
"""
function synapticConnStrength(currentStrength::AbstractFloat, updown::String, bias::Number=0)::Float64
currentStrength += bias
if currentStrength > 0
Δstrength = (1.0 - sigmoid(currentStrength))
else
Δstrength = sigmoid(currentStrength)
end
if updown == "up"
updatedStrength = currentStrength + Δstrength
else
updatedStrength = currentStrength - Δstrength
end
updatedStrength -= bias
return updatedStrength
end
""" Compute all synaptic connection strength of a neuron. Also mark n.wRec to 0 if wRec goes
below lowerlimit.
"""
function synapticConnStrength!(n::Union{computeNeuron, outputNeuron})
for (i, connStrength) in enumerate(n.synapticStrength)
# check whether connStrength increase or decrease based on usage from n.epsilonRec
""" use n.wRecChange instead of the best choise, epsilonRec, here because ΔwRecChange
calculation in learn!() will reset epsilonRec to zeroes vector in case where
output neuron fires and trigger learn!() just before this synapticConnStrength
calculation.
Since n.wRecChange indicates whether a synaptic connection were used or not, it is
ok to use. n.wRecChange also span across a training sample without resetting.
"""
updown = n.z_i_t_commulative[i] == 0 ? "down" : "up" #
updatedConnStrength = synapticConnStrength(connStrength, updown)
updatedConnStrength = GeneralUtils.limitvalue(updatedConnStrength,
n.synapticStrengthLimit.lowerlimit, n.synapticStrengthLimit.upperlimit)
# at lowerlimit, mark wRec at this position to 0. for new random synaptic conn
if updatedConnStrength == n.synapticStrengthLimit.lowerlimit[1]
n.wRec[i] = 0.0
end
n.synapticStrength[i] = updatedConnStrength
end
end
function synapticConnStrength!(n::inputNeuron) end
""" normalize a part of a vector centering at a vector's maximum value along with nearby value
within its radius. radius must be odd number.
v1 will be normalized based on v2's peak
"""
function normalizePeak!(v1::Vector, v2::Vector, radius::Integer=2)
peak = findall(isequal.(abs.(v2), maximum(abs.(v2))))[1]
upindex = peak - radius
upindex = upindex < 1 ? 1 : upindex
downindex = peak + radius
downindex = downindex > length(v1) ? length(v1) : downindex
subvector = view(v1, upindex:downindex)
normalize!(subvector, 1)
end
""" rewire of neuron synaptic connection that has 0 weight. With connection's excitatory and
inhabitory ratio constraint.
"""
# function neuroplasticity!(n::Union{computeNeuron, outputNeuron}, firedNeurons::Vector,
# nExcitatory::Vector, nInhabitory::Vector, excitatoryPercent::Integer)
# # if there is 0-weight then replace it with new connection
# zeroWeightConnIndex = findall(iszero.(n.wRec)) # connection that has 0 weight
# desiredEx = Int(floor((excitatoryPercent / 100) * length(n.subscriptionList)))
# desiredIn = length(n.subscriptionList) - desiredEx
# wRecSign = sign.(n.wRec)
# inConn = sum(isequal.(wRecSign, -1))
# # random new synaptic connection
# inConnToAdd = desiredIn - inConn
# if inConnToAdd <= 0
# # skip all new Conn will be excitatory type
# else
# newConnVecSign = ones(length(zeroWeightConnIndex))
# newConnVecSign = view(newConnVecSign, 1:inConnToAdd) * -1
# end
# # new synaptic connection must sample fron neuron that fires
# inPool = nInhabitory ∩ firedNeurons
# filter!(x -> x ∉ [n.id], inPool) # exclude this neuron id from the list
# filter!(x -> x ∉ n.subscriptionList, inPool) # exclude this neuron's subscriptionList from the list
# exPool = nExcitatory ∩ firedNeurons
# filter!(x -> x ∉ [n.id], exPool) # exclude this neuron id from the list
# filter!(x -> x ∉ n.subscriptionList, exPool) # exclude this neuron's subscriptionList from the list
# w = [rand(0.01:0.01:0.2, length(zeroWeightConnIndex))] .* newConnVecSign
# synapticStrength = [rand(-5:0.01:-4, length(zeroWeightConnIndex))]
# # add new synaptic connection to neuron
# for (i, connIndex) in enumerate(zeroWeightConnIndex)
# n.subscriptionList[connIndex] = newConnVecSign[i] < 0 ? pop!(inPool) : pop!(exPool)
# n.wRec[connIndex] = w[i]
# n.synapticStrength[connIndex] = synapticStrength[i]
# end
# end
""" rewire of neuron synaptic connection that has 0 weight. Without connection's excitatory and
inhabitory ratio constraint.
"""
function neuroplasticity!(n::computeNeuron, firedNeurons::Vector,
nExInTypeList::Vector)
# if there is 0-weight then replace it with new connection
zeroWeightConnIndex = findall(iszero.(n.wRec)) # connection that has 0 weight
# new synaptic connection must sample fron neuron that fires
nFiredPool = filter(x -> x [n.id], firedNeurons) # exclude this neuron id from the id list
filter!(x -> x n.subscriptionList, nFiredPool) # exclude this neuron's subscriptionList from the list
nNonFiredPool = setdiff!([1:length(nExInTypeList)...], nFiredPool)
filter!(x -> x [n.id], nNonFiredPool) # exclude this neuron id from the id list
filter!(x -> x n.subscriptionList, nNonFiredPool) # exclude this neuron's subscriptionList from the list
w = rand(0.01:0.01:0.2, length(zeroWeightConnIndex))
synapticStrength = rand(-5:0.01:-4, length(zeroWeightConnIndex))
shuffle!(nFiredPool)
shuffle!(nNonFiredPool)
# add new synaptic connection to neuron
for (i, connIndex) in enumerate(zeroWeightConnIndex)
if length(nFiredPool) != 0
newConn = popfirst!(nFiredPool)
else
newConn = popfirst!(nNonFiredPool)
end
""" conn that is being replaced has to go into nNonFiredPool so nNonFiredPool isn't empty
"""
push!(nNonFiredPool, n.subscriptionList[connIndex])
n.subscriptionList[connIndex] = newConn
n.wRec[connIndex] = w[i] * nExInTypeList[newConn]
n.synapticStrength[connIndex] = synapticStrength[i]
end
end
function neuroplasticity!(n::outputNeuron, firedNeurons::Vector,
nExInTypeList::Vector, totalInputNeuron::Integer)
# if there is 0-weight then replace it with new connection
zeroWeightConnIndex = findall(iszero.(n.wRec)) # connection that has 0 weight
# new synaptic connection must sample fron neuron that fires
nFiredPool = filter(x -> x [n.id], firedNeurons) # exclude this neuron id from the id list
filter!(x -> x n.subscriptionList, nFiredPool) # exclude this neuron's subscriptionList from the list
filter!(x -> x [1:totalInputNeuron...], nFiredPool) # exclude input neuron
nNonFiredPool = setdiff!([1:length(nExInTypeList)...], nFiredPool)
filter!(x -> x [n.id], nNonFiredPool) # exclude this neuron id from the id list
filter!(x -> x n.subscriptionList, nNonFiredPool) # exclude this neuron's subscriptionList from the list
filter!(x -> x [1:totalInputNeuron...], nNonFiredPool) # exclude input neuron
w = rand(0.01:0.01:0.2, length(zeroWeightConnIndex))
synapticStrength = rand(-5:0.01:-4, length(zeroWeightConnIndex))
shuffle!(nFiredPool)
shuffle!(nNonFiredPool)
# add new synaptic connection to neuron
for (i, connIndex) in enumerate(zeroWeightConnIndex)
if length(nFiredPool) != 0
newConn = popfirst!(nFiredPool)
else
newConn = popfirst!(nNonFiredPool)
end
""" conn that is being replaced has to go into nNonFiredPool so nNonFiredPool isn't empty
"""
push!(nNonFiredPool, n.subscriptionList[connIndex])
n.subscriptionList[connIndex] = newConn
n.wRec[connIndex] = w[i] * nExInTypeList[newConn]
n.synapticStrength[connIndex] = synapticStrength[i]
end
end
end # end module