diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..9e26dfe --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1 @@ +{} \ No newline at end of file diff --git a/Manifest.toml b/Manifest.toml new file mode 100644 index 0000000..bc62edc --- /dev/null +++ b/Manifest.toml @@ -0,0 +1,7 @@ +# This file is machine-generated - editing it directly is not advised + +julia_version = "1.8.5" +manifest_format = "2.0" +project_hash = "da39a3ee5e6b4b0d3255bfef95601890afd80709" + +[deps] diff --git a/src/SNNUtils.jl b/src/SNNUtils.jl index 256277c..30f4a35 100644 --- a/src/SNNUtils.jl +++ b/src/SNNUtils.jl @@ -1,5 +1,470 @@ module SNNUtils -greet() = print("Hello World!") +export updateVector!, resetParams!, resetParams_noWChange! -end # module SNNUtils +#using Statistics, Random + +# using ..Types, ..Utils + +#------------------------------------------------------------------------------------------------100 + +mul!(x::AbstractVector, y::AbstractVector) = x .*= y +mul(x::AbstractVector, y::AbstractVector) = x .* y + +ReLu(x::Number) = max(0, x) + +updateVector!(x::AbstractVector, target::Number) = x .= target +updateVector!(x::AbstractVector, target::AbstractArray) = x .= target +# selectAdd(x::AbstractVector, ind::AbstractVector, value::AbstractVector) = @. x + (ind * value) +function selectAdd!(x::AbstractVector, ind::AbstractVector, value::AbstractVector) + @. x = x + (ind * value) +end + +function timestepForward!(kfn::knowledgeFn) + kfn.zt_1 .= kfn.zt0 + kfn.zt0 .= 0 + kfn.on_out_t_1 .= kfn.on_out_t0 + kfn.on_out_t0 .= 0 + + kfn.lif_vt_1 .= kfn.lif_vt0 + kfn.lif_vt0 .= 0 + kfn.lif_zt_1 .= kfn.lif_zt0 + kfn.lif_zt0 .= 0 + # kfn.lif_epsilonInV_t_1 = deepcopy(kfn.lif_epsilonInV_t0) + updateVector!.(kfn.lif_epsilonInV_t0, 0.0) + # kfn.lif_epsilonRecV_t_1 = deepcopy(kfn.lif_epsilonRecV_t0) + updateVector!.(kfn.lif_epsilonRecV_t0, 0.0) + + kfn.alif_vt_1 .= kfn.alif_vt0 + kfn.alif_vt0 .= 0 + kfn.alif_zt_1 .= kfn.alif_zt0 + kfn.alif_zt0 .= 0 + kfn.alif_phi_t_1 .= kfn.alif_phi_t0 + kfn.alif_phi_t0 .= 0 + kfn.alif_epsilonInV_t_1 = deepcopy(kfn.alif_epsilonInV_t0) + updateVector!.(kfn.alif_epsilonInV_t0, 0.0) + kfn.alif_epsilonInA_t_1 = deepcopy(kfn.alif_epsilonInA_t0) + updateVector!.(kfn.alif_epsilonInA_t0, 0.0) + kfn.alif_epsilonRecV_t_1 = deepcopy(kfn.alif_epsilonRecV_t0) + updateVector!.(kfn.alif_epsilonRecV_t0, 0.0) + kfn.alif_epsilonRecA_t_1 = deepcopy(kfn.alif_epsilonRecA_t0) + updateVector!.(kfn.alif_epsilonRecA_t0, 0.0) +end + +function resetParams!(kfn::knowledgeFn) + updateVector!(kfn.zt_1, 0.0) + updateVector!(kfn.kfnError, 0.0) + + updateVector!(kfn.lif_lastFiringTime, 0.0) + updateVector!(kfn.lif_refractoryState, 0.0) + updateVector!(kfn.lif_vt_1, 0.0) + kfn.lif_vt0 .= 0 + updateVector!(kfn.lif_zt_1, 0.0) + kfn.lif_zt0 .= + kfn.lif_phi_t0 .= 0 + updateVector!.(kfn.lif_wInChange, 0.0) + updateVector!.(kfn.lif_wRecChange, 0.0) + updateVector!.(kfn.lif_wInChange_vReg, 0.0) + updateVector!.(kfn.lif_wRecChange_vReg, 0.0) + updateVector!(kfn.lif_vRegIn_a, 0.0) + updateVector!.(kfn.lif_vRegIn_b, 0.0) + updateVector!(kfn.lif_vRegInError, 0.0) + updateVector!(kfn.lif_vRegRec_a, 0.0) + updateVector!.(kfn.lif_vRegRec_b, 0.0) + updateVector!(kfn.lif_vRegRecError, 0.0) + updateVector!(kfn.lif_fCounter, 0.0) + updateVector!(kfn.lif_fRegDiff, 0.0) + updateVector!(kfn.lif_fRegError, 0.0) + updateVector!.(kfn.lif_inCount, 0.0) + updateVector!.(kfn.lif_ziCount, 0.0) + updateVector!.(kfn.lif_epsilonInV_t0, 0.0) + updateVector!.(kfn.lif_epsilonRecV_t0, 0.0) + updateVector!.(kfn.lif_eIn_timeAverage, 0.0) + updateVector!.(kfn.lif_eRec_timeAverage, 0.0) + kfn.lif_firingRate .= 0 + + + updateVector!(kfn.alif_lastFiringTime, 0.0) + updateVector!(kfn.alif_refractoryState, 0.0) + updateVector!(kfn.alif_vt_1, 0.0) + kfn.alif_vt0 .= 0 + updateVector!(kfn.alif_zt_1, 0.0) + kfn.alif_zt0 .= 0 + kfn.alif_phi_t_1 .= 0 + kfn.alif_phi_t0 .= 0 + updateVector!.(kfn.alif_wInChange, 0.0) + updateVector!.(kfn.alif_wRecChange, 0.0) + updateVector!.(kfn.alif_wInChange_vReg, 0.0) + updateVector!.(kfn.alif_wRecChange_vReg, 0.0) + updateVector!(kfn.alif_vRegIn_a, 0.0) + updateVector!.(kfn.alif_vRegIn_b, 0.0) + updateVector!(kfn.alif_vRegInError, 0.0) + updateVector!(kfn.alif_vRegRec_a, 0.0) + updateVector!.(kfn.alif_vRegRec_b, 0.0) + updateVector!(kfn.alif_vRegRecError, 0.0) + updateVector!(kfn.alif_fCounter, 0.0) + updateVector!(kfn.alif_fRegDiff, 0.0) + updateVector!(kfn.alif_fRegError, 0.0) + updateVector!.(kfn.alif_inCount, 0.0) + updateVector!.(kfn.alif_ziCount, 0.0) + updateVector!.(kfn.alif_epsilonInV_t_1, 0.0) + updateVector!.(kfn.alif_epsilonInV_t0, 0.0) + updateVector!.(kfn.alif_epsilonInA_t_1, 0.0) + updateVector!.(kfn.alif_epsilonInA_t0, 0.0) + updateVector!.(kfn.alif_epsilonRecV_t_1, 0.0) + updateVector!.(kfn.alif_epsilonRecV_t0, 0.0) + updateVector!.(kfn.alif_epsilonRecA_t_1, 0.0) + updateVector!.(kfn.alif_epsilonRecA_t0, 0.0) + updateVector!.(kfn.alif_eIn_timeAverage, 0.0) + updateVector!.(kfn.alif_eRec_timeAverage, 0.0) + kfn.alif_firingRate .= 0 + + updateVector!.(kfn.on_epsilonJ, 0.0) + updateVector!(kfn.on_epsilon_b, 0.0) + updateVector!(kfn.on_out_t_1, 0.0) + updateVector!(kfn.on_out_t0, 0.0) + updateVector!.(kfn.on_wOutChange, 0.0) + updateVector!(kfn.on_bChange, 0.0) + updateVector!(kfn.on_error, 0.0) +end + +""" just like resetParams but does not reset wChange and bChange +""" +function resetParams_noWChange!(kfn::knowledgeFn) + updateVector!(kfn.zt_1, 0.0) + updateVector!(kfn.kfnError, 0.0) + + updateVector!(kfn.lif_lastFiringTime, 0.0) + updateVector!(kfn.lif_refractoryState, 0.0) + updateVector!(kfn.lif_vt_1, 0.0) + kfn.lif_vt0 .= 0 + updateVector!(kfn.lif_zt_1, 0.0) + kfn.lif_zt0 .= + kfn.lif_phi_t0 .= 0 + updateVector!(kfn.lif_vRegIn_a, 0.0) + updateVector!.(kfn.lif_vRegIn_b, 0.0) + updateVector!(kfn.lif_vRegInError, 0.0) + updateVector!(kfn.lif_vRegRec_a, 0.0) + updateVector!.(kfn.lif_vRegRec_b, 0.0) + updateVector!(kfn.lif_vRegRecError, 0.0) + updateVector!(kfn.lif_fCounter, 0.0) + updateVector!(kfn.lif_fRegDiff, 0.0) + updateVector!(kfn.lif_fRegError, 0.0) + updateVector!.(kfn.lif_inCount, 0.0) + updateVector!.(kfn.lif_ziCount, 0.0) + updateVector!.(kfn.lif_epsilonInV_t0, 0.0) + updateVector!.(kfn.lif_epsilonRecV_t0, 0.0) + updateVector!.(kfn.lif_eIn_timeAverage, 0.0) + updateVector!.(kfn.lif_eRec_timeAverage, 0.0) + kfn.lif_firingRate .= 0 + + updateVector!(kfn.alif_lastFiringTime, 0.0) + updateVector!(kfn.alif_refractoryState, 0.0) + updateVector!(kfn.alif_vt_1, 0.0) + kfn.alif_vt0 .= 0 + updateVector!(kfn.alif_zt_1, 0.0) + kfn.alif_zt0 .= 0 + kfn.alif_phi_t_1 .= 0 + kfn.alif_phi_t0 .= 0 + updateVector!(kfn.alif_vRegIn_a, 0.0) + updateVector!.(kfn.alif_vRegIn_b, 0.0) + updateVector!(kfn.alif_vRegInError, 0.0) + updateVector!(kfn.alif_vRegRec_a, 0.0) + updateVector!.(kfn.alif_vRegRec_b, 0.0) + updateVector!(kfn.alif_vRegRecError, 0.0) + updateVector!(kfn.alif_fCounter, 0.0) + updateVector!(kfn.alif_fRegDiff, 0.0) + updateVector!(kfn.alif_fRegError, 0.0) + updateVector!.(kfn.alif_inCount, 0.0) + updateVector!.(kfn.alif_ziCount, 0.0) + updateVector!.(kfn.alif_epsilonInV_t_1, 0.0) + updateVector!.(kfn.alif_epsilonInV_t0, 0.0) + updateVector!.(kfn.alif_epsilonInA_t_1, 0.0) + updateVector!.(kfn.alif_epsilonInA_t0, 0.0) + updateVector!.(kfn.alif_epsilonRecV_t_1, 0.0) + updateVector!.(kfn.alif_epsilonRecV_t0, 0.0) + updateVector!.(kfn.alif_epsilonRecA_t_1, 0.0) + updateVector!.(kfn.alif_epsilonRecA_t0, 0.0) + updateVector!.(kfn.alif_eIn_timeAverage, 0.0) + updateVector!.(kfn.alif_eRec_timeAverage, 0.0) + kfn.alif_firingRate .= 0 + + updateVector!.(kfn.on_epsilonJ, 0.0) + updateVector!(kfn.on_epsilon_b, 0.0) + updateVector!(kfn.on_out_t_1, 0.0) + updateVector!(kfn.on_out_t0, 0.0) + updateVector!(kfn.on_error, 0.0) +end + +#------------------------------------------------------------------------------------------------100 + +firingRate!(fCounter, timeStamp) = (fCounter / timeStamp) * 1000 +#------------------------------------------------------------------------------------------------100 + +function cal_learningSignal(Bn::AbstractVector, error) + result = sum(Bn .* error) + return result +end + +function neuroplasticity!(id::AbstractVector, wRec::AbstractVector, + subscriptionList::AbstractVector, neuronFiresList::AbstractVector, probController::Number) + + subOption = memberFilter.((neuronFiresList,), id, subscriptionList) + shuffle!.(subOption) + zeroMarker = Utils.Zero.(wRec) + notzeroMarker = Utils.notZero.(wRec) + mul!.(subscriptionList, notzeroMarker) + + a = unmatchMul.(zeroMarker, subOption) + subscriptionList .+= a + + # newConnectionPercent = 11 - ((probController / 0.0001) / 10) # percent is in range 0.1 to 10 + newConnectionPercent = 11-((probController / 0.0001) / 10) # percent is in range 0.1 to 10 + prob = [newConnectionPercent, 100.0 - newConnectionPercent] / 100.0 + b = Utils.randomChoiceTarget.(zeroMarker, ([true, false],), (prob,)) .* (0.1*rand()) + mul!.(wRec, notzeroMarker) + wRec .+= b +end + +function memberFilter(vec::AbstractVector, excludeList1::Number, + excludeList2::AbstractVector) + result = filter(x -> x ∉ excludeList1 && x ∉ excludeList2 , vec) + return result +end + +""" x is primary dimension +""" +function unmatchMul(x::AbstractVector, y::AbstractVector) + # sometime length of x is larger than y, because no neuron fires. Increae total neuron number + if length(y) >= length(x) + a = @view y[1:length(x)] + x = x .* a + return x + else + # a = [1:length(x)...] + # x = x .* a + # return x + end +end + +function unmatchMul!(x::AbstractVector, y::AbstractVector) + a = @view y[1:length(x)] + x .*= a +end + +function refractoryState!(refractoryState::Number, timeStamp::Number, + lastFiringTime::Number, refractoryDuration::Number) + + if refractoryState == 1 && timeStamp - lastFiringTime <= refractoryDuration + # skip + else + refractoryState = 0 + end + + return refractoryState +end + +function refractoryStateUpdate(refractoryState::Number, zt0::Number) + if refractoryState == 0 && zt0 == 0 + return 0 + elseif refractoryState == 0 && zt0 == 1 + return 1 + elseif refractoryState == 1 && zt0 == 0 + return 1 + else + error("invalid neuron refractory status") + end +end + +lastFiringTimeUpdate!(lastFiringTime::Number, zt1::Number, timeStamp::Number) = + zt1 == 1 ? lastFiringTime = timeStamp : lastFiringTime + +fCounter!(fCounter::Number, zt1::Number) = + zt1 == 1 ? fCounter += 1 : fCounter + +# function zit(refractoryState::Number, zt0::AbstractVector, +# subscriptionList::AbstractVector) + +# return refractoryState == 1 ? zt0 * 0.0 : subscriptionList .* zt0 +# end + +function getNeuronFires(refractoryState::Number, zt0::AbstractVector, + subscriptionList::AbstractVector) + + return refractoryState == 1 ? zt0 * 0.0 : subscriptionList .* zt0 +end + +function incomingSignal(inputSignal::AbstractVector, weight::AbstractVector) + return sum(weight .* inputSignal) +end + +function neuronFiring(refractoryState::Number, inputSignal::Number, vt::Number, vth::Number) + vt1 = inputSignal + vt + + return refractoryState == 0 && vt1 > vth ? 1 : 0 +end + +function membranePotential(timeStamp::Number, + refractoryState::Number, + alpha::Number, + vt0::Number, + zt0::Number, + recurrentSignal::Number, + inputSignal::Number, + lastFiringTime::Number) + if refractoryState == 1 # exponantial decay + # vt1 = vt0 * (1 - alpha^(timeStamp - lastFiringTime)) # or n.v_t1 = n.alpha * n.v_t + vt1 = Utils.expDecay(vt0, alpha, timeStamp - lastFiringTime) + else + vt1 = (alpha * vt0) + recurrentSignal + inputSignal - zt0 + end + if vt1 === NaN + println("vt1 is NaN") + end + + return vt1 +end + +# function cal_ziCount(refractoryState::Number, alpha::Number, ziCount::AbstractVector, +# zit::AbstractVector) + +# if refractoryState == 1 +# return alpha .* ziCount +# else +# return (alpha .* ziCount) + zit +# end +# end + +function temporalFilter(refractoryState::Number, decayConstant::Number, + previousTemporalAverage::AbstractVector, + discreteSignal::AbstractVector) + if refractoryState == 1 + return decayConstant .* previousTemporalAverage + else + return (decayConstant .* previousTemporalAverage) + discreteSignal + end +end + +# function pseudoGradient(refractoryState::Number, gammaPd::Number, vt1::Number, vth::Number) +# if refractoryState == 1 +# return 0 +# else +# return (gammaPd / vth) * max(0, 1 - ((vt1 - vth) / vth)) +# end +# end + +function pseudoGradient(refractoryState::Number, gammaPd::Number, vt1::Number, vth::Number) + return (gammaPd / vth) * max(0, 1 - ((vt1 - vth) / vth)) +end + +# function lif_eligibilityTrace(refractoryState::Number, phi::Number, epsilonRecV::AbstractVector) +# if refractoryState == 1 +# return 0.0 .* epsilonRecV +# else +# return phi .* epsilonRecV +# end +# end + +function lif_eligibilityTrace(refractoryState::Number, phi::Number, epsilonRecV::AbstractVector) + return phi .* epsilonRecV +end + +function cal_expDecay(zt1::Number, vt1::Number, alpha::Number; timePass::Number=1) + vt1 = zt1 == 1 ? Utils.expDecay(vt1, alpha, timePass) : vt1 # exponantial decay + return vt1 +end + +# function getZit!(zit::AbstractVector, subscriptionList::AbstractVector, kfn_zt0::AbstractVector) +# zit .= getindex(kfn_zt0, subscriptionList) +# end + +function thresholdAdaptation(a::Number, rho::Number, zt0) + return (rho * a) + ((1 - rho) * zt0) +end + +# function avth(refractoryState::Number, avth::Number, a::Number, vth::Number, +# beta::Number) +# return refractoryState == 1 ? avth : vth + (beta * a) +# end + +function avth(refractoryState::Number, avth::Number, a::Number, vth::Number, beta::Number) + return vth + (beta * a) +end + +# function alif_eligibilityTrace(refractoryState::Number, phi::Number, +# beta::Number, epsilonV_t0::AbstractVector, epsilonRecA::AbstractVector) +# if refractoryState == 1 +# return 0.0 .* epsilonV_t0 +# else +# return phi .* (epsilonV_t0 - (beta .* epsilonRecA)) +# end +# end + +function alif_eligibilityTrace(refractoryState::Number, phi::Number, + beta::Number, epsilonV_t0::AbstractVector, epsilonRecA::AbstractVector) + return phi .* (epsilonV_t0 - (beta .* epsilonRecA)) +end + + + + + + + + + +# function neuroplasticity!(id::AbstractVector, wRec::AbstractVector, +# subscriptionList::AbstractVector, neuronFiresList::AbstractVector) + +# subOption = memberFilter.((neuronFiresList,), id, subscriptionList) +# shuffle!.(subOption) +# zeroMarker = Utils.Zero.(wRec) +# notzeroMarker = Utils.notZero.(wRec) + +# a = unmatchMul.(zeroMarker, subOption) +# mul!.(subscriptionList, notzeroMarker) +# subscriptionList .+= a + +# # b = zeroMarker * 0.01 # new connection's initial weight +# b = zeroMarker * 0.1 +# mul!.(wRec, notzeroMarker) +# wRec .+= b +# end + + +function randomNewWeight!(membranePotential::Number, weights::AbstractVector, + threshold::Number) + if membranePotential < threshold + return weights .= randn(length(weights)) + end +end + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +end # end module \ No newline at end of file diff --git a/src/interface.jl b/src/interface.jl new file mode 100644 index 0000000..e69de29