From 213f35b1ff08b5d0cd2e7e35d58deb2e6a5b12e3 Mon Sep 17 00:00:00 2001 From: ton Date: Sun, 17 Sep 2023 13:57:54 +0700 Subject: [PATCH] rewrite neuroplasticity() --- src/forward.jl | 10 ++-- src/learn.jl | 135 ++----------------------------------------------- src/snnUtil.jl | 11 ++-- 3 files changed, 13 insertions(+), 143 deletions(-) diff --git a/src/forward.jl b/src/forward.jl index ac44c62..3e97098 100644 --- a/src/forward.jl +++ b/src/forward.jl @@ -319,9 +319,8 @@ function lifForward( zit, wRecChange[i1,i2,i3,i4] = -0.01*0.0001 * (vt[i1,i2,i3,i4] - vth[i1,i2,i3,i4]) * zit[i1,i2,i3,i4] - if synapseReconnectDelay[i1,i2,i3,i4] < 0 && # negative value is counting mode - synapseReconnectDelay[i1,i2,i3,i4] < -0.2 # -0.1 < -0.1 won't work on GPU - + # negative value is counting mode, -0.1 < -0.1 won't work on GPU + if synapseReconnectDelay[i1,i2,i3,i4] < -0.2 synapseReconnectDelay[i1,i2,i3,i4] += 1 if synapseReconnectDelay[i1,i2,i3,i4] == 0 # mark timestep @@ -528,9 +527,8 @@ function alifForward( zit, wRecChange[i1,i2,i3,i4] = -0.01*0.0001 * (vt[i1,i2,i3,i4] - avth[i1,i2,i3,i4]) * zit[i1,i2,i3,i4] - if synapseReconnectDelay[i1,i2,i3,i4] < 0 && # negative value is counting mode - synapseReconnectDelay[i1,i2,i3,i4] < -0.2 # -0.1 < -0.1 won't work on GPU - + # negative value is counting mode, -0.1 < -0.1 won't work on GPU + if synapseReconnectDelay[i1,i2,i3,i4] < -0.2 synapseReconnectDelay[i1,i2,i3,i4] += 1 if synapseReconnectDelay[i1,i2,i3,i4] == 0 # mark timestep diff --git a/src/learn.jl b/src/learn.jl index 1f63dc7..9a8ddf5 100644 --- a/src/learn.jl +++ b/src/learn.jl @@ -545,7 +545,7 @@ function neuroplasticity(synapseConnectionNumber, # skip neuroplasticity #TODO I may need to do something with neuronInactivityCounter and other variables wRecChange .= 0 - error("DEBUG -> neuroplasticity") + # error("DEBUG -> neuroplasticity") elseif progress != 0 # progress increase # ready to reconnect synapse must not have wRecChange mask = (!isequal).(wRec, 0) @@ -566,8 +566,7 @@ function neuroplasticity(synapseConnectionNumber, # rewire synapse connection rewireSynapse!(wRec, neuronInactivityCounter, synapticActivityCounter, synapseReconnectDelay, zitCumulative) - error("DEBUG -> neuroplasticity") - + # error("DEBUG -> neuroplasticity 1") elseif progress == 0 # no progress, no weight update, only rewire # -w all non-fire connection except mature connection weakenNotMatureSynapse!(wRec, synapticActivityCounter, eta) @@ -579,143 +578,15 @@ function neuroplasticity(synapseConnectionNumber, rewireSynapse!(wRec, neuronInactivityCounter, synapticActivityCounter, synapseReconnectDelay, zitCumulative) - error("DEBUG -> neuroplasticity $(Dates.now())") + # error("DEBUG -> neuroplasticity") else error("undefined condition line $(@__LINE__)") end - - # error("DEBUG -> neuroplasticity $(Dates.now())") - - - - - - # # for each neuron, find total number of synaptic conn that should draw - # # new connection to firing and non-firing neurons pool - # subToFireNeuron_toBe = Int(floor(0.7 * synapseConnectionNumber)) - - # # for each neuron, count how many synapse already subscribed to firing-neurons - # zw = zitCumulative .* wRec - # subToFireNeuron_current = sum(GeneralUtils.GeneralUtils.isBetween.(zw, 0.0, 100.0), dims=(1,2)) # (1, 1, n) - # zitMask = (!iszero).(zitCumulative) # zitMask of firing neurons = 1, non-firing = 0 - # projection = ones(i1,i2,i3) - # zitMask = zitMask .* projection # (row, col, n) - # totalNewConn = sum(isequal.(wRec, -1.0), dims=(1,2)) # count new conn mark (-1.0), (1, 1, n) - # println("neuroplasticity, from $(synapseConnectionNumber*size(totalNewConn, 3)) conn, $(sum(totalNewConn)) are replaced") - - # # clear -1.0 marker - # GeneralUtils.replaceElements!(wRec, -1.0, synapseReconnectDelay, -0.99) - # GeneralUtils.replaceElements!(wRec, -1.0, 0.0) # -1.0 marker is no longer required - - # for i in 1:i3 - # if neuronInactivityCounter[1:1:i][1] < -10000 # neuron die i.e. reset all weight - # println("neuron die") - # neuronInactivityCounter[:,:,i] .= 0 # reset - # w = random_wRec(i1,i2,1,synapseConnectionNumber) - # wRec[:,:,i] .= w - - # a = similar(w) .= -0.99 # synapseConnectionNumber of this neuron - # mask = (!iszero).(w) - # GeneralUtils.replaceElements!(mask, 1, a, 0) - # synapseReconnectDelay[:,:,i] = a - # else - # remaining = 0 - # if subToFireNeuron_current[1,1,i] < subToFireNeuron_toBe - # toAddConn = subToFireNeuron_toBe - subToFireNeuron_current[1,1,i] - # totalNewConn[1,1,i] = totalNewConn[1,1,i] - toAddConn - # # add new conn to firing neurons pool - # remaining = addNewSynapticConn!(zitMask[:,:,i], 1, - # @view(wRec[:,:,i]), - # @view(synapseReconnectDelay[:,:,i]), - # toAddConn) - # totalNewConn[1,1,i] += remaining - # end - - # # add new conn to non-firing neurons pool - # remaining = addNewSynapticConn!(zitMask[:,:,i], 0, - # @view(wRec[:,:,i]), - # @view(synapseReconnectDelay[:,:,i]), - # totalNewConn[1,1,i]) - # if remaining > 0 # final get-all round if somehow non-firing pool has not enough slot - # remaining = addNewSynapticConn!(zitMask[:,:,i], 1, - # @view(wRec[:,:,i]), - # @view(synapseReconnectDelay[:,:,i]), - # remaining) - # end - # end - # end # error("DEBUG -> neuroplasticity $(Dates.now())") return wRec end -# function neuroplasticity(synapseConnectionNumber, -# zitCumulative, # (row, col) -# wRec, # (row, col, n) -# neuronInactivityCounter, -# synapseReconnectDelay) # (row, col, n) - -# i1,i2,i3 = size(wRec) - -# # for each neuron, find total number of synaptic conn that should draw -# # new connection to firing and non-firing neurons pool -# subToFireNeuron_toBe = Int(floor(0.7 * synapseConnectionNumber)) - -# # for each neuron, count how many synap already subscribed to firing-neurons -# zw = zitCumulative .* wRec -# subToFireNeuron_current = sum(GeneralUtils.GeneralUtils.isBetween.(zw, 0.0, 100.0), dims=(1,2)) # (1, 1, n) -# zitMask = (!iszero).(zitCumulative) # zitMask of firing neurons = 1, non-firing = 0 -# projection = ones(i1,i2,i3) -# zitMask = zitMask .* projection # (row, col, n) -# totalNewConn = sum(isequal.(wRec, -1.0), dims=(1,2)) # count new conn mark (-1.0), (1, 1, n) -# println("neuroplasticity, from $(synapseConnectionNumber*size(totalNewConn, 3)) conn, $(sum(totalNewConn)) are replaced") - -# # clear -1.0 marker -# GeneralUtils.replaceElements!(wRec, -1.0, synapseReconnectDelay, -0.99) -# GeneralUtils.replaceElements!(wRec, -1.0, 0.0) # -1.0 marker is no longer required - -# for i in 1:i3 -# if neuronInactivityCounter[1:1:i][1] < -10000 # neuron die i.e. reset all weight -# println("neuron die") -# neuronInactivityCounter[:,:,i] .= 0 # reset -# w = random_wRec(i1,i2,1,synapseConnectionNumber) -# wRec[:,:,i] .= w - -# a = similar(w) .= -0.99 # synapseConnectionNumber of this neuron -# mask = (!iszero).(w) -# GeneralUtils.replaceElements!(mask, 1, a, 0) -# synapseReconnectDelay[:,:,i] = a -# else -# remaining = 0 -# if subToFireNeuron_current[1,1,i] < subToFireNeuron_toBe -# toAddConn = subToFireNeuron_toBe - subToFireNeuron_current[1,1,i] -# totalNewConn[1,1,i] = totalNewConn[1,1,i] - toAddConn -# # add new conn to firing neurons pool -# remaining = addNewSynapticConn!(zitMask[:,:,i], 1, -# @view(wRec[:,:,i]), -# @view(synapseReconnectDelay[:,:,i]), -# toAddConn) -# totalNewConn[1,1,i] += remaining -# end - -# # add new conn to non-firing neurons pool -# remaining = addNewSynapticConn!(zitMask[:,:,i], 0, -# @view(wRec[:,:,i]), -# @view(synapseReconnectDelay[:,:,i]), -# totalNewConn[1,1,i]) -# if remaining > 0 # final get-all round if somehow non-firing pool has not enough slot -# remaining = addNewSynapticConn!(zitMask[:,:,i], 1, -# @view(wRec[:,:,i]), -# @view(synapseReconnectDelay[:,:,i]), -# remaining) -# end -# end -# end - -# # error("DEBUG -> neuroplasticity $(Dates.now())") -# return wRec -# end - # learningLiquidity(x) = -0.0001x + 1 # -10000 to +10000; f(x) = -5e-05x+0.5 function learningLiquidity(x) diff --git a/src/snnUtil.jl b/src/snnUtil.jl index 9217d5f..b99503f 100644 --- a/src/snnUtil.jl +++ b/src/snnUtil.jl @@ -88,7 +88,7 @@ function mergeLearnWeight!(wRec::AbstractArray, exInType, wRecChange::AbstractAr # error("DEBUG -> mergeLearnWeight!") end -function growRepeatedPath!(wRec, synapticActivityCounter, eta) #BUG wRec get all 0 +function growRepeatedPath!(wRec, synapticActivityCounter, eta) # seperate active synapse out of inactive in this signal mask_activeSynapse = (!isequal).(synapticActivityCounter, 0) # println("synapticActivityCounter ", synapticActivityCounter[:,:,1,1]) @@ -134,6 +134,7 @@ end function pruneSynapse!(wRec, synapticActivityCounter, synapseReconnectDelay) mask_weak = GeneralUtils.isBetween.(wRec, 0.0, 0.01) + println("weak synapse ", sum(mask_weak)) mask_notweak = (!GeneralUtils.isBetween).(wRec, 0.0, 0.01) wRec .*= mask_notweak # all marked weak synapse weight need to be 0.0 i.e. pruned # all weak synapse activity are reset @@ -164,23 +165,23 @@ function rewireSynapse!(wRec::AbstractArray, neuronInactivityCounter::AbstractAr else for ind in eachindex(synapseReconnectDelay[:,:,n,i4]) timemark = synapseReconnectDelay[:,:,n,i4][ind] - if timemark > 0 # mark timeStep available + # println("timemark ", timemark) + if timemark > 0 #TODO not fully tested. mark timeStep available # get neuron pool at 10 timeStep earlier earlier = timemark - 10 > 0 ? timemark - 10 : timemark pool = sum(zitCumulative[:,:,earlier:timemark], dims=3) - println(">>> ", sum(pool)) if sum(pool) != 0 indices = findall(x -> x != 0, pool) pick = rand(indices) wRec[pick] = rand(0.01:0.01:0.05) synapticActivityCounter[pick] = 0 synapseReconnectDelay[pick] = -0.1 + error("DEBUG -> rewireSynapse!") else # if neurons not firing at all, try again next time synapticActivityCounter[pick] = 0 synapseReconnectDelay[:,:,n,i4] = rand(1:1000) + error("DEBUG -> rewireSynapse!") end - - error("DEBUG -> rewireSynapse!") end end end