module snnUtil export refractoryStatus!, addNewSynapticConn!, mergeLearnWeight!, growRepeatedPath!, weakenNotMatureSynapse!, pruneSynapse!, rewireSynapse! using Random, GeneralUtils #------------------------------------------------------------------------------------------------100 function refractoryStatus!(refractoryCounter, refractoryActive, refractoryInactive) d1, d2, d3, d4 = size(refractoryCounter) for j in 1:d4 for i in 1:d3 if refractoryCounter[1, 1, i, j] > 0 # inactive view(refractoryActive, 1, 1, i, j) .= 0 view(refractoryInactive, 1, 1, i, j) .= 1 else # active view(refractoryActive, 1, 1, i, j) .= 1 view(refractoryInactive, 1, 1, i, j) .= 0 end end end end # function frobenius_distance(A, B) # # Check if the matrices have the same size # if size(A) != size(B) # error("The matrices must have the same size") # end # # Initialize the distance to zero # distance = 0.0 # # Loop over the elements of the matrices and add the squared differences # for i in 1:size(A, 1) # for j in 1:size(A, 2) # distance += (A[i, j] - B[i, j])^2 # end # end # # Return the square root of the distance # return sqrt(distance) # end function addNewSynapticConn!(mask::AbstractArray{<:Any}, markValue::Number, wRec::AbstractArray{<:Any}, counter::AbstractArray{<:Any}, n=0; rng::AbstractRNG=MersenneTwister(1234)) # println("mask ", mask, size(mask)) # println("") # println("x ", x, size(x)) # println("") # println("wRec ", wRec, size(wRec)) # println("") # println("counter ", counter, size(counter)) # println("") # println("n ", n, size(n)) # println("") # check if mask and wRec have the same size if size(mask) != size(wRec) error("mask and wRec must have the same size") end # get the indices of elements in mask that equal markValue indices = findall(x -> x == markValue, mask) alreadySub = findall(x -> x != 0, wRec) # get already subscribe setdiff!(indices, alreadySub) # remove already sub conn from pool remaining = 0 if n == 0 || n > length(indices) remaining = n - length(indices) n = length(indices) end # shuffle the indices using the rng function shuffle!(rng, indices) # select the first n indices n > length(indices) ? println(">>> ", total_x_tobeReplced) : nothing selected = indices[1:n] # replace the elements in wRec at the selected positions with a for i in selected wRec[i] = rand(0.01:0.01:0.1) counter[i] = 0 # counting start from 0 end # error("DEBUG addNewSynapticConn!") return remaining end # function mergeLearnWeight!(wRec::AbstractArray, exInType, wRecChange::AbstractArray, # synapticActivityCounter::AbstractArray, # synapseReconnectDelay::AbstractArray) # println("wRec 2 $(size(wRec)) ", wRec[:,:,1,1]) # println("wRecChange ", wRecChange[:,:,1,1]) # #WORKING look for flipped sign, it needs to get pruned # wRec .= (exInType .* wRec) .+ wRecChange # println("wRec 3 $(size(wRec)) ", wRec[:,:,1,1]) # mask_negative = isless.(wRec, 0) # mask_positive = (!isless).(wRec, 0) # GeneralUtils.replaceElements!(mask_negative, 1, wRec, 0.0) # negative synapse get pruned # println("wRec 4 $(size(wRec)) ", wRec[:,:,1,1]) # GeneralUtils.replaceElements!(mask_negative, 1, synapticActivityCounter, -0.1) # # set pruned synapse to random wait time # waittime = rand((1:1000), size(wRec)) .* mask_negative # synapse's random wait time to reconnect # # synapseReconnectDelay counting mode when value is negative hence .* -1 # synapseReconnectDelay .= (synapseReconnectDelay .* mask_positive) .+ (waittime .* -1) # error("DEBUG -> mergeLearnWeight!") # end function mergeLearnWeight!(wRec::AbstractArray, exInType, wRecChange::AbstractArray, synapticActivityCounter::AbstractArray, synapseReconnectDelay::AbstractArray) wRecSigned = exInType .* wRec # -0.0 == 0.0 but isequal() implement as -0.0 != 0.0, so I need to get rid of -0.0 manually GeneralUtils.replaceElements!(wRecSigned, -0, 0) # println("wRec 2 $(size(wRecSigned)) ", wRecSigned[:,:,1,1]) # println("wRecChange ", wRecChange[:,:,1,1]) originalsign = sign.(wRecSigned) # println("originalsign ", originalsign[:,:,1,1]) wRecSigned .= wRecSigned .+ wRecChange # println("wRec 3 $(size(wRecSigned)) ", wRecSigned[:,:,1,1]) newsign = sign.(wRecSigned) # look for flipped sign, it needs to get pruned # println("newsign ", newsign[:,:,1,1]) flipsign = (!isequal).(originalsign, newsign) # println("flipsign ", flipsign[:,:,1,1]) nonflipsign = (isequal).(originalsign, newsign) wRec .= abs.(wRecSigned) # wRec store magnitude only, sign is at exInType # println("wRec 4 $(size(wRec)) ", wRec[:,:,1,1]) GeneralUtils.replaceElements!(flipsign, 1, wRec, 0.0) # negative synapse get pruned # println("wRec 5 $(size(wRec)) ", wRec[:,:,1,1]) GeneralUtils.replaceElements!(flipsign, 1, synapticActivityCounter, 0) # set pruned synapse to random wait time waittime = rand((1:1000), size(wRec)) .* flipsign # synapse's random wait time to reconnect # synapseReconnectDelay counting mode when value is negative hence .* -1 synapseReconnectDelay .= (synapseReconnectDelay .* nonflipsign) .+ (waittime .* -1) # println("synapseReconnectDelay ", synapseReconnectDelay[:,:,1,1]) #TODO check value # error("DEBUG -> mergeLearnWeight!") end function growRepeatedPath!(wRec, synapticActivityCounter, eta) #BUG wRec get all 0 # seperate active synapse out of inactive in this signal mask_activeSynapse = (!isequal).(synapticActivityCounter, 0) # adjust weight based on vt progress and repeatition (90% +w, 10% -w) depend on epsilonRec avgActivity = sum(synapticActivityCounter) / length(synapticActivityCounter) lowerlimit = 0.1 * avgActivity # +w, synapse with more than 10% of avg activity get increase weight by eta mask_more = (!isless).(synapticActivityCounter, lowerlimit) mask_2 = GeneralUtils.allTrue.(mask_activeSynapse, mask_more) mask_3 = mask_2 .* (1 .+ eta) # minor activity synapse weight will be reduced by eta wRec .*= mask_3 # -w, synapse with less than 10% of avg activity get reduced weight by eta mask_less = GeneralUtils.isBetween.(synapticActivityCounter, 0, lowerlimit) # 1st criteria mask_3 = GeneralUtils.allTrue.(mask_activeSynapse, mask_less) mask_4 = mask_3 .* (1 .- eta) # minor activity synapse weight will be reduced by eta wRec .*= mask_4 error("DEBUG -> growRepeatedPath!") end function weakenNotMatureSynapse!(wRec, synapticActivityCounter, eta) mask_inactiveSynapse = isequal.(synapticActivityCounter, 0) mask_notmature = GeneralUtils.isBetween.(wRec, 0.0, 0.1) # 2nd criteria, not mature synapse has weight < 0.1 mask_1 = GeneralUtils.allTrue.(mask_inactiveSynapse, mask_notmature) mask_2 = mask_1 .* (1 .- eta) wRec .*= mask_2 end function pruneSynapse!(wRec, synapticActivityCounter, synapseReconnectDelay) mask_weak = GeneralUtils.isBetween.(wRec, 0.0, 0.01) mask_notweak = (!GeneralUtils.isBetween).(wRec, 0.0, 0.01) wRec .*= mask_notweak # all marked weak synapse weight need to be 0.0 i.e. pruned # all weak synapse activity are reset GeneralUtils.replaceElements!(mask_weak, 1, synapticActivityCounter, 0) # set pruned synapse to random wait time r = rand((1:1000), size(wRec)) .* mask_weak # synapse's random wait time to reconnect # synapseReconnectDelay counting mode when value is negative hence .* -1 synapseReconnectDelay .= (synapseReconnectDelay .* mask_notweak) .+ (r .* -1) end function rewireSynapse!(wRec::AbstractArray, neuronInactivityCounter::AbstractArray, synapticActivityCounter::AbstractArray, synapseReconnectDelay::AbstractArray, zitCumulative::AbstractArray) _,_,i3,i4 = size(wRec) for i in 1:i3 # neuron-by-neuron if neuronInactivityCounter[1,1,i,i4][1] < -10000 # neuron die i.e. reset all weight println("neuron $i die") neuronInactivityCounter[:,:,i,i4] .= 0 # reset w = random_wRec(i1,i2,1,synapseConnectionNumber) wRec[:,:,i,i4] .= w a = similar(w) .= -0.1 # temp matrix use to put -0.1 into synapseReconnectDelay mask = (!iszero).(w) GeneralUtils.replaceElements!(mask, 1, a, 0) synapseReconnectDelay[:,:,i,i4] = a else for i in eachindex(synapseReconnectDelay[:,:,i,i4]) println("synapse 0 ", synapseReconnectDelay[:,:,i,i4]) if i > 0 # mark timeStep available # get neuron pool at 10 timeStep earlier earlier = i - 10 > 0 ? i - 10 : i println("i $i earlier $earlier") println("zit $(size(zitCumulative)) ") pool = sum(zitCumulative[:,:,earlier:i], dims=3) println("pool $(size(pool)) ", pool) indices = findall(x -> x != 0, pool) pick = rand(indices) # println("wRec 1 ", wRec[:,:,i,i4]) wRec[:,:,i,i4][pick] = rand(0.01:0.01:0.5) # println("wRec 2 ", wRec[:,:,i,i4]) synapticActivityCounter[:,:,i,i4][pick] = 0 synapseReconnectDelay[:,:,i,i4][pick] = -0.1 error("DEBUG -> rewireSynapse!") end end end end end end # module