module snn_utils using Flux.Optimise: apply! export calculate_α, calculate_ρ, calculate_k, timestep_forward!, init_neuron, no_negative!, precision, calculate_w_change!, store_knowledgefn_error!, interneurons_adjustment!, reset_z_t!, resetLearningParams!, reset_learning_history_params!, reset_epsilonRec!, reset_epsilonRecA!, synapticConnStrength!, normalizePeak!, firing_rate_error!, firing_rate_regulator!, update_Bn!, cal_firing_reg!, neuroplasticity!, shakeup!, reset_learning_no_wchange!, adjust_internal_learning_rate!, gradient_withloss using Statistics, Random, LinearAlgebra, Distributions, Zygote, Flux using GeneralUtils using ..types #------------------------------------------------------------------------------------------------100 function timestep_forward!(x::passthroughNeuron) x.z_t = x.z_t1 end function timestep_forward!(x::Union{computeNeuron, outputNeuron}) x.z_t = x.z_t1 x.v_t = x.v_t1 end no_negative!(x) = x < 0.0 ? 0.0 : x precision(x::Array{<:Array}) = ( std(mean.(x)) / mean(mean.(x)) ) * 100 # reset functions for LIF/ALIF neuron reset_last_firing_time!(n::computeNeuron) = n.lastFiringTime = 0.0 reset_refractory_state_active!(n::computeNeuron) = n.refractory_state_active = false reset_v_t!(n::neuron) = n.v_t = n.vRest reset_z_t!(n::computeNeuron) = n.z_t = false reset_epsilonRec!(n::computeNeuron) = n.epsilonRec = n.epsilonRec * 0.0 reset_epsilonRec!(n::outputNeuron) = n.epsilonRec = n.epsilonRec * 0.0 reset_epsilonRecA!(n::alifNeuron) = n.epsilonRecA = n.epsilonRecA * 0.0 reset_epsilon_in!(n::computeNeuron) = n.epsilon_in = isnothing(n.epsilon_in) ? nothing : n.epsilon_in * 0.0 reset_error!(n::Union{computeNeuron, outputNeuron}) = n.error = nothing reset_w_in_change!(n::computeNeuron) = n.w_in_change = isnothing(n.w_in_change) ? nothing : n.w_in_change * 0.0 reset_wRecChange!(n::Union{computeNeuron, outputNeuron}) = n.wRecChange = n.wRecChange * 0.0 reset_a!(n::alifNeuron) = n.a = n.a * 0.0 reset_reg_voltage_a!(n::computeNeuron) = n.reg_voltage_a = n.reg_voltage_a * 0.0 reset_reg_voltage_b!(n::computeNeuron) = n.reg_voltage_b = n.reg_voltage_b * 0.0 reset_reg_voltage_error!(n::computeNeuron) = n.reg_voltage_error = n.reg_voltage_error * 0.0 reset_firing_counter!(n::Union{computeNeuron, outputNeuron}) = n.firingCounter = n.firingCounter * 0.0 reset_firing_diff!(n::Union{computeNeuron, outputNeuron}) = n.firingDiff = n.firingDiff * 0.0 reset_refractoryCounter!(n::Union{computeNeuron, outputNeuron}) = n.refractoryCounter = n.refractoryCounter * 0.0 reset_z_i_t_commulative!(n::Union{computeNeuron, outputNeuron}) = n.z_i_t_commulative = n.z_i_t_commulative * 0.0 # reset function for output neuron reset_epsilon_j!(n::linearNeuron) = n.epsilon_j = n.epsilon_j * 0.0 reset_out_t!(n::linearNeuron) = n.out_t = n.out_t * 0.0 reset_w_out_change!(n::linearNeuron) = n.w_out_change = n.w_out_change * 0.0 reset_b_change!(n::linearNeuron) = n.b_change = n.b_change * 0.0 """ Reset a part of learning-related params that used to collect learning history during learning session """ # function reset_learning_no_wchange!(n::lifNeuron) # reset_epsilonRec!(n) # # reset_v_t!(n) # # reset_z_t!(n) # # reset_reg_voltage_a!(n) # # reset_reg_voltage_b!(n) # # reset_reg_voltage_error!(n) # reset_firing_counter!(n) # reset_firing_diff!(n) # reset_previous_error!(n) # reset_error!(n) # # # reset refractory state at the end of episode. Otherwise once neuron goes into refractory state, # # # it will stay in refractory state forever # # reset_refractory_state_active!(n) # end # function reset_learning_no_wchange!(n::Union{alifNeuron, elif_neuron}) # reset_epsilonRec!(n) # reset_epsilonRecA!(n) # reset_v_t!(n) # reset_z_t!(n) # # reset_a!(n) # reset_reg_voltage_a!(n) # reset_reg_voltage_b!(n) # reset_reg_voltage_error!(n) # reset_firing_counter!(n) # reset_firing_diff!(n) # reset_previous_error!(n) # reset_error!(n) # # reset refractory state at the end of episode. Otherwise once neuron goes into refractory state, # # it will stay in refractory state forever # reset_refractory_state_active!(n) # end # function reset_learning_no_wchange!(n::linearNeuron) # reset_epsilon_j!(n) # reset_out_t!(n) # reset_error!(n) # end """ Reset all learning-related params at the END of learning session """ function resetLearningParams!(n::lifNeuron) reset_epsilonRec!(n) reset_wRecChange!(n) # reset_v_t!(n) # reset_z_t!(n) reset_firing_counter!(n) reset_firing_diff!(n) # reset refractory state at the start/end of episode. Otherwise once neuron goes into # refractory state, it will stay in refractory state forever # reset_refractoryCounter!(n) reset_z_i_t_commulative!(n) end function resetLearningParams!(n::alifNeuron) reset_epsilonRec!(n) reset_epsilonRecA!(n) reset_wRecChange!(n) # reset_v_t!(n) # reset_z_t!(n) # reset_a!(n) reset_firing_counter!(n) reset_firing_diff!(n) # reset refractory state at the start/end of episode. Otherwise once neuron goes into # refractory state, it will stay in refractory state forever # reset_refractoryCounter!(n) reset_z_i_t_commulative!(n) end # function reset_learning_no_wchange!(n::passthroughNeuron) # end function resetLearningParams!(n::passthroughNeuron) # skip end function resetLearningParams!(n::linearNeuron) reset_epsilonRec!(n) reset_wRecChange!(n) # reset_v_t!(n) reset_firing_counter!(n) # reset refractory state at the start/end of episode. Otherwise once neuron goes into # refractory state, it will stay in refractory state forever # reset_refractoryCounter!(n) reset_z_i_t_commulative!(n) end #------------------------------------------------------------------------------------------------100 function store_knowledgefn_error!(kfn::knowledgeFn) # condition to adjust nueron in KFN plane in addition to weight adjustment inside each neuron if kfn.learningStage == "start_learning" if kfn.recent_knowledgeFn_error === nothing && kfn.knowledgeFn_error === nothing kfn.recent_knowledgeFn_error = [[]] elseif kfn.recent_knowledgeFn_error === nothing kfn.recent_knowledgeFn_error = [[kfn.knowledgeFn_error]] elseif kfn.recent_knowledgeFn_error !== nothing && kfn.knowledgeFn_error === nothing push!(kfn.recent_knowledgeFn_error, []) else push!(kfn.recent_knowledgeFn_error, [kfn.knowledgeFn_error]) end elseif kfn.learningStage == "during_learning" if kfn.knowledgeFn_error === nothing #skip else push!(kfn.recent_knowledgeFn_error[end], kfn.knowledgeFn_error) end elseif kfn.learningStage == "end_learning" if kfn.recent_knowledgeFn_error === nothing #skip else push!(kfn.recent_knowledgeFn_error[end], kfn.knowledgeFn_error) end else error("case does not defined yet") end if length(kfn.recent_knowledgeFn_error) > 3 deleteat!(kfn.recent_knowledgeFn_error, 1) end end function update_Bn!(kfn::knowledgeFn) Δw = nothing for n in kfn.outputNeuronsArray Δw = Δw === nothing ? n.w_out_change : Δw + n.w_out_change n.w_out = n.w_out - (n.Bn_wout_decay * n.w_out) # w_out decay end # Δw = Δw / kfn.kfnParams[:linear_neuron_number] # average input_neuron_number = kfn.kfnParams[:input_neuron_number] # skip input neuron for i = 1:kfn.kfnParams[:compute_neuron_number] n = kfn.neuronsArray[input_neuron_number+i] n.Bn = n.Bn + Δw[i] n.Bn = n.Bn - (n.Bn_wout_decay * n.Bn) # w_out decay end end """ Regulates membrane potential to stay under v_th, output is weight change """ function cal_v_reg!(n::lifNeuron) # retified linear function component_a1 = n.v_t1 - n.v_th < 0 ? 0 : (n.v_t1 - n.v_th)^2 component_a2 = -n.v_t1 - n.v_th < 0 ? 0 : (-n.v_t1 - n.v_th)^2 n.reg_voltage_a = n.reg_voltage_a + component_a1 + component_a2 component_b = n.v_t1 - n.v_th < 0 ? 0 : n.v_t1 - n.v_th #FIXME: not sure the following line is correct n.reg_voltage_b = n.reg_voltage_b + (component_b * n.epsilonRec) end function cal_v_reg!(n::alifNeuron) # retified linear function component_a1 = n.v_t1 - n.av_th < 0 ? 0 : (n.v_t1 - n.av_th)^2 component_a2 = -n.v_t1 - n.av_th < 0 ? 0 : (-n.v_t1 - n.av_th)^2 n.reg_voltage_a = n.reg_voltage_a + component_a1 + component_a2 component_b = n.v_t1 - n.av_th < 0 ? 0 : n.v_t1 - n.av_th #FIXME: not sure the following line is correct n.reg_voltage_b = n.reg_voltage_b + (component_b * (n.epsilonRec - n.epsilonRecA)) end function voltage_error!(n::computeNeuron) n.reg_voltage_error = 0.5 * n.reg_voltage_a return n.reg_voltage_error end function voltage_regulator!(n::computeNeuron) # running average Δw = n.optimiser.eta * n.c_reg_v * n.reg_voltage_b return Δw end function firingRateError(kfn::knowledgeFn) start_id = kfn.kfnParams[:input_neuron_number] + 1 return 0.5 * sum([(n.firingDiff)^2 for n in kfn.neuronsArray[start_id:end]]) end function firing_rate_regulator!(n::computeNeuron) # n.firingRate NOT running average (average over learning batch) Δw = n.optimiser.eta * n.c_reg * (n.firingRate - n.firingRateTarget) * n.eRec Δw = n.firingRate > n.firingRateTarget ? Δw : Δw * 0.0 return Δw end firing_rate!(n::computeNeuron) = n.firingRate = (n.firingCounter / n.timeStep) * 1000 firing_diff!(n::computeNeuron) = n.firingDiff = n.firingRate - n.firingRateTarget function adjust_internal_learning_rate!(n::computeNeuron) n.internal_learning_rate = n.error_diff[end] < 0.0 ? n.internal_learning_rate * 0.99 : n.internal_learning_rate * 1.005 end """ compute synaptic connection strength. bias will shift currentStrength to fit into sigmoid operating range which centred at 0 and range is -37 to 37. # Example synaptic strength range is 0 to 10 one may use bias = -5 to transform synaptic strength into range -5 to 5 the return value is shifted back to original scale """ function synapticConnStrength(currentStrength::AbstractFloat, updown::String, bias::Number=0)::Float64 currentStrength += bias if currentStrength > 0 Δstrength = (1.0 - sigmoid(currentStrength)) else Δstrength = sigmoid(currentStrength) end if updown == "up" updatedStrength = currentStrength + Δstrength else updatedStrength = currentStrength - Δstrength end updatedStrength -= bias return updatedStrength end """ Compute all synaptic connection strength of a neuron. Also mark n.wRec to 0 if wRec goes below lowerlimit. """ function synapticConnStrength!(n::Union{computeNeuron, outputNeuron}) for (i, connStrength) in enumerate(n.synapticStrength) # check whether connStrength increase or decrease based on usage from n.epsilonRec """ use n.wRecChange instead of the best choise, epsilonRec, here because ΔwRecChange calculation in learn!() will reset epsilonRec to zeroes vector in case where output neuron fires and trigger learn!() just before this synapticConnStrength calculation. Since n.wRecChange indicates whether a synaptic connection were used or not, it is ok to use. n.wRecChange also span across a training sample without resetting. """ updown = n.z_i_t_commulative[i] == 0 ? "down" : "up" # updatedConnStrength = synapticConnStrength(connStrength, updown) updatedConnStrength = GeneralUtils.limitvalue(updatedConnStrength, n.synapticStrengthLimit.lowerlimit, n.synapticStrengthLimit.upperlimit) # at lowerlimit, mark wRec at this position to 0. for new random synaptic conn if updatedConnStrength == n.synapticStrengthLimit.lowerlimit[1] n.wRec[i] = 0.0 end n.synapticStrength[i] = updatedConnStrength end end function synapticConnStrength!(n::inputNeuron) end """ normalize a part of a vector centering at a vector's maximum value along with nearby value within its radius. radius must be odd number. v1 will be normalized based on v2's peak """ function normalizePeak!(v1::Vector, v2::Vector, radius::Integer=2) peak = findall(isequal.(abs.(v2), maximum(abs.(v2))))[1] upindex = peak - radius upindex = upindex < 1 ? 1 : upindex downindex = peak + radius downindex = downindex > length(v1) ? length(v1) : downindex subvector = view(v1, upindex:downindex) normalize!(subvector, 1) end """ rewire of neuron synaptic connection that has 0 weight. With connection's excitatory and inhabitory ratio constraint. """ # function neuroplasticity!(n::Union{computeNeuron, outputNeuron}, firedNeurons::Vector, # nExcitatory::Vector, nInhabitory::Vector, excitatoryPercent::Integer) # # if there is 0-weight then replace it with new connection # zeroWeightConnIndex = findall(iszero.(n.wRec)) # connection that has 0 weight # desiredEx = Int(floor((excitatoryPercent / 100) * length(n.subscriptionList))) # desiredIn = length(n.subscriptionList) - desiredEx # wRecSign = sign.(n.wRec) # inConn = sum(isequal.(wRecSign, -1)) # # random new synaptic connection # inConnToAdd = desiredIn - inConn # if inConnToAdd <= 0 # # skip all new Conn will be excitatory type # else # newConnVecSign = ones(length(zeroWeightConnIndex)) # newConnVecSign = view(newConnVecSign, 1:inConnToAdd) * -1 # end # # new synaptic connection must sample fron neuron that fires # inPool = nInhabitory ∩ firedNeurons # filter!(x -> x ∉ [n.id], inPool) # exclude this neuron id from the list # filter!(x -> x ∉ n.subscriptionList, inPool) # exclude this neuron's subscriptionList from the list # exPool = nExcitatory ∩ firedNeurons # filter!(x -> x ∉ [n.id], exPool) # exclude this neuron id from the list # filter!(x -> x ∉ n.subscriptionList, exPool) # exclude this neuron's subscriptionList from the list # w = [rand(0.01:0.01:0.2, length(zeroWeightConnIndex))] .* newConnVecSign # synapticStrength = [rand(-5:0.01:-4, length(zeroWeightConnIndex))] # # add new synaptic connection to neuron # for (i, connIndex) in enumerate(zeroWeightConnIndex) # n.subscriptionList[connIndex] = newConnVecSign[i] < 0 ? pop!(inPool) : pop!(exPool) # n.wRec[connIndex] = w[i] # n.synapticStrength[connIndex] = synapticStrength[i] # end # end """ rewire of neuron synaptic connection that has 0 weight. Without connection's excitatory and inhabitory ratio constraint. """ function neuroplasticity!(n::computeNeuron, firedNeurons::Vector, nExInTypeList::Vector) # if there is 0-weight then replace it with new connection zeroWeightConnIndex = findall(iszero.(n.wRec)) # connection that has 0 weight # new synaptic connection must sample fron neuron that fires nFiredPool = filter(x -> x ∉ [n.id], firedNeurons) # exclude this neuron id from the id list filter!(x -> x ∉ n.subscriptionList, nFiredPool) # exclude this neuron's subscriptionList from the list nNonFiredPool = setdiff!([1:length(nExInTypeList)...], nFiredPool) filter!(x -> x ∉ [n.id], nNonFiredPool) # exclude this neuron id from the id list filter!(x -> x ∉ n.subscriptionList, nNonFiredPool) # exclude this neuron's subscriptionList from the list w = rand(0.01:0.01:0.2, length(zeroWeightConnIndex)) synapticStrength = rand(-5:0.01:-4, length(zeroWeightConnIndex)) shuffle!(nFiredPool) shuffle!(nNonFiredPool) # add new synaptic connection to neuron for (i, connIndex) in enumerate(zeroWeightConnIndex) if length(nFiredPool) != 0 newConn = popfirst!(nFiredPool) else newConn = popfirst!(nNonFiredPool) end """ conn that is being replaced has to go into nNonFiredPool so nNonFiredPool isn't empty """ push!(nNonFiredPool, n.subscriptionList[connIndex]) n.subscriptionList[connIndex] = newConn n.wRec[connIndex] = w[i] * nExInTypeList[newConn] n.synapticStrength[connIndex] = synapticStrength[i] end end function neuroplasticity!(n::outputNeuron, firedNeurons::Vector, nExInTypeList::Vector, totalInputNeuron::Integer) # if there is 0-weight then replace it with new connection zeroWeightConnIndex = findall(iszero.(n.wRec)) # connection that has 0 weight # new synaptic connection must sample fron neuron that fires nFiredPool = filter(x -> x ∉ [n.id], firedNeurons) # exclude this neuron id from the id list filter!(x -> x ∉ n.subscriptionList, nFiredPool) # exclude this neuron's subscriptionList from the list filter!(x -> x ∉ [1:totalInputNeuron...], nFiredPool) # exclude input neuron nNonFiredPool = setdiff!([1:length(nExInTypeList)...], nFiredPool) filter!(x -> x ∉ [n.id], nNonFiredPool) # exclude this neuron id from the id list filter!(x -> x ∉ n.subscriptionList, nNonFiredPool) # exclude this neuron's subscriptionList from the list filter!(x -> x ∉ [1:totalInputNeuron...], nNonFiredPool) # exclude input neuron w = rand(0.01:0.01:0.2, length(zeroWeightConnIndex)) synapticStrength = rand(-5:0.01:-4, length(zeroWeightConnIndex)) shuffle!(nFiredPool) shuffle!(nNonFiredPool) # add new synaptic connection to neuron for (i, connIndex) in enumerate(zeroWeightConnIndex) if length(nFiredPool) != 0 newConn = popfirst!(nFiredPool) else newConn = popfirst!(nNonFiredPool) end """ conn that is being replaced has to go into nNonFiredPool so nNonFiredPool isn't empty """ push!(nNonFiredPool, n.subscriptionList[connIndex]) n.subscriptionList[connIndex] = newConn n.wRec[connIndex] = w[i] * nExInTypeList[newConn] n.synapticStrength[connIndex] = synapticStrength[i] end end end # end module