module learn export learn!, compute_paramsChange! using Statistics, Random, LinearAlgebra, JSON3, Flux, CUDA, Dates using GeneralUtils using ..type, ..snnUtil #------------------------------------------------------------------------------------------------100 function compute_paramsChange!(kfn::kfn_1, modelError::CuArray, outputError::CuArray, label) lifComputeParamsChange!(kfn.timeStep, kfn.lif_phi, kfn.lif_epsilonRec, kfn.lif_eta, kfn.lif_eRec, kfn.lif_wRec, kfn.lif_exInType, kfn.lif_wRecChange, kfn.on_wOut, kfn.lif_firingCounter, kfn.lif_firingTargetFrequency, kfn.lif_arrayProjection4d, kfn.lif_error, modelError, outputError, kfn.inputSize, kfn.bk, label, ) alifComputeParamsChange!(kfn.timeStep, kfn.alif_phi, kfn.alif_epsilonRec, kfn.alif_eta, kfn.alif_eRec, kfn.alif_wRec, kfn.alif_exInType, kfn.alif_wRecChange, kfn.on_wOut, kfn.alif_firingCounter, kfn.alif_firingTargetFrequency, kfn.alif_arrayProjection4d, kfn.alif_error, modelError, outputError, kfn.inputSize, kfn.bk, label, kfn.alif_epsilonRecA, kfn.alif_beta, ) onComputeParamsChange!(kfn.on_phi, kfn.on_epsilonRec, kfn.on_eta, kfn.on_eRec, kfn.on_wOutChange, kfn.on_arrayProjection4d, kfn.on_error, outputError, ) # error("DEBUG -> kfn compute_paramsChange! $(Dates.now())") end function lifComputeParamsChange!( timeStep::CuArray, phi::CuArray, epsilonRec::CuArray, eta::CuArray, eRec::CuArray, wRec::CuArray, exInType::CuArray, wRecChange::CuArray, wOut::CuArray, firingCounter::CuArray, firingTargetFrequency::CuArray, arrayProjection4d::CuArray, nError::CuArray, modelError::CuArray, outputError::CuArray, inputSize::CuArray, bk::CuArray, label, ) eRec .= phi .* epsilonRec # 2D wRec matrix contain input, lif, alif neurons. I need only lif neurons startIndex = prod(inputSize) +1 stopIndex = startIndex + size(wRec, 3) -1 startCol = CartesianIndices(wRec)[startIndex][2] stopCol = CartesianIndices(wRec)[stopIndex][2] # some RSNN neuron that has direct connection to output neuron need to get Bjk # from output neuron that represent correct answer, the rest of RSNN get random Bjk onW = @view(wOut[:, startCol:stopCol, sum(label+1), 1]) # label+1 because julia is 1-based index _bk = @view(bk[:, startCol:stopCol, 1]) mask = iszero.(onW) bk_ = mask .* _bk bkComposed = onW .+ bk_ nError = bkComposed .* modelError nError = reshape(nError, (1,1,:,1)) # _,_,i3,_ = size(wOut) # for i in 1:i3 # # nError a.k.a. learning signal use dopamine concept, # # this neuron receive summed error signal (modelError) # onW = @view(wOut[:, startCol:stopCol, i, 1]) # _bk = @view(bk[:, startCol:stopCol, i, 1]) # mask = (iszero).(onW) # bk_ = mask .* _bk # bkComposed = onW .+ bk_ # nError = bkComposed .* modelError # nError = reshape(nError, (1,1,:,1)) # # compute wRecChange of all neurons wrt to iᵗʰ output neuron # wRecChange .+= (eta .* nError .* eRec) # end # compute wRecChange of all neurons wrt to iᵗʰ output neuron wRecChange .+= (eta .* nError .* eRec) # frequency regulator wRecChange .+= 0.001 .* ((firingTargetFrequency - (firingCounter./timeStep)) ./ timeStep) .* eta .* eRec # if sum(timeStep) == 785 # epsilonRec_cpu = epsilonRec |> cpu # println("modelError $modelError $(size(modelError))", modelError) # println("") # println("wOutSum $(size(wOutSum))") # wchange = (eta .* nError .* eRec) |> cpu # println("wchange 5 1 ", wchange[:,:,5,1]) # println("") # println("epsilonRec 5 1 ", epsilonRec_cpu[:,:,5,1]) # println("") # error("DEBUG lifComputeParamsChange!") # end # error("DEBUG lifComputeParamsChange!") # reset epsilonRec epsilonRec .= 0 end function alifComputeParamsChange!( timeStep::CuArray, phi::CuArray, epsilonRec::CuArray, eta::CuArray, eRec::CuArray, wRec::CuArray, exInType::CuArray, wRecChange::CuArray, wOut::CuArray, firingCounter::CuArray, firingTargetFrequency::CuArray, arrayProjection4d::CuArray, nError::CuArray, modelError::CuArray, outputError::CuArray, inputSize::CuArray, bk::CuArray, label, epsilonRecA::CuArray, beta::CuArray, ) eRec .= phi .* (epsilonRec .- (beta .* epsilonRecA)) # use eq. 25 # 2D wRec matrix contain input, lif, alif neurons. I need only lif neurons startIndex = prod(inputSize) +1 stopIndex = startIndex + size(wRec, 3) -1 startCol = CartesianIndices(wRec)[startIndex][2] stopCol = CartesianIndices(wRec)[stopIndex][2] # some RSNN neuron that has direct connection to output neuron need to get Bjk # from output neuron that represent correct answer, the rest of RSNN get random Bjk onW = @view(wOut[:, startCol:stopCol, sum(label+1), 1]) # label+1 because julia is 1-based index _bk = @view(bk[:, startCol:stopCol, 1]) mask = iszero.(onW) bk_ = mask .* _bk bkComposed = onW .+ bk_ nError = bkComposed .* modelError nError = reshape(nError, (1,1,:,1)) wRecChange .+= (eta .* nError .* eRec) # frequency regulator wRecChange .+= 0.001 .* ((firingTargetFrequency - (firingCounter./timeStep)) ./ timeStep) .* eta .* eRec # reset epsilonRec epsilonRec .= 0 epsilonRecA .= 0 # error("DEBUG -> alifComputeParamsChange! $(Dates.now())") end function onComputeParamsChange!(phi::CuArray, epsilonRec::CuArray, eta::CuArray, eRec::CuArray, wOutChange::CuArray, arrayProjection4d::CuArray, nError::CuArray, outputError::CuArray # outputError is output neuron's error ) eRec .= phi .* epsilonRec nError .= reshape(outputError, (1, 1, :, size(outputError, 2))) .* arrayProjection4d wOutChange .+= (-eta .* nError .* eRec) # reset epsilonRec epsilonRec .= 0 # error("DEBUG -> onComputeParamsChange! $(Dates.now())") end function lifComputeParamsChange!( phi::AbstractArray, epsilonRec::AbstractArray, eta::AbstractArray, wRec::AbstractArray, wRecChange::AbstractArray, wOut::AbstractArray, modelError::AbstractArray) d1, d2, d3, d4 = size(epsilonRec) # Bₖⱼ in paper, sum() to get each neuron's total wOut weight wOutSum = reshape(sum(wOut, dims=3), (d1, :, d4)) for j in 1:d4, i in 1:d3 # compute along neurons axis of every batch # how much error of this neuron 1-spike causing each output neuron's error view(wRecChange, :, :, i, j) .+= (-1 * view(eta, :, :, i, j)[1]) .* # eRec ( (view(phi, :, :, i, j)[1] .* view(epsilonRec, :, :, i, j)) .* # nError a.k.a. learning signal ( view(modelError, :, j)[1] * # dopamine concept, this neuron receive summed error signal # RSNN neuron's total wOut weight (neuron synaptic subscription .* wOutSum) view(wOutSum, :, :, j)[i] ) ) end end function alifComputeParamsChange!( phi::AbstractArray, epsilonRec::AbstractArray, epsilonRecA::AbstractArray, eta::AbstractArray, wRec::AbstractArray, wRecChange::AbstractArray, beta::AbstractArray, wOut::AbstractArray, modelError::AbstractArray) d1, d2, d3, d4 = size(epsilonRec) # Bₖⱼ in paper, sum() to get each neuron's total wOut weight wOutSum = reshape(sum(wOut, dims=3), (d1, :, d4)) for j in 1:d4, i in 1:d3 # compute along neurons axis of every batch # how much error of this neuron 1-spike causing each output neuron's error view(wRecChange, :, :, i, j) .+= (-1 * view(eta, :, :, i, j)[1]) .* # eRec ( # eRec_v (view(phi, :, :, i, j)[1] .* view(epsilonRec, :, :, i, j)) .+ # eRec_a ((view(phi, :, :, i, j)[1] * view(beta, :, :, i, j)[1]) .* view(epsilonRecA, :, :, i, j)) ) .* # nError a.k.a. learning signal ( view(modelError, :, j)[1] * # RSNN neuron's total wOut weight (neuron synaptic subscription .* wOutSum) view(wOutSum, :, :, j)[i] # sum(GeneralUtils.isNotEqual.(view(wRec, :, :, i, j), 0) .* # view(wOutSum, :, :, j)) ) end end function onComputeParamsChange!(phi::AbstractArray, epsilonRec::AbstractArray, eta::AbstractArray, wOutChange::AbstractArray, outputError::AbstractArray) d1, d2, d3, d4 = size(epsilonRec) for j in 1:d4, i in 1:d3 # compute along neurons axis of every batch # how much error of this neuron 1-spike causing each output neuron's error view(wOutChange, :, :, i, j) .+= (-1 * view(eta, :, :, i, j)[1]) .* # eRec ( (view(phi, :, :, i, j)[1] .* view(epsilonRec, :, :, i, j)) .* # nError a.k.a. learning signal, output neuron receives error of its own answer - correct answer. view(outputError, :, j)[i] ) end end function learn!(kfn::kfn_1, progress, device=cpu) # lif learn kfn.lif_wRec, kfn.lif_neuronInactivityCounter, kfn.lif_synapseReconnectDelay = lifLearn(kfn.lif_wRec, kfn.lif_wRecChange, kfn.lif_exInType, kfn.lif_arrayProjection4d, kfn.lif_neuronInactivityCounter, kfn.lif_synapseReconnectDelay, kfn.lif_synapseConnectionNumber, kfn.lif_synapticActivityCounter, kfn.lif_eta, kfn.lif_vt, kfn.zitCumulative, progress, device) # alif learn kfn.alif_wRec, kfn.alif_neuronInactivityCounter, kfn.alif_synapseReconnectDelay = alifLearn(kfn.alif_wRec, kfn.alif_wRecChange, kfn.alif_exInType, kfn.alif_arrayProjection4d, kfn.alif_neuronInactivityCounter, kfn.alif_synapseReconnectDelay, kfn.alif_synapseConnectionNumber, kfn.alif_synapticActivityCounter, kfn.alif_eta, kfn.alif_vt, kfn.zitCumulative, progress, device) # on learn onLearn!(kfn.on_wOut, kfn.on_wOutChange, kfn.on_arrayProjection4d) # wrap up learning session if kfn.learningStage == [3] kfn.learningStage = [0] end # error("DEBUG -> kfn learn! $(Dates.now())") end # function lifLearn(wRec, # exInType, # wRecChange, # arrayProjection4d, # neuronInactivityCounter, # synapseReconnectDelay, # synapseConnectionNumber, # synapticWChangeCounter, #TODO # eta, # zitCumulative, # device) # # merge learning weight with average learning weight of all batch # wch = sum(wRecChange, dims=4) ./ (size(wRec, 4)) .* arrayProjection4d # wRec .= (exInType .* wRec) .+ wch # arrayProjection4d_cpu = arrayProjection4d |> cpu # wRec_cpu = wRec |> cpu # wRec_cpu = wRec_cpu[:,:,:,1] # since every batch has the same neuron wRec, (row, col, n) # eta_cpu = eta |> cpu # eta_cpu = eta_cpu[:,:,:,1] # neuronInactivityCounter_cpu = neuronInactivityCounter |> cpu # neuronInactivityCounter_cpu = neuronInactivityCounter_cpu[:,:,:,1] # (row, col, n) # synapseReconnectDelay_cpu = synapseReconnectDelay |> cpu # synapseReconnectDelay_cpu = synapseReconnectDelay_cpu[:,:,:,1] # zitCumulative_cpu = zitCumulative |> cpu # zitCumulative_cpu = zitCumulative_cpu[:,:,1] # (row, col) # # -W if less than 10% of repeat avg, +W otherwise # _, _, i3 = size(wRec_cpu) # for i in 1:i3 # x = 0.1 * (sum(synapseReconnectDelay[:,:,i]) / length(synapseReconnectDelay[:,:,i])) # mask = GeneralUtils.replaceLessThan.(wRec_cpu[:,:,i], x, -1, 1) # wRec_cpu[:,:,i] .+= mask .* eta_cpu[:,:,i] .* wRec_cpu[:,:,i] # end # # weak / negative synaptic connection will get randomed in neuroplasticity() # wRec_cpu = GeneralUtils.replaceBetween.(wRec_cpu, 0.0, 0.01, -1.0) # mark with -1.0 # # neuroplasticity, work on CPU side # wRec_cpu = neuroplasticity(synapseConnectionNumber, # zitCumulative_cpu, # wRec_cpu, # neuronInactivityCounter_cpu, # synapseReconnectDelay_cpu) # wRec_cpu = wRec_cpu .* arrayProjection4d_cpu # wRec = wRec_cpu |> device # neuronInactivityCounter_cpu = neuronInactivityCounter_cpu .* arrayProjection4d_cpu # neuronInactivityCounter = neuronInactivityCounter_cpu |> device # synapseReconnectDelay_cpu = synapseReconnectDelay_cpu .* arrayProjection4d_cpu # synapseReconnectDelay = synapseReconnectDelay_cpu |> device # return wRec, neuronInactivityCounter, synapseReconnectDelay # end function lifLearn(wRec, wRecChange, exInType, arrayProjection4d, neuronInactivityCounter, synapseReconnectDelay, synapseConnectionNumber, synapticActivityCounter, eta, vt, zitCumulative, progress, device) # transfer data to cpu arrayProjection4d_cpu = arrayProjection4d |> cpu wRec_cpu = wRec |> cpu wRecChange_cpu = wRecChange |> cpu eta_cpu = eta |> cpu exInType_cpu = exInType |> cpu neuronInactivityCounter_cpu = neuronInactivityCounter |> cpu synapseReconnectDelay_cpu = synapseReconnectDelay |> cpu synapticActivityCounter_cpu = synapticActivityCounter |> cpu zitCumulative_cpu = zitCumulative |> cpu println("synapse 3 ", synapseReconnectDelay_cpu[:,:,1,1]) # neuroplasticity, work on CPU side wRec_cpu, neuronInactivityCounter_cpu, synapseReconnectDelay_cpu = neuroplasticity(synapseConnectionNumber, zitCumulative_cpu, wRec_cpu, exInType_cpu, wRecChange_cpu, vt, eta_cpu, neuronInactivityCounter_cpu, synapseReconnectDelay_cpu, synapticActivityCounter_cpu, progress,) # transfer data backto gpu wRec_cpu = wRec_cpu .* arrayProjection4d_cpu wRec = wRec_cpu |> device neuronInactivityCounter_cpu = neuronInactivityCounter_cpu .* arrayProjection4d_cpu neuronInactivityCounter = neuronInactivityCounter_cpu |> device synapseReconnectDelay_cpu = synapseReconnectDelay_cpu .* arrayProjection4d_cpu synapseReconnectDelay = synapseReconnectDelay_cpu |> device error("DEBUG -> lifLearn! $(Dates.now())") return wRec, neuronInactivityCounter, synapseReconnectDelay end function alifLearn(wRec, wRecChange, exInType, arrayProjection4d, neuronInactivityCounter, synapseReconnectDelay, synapseConnectionNumber, synapticActivityCounter, eta, vt, zitCumulative, progress, device) # merge learning weight with average learning weight of all batch wch = sum(wRecChange, dims=4) ./ (size(wRec, 4)) .* arrayProjection4d wRec .= (exInType .* wRec) .+ wch arrayProjection4d_cpu = arrayProjection4d |> cpu wRec_cpu = wRec |> cpu eta_cpu = eta |> cpu neuronInactivityCounter_cpu = neuronInactivityCounter |> cpu synapseReconnectDelay_cpu = synapseReconnectDelay |> cpu zitCumulative_cpu = zitCumulative |> cpu # -W if less than 10% of repeat avg, +W otherwise _, _, i3 = size(wRec_cpu) for i in 1:i3 x = 0.1 * (sum(synapseReconnectDelay[:,:,i]) / length(synapseReconnectDelay[:,:,i])) mask = GeneralUtils.replaceLessThan.(wRec_cpu[:,:,i], x, -1, 1) wRec_cpu[:,:,i] .+= mask .* eta_cpu[:,:,i] .* wRec_cpu[:,:,i] end # weak / negative synaptic connection will get randomed in neuroplasticity() wRec_cpu = GeneralUtils.replaceBetween.(wRec_cpu, 0.0, 0.01, -1.0) # mark with -1.0 # neuroplasticity, work on CPU side wRec_cpu = neuroplasticity(synapseConnectionNumber, zitCumulative_cpu, wRec_cpu, neuronInactivityCounter_cpu, synapseReconnectDelay_cpu) wRec_cpu = wRec_cpu .* arrayProjection4d_cpu wRec = wRec_cpu |> device neuronInactivityCounter_cpu = neuronInactivityCounter_cpu .* arrayProjection4d_cpu neuronInactivityCounter = neuronInactivityCounter_cpu |> device synapseReconnectDelay_cpu = synapseReconnectDelay_cpu .* arrayProjection4d_cpu synapseReconnectDelay = synapseReconnectDelay_cpu |> device # error("DEBUG -> alifLearn! $(Dates.now())") return wRec, neuronInactivityCounter, synapseReconnectDelay end function onLearn!(wOut, wOutChange, arrayProjection4d) # merge learning weight with average learning weight wOut .+= (sum(wOutChange, dims=4) ./ (size(wOut, 4))) .* arrayProjection4d # adaptive wOut to help convergence using c_decay wOut .-= 0.001 .* wOut end GeneralUtils.allTrue(args...) = false ∈ [args...] ? false : true #WORKING 2) rewrite this function function neuroplasticity(synapseConnectionNumber, zitCumulative, # (row, col) wRec, # (row, col, n) exInType, wRecChange, vt, eta, neuronInactivityCounter, synapseReconnectDelay, synapticActivityCounter, progress,) # (row, col, n) i1,i2,i3 = size(wRec) println("eta $(size(eta))") println("wRec 1 $(size(wRec)) ", wRec[:,:,1,1]) println("zitCumulative $(size(zitCumulative))") println("progress $progress") if progress == 2 # no need to learn # skip neuroplasticity #TODO I may need to do something with neuronInactivityCounter and other variables wRecChange .= 0 error("DEBUG -> neuroplasticity") elseif progress != 0 # progress increase # ready to reconnect synapse must not have wRecChange mask = (!isequal).(wRec, 0) wRecChange .*= mask # merge learning weight, all resulting negative wRec will get pruned mergeLearnWeight!(wRec, exInType, wRecChange, synapticActivityCounter, synapseReconnectDelay) # adjust wRec based on repeatition (90% +w, 10% -w) growRepeatedPath!(wRec, synapticActivityCounter, eta) error("DEBUG -> neuroplasticity") # -w all non-fire connection except mature connection weakenNotMatureSynapse!(wRec, synapticActivityCounter, eta) # prune weak synapse pruneSynapse!(wRec, synapticActivityCounter, synapseReconnectDelay) # rewire synapse connection rewireSynapse!(wRec, neuronInactivityCounter, synapticActivityCounter, synapseReconnectDelay, zitCumulative) # wRec .= (exInType .* wRec) .+ wRecChange # mask_negative = isless.(wRec, 0) # mask_positive = (!isless).(wRec, 0) # GeneralUtils.replaceElements!(mask_negative, 1, wRec, 0.0) # negative synapse get pruned # GeneralUtils.replaceElements!(mask_negative, 1, synapticActivityCounter, -0.1) # # set pruned synapse to random wait time # waittime = rand((1:1000), size(wRec)) .* mask_negative # synapse's random wait time to reconnect # # synapseReconnectDelay counting mode when value is negative hence .* -1 # synapseReconnectDelay .= (synapseReconnectDelay .* mask_positive) .+ (waittime .* -1) # # seperate active synapse out of inactive in this signal # mask_activeSynapse = (!isequal).(synapticActivityCounter, 0) # # adjust weight based on vt progress and repeatition (90% +w, 10% -w) depend on epsilonRec # avgActivity = sum(synapticActivityCounter) / length(synapticActivityCounter) # lowerlimit = 0.1 * avgActivity # # +w, synapse with more than 10% of avg activity get increase weight by eta # mask_more = (!isless).(synapticActivityCounter, lowerlimit) # mask_2 = GeneralUtils.allTrue.(mask_activeSynapse, mask_more) # mask_2 .*= 1 .+ eta # minor activity synapse weight will be reduced by eta # wRec .*= mask_2 # # -w, synapse with less than 10% of avg activity get reduced weight by eta # mask_less = GeneralUtils.isBetween.(synapticActivityCounter, 0.0, lowerlimit) # 1st criteria # mask_3 = GeneralUtils.allTrue.(mask_activeSynapse, mask_less) # mask_3 .*= 1 .- eta # minor activity synapse weight will be reduced by eta # wRec .*= mask_3 # -w all non-fire connection except mature connection # mask_inactiveSynapse = isequal.(synapticActivityCounter, 0) # mask_notmature = GeneralUtils.isBetween.(wRec, 0.0, 0.1) # 2nd criteria, not mature synapse has weight < 0.1 # mask_1 = GeneralUtils.allTrue.(mask_inactiveSynapse, mask_notmature) # mask_1 .*= 1 .- eta # wRec .*= mask_1 # prune synapse # mask_weak = GeneralUtils.isBetween.(wRec, 0.0, 0.01) # mask_notweak = (!GeneralUtils.isBetween).(wRec, 0.0, 0.01) # wRec .*= mask_notweak # all marked weak synapse weight need to be 0.0 i.e. pruned # # all weak synapse activity are reset # GeneralUtils.replaceElements!(mask_weak, 1, synapticActivityCounter, -0.1) # # set pruned synapse to random wait time # r = rand((1:1000), size(wRec)) .* mask_weak # synapse's random wait time to reconnect # # synapseReconnectDelay counting mode when value is negative hence .* -1 # synapseReconnectDelay .= (synapseReconnectDelay .* mask_notweak) .+ (r .* -1) # rewire synapse connection # for i in 1:i3 # neuron-by-neuron # if neuronInactivityCounter[1:1:i][1] < -10000 # neuron die i.e. reset all weight # println("neuron $i die") # neuronInactivityCounter[:,:,i] .= 0 # reset # w = random_wRec(i1,i2,1,synapseConnectionNumber) # wRec[:,:,i] .= w # a = similar(w) .= -0.1 # synapseConnectionNumber of this neuron # mask = (!iszero).(w) # GeneralUtils.replaceElements!(mask, 1, a, 0) # synapseReconnectDelay[:,:,i] = a # else #WORKING # for i in eachindex(synapseReconnectDelay[:,:,i]) # if i > 0 # mark timeStep available # # get neuron pool at 10 timeStep earlier # earlier = i - 10 > 0 ? i : i - 10 # pool = sum(zitCumulative[:,:,earlier:i], dims=3) # indices = findall(x -> x != 0, pool) # pick = rand(indices) # wRec[:,:,i][pick] = rand(0.01:0.01:0.5) # synapticActivityCounter[:,:,i][pick] = 0 # synapseReconnectDelay[:,:,i][pick] = -0.1 # end # end # end # end elseif progress == 0 # no progress, no weight update, only rewire # -w all non-fire connection except mature connection weakenNotMatureSynapse!(wRec, synapticActivityCounter, eta) # prune weak synapse pruneSynapse!(wRec, synapticActivityCounter, synapseReconnectDelay) # rewire synapse connection rewireSynapse!(wRec, neuronInactivityCounter, synapticActivityCounter, synapseReconnectDelay, zitCumulative) error("DEBUG -> neuroplasticity $(Dates.now())") else error("undefined condition line $(@__LINE__)") end # error("DEBUG -> neuroplasticity $(Dates.now())") # # for each neuron, find total number of synaptic conn that should draw # # new connection to firing and non-firing neurons pool # subToFireNeuron_toBe = Int(floor(0.7 * synapseConnectionNumber)) # # for each neuron, count how many synapse already subscribed to firing-neurons # zw = zitCumulative .* wRec # subToFireNeuron_current = sum(GeneralUtils.GeneralUtils.isBetween.(zw, 0.0, 100.0), dims=(1,2)) # (1, 1, n) # zitMask = (!iszero).(zitCumulative) # zitMask of firing neurons = 1, non-firing = 0 # projection = ones(i1,i2,i3) # zitMask = zitMask .* projection # (row, col, n) # totalNewConn = sum(isequal.(wRec, -1.0), dims=(1,2)) # count new conn mark (-1.0), (1, 1, n) # println("neuroplasticity, from $(synapseConnectionNumber*size(totalNewConn, 3)) conn, $(sum(totalNewConn)) are replaced") # # clear -1.0 marker # GeneralUtils.replaceElements!(wRec, -1.0, synapseReconnectDelay, -0.99) # GeneralUtils.replaceElements!(wRec, -1.0, 0.0) # -1.0 marker is no longer required # for i in 1:i3 # if neuronInactivityCounter[1:1:i][1] < -10000 # neuron die i.e. reset all weight # println("neuron die") # neuronInactivityCounter[:,:,i] .= 0 # reset # w = random_wRec(i1,i2,1,synapseConnectionNumber) # wRec[:,:,i] .= w # a = similar(w) .= -0.99 # synapseConnectionNumber of this neuron # mask = (!iszero).(w) # GeneralUtils.replaceElements!(mask, 1, a, 0) # synapseReconnectDelay[:,:,i] = a # else # remaining = 0 # if subToFireNeuron_current[1,1,i] < subToFireNeuron_toBe # toAddConn = subToFireNeuron_toBe - subToFireNeuron_current[1,1,i] # totalNewConn[1,1,i] = totalNewConn[1,1,i] - toAddConn # # add new conn to firing neurons pool # remaining = addNewSynapticConn!(zitMask[:,:,i], 1, # @view(wRec[:,:,i]), # @view(synapseReconnectDelay[:,:,i]), # toAddConn) # totalNewConn[1,1,i] += remaining # end # # add new conn to non-firing neurons pool # remaining = addNewSynapticConn!(zitMask[:,:,i], 0, # @view(wRec[:,:,i]), # @view(synapseReconnectDelay[:,:,i]), # totalNewConn[1,1,i]) # if remaining > 0 # final get-all round if somehow non-firing pool has not enough slot # remaining = addNewSynapticConn!(zitMask[:,:,i], 1, # @view(wRec[:,:,i]), # @view(synapseReconnectDelay[:,:,i]), # remaining) # end # end # end # error("DEBUG -> neuroplasticity $(Dates.now())") return wRec end # function neuroplasticity(synapseConnectionNumber, # zitCumulative, # (row, col) # wRec, # (row, col, n) # neuronInactivityCounter, # synapseReconnectDelay) # (row, col, n) # i1,i2,i3 = size(wRec) # # for each neuron, find total number of synaptic conn that should draw # # new connection to firing and non-firing neurons pool # subToFireNeuron_toBe = Int(floor(0.7 * synapseConnectionNumber)) # # for each neuron, count how many synap already subscribed to firing-neurons # zw = zitCumulative .* wRec # subToFireNeuron_current = sum(GeneralUtils.GeneralUtils.isBetween.(zw, 0.0, 100.0), dims=(1,2)) # (1, 1, n) # zitMask = (!iszero).(zitCumulative) # zitMask of firing neurons = 1, non-firing = 0 # projection = ones(i1,i2,i3) # zitMask = zitMask .* projection # (row, col, n) # totalNewConn = sum(isequal.(wRec, -1.0), dims=(1,2)) # count new conn mark (-1.0), (1, 1, n) # println("neuroplasticity, from $(synapseConnectionNumber*size(totalNewConn, 3)) conn, $(sum(totalNewConn)) are replaced") # # clear -1.0 marker # GeneralUtils.replaceElements!(wRec, -1.0, synapseReconnectDelay, -0.99) # GeneralUtils.replaceElements!(wRec, -1.0, 0.0) # -1.0 marker is no longer required # for i in 1:i3 # if neuronInactivityCounter[1:1:i][1] < -10000 # neuron die i.e. reset all weight # println("neuron die") # neuronInactivityCounter[:,:,i] .= 0 # reset # w = random_wRec(i1,i2,1,synapseConnectionNumber) # wRec[:,:,i] .= w # a = similar(w) .= -0.99 # synapseConnectionNumber of this neuron # mask = (!iszero).(w) # GeneralUtils.replaceElements!(mask, 1, a, 0) # synapseReconnectDelay[:,:,i] = a # else # remaining = 0 # if subToFireNeuron_current[1,1,i] < subToFireNeuron_toBe # toAddConn = subToFireNeuron_toBe - subToFireNeuron_current[1,1,i] # totalNewConn[1,1,i] = totalNewConn[1,1,i] - toAddConn # # add new conn to firing neurons pool # remaining = addNewSynapticConn!(zitMask[:,:,i], 1, # @view(wRec[:,:,i]), # @view(synapseReconnectDelay[:,:,i]), # toAddConn) # totalNewConn[1,1,i] += remaining # end # # add new conn to non-firing neurons pool # remaining = addNewSynapticConn!(zitMask[:,:,i], 0, # @view(wRec[:,:,i]), # @view(synapseReconnectDelay[:,:,i]), # totalNewConn[1,1,i]) # if remaining > 0 # final get-all round if somehow non-firing pool has not enough slot # remaining = addNewSynapticConn!(zitMask[:,:,i], 1, # @view(wRec[:,:,i]), # @view(synapseReconnectDelay[:,:,i]), # remaining) # end # end # end # # error("DEBUG -> neuroplasticity $(Dates.now())") # return wRec # end # learningLiquidity(x) = -0.0001x + 1 # -10000 to +10000; f(x) = -5e-05x+0.5 function learningLiquidity(x) if x > 10000 y = 0.0 elseif x < -10000 y = 1.0 else y = -5e-05x+0.5 # range -10000 to +10000 end return y end end # module