From 6b58d8480490700ee8e3b0273ca9823ce5986fe4 Mon Sep 17 00:00:00 2001 From: ton Date: Sat, 16 Sep 2023 22:06:44 +0700 Subject: [PATCH] dev --- src/forward.jl | 16 +-- src/learn.jl | 359 +++++++++++++++++++++---------------------------- src/snnUtil.jl | 156 +++++++++++++++++++-- src/type.jl | 10 +- 4 files changed, 308 insertions(+), 233 deletions(-) diff --git a/src/forward.jl b/src/forward.jl index c38301c..ac44c62 100644 --- a/src/forward.jl +++ b/src/forward.jl @@ -319,10 +319,10 @@ function lifForward( zit, wRecChange[i1,i2,i3,i4] = -0.01*0.0001 * (vt[i1,i2,i3,i4] - vth[i1,i2,i3,i4]) * zit[i1,i2,i3,i4] - if !iszero(wRec[i1,i2,i3,i4]) && # check if this is wRec subscription - synapseReconnectDelay[i1,i2,i3,i4] != 0 - - synapseReconnectDelay[i1,i2,i3,i4] -= 1 + if synapseReconnectDelay[i1,i2,i3,i4] < 0 && # negative value is counting mode + synapseReconnectDelay[i1,i2,i3,i4] < -0.2 # -0.1 < -0.1 won't work on GPU + + synapseReconnectDelay[i1,i2,i3,i4] += 1 if synapseReconnectDelay[i1,i2,i3,i4] == 0 # mark timestep synapseReconnectDelay[i1,i2,i3,i4] = sum(timeStep) @@ -528,10 +528,10 @@ function alifForward( zit, wRecChange[i1,i2,i3,i4] = -0.01*0.0001 * (vt[i1,i2,i3,i4] - avth[i1,i2,i3,i4]) * zit[i1,i2,i3,i4] - if !iszero(wRec[i1,i2,i3,i4]) && # check if this is wRec subscription - synapseReconnectDelay[i1,i2,i3,i4] != 0 - - synapseReconnectDelay[i1,i2,i3,i4] -= 1 + if synapseReconnectDelay[i1,i2,i3,i4] < 0 && # negative value is counting mode + synapseReconnectDelay[i1,i2,i3,i4] < -0.2 # -0.1 < -0.1 won't work on GPU + + synapseReconnectDelay[i1,i2,i3,i4] += 1 if synapseReconnectDelay[i1,i2,i3,i4] == 0 # mark timestep synapseReconnectDelay[i1,i2,i3,i4] = sum(timeStep) diff --git a/src/learn.jl b/src/learn.jl index d637100..f69e3db 100644 --- a/src/learn.jl +++ b/src/learn.jl @@ -94,7 +94,7 @@ function lifComputeParamsChange!( timeStep::CuArray, # some RSNN neuron that has direct connection to output neuron need to get Bjk # from output neuron that represent correct answer, the rest of RSNN get random Bjk - onW = @view(wOut[:, startCol:stopCol, sum(label), 1]) + onW = @view(wOut[:, startCol:stopCol, sum(label+1), 1]) # label+1 because julia is 1-based index _bk = @view(bk[:, startCol:stopCol, 1]) mask = iszero.(onW) bk_ = mask .* _bk @@ -175,7 +175,7 @@ function alifComputeParamsChange!( timeStep::CuArray, # some RSNN neuron that has direct connection to output neuron need to get Bjk # from output neuron that represent correct answer, the rest of RSNN get random Bjk - onW = @view(wOut[:, startCol:stopCol, sum(label), 1]) + onW = @view(wOut[:, startCol:stopCol, sum(label+1), 1]) # label+1 because julia is 1-based index _bk = @view(bk[:, startCol:stopCol, 1]) mask = iszero.(onW) bk_ = mask .* _bk @@ -423,20 +423,13 @@ function lifLearn(wRec, arrayProjection4d_cpu = arrayProjection4d |> cpu wRec_cpu = wRec |> cpu wRecChange_cpu = wRecChange |> cpu - wRecChange_cpu = wRecChange_cpu[:,:,:,1] eta_cpu = eta |> cpu - eta_cpu = eta_cpu[:,:,:,1] exInType_cpu = exInType |> cpu - exInType_cpu = exInType_cpu[:,:,:,1] neuronInactivityCounter_cpu = neuronInactivityCounter |> cpu - neuronInactivityCounter_cpu = neuronInactivityCounter_cpu[:,:,:,1] # (row, col, n) synapseReconnectDelay_cpu = synapseReconnectDelay |> cpu - synapseReconnectDelay_cpu = synapseReconnectDelay_cpu[:,:,:,1] synapticActivityCounter_cpu = synapticActivityCounter |> cpu - synapticActivityCounter_cpu = synapticActivityCounter_cpu[:,:,:,1] zitCumulative_cpu = zitCumulative |> cpu - zitCumulative_cpu = zitCumulative_cpu[:,:,1] - + println("synapse 3 ", synapseReconnectDelay_cpu[:,:,1,1]) # neuroplasticity, work on CPU side wRec_cpu, neuronInactivityCounter_cpu, synapseReconnectDelay_cpu = neuroplasticity(synapseConnectionNumber, @@ -445,44 +438,11 @@ function lifLearn(wRec, exInType_cpu, wRecChange_cpu, vt, - eta, + eta_cpu, neuronInactivityCounter_cpu, synapseReconnectDelay_cpu, synapticActivityCounter_cpu, progress,) - - - - - - - - - # # merge learning weight with average learning weight of all batch - # wch = sum(wRecChange, dims=4) ./ (size(wRec, 4)) .* arrayProjection4d - # wRec .= (exInType .* wRec) .+ wch - - # # (row, col) - - # # -W if less than 10% of repeat avg, +W otherwise - # _, _, i3 = size(wRec_cpu) - # for i in 1:i3 - # x = 0.1 * (sum(synapseReconnectDelay[:,:,i]) / length(synapseReconnectDelay[:,:,i])) - # mask = GeneralUtils.replaceLessThan.(wRec_cpu[:,:,i], x, -1, 1) - # wRec_cpu[:,:,i] .+= mask .* eta_cpu[:,:,i] .* wRec_cpu[:,:,i] - # end - - # # weak / negative synaptic connection will get randomed in neuroplasticity() - # wRec_cpu = GeneralUtils.replaceBetween.(wRec_cpu, 0.0, 0.01, -1.0) # mark with -1.0 - - # neuroplasticity, work on CPU side - # wRec_cpu = neuroplasticity(synapseConnectionNumber, - # zitCumulative_cpu, - # wRec_cpu, - # wRecChange_cpu, - # vt, - # neuronInactivityCounter_cpu, - # synapseReconnectDelay_cpu) # transfer data backto gpu wRec_cpu = wRec_cpu .* arrayProjection4d_cpu @@ -491,7 +451,7 @@ function lifLearn(wRec, neuronInactivityCounter = neuronInactivityCounter_cpu |> device synapseReconnectDelay_cpu = synapseReconnectDelay_cpu .* arrayProjection4d_cpu synapseReconnectDelay = synapseReconnectDelay_cpu |> device - + error("DEBUG -> lifLearn! $(Dates.now())") return wRec, neuronInactivityCounter, synapseReconnectDelay end @@ -515,15 +475,10 @@ function alifLearn(wRec, arrayProjection4d_cpu = arrayProjection4d |> cpu wRec_cpu = wRec |> cpu - wRec_cpu = wRec_cpu[:,:,:,1] # since every batch has the same neuron wRec, (row, col, n) eta_cpu = eta |> cpu - eta_cpu = eta_cpu[:,:,:,1] neuronInactivityCounter_cpu = neuronInactivityCounter |> cpu - neuronInactivityCounter_cpu = neuronInactivityCounter_cpu[:,:,:,1] # (row, col, n) synapseReconnectDelay_cpu = synapseReconnectDelay |> cpu - synapseReconnectDelay_cpu = synapseReconnectDelay_cpu[:,:,:,1] zitCumulative_cpu = zitCumulative |> cpu - zitCumulative_cpu = zitCumulative_cpu[:,:,1] # (row, col) # -W if less than 10% of repeat avg, +W otherwise _, _, i3 = size(wRec_cpu) @@ -566,10 +521,9 @@ function onLearn!(wOut, wOut .-= 0.001 .* wOut end -alltrue(args...) = false ∈ [args...] ? false : true -isbetween(x, lowerlimit, upperlimit) = lowerlimit < x < upperlimit ? true : false +GeneralUtils.allTrue(args...) = false ∈ [args...] ? false : true -#WORKING 1) implement 90% +w, 10% -w 2) rewrite this function +#WORKING 2) rewrite this function function neuroplasticity(synapseConnectionNumber, zitCumulative, # (row, col) wRec, # (row, col, n) @@ -583,196 +537,191 @@ function neuroplasticity(synapseConnectionNumber, progress,) # (row, col, n) i1,i2,i3 = size(wRec) println("eta $(size(eta))") - println("wRec $(size(wRec))") - error("DEBUG -> neuroplasticity $(Dates.now())") + println("wRec 1 $(size(wRec)) ", wRec[:,:,1,1]) + println("zitCumulative $(size(zitCumulative))") + + println("progress $progress") if progress == 2 # no need to learn # skip neuroplasticity #TODO I may need to do something with neuronInactivityCounter and other variables wRecChange .= 0 - elseif progress == 1 # progress increase + error("DEBUG -> neuroplasticity") + elseif progress != 0 # progress increase # ready to reconnect synapse must not have wRecChange mask = (!isequal).(wRec, 0) wRecChange .*= mask - - # merge learning weight with average learning weight of all batch - wRec .= abs.((exInType .* wRec) .+ wRecChange) # abs because wRec doesn't carry sign - - # seperate active synapse out of inactive in this signal - mask_inactiveSynapse = isequal.(synapticActivityCounter, 0) - mask_activeSynapse = (!isequal).(synapticActivityCounter, 0) - - # adjust weight based on vt progress and repeatition (90% +w, 10% -w) depend on epsilonRec - avgActivity = sum(synapticActivityCounter) / length(synapticActivityCounter) - lowerlimit = 0.1 * avgActivity - # +w, synapse with more than 10% of avg activity get increase weight by eta - mask_more = (!isless).(synapticActivityCounter, lowerlimit) - mask_2 = alltrue.(mask_activeSynapse, mask_more) - mask_2 .*= 1 .+ eta # minor activity synapse weight will be reduced by eta - wRec .*= mask_2 + # merge learning weight, all resulting negative wRec will get pruned + mergeLearnWeight!(wRec, exInType, wRecChange, synapticActivityCounter, synapseReconnectDelay) + # println("wRec 2 $(size(wRec)) ", wRec[:,:,1,1]) + # adjust wRec based on repeatition (90% +w, 10% -w) + growRepeatedPath!(wRec, synapticActivityCounter, eta) + # println("wRec 3 $(size(wRec)) ", wRec[:,:,1,1]) + # -w all non-fire connection except mature connection + weakenNotMatureSynapse!(wRec, synapticActivityCounter, eta) - # -w, synapse with less than 10% of avg activity get reduced weight by eta - mask_less = isbetween.(synapticActivityCounter, 0.0, lowerlimit) # 1st criteria + # prune weak synapse + pruneSynapse!(wRec, synapticActivityCounter, synapseReconnectDelay) - mask_3 = alltrue.(mask_activeSynapse, mask_less) - mask_3 .*= 1 .- eta # minor activity synapse weight will be reduced by eta - wRec .*= mask_3 + # rewire synapse connection + rewireSynapse!(wRec, neuronInactivityCounter, synapticActivityCounter, + synapseReconnectDelay, zitCumulative) + + error("DEBUG -> neuroplasticity") + + # wRec .= (exInType .* wRec) .+ wRecChange + # mask_negative = isless.(wRec, 0) + # mask_positive = (!isless).(wRec, 0) + # GeneralUtils.replaceElements!(mask_negative, 1, wRec, 0.0) # negative synapse get pruned + # GeneralUtils.replaceElements!(mask_negative, 1, synapticActivityCounter, -0.1) + # # set pruned synapse to random wait time + # waittime = rand((1:1000), size(wRec)) .* mask_negative # synapse's random wait time to reconnect + # # synapseReconnectDelay counting mode when value is negative hence .* -1 + # synapseReconnectDelay .= (synapseReconnectDelay .* mask_positive) .+ (waittime .* -1) + + # # seperate active synapse out of inactive in this signal + # mask_activeSynapse = (!isequal).(synapticActivityCounter, 0) + + # # adjust weight based on vt progress and repeatition (90% +w, 10% -w) depend on epsilonRec + # avgActivity = sum(synapticActivityCounter) / length(synapticActivityCounter) + # lowerlimit = 0.1 * avgActivity + + # # +w, synapse with more than 10% of avg activity get increase weight by eta + # mask_more = (!isless).(synapticActivityCounter, lowerlimit) + # mask_2 = GeneralUtils.allTrue.(mask_activeSynapse, mask_more) + # mask_2 .*= 1 .+ eta # minor activity synapse weight will be reduced by eta + # wRec .*= mask_2 + + # # -w, synapse with less than 10% of avg activity get reduced weight by eta + # mask_less = GeneralUtils.isBetween.(synapticActivityCounter, 0.0, lowerlimit) # 1st criteria + + # mask_3 = GeneralUtils.allTrue.(mask_activeSynapse, mask_less) + # mask_3 .*= 1 .- eta # minor activity synapse weight will be reduced by eta + # wRec .*= mask_3 # -w all non-fire connection except mature connection - mask_notmature = isless.(wRec, 0.1) # 2nd criteria, not mature synapse has weight < 0.1 - mask_1 = alltrue.(mask_inactiveSynapse, mask_notmature) - mask_1 .*= 1 .- eta - wRec .*= mask_1 + # mask_inactiveSynapse = isequal.(synapticActivityCounter, 0) + # mask_notmature = GeneralUtils.isBetween.(wRec, 0.0, 0.1) # 2nd criteria, not mature synapse has weight < 0.1 + # mask_1 = GeneralUtils.allTrue.(mask_inactiveSynapse, mask_notmature) + # mask_1 .*= 1 .- eta + # wRec .*= mask_1 # prune synapse - mask_weak = isbetween.(wRec, 0.0, 0.01) - mask_notweak = (!isbetween).(wRec, 0.0, 0.01) - wRec .*= mask_notweak # all marked weak synapse weight need to be 0.0 i.e. pruned - r = rand((1:1000), size(wRec)) .* mask_weak # synapse random wait time to reconnect - synapticActivityCounter .*= mask_notweak # all marked weak synapse activity are reset - synapticActivityCounter .+= (mask_weak .* -1.0) - synapseReconnectDelay .= (synapseReconnectDelay .* mask_notweak) .+ r # set pruned synapse to random wait time + # mask_weak = GeneralUtils.isBetween.(wRec, 0.0, 0.01) + # mask_notweak = (!GeneralUtils.isBetween).(wRec, 0.0, 0.01) + # wRec .*= mask_notweak # all marked weak synapse weight need to be 0.0 i.e. pruned + # # all weak synapse activity are reset + # GeneralUtils.replaceElements!(mask_weak, 1, synapticActivityCounter, -0.1) + # # set pruned synapse to random wait time + # r = rand((1:1000), size(wRec)) .* mask_weak # synapse's random wait time to reconnect + # # synapseReconnectDelay counting mode when value is negative hence .* -1 + # synapseReconnectDelay .= (synapseReconnectDelay .* mask_notweak) .+ (r .* -1) - #WORKING rewire synapse connection - synapseReconnectDelay mark timeStep while also counting delay == BUG - - for i in 1:i3 # neuron-by-neuron - if neuronInactivityCounter[1:1:i][1] < -10000 # neuron die i.e. reset all weight - println("neuron die") - neuronInactivityCounter[:,:,i] .= 0 # reset - w = random_wRec(i1,i2,1,synapseConnectionNumber) - wRec[:,:,i] .= w - - #WORKING - a = similar(w) .= -0.99 # synapseConnectionNumber of this neuron - mask = (!iszero).(w) - GeneralUtils.replaceElements!(mask, 1, a, 0) - synapseReconnectDelay[:,:,i] = a - else - remaining = 0 - if subToFireNeuron_current[1,1,i] < subToFireNeuron_toBe - toAddConn = subToFireNeuron_toBe - subToFireNeuron_current[1,1,i] - totalNewConn[1,1,i] = totalNewConn[1,1,i] - toAddConn - # add new conn to firing neurons pool - remaining = addNewSynapticConn!(zitMask[:,:,i], 1, - @view(wRec[:,:,i]), - @view(synapseReconnectDelay[:,:,i]), - toAddConn) - totalNewConn[1,1,i] += remaining - end - - # add new conn to non-firing neurons pool - remaining = addNewSynapticConn!(zitMask[:,:,i], 0, - @view(wRec[:,:,i]), - @view(synapseReconnectDelay[:,:,i]), - totalNewConn[1,1,i]) - if remaining > 0 # final get-all round if somehow non-firing pool has not enough slot - remaining = addNewSynapticConn!(zitMask[:,:,i], 1, - @view(wRec[:,:,i]), - @view(synapseReconnectDelay[:,:,i]), - remaining) - end - end - end + # rewire synapse connection + # for i in 1:i3 # neuron-by-neuron + # if neuronInactivityCounter[1:1:i][1] < -10000 # neuron die i.e. reset all weight + # println("neuron $i die") + # neuronInactivityCounter[:,:,i] .= 0 # reset + # w = random_wRec(i1,i2,1,synapseConnectionNumber) + # wRec[:,:,i] .= w + # a = similar(w) .= -0.1 # synapseConnectionNumber of this neuron + # mask = (!iszero).(w) + # GeneralUtils.replaceElements!(mask, 1, a, 0) + # synapseReconnectDelay[:,:,i] = a + # else #WORKING + # for i in eachindex(synapseReconnectDelay[:,:,i]) + # if i > 0 # mark timeStep available + # # get neuron pool at 10 timeStep earlier + # earlier = i - 10 > 0 ? i : i - 10 + # pool = sum(zitCumulative[:,:,earlier:i], dims=3) + # indices = findall(x -> x != 0, pool) + # pick = rand(indices) + # wRec[:,:,i][pick] = rand(0.01:0.01:0.5) + # synapticActivityCounter[:,:,i][pick] = 0 + # synapseReconnectDelay[:,:,i][pick] = -0.1 + # end + # end + # end + # end elseif progress == 0 # no progress, no weight update, only rewire - # -w all non-fire connection except mature connection + weakenNotMatureSynapse!(wRec, synapticActivityCounter, eta) - # prune weak connection + # prune weak synapse + pruneSynapse!(wRec, synapticActivityCounter, synapseReconnectDelay) # rewire synapse connection - elseif progress == -1 # setback - # adjust weight based on vt progress and repeatition (90% +w, 10% -w) depend on epsilonRec + rewireSynapse!(wRec, neuronInactivityCounter, synapticActivityCounter, + synapseReconnectDelay, zitCumulative) - # -w all non-fire connection except mature connection - - # prune weak connection - - # rewire synapse connection + error("DEBUG -> neuroplasticity $(Dates.now())") else error("undefined condition line $(@__LINE__)") end # error("DEBUG -> neuroplasticity $(Dates.now())") - # merge learning weight with average learning weight of all batch - wRec .= abs.((exInType .* wRec) .+ wRecChange) # abs because wRec doesn't carry sign - - - # adjust weight based on vt progress and repeatition (90% +w, 10% -w) depend on epsilonRec - - - - # -w all non-fire connection except mature connection - - - - # prune weak connection - # rewire synapse connection + # # for each neuron, find total number of synaptic conn that should draw + # # new connection to firing and non-firing neurons pool + # subToFireNeuron_toBe = Int(floor(0.7 * synapseConnectionNumber)) + # # for each neuron, count how many synapse already subscribed to firing-neurons + # zw = zitCumulative .* wRec + # subToFireNeuron_current = sum(GeneralUtils.GeneralUtils.isBetween.(zw, 0.0, 100.0), dims=(1,2)) # (1, 1, n) + # zitMask = (!iszero).(zitCumulative) # zitMask of firing neurons = 1, non-firing = 0 + # projection = ones(i1,i2,i3) + # zitMask = zitMask .* projection # (row, col, n) + # totalNewConn = sum(isequal.(wRec, -1.0), dims=(1,2)) # count new conn mark (-1.0), (1, 1, n) + # println("neuroplasticity, from $(synapseConnectionNumber*size(totalNewConn, 3)) conn, $(sum(totalNewConn)) are replaced") - - # for each neuron, find total number of synaptic conn that should draw - # new connection to firing and non-firing neurons pool - subToFireNeuron_toBe = Int(floor(0.7 * synapseConnectionNumber)) - - # for each neuron, count how many synapse already subscribed to firing-neurons - zw = zitCumulative .* wRec - subToFireNeuron_current = sum(GeneralUtils.isBetween.(zw, 0.0, 100.0), dims=(1,2)) # (1, 1, n) - zitMask = (!iszero).(zitCumulative) # zitMask of firing neurons = 1, non-firing = 0 - projection = ones(i1,i2,i3) - zitMask = zitMask .* projection # (row, col, n) - totalNewConn = sum(isequal.(wRec, -1.0), dims=(1,2)) # count new conn mark (-1.0), (1, 1, n) - println("neuroplasticity, from $(synapseConnectionNumber*size(totalNewConn, 3)) conn, $(sum(totalNewConn)) are replaced") - - # clear -1.0 marker - GeneralUtils.replaceElements!(wRec, -1.0, synapseReconnectDelay, -0.99) - GeneralUtils.replaceElements!(wRec, -1.0, 0.0) # -1.0 marker is no longer required + # # clear -1.0 marker + # GeneralUtils.replaceElements!(wRec, -1.0, synapseReconnectDelay, -0.99) + # GeneralUtils.replaceElements!(wRec, -1.0, 0.0) # -1.0 marker is no longer required - for i in 1:i3 - if neuronInactivityCounter[1:1:i][1] < -10000 # neuron die i.e. reset all weight - println("neuron die") - neuronInactivityCounter[:,:,i] .= 0 # reset - w = random_wRec(i1,i2,1,synapseConnectionNumber) - wRec[:,:,i] .= w + # for i in 1:i3 + # if neuronInactivityCounter[1:1:i][1] < -10000 # neuron die i.e. reset all weight + # println("neuron die") + # neuronInactivityCounter[:,:,i] .= 0 # reset + # w = random_wRec(i1,i2,1,synapseConnectionNumber) + # wRec[:,:,i] .= w - a = similar(w) .= -0.99 # synapseConnectionNumber of this neuron - mask = (!iszero).(w) - GeneralUtils.replaceElements!(mask, 1, a, 0) - synapseReconnectDelay[:,:,i] = a - else - remaining = 0 - if subToFireNeuron_current[1,1,i] < subToFireNeuron_toBe - toAddConn = subToFireNeuron_toBe - subToFireNeuron_current[1,1,i] - totalNewConn[1,1,i] = totalNewConn[1,1,i] - toAddConn - # add new conn to firing neurons pool - remaining = addNewSynapticConn!(zitMask[:,:,i], 1, - @view(wRec[:,:,i]), - @view(synapseReconnectDelay[:,:,i]), - toAddConn) - totalNewConn[1,1,i] += remaining - end + # a = similar(w) .= -0.99 # synapseConnectionNumber of this neuron + # mask = (!iszero).(w) + # GeneralUtils.replaceElements!(mask, 1, a, 0) + # synapseReconnectDelay[:,:,i] = a + # else + # remaining = 0 + # if subToFireNeuron_current[1,1,i] < subToFireNeuron_toBe + # toAddConn = subToFireNeuron_toBe - subToFireNeuron_current[1,1,i] + # totalNewConn[1,1,i] = totalNewConn[1,1,i] - toAddConn + # # add new conn to firing neurons pool + # remaining = addNewSynapticConn!(zitMask[:,:,i], 1, + # @view(wRec[:,:,i]), + # @view(synapseReconnectDelay[:,:,i]), + # toAddConn) + # totalNewConn[1,1,i] += remaining + # end - # add new conn to non-firing neurons pool - remaining = addNewSynapticConn!(zitMask[:,:,i], 0, - @view(wRec[:,:,i]), - @view(synapseReconnectDelay[:,:,i]), - totalNewConn[1,1,i]) - if remaining > 0 # final get-all round if somehow non-firing pool has not enough slot - remaining = addNewSynapticConn!(zitMask[:,:,i], 1, - @view(wRec[:,:,i]), - @view(synapseReconnectDelay[:,:,i]), - remaining) - end - end - end + # # add new conn to non-firing neurons pool + # remaining = addNewSynapticConn!(zitMask[:,:,i], 0, + # @view(wRec[:,:,i]), + # @view(synapseReconnectDelay[:,:,i]), + # totalNewConn[1,1,i]) + # if remaining > 0 # final get-all round if somehow non-firing pool has not enough slot + # remaining = addNewSynapticConn!(zitMask[:,:,i], 1, + # @view(wRec[:,:,i]), + # @view(synapseReconnectDelay[:,:,i]), + # remaining) + # end + # end + # end # error("DEBUG -> neuroplasticity $(Dates.now())") return wRec @@ -792,7 +741,7 @@ end # # for each neuron, count how many synap already subscribed to firing-neurons # zw = zitCumulative .* wRec -# subToFireNeuron_current = sum(GeneralUtils.isBetween.(zw, 0.0, 100.0), dims=(1,2)) # (1, 1, n) +# subToFireNeuron_current = sum(GeneralUtils.GeneralUtils.isBetween.(zw, 0.0, 100.0), dims=(1,2)) # (1, 1, n) # zitMask = (!iszero).(zitCumulative) # zitMask of firing neurons = 1, non-firing = 0 # projection = ones(i1,i2,i3) # zitMask = zitMask .* projection # (row, col, n) diff --git a/src/snnUtil.jl b/src/snnUtil.jl index e7bd345..00bbae2 100644 --- a/src/snnUtil.jl +++ b/src/snnUtil.jl @@ -1,8 +1,9 @@ module snnUtil -export refractoryStatus!, addNewSynapticConn! +export refractoryStatus!, addNewSynapticConn!, mergeLearnWeight!, growRepeatedPath!, + weakenNotMatureSynapse!, pruneSynapse!, rewireSynapse! -using Random +using Random, GeneralUtils #------------------------------------------------------------------------------------------------100 @@ -38,7 +39,7 @@ end # return sqrt(distance) # end -function addNewSynapticConn!(mask::AbstractArray{<:Any}, x::Number, wRec::AbstractArray{<:Any}, +function addNewSynapticConn!(mask::AbstractArray{<:Any}, markValue::Number, wRec::AbstractArray{<:Any}, counter::AbstractArray{<:Any}, n=0; rng::AbstractRNG=MersenneTwister(1234)) # println("mask ", mask, size(mask)) @@ -56,8 +57,8 @@ function addNewSynapticConn!(mask::AbstractArray{<:Any}, x::Number, wRec::Abstra if size(mask) != size(wRec) error("mask and wRec must have the same size") end - # get the indices of elements in mask that equal x - indices = findall(x -> x == x, mask) + # get the indices of elements in mask that equal markValue + indices = findall(x -> x == markValue, mask) alreadySub = findall(x -> x != 0, wRec) # get already subscribe setdiff!(indices, alreadySub) # remove already sub conn from pool @@ -81,6 +82,142 @@ function addNewSynapticConn!(mask::AbstractArray{<:Any}, x::Number, wRec::Abstra return remaining end +# function mergeLearnWeight!(wRec::AbstractArray, exInType, wRecChange::AbstractArray, +# synapticActivityCounter::AbstractArray, +# synapseReconnectDelay::AbstractArray) +# println("wRec 2 $(size(wRec)) ", wRec[:,:,1,1]) +# println("wRecChange ", wRecChange[:,:,1,1]) +# #WORKING look for flipped sign, it needs to get pruned +# wRec .= (exInType .* wRec) .+ wRecChange +# println("wRec 3 $(size(wRec)) ", wRec[:,:,1,1]) +# mask_negative = isless.(wRec, 0) +# mask_positive = (!isless).(wRec, 0) +# GeneralUtils.replaceElements!(mask_negative, 1, wRec, 0.0) # negative synapse get pruned +# println("wRec 4 $(size(wRec)) ", wRec[:,:,1,1]) +# GeneralUtils.replaceElements!(mask_negative, 1, synapticActivityCounter, -0.1) +# # set pruned synapse to random wait time +# waittime = rand((1:1000), size(wRec)) .* mask_negative # synapse's random wait time to reconnect +# # synapseReconnectDelay counting mode when value is negative hence .* -1 +# synapseReconnectDelay .= (synapseReconnectDelay .* mask_positive) .+ (waittime .* -1) +# error("DEBUG -> mergeLearnWeight!") +# end + +function mergeLearnWeight!(wRec::AbstractArray, exInType, wRecChange::AbstractArray, + synapticActivityCounter::AbstractArray, + synapseReconnectDelay::AbstractArray) + wRecSigned = exInType .* wRec + # -0.0 == 0.0 but isequal() implement as -0.0 != 0.0, so Ineed to get rid of -0.0 manually + GeneralUtils.replaceElements!(wRecSigned, -0, 0) + # println("wRec 2 $(size(wRecSigned)) ", wRecSigned[:,:,1,1]) + # println("wRecChange ", wRecChange[:,:,1,1]) + originalsign = sign.(wRecSigned) + + # println("originalsign ", originalsign[:,:,1,1]) + wRecSigned .= wRecSigned .+ wRecChange + # println("wRec 3 $(size(wRecSigned)) ", wRecSigned[:,:,1,1]) + newsign = sign.(wRecSigned) # look for flipped sign, it needs to get pruned + + # println("newsign ", newsign[:,:,1,1]) + flipsign = (!isequal).(originalsign, newsign) + # println("flipsign ", flipsign[:,:,1,1]) + nonflipsign = (isequal).(originalsign, newsign) + wRec .= abs.(wRecSigned) + + println("wRec 4 $(size(wRec)) ", wRec[:,:,1,1]) + GeneralUtils.replaceElements!(flipsign, 1, wRec, 0.0) # negative synapse get pruned + println("wRec 5 $(size(wRec)) ", wRec[:,:,1,1]) + GeneralUtils.replaceElements!(flipsign, 1, synapticActivityCounter, -0.1) + println("synapticActivityCounter ", synapticActivityCounter[:,:,1,1]) #BUG why 0.0 alot? + # set pruned synapse to random wait time + waittime = rand((1:1000), size(wRec)) .* flipsign # synapse's random wait time to reconnect + # synapseReconnectDelay counting mode when value is negative hence .* -1 + synapseReconnectDelay .= (synapseReconnectDelay .* nonflipsign) .+ (waittime .* -1) + println("synapseReconnectDelay ", synapseReconnectDelay[:,:,1,1]) + error("DEBUG -> mergeLearnWeight!") +end + +function growRepeatedPath!(wRec, synapticActivityCounter, eta) #BUG wRec get all 0 + # seperate active synapse out of inactive in this signal + mask_activeSynapse = (!isequal).(synapticActivityCounter, 0) + + # adjust weight based on vt progress and repeatition (90% +w, 10% -w) depend on epsilonRec + avgActivity = sum(synapticActivityCounter) / length(synapticActivityCounter) + lowerlimit = 0.1 * avgActivity + + # +w, synapse with more than 10% of avg activity get increase weight by eta + mask_more = (!isless).(synapticActivityCounter, lowerlimit) + mask_2 = GeneralUtils.allTrue.(mask_activeSynapse, mask_more) + mask_3 = mask_2 .* (1 .+ eta) # minor activity synapse weight will be reduced by eta + wRec .*= mask_3 + + # -w, synapse with less than 10% of avg activity get reduced weight by eta + mask_less = GeneralUtils.isBetween.(synapticActivityCounter, 0.0, lowerlimit) # 1st criteria + mask_3 = GeneralUtils.allTrue.(mask_activeSynapse, mask_less) + mask_4 = mask_3 .* (1 .- eta) # minor activity synapse weight will be reduced by eta + wRec .*= mask_4 + error("DEBUG -> growRepeatedPath!") +end + +function weakenNotMatureSynapse!(wRec, synapticActivityCounter, eta) + mask_inactiveSynapse = isequal.(synapticActivityCounter, 0) + mask_notmature = GeneralUtils.isBetween.(wRec, 0.0, 0.1) # 2nd criteria, not mature synapse has weight < 0.1 + mask_1 = GeneralUtils.allTrue.(mask_inactiveSynapse, mask_notmature) + mask_2 = mask_1 .* (1 .- eta) + wRec .*= mask_2 +end + +function pruneSynapse!(wRec, synapticActivityCounter, synapseReconnectDelay) + mask_weak = GeneralUtils.isBetween.(wRec, 0.0, 0.01) + mask_notweak = (!GeneralUtils.isBetween).(wRec, 0.0, 0.01) + wRec .*= mask_notweak # all marked weak synapse weight need to be 0.0 i.e. pruned + # all weak synapse activity are reset + GeneralUtils.replaceElements!(mask_weak, 1, synapticActivityCounter, -0.1) + # set pruned synapse to random wait time + r = rand((1:1000), size(wRec)) .* mask_weak # synapse's random wait time to reconnect + # synapseReconnectDelay counting mode when value is negative hence .* -1 + synapseReconnectDelay .= (synapseReconnectDelay .* mask_notweak) .+ (r .* -1) +end + +function rewireSynapse!(wRec::AbstractArray, neuronInactivityCounter::AbstractArray, + synapticActivityCounter::AbstractArray, + synapseReconnectDelay::AbstractArray, + zitCumulative::AbstractArray) + _,_,i3,i4 = size(wRec) + for i in 1:i3 # neuron-by-neuron + if neuronInactivityCounter[1,1,i,i4][1] < -10000 # neuron die i.e. reset all weight + println("neuron $i die") + neuronInactivityCounter[:,:,i,i4] .= 0 # reset + w = random_wRec(i1,i2,1,synapseConnectionNumber) + wRec[:,:,i,i4] .= w + + a = similar(w) .= -0.1 # synapseConnectionNumber of this neuron + mask = (!iszero).(w) + GeneralUtils.replaceElements!(mask, 1, a, 0) + synapseReconnectDelay[:,:,i,i4] = a + else + for i in eachindex(synapseReconnectDelay[:,:,i,i4]) + println("synapse 0 ", synapseReconnectDelay[:,:,i,i4]) + if i > 0 # mark timeStep available + # get neuron pool at 10 timeStep earlier + earlier = i - 10 > 0 ? i - 10 : i + println("i $i earlier $earlier") + println("zit $(size(zitCumulative)) ") + pool = sum(zitCumulative[:,:,earlier:i], dims=3) + println("pool $(size(pool)) ", pool) + indices = findall(x -> x != 0, pool) + pick = rand(indices) + # println("wRec 1 ", wRec[:,:,i,i4]) + wRec[:,:,i,i4][pick] = rand(0.01:0.01:0.5) + # println("wRec 2 ", wRec[:,:,i,i4]) + synapticActivityCounter[:,:,i,i4][pick] = 0 + synapseReconnectDelay[:,:,i,i4][pick] = -0.1 + error("DEBUG -> rewireSynapse!") + end + end + end + end +end + @@ -96,15 +233,6 @@ end - - - - - - - - - diff --git a/src/type.jl b/src/type.jl index 911c516..74c5465 100644 --- a/src/type.jl +++ b/src/type.jl @@ -232,9 +232,9 @@ function kfn_1(params::Dict; device=cpu) # count subscribed synapse activity, just like epsilonRec but without decay. # use to adjust weight based on how often neural pathway is used - kfn.lif_synapseReconnectDelay = (similar(kfn.lif_wRec) .= -1.0) # -1.0 for non-sub conn + kfn.lif_synapseReconnectDelay = (similar(kfn.lif_wRec) .= -0.1) # -0.1 for non-sub conn - kfn.lif_synapticActivityCounter = Array(similar(kfn.lif_wRec) .= -1.0) # -1.0 for non-sub conn + kfn.lif_synapticActivityCounter = Array(similar(kfn.lif_wRec) .= -0.1) # -0.1 for non-sub conn mask = Array((!iszero).(kfn.lif_wRec)) # initial value subscribed conn GeneralUtils.replaceElements!(mask, 1, kfn.lif_synapticActivityCounter, 0.0) @@ -283,9 +283,9 @@ function kfn_1(params::Dict; device=cpu) kfn.alif_firingCounter = (similar(kfn.alif_wRec) .= 0) kfn.alif_firingTargetFrequency = (similar(kfn.alif_wRec) .= 0.1) kfn.alif_neuronInactivityCounter = (similar(kfn.alif_wRec) .= 0) - kfn.alif_synapseReconnectDelay = (similar(kfn.alif_wRec) .= -1.0) # -1.0 for non-sub conn + kfn.alif_synapseReconnectDelay = (similar(kfn.alif_wRec) .= -0.1) # -0.1 for non-sub conn - kfn.alif_synapticActivityCounter = Array(similar(kfn.alif_wRec) .= -1.0) # -1.0 for non-sub conn + kfn.alif_synapticActivityCounter = Array(similar(kfn.alif_wRec) .= -0.1) # -0.1 for non-sub conn mask = Array((!iszero).(kfn.alif_wRec)) # initial value subscribed conn GeneralUtils.replaceElements!(mask, 1, kfn.alif_synapticActivityCounter, 0.0) @@ -434,8 +434,6 @@ end - -