diff --git a/src/learn.jl b/src/learn.jl index 6a7bcbd..1f63dc7 100644 --- a/src/learn.jl +++ b/src/learn.jl @@ -562,88 +562,11 @@ function neuroplasticity(synapseConnectionNumber, # prune weak synapse pruneSynapse!(wRec, synapticActivityCounter, synapseReconnectDelay) - error("DEBUG -> neuroplasticity") + # rewire synapse connection rewireSynapse!(wRec, neuronInactivityCounter, synapticActivityCounter, synapseReconnectDelay, zitCumulative) - - - - # wRec .= (exInType .* wRec) .+ wRecChange - # mask_negative = isless.(wRec, 0) - # mask_positive = (!isless).(wRec, 0) - # GeneralUtils.replaceElements!(mask_negative, 1, wRec, 0.0) # negative synapse get pruned - # GeneralUtils.replaceElements!(mask_negative, 1, synapticActivityCounter, -0.1) - # # set pruned synapse to random wait time - # waittime = rand((1:1000), size(wRec)) .* mask_negative # synapse's random wait time to reconnect - # # synapseReconnectDelay counting mode when value is negative hence .* -1 - # synapseReconnectDelay .= (synapseReconnectDelay .* mask_positive) .+ (waittime .* -1) - - # # seperate active synapse out of inactive in this signal - # mask_activeSynapse = (!isequal).(synapticActivityCounter, 0) - - # # adjust weight based on vt progress and repeatition (90% +w, 10% -w) depend on epsilonRec - # avgActivity = sum(synapticActivityCounter) / length(synapticActivityCounter) - # lowerlimit = 0.1 * avgActivity - - # # +w, synapse with more than 10% of avg activity get increase weight by eta - # mask_more = (!isless).(synapticActivityCounter, lowerlimit) - # mask_2 = GeneralUtils.allTrue.(mask_activeSynapse, mask_more) - # mask_2 .*= 1 .+ eta # minor activity synapse weight will be reduced by eta - # wRec .*= mask_2 - - # # -w, synapse with less than 10% of avg activity get reduced weight by eta - # mask_less = GeneralUtils.isBetween.(synapticActivityCounter, 0.0, lowerlimit) # 1st criteria - - # mask_3 = GeneralUtils.allTrue.(mask_activeSynapse, mask_less) - # mask_3 .*= 1 .- eta # minor activity synapse weight will be reduced by eta - # wRec .*= mask_3 - - # -w all non-fire connection except mature connection - # mask_inactiveSynapse = isequal.(synapticActivityCounter, 0) - # mask_notmature = GeneralUtils.isBetween.(wRec, 0.0, 0.1) # 2nd criteria, not mature synapse has weight < 0.1 - # mask_1 = GeneralUtils.allTrue.(mask_inactiveSynapse, mask_notmature) - # mask_1 .*= 1 .- eta - # wRec .*= mask_1 - - # prune synapse - # mask_weak = GeneralUtils.isBetween.(wRec, 0.0, 0.01) - # mask_notweak = (!GeneralUtils.isBetween).(wRec, 0.0, 0.01) - # wRec .*= mask_notweak # all marked weak synapse weight need to be 0.0 i.e. pruned - # # all weak synapse activity are reset - # GeneralUtils.replaceElements!(mask_weak, 1, synapticActivityCounter, -0.1) - # # set pruned synapse to random wait time - # r = rand((1:1000), size(wRec)) .* mask_weak # synapse's random wait time to reconnect - # # synapseReconnectDelay counting mode when value is negative hence .* -1 - # synapseReconnectDelay .= (synapseReconnectDelay .* mask_notweak) .+ (r .* -1) - - # rewire synapse connection - # for i in 1:i3 # neuron-by-neuron - # if neuronInactivityCounter[1:1:i][1] < -10000 # neuron die i.e. reset all weight - # println("neuron $i die") - # neuronInactivityCounter[:,:,i] .= 0 # reset - # w = random_wRec(i1,i2,1,synapseConnectionNumber) - # wRec[:,:,i] .= w - - # a = similar(w) .= -0.1 # synapseConnectionNumber of this neuron - # mask = (!iszero).(w) - # GeneralUtils.replaceElements!(mask, 1, a, 0) - # synapseReconnectDelay[:,:,i] = a - # else #WORKING - # for i in eachindex(synapseReconnectDelay[:,:,i]) - # if i > 0 # mark timeStep available - # # get neuron pool at 10 timeStep earlier - # earlier = i - 10 > 0 ? i : i - 10 - # pool = sum(zitCumulative[:,:,earlier:i], dims=3) - # indices = findall(x -> x != 0, pool) - # pick = rand(indices) - # wRec[:,:,i][pick] = rand(0.01:0.01:0.5) - # synapticActivityCounter[:,:,i][pick] = 0 - # synapseReconnectDelay[:,:,i][pick] = -0.1 - # end - # end - # end - # end + error("DEBUG -> neuroplasticity") elseif progress == 0 # no progress, no weight update, only rewire # -w all non-fire connection except mature connection diff --git a/src/snnUtil.jl b/src/snnUtil.jl index adce33b..9217d5f 100644 --- a/src/snnUtil.jl +++ b/src/snnUtil.jl @@ -22,84 +22,37 @@ function refractoryStatus!(refractoryCounter, refractoryActive, refractoryInacti end end -# function frobenius_distance(A, B) -# # Check if the matrices have the same size -# if size(A) != size(B) -# error("The matrices must have the same size") +# function addNewSynapticConn!(mask::AbstractArray{<:Any}, markValue::Number, wRec::AbstractArray{<:Any}, +# counter::AbstractArray{<:Any}, n=0; +# rng::AbstractRNG=MersenneTwister(1234)) + +# # check if mask and wRec have the same size +# if size(mask) != size(wRec) +# error("mask and wRec must have the same size") # end -# # Initialize the distance to zero -# distance = 0.0 -# # Loop over the elements of the matrices and add the squared differences -# for i in 1:size(A, 1) -# for j in 1:size(A, 2) -# distance += (A[i, j] - B[i, j])^2 -# end +# # get the indices of elements in mask that equal markValue +# indices = findall(x -> x == markValue, mask) +# alreadySub = findall(x -> x != 0, wRec) # get already subscribe +# setdiff!(indices, alreadySub) # remove already sub conn from pool + +# remaining = 0 +# if n == 0 || n > length(indices) +# remaining = n - length(indices) +# n = length(indices) # end -# # Return the square root of the distance -# return sqrt(distance) -# end -function addNewSynapticConn!(mask::AbstractArray{<:Any}, markValue::Number, wRec::AbstractArray{<:Any}, - counter::AbstractArray{<:Any}, n=0; - rng::AbstractRNG=MersenneTwister(1234)) - # println("mask ", mask, size(mask)) - # println("") - # println("x ", x, size(x)) - # println("") - # println("wRec ", wRec, size(wRec)) - # println("") - # println("counter ", counter, size(counter)) - # println("") - # println("n ", n, size(n)) - # println("") - - # check if mask and wRec have the same size - if size(mask) != size(wRec) - error("mask and wRec must have the same size") - end - # get the indices of elements in mask that equal markValue - indices = findall(x -> x == markValue, mask) - alreadySub = findall(x -> x != 0, wRec) # get already subscribe - setdiff!(indices, alreadySub) # remove already sub conn from pool - - remaining = 0 - if n == 0 || n > length(indices) - remaining = n - length(indices) - n = length(indices) - end - - # shuffle the indices using the rng function - shuffle!(rng, indices) - # select the first n indices - n > length(indices) ? println(">>> ", total_x_tobeReplced) : nothing - selected = indices[1:n] - # replace the elements in wRec at the selected positions with a - for i in selected - wRec[i] = rand(0.01:0.01:0.1) - counter[i] = 0 # counting start from 0 - end - # error("DEBUG addNewSynapticConn!") - return remaining -end - -# function mergeLearnWeight!(wRec::AbstractArray, exInType, wRecChange::AbstractArray, -# synapticActivityCounter::AbstractArray, -# synapseReconnectDelay::AbstractArray) -# println("wRec 2 $(size(wRec)) ", wRec[:,:,1,1]) -# println("wRecChange ", wRecChange[:,:,1,1]) -# #WORKING look for flipped sign, it needs to get pruned -# wRec .= (exInType .* wRec) .+ wRecChange -# println("wRec 3 $(size(wRec)) ", wRec[:,:,1,1]) -# mask_negative = isless.(wRec, 0) -# mask_positive = (!isless).(wRec, 0) -# GeneralUtils.replaceElements!(mask_negative, 1, wRec, 0.0) # negative synapse get pruned -# println("wRec 4 $(size(wRec)) ", wRec[:,:,1,1]) -# GeneralUtils.replaceElements!(mask_negative, 1, synapticActivityCounter, -0.1) -# # set pruned synapse to random wait time -# waittime = rand((1:1000), size(wRec)) .* mask_negative # synapse's random wait time to reconnect -# # synapseReconnectDelay counting mode when value is negative hence .* -1 -# synapseReconnectDelay .= (synapseReconnectDelay .* mask_positive) .+ (waittime .* -1) -# error("DEBUG -> mergeLearnWeight!") +# # shuffle the indices using the rng function +# shuffle!(rng, indices) +# # select the first n indices +# n > length(indices) ? println(">>> ", total_x_tobeReplced) : nothing +# selected = indices[1:n] +# # replace the elements in wRec at the selected positions with a +# for i in selected +# wRec[i] = rand(0.01:0.01:0.1) +# counter[i] = 0 # counting start from 0 +# end +# # error("DEBUG addNewSynapticConn!") +# return remaining # end function mergeLearnWeight!(wRec::AbstractArray, exInType, wRecChange::AbstractArray, @@ -120,7 +73,7 @@ function mergeLearnWeight!(wRec::AbstractArray, exInType, wRecChange::AbstractAr # println("newsign ", newsign[:,:,1,1]) flipsign = (!isequal).(originalsign, newsign) # println("flipsign ", flipsign[:,:,1,1]) - nonflipsign = (isequal).(originalsign, newsign) + nonflipsign = isequal.(originalsign, newsign) wRec .= abs.(wRecSigned) # wRec store magnitude only, sign is at exInType # println("wRec 4 $(size(wRec)) ", wRec[:,:,1,1]) @@ -131,7 +84,7 @@ function mergeLearnWeight!(wRec::AbstractArray, exInType, wRecChange::AbstractAr waittime = rand((1:1000), size(wRec)) .* flipsign # synapse's random wait time to reconnect # synapseReconnectDelay counting mode when value is negative hence .* -1 synapseReconnectDelay .= (synapseReconnectDelay .* nonflipsign) .+ (waittime .* -1) - # println("synapseReconnectDelay ", synapseReconnectDelay[:,:,1,1]) #TODO check value + # println("synapseReconnectDelay ", synapseReconnectDelay[:,:,1,1]) # error("DEBUG -> mergeLearnWeight!") end @@ -158,9 +111,8 @@ function growRepeatedPath!(wRec, synapticActivityCounter, eta) #BUG wRec get al mask_less = GeneralUtils.isBetween.(synapticActivityCounter, 0, lowerlimit) # 1st criteria mask_3 = GeneralUtils.allTrue.(mask_activeSynapse, mask_less) mask_4 = mask_3 .* (1 .- eta) # minor activity synapse weight will be reduced by eta - # println("wRec 1 ", wRec[:,:,1,1]) - # println("mask_less ", mask_less[:,:,1,1]) - GeneralUtils.replaceElements!(mask_4, 0, 1) # replace 0 with 1 so mask * Wrec will not get 0 weight + # replace 0 with 1 so mask * wRec will not get 0 weight i.e. non-effected weight remain the same + GeneralUtils.replaceElements!(mask_4, 0, 1) # println("mask_4 ", mask_4[:,:,1,1]) wRec .*= mask_4 # println("wRec 2 ", wRec[:,:,1,1]) @@ -187,9 +139,10 @@ function pruneSynapse!(wRec, synapticActivityCounter, synapseReconnectDelay) # all weak synapse activity are reset GeneralUtils.replaceElements!(mask_weak, 1, synapticActivityCounter, 0) # set pruned synapse to random wait time - r = rand((1:1000), size(wRec)) .* mask_weak # synapse's random wait time to reconnect + waittime = rand((1:1000), size(wRec)) .* mask_weak # synapse's random wait time to reconnect # synapseReconnectDelay counting mode when value is negative hence .* -1 - synapseReconnectDelay .= (synapseReconnectDelay .* mask_notweak) .+ (r .* -1) + synapseReconnectDelay .= (synapseReconnectDelay .* mask_notweak) .+ (waittime .* -1) + # error("DEBUG -> pruneSynapse!") end function rewireSynapse!(wRec::AbstractArray, neuronInactivityCounter::AbstractArray, @@ -197,34 +150,36 @@ function rewireSynapse!(wRec::AbstractArray, neuronInactivityCounter::AbstractAr synapseReconnectDelay::AbstractArray, zitCumulative::AbstractArray) _,_,i3,i4 = size(wRec) - for i in 1:i3 # neuron-by-neuron - if neuronInactivityCounter[1,1,i,i4][1] < -10000 # neuron die i.e. reset all weight - println("neuron $i die") - neuronInactivityCounter[:,:,i,i4] .= 0 # reset + for n in 1:i3 # neuron-by-neuron + if neuronInactivityCounter[1,1,n,i4][1] < -10000 # neuron die i.e. reset all weight + println("neuron $n die") + neuronInactivityCounter[:,:,n,i4] .= 0 # reset w = random_wRec(i1,i2,1,synapseConnectionNumber) - wRec[:,:,i,i4] .= w + wRec[:,:,n,i4] .= w a = similar(w) .= -0.1 # temp matrix use to put -0.1 into synapseReconnectDelay mask = (!iszero).(w) GeneralUtils.replaceElements!(mask, 1, a, 0) - synapseReconnectDelay[:,:,i,i4] = a + synapseReconnectDelay[:,:,n,i4] = a else - for i in eachindex(synapseReconnectDelay[:,:,i,i4]) - println("synapse 0 ", synapseReconnectDelay[:,:,i,i4]) - if i > 0 # mark timeStep available + for ind in eachindex(synapseReconnectDelay[:,:,n,i4]) + timemark = synapseReconnectDelay[:,:,n,i4][ind] + if timemark > 0 # mark timeStep available # get neuron pool at 10 timeStep earlier - earlier = i - 10 > 0 ? i - 10 : i - println("i $i earlier $earlier") - println("zit $(size(zitCumulative)) ") - pool = sum(zitCumulative[:,:,earlier:i], dims=3) - println("pool $(size(pool)) ", pool) - indices = findall(x -> x != 0, pool) - pick = rand(indices) - # println("wRec 1 ", wRec[:,:,i,i4]) - wRec[:,:,i,i4][pick] = rand(0.01:0.01:0.5) - # println("wRec 2 ", wRec[:,:,i,i4]) - synapticActivityCounter[:,:,i,i4][pick] = 0 - synapseReconnectDelay[:,:,i,i4][pick] = -0.1 + earlier = timemark - 10 > 0 ? timemark - 10 : timemark + pool = sum(zitCumulative[:,:,earlier:timemark], dims=3) + println(">>> ", sum(pool)) + if sum(pool) != 0 + indices = findall(x -> x != 0, pool) + pick = rand(indices) + wRec[pick] = rand(0.01:0.01:0.05) + synapticActivityCounter[pick] = 0 + synapseReconnectDelay[pick] = -0.1 + else # if neurons not firing at all, try again next time + synapticActivityCounter[pick] = 0 + synapseReconnectDelay[:,:,n,i4] = rand(1:1000) + end + error("DEBUG -> rewireSynapse!") end end