From 1cc135c911fec816bff90dbdaf4ad8309508e53f Mon Sep 17 00:00:00 2001 From: ton Date: Fri, 15 Sep 2023 21:11:43 +0700 Subject: [PATCH] dev --- src/forward.jl | 40 ++- src/learn copy.jl | 857 ++++++++++++++++++++++++++++++++++++++++++++++ src/learn.jl | 58 +++- src/type.jl | 26 +- 4 files changed, 937 insertions(+), 44 deletions(-) create mode 100644 src/learn copy.jl diff --git a/src/forward.jl b/src/forward.jl index 8e85723..c38301c 100644 --- a/src/forward.jl +++ b/src/forward.jl @@ -313,20 +313,21 @@ function lifForward( zit, (zit[i1,i2,i3,i4] * !iszero(wRec[i1,i2,i3,i4])) # !iszero indicates synaptic subscription - synapticActivityCounter[i1,i2,i3,i4] = zit[i1,i2,i3,i4] * !iszero(wRec[i1,i2,i3,i4]) - - if !iszero(wRec[i1,i2,i3,i4]) # check if this is wRec subscription - synapseReconnectDelay[i1,i2,i3,i4] -= 1 - if synapseReconnectDelay[i1,i2,i3,i4] == 0 - # mark timestep - synapseReconnectDelay[i1,i2,i3,i4] = sum(timeStep) - wRec[i1,i2,i3,i4] = -1.0 # mark for reconnect - end - end + synapticActivityCounter[i1,i2,i3,i4] += zit[i1,i2,i3,i4] * !iszero(wRec[i1,i2,i3,i4]) # voltage regulator wRecChange[i1,i2,i3,i4] = -0.01*0.0001 * (vt[i1,i2,i3,i4] - vth[i1,i2,i3,i4]) * zit[i1,i2,i3,i4] + + if !iszero(wRec[i1,i2,i3,i4]) && # check if this is wRec subscription + synapseReconnectDelay[i1,i2,i3,i4] != 0 + + synapseReconnectDelay[i1,i2,i3,i4] -= 1 + if synapseReconnectDelay[i1,i2,i3,i4] == 0 + # mark timestep + synapseReconnectDelay[i1,i2,i3,i4] = sum(timeStep) + end + end end end return nothing @@ -521,18 +522,21 @@ function alifForward( zit, (phi[i1,i2,i3,i4] * epsilonRec[i1,i2,i3,i4])) + (zit[i1,i2,i3,i4] * !iszero(wRec[i1,i2,i3,i4])) - synapticActivityCounter[i1,i2,i3,i4] = zit[i1,i2,i3,i4] * !iszero(wRec[i1,i2,i3,i4]) + synapticActivityCounter[i1,i2,i3,i4] += zit[i1,i2,i3,i4] * !iszero(wRec[i1,i2,i3,i4]) - if !iszero(wRec[i1,i2,i3,i4]) # check if this is wRec subscription - synapseReconnectDelay[i1,i2,i3,i4] -= 1 - if synapseReconnectDelay[i1,i2,i3,i4] == 0 - synapseReconnectDelay[i1,i2,i3,i4] = sum(timeStep) - wRec[i1,i2,i3,i4] = -1.0 # mark for reconnect - end - end # voltage regulator wRecChange[i1,i2,i3,i4] = -0.01*0.0001 * (vt[i1,i2,i3,i4] - avth[i1,i2,i3,i4]) * zit[i1,i2,i3,i4] + + if !iszero(wRec[i1,i2,i3,i4]) && # check if this is wRec subscription + synapseReconnectDelay[i1,i2,i3,i4] != 0 + + synapseReconnectDelay[i1,i2,i3,i4] -= 1 + if synapseReconnectDelay[i1,i2,i3,i4] == 0 + # mark timestep + synapseReconnectDelay[i1,i2,i3,i4] = sum(timeStep) + end + end end end return nothing diff --git a/src/learn copy.jl b/src/learn copy.jl new file mode 100644 index 0000000..0fa24e5 --- /dev/null +++ b/src/learn copy.jl @@ -0,0 +1,857 @@ +module learn + +export learn!, compute_paramsChange! + +using Statistics, Random, LinearAlgebra, JSON3, Flux, CUDA, Dates +using GeneralUtils +using ..type, ..snnUtil + +#------------------------------------------------------------------------------------------------100 + +function compute_paramsChange!(kfn::kfn_1, modelError::CuArray, outputError::CuArray, label) + + lifComputeParamsChange!(kfn.timeStep, + kfn.lif_phi, + kfn.lif_epsilonRec, + kfn.lif_eta, + kfn.lif_eRec, + kfn.lif_wRec, + kfn.lif_exInType, + kfn.lif_wRecChange, + kfn.on_wOut, + kfn.lif_firingCounter, + kfn.lif_firingTargetFrequency, + kfn.lif_arrayProjection4d, + kfn.lif_error, + modelError, + outputError, + kfn.inputSize, + kfn.bk, + label, + ) + + alifComputeParamsChange!(kfn.timeStep, + kfn.alif_phi, + kfn.alif_epsilonRec, + kfn.alif_eta, + kfn.alif_eRec, + kfn.alif_wRec, + kfn.alif_exInType, + kfn.alif_wRecChange, + kfn.on_wOut, + kfn.alif_firingCounter, + kfn.alif_firingTargetFrequency, + kfn.alif_arrayProjection4d, + kfn.alif_error, + modelError, + outputError, + kfn.inputSize, + kfn.bk, + label, + + kfn.alif_epsilonRecA, + kfn.alif_beta, + ) + + onComputeParamsChange!(kfn.on_phi, + kfn.on_epsilonRec, + kfn.on_eta, + kfn.on_eRec, + kfn.on_wOutChange, + kfn.on_arrayProjection4d, + kfn.on_error, + outputError, + ) + # error("DEBUG -> kfn compute_paramsChange! $(Dates.now())") +end + +function lifComputeParamsChange!( timeStep::CuArray, + phi::CuArray, + epsilonRec::CuArray, + eta::CuArray, + eRec::CuArray, + wRec::CuArray, + exInType::CuArray, + wRecChange::CuArray, + wOut::CuArray, + firingCounter::CuArray, + firingTargetFrequency::CuArray, + arrayProjection4d::CuArray, + nError::CuArray, + modelError::CuArray, + outputError::CuArray, + inputSize::CuArray, + bk::CuArray, + label, + ) + eRec .= phi .* epsilonRec + + # 2D wRec matrix contain input, lif, alif neurons. I need only lif neurons + startIndex = prod(inputSize) +1 + stopIndex = startIndex + size(wRec, 3) -1 + startCol = CartesianIndices(wRec)[startIndex][2] + stopCol = CartesianIndices(wRec)[stopIndex][2] + + # some RSNN neuron that has direct connection to output neuron need to get Bjk + # from output neuron that represent correct answer, the rest of RSNN get random Bjk + onW = @view(wOut[:, startCol:stopCol, sum(label), 1]) + _bk = @view(bk[:, startCol:stopCol, 1]) + mask = iszero.(onW) + bk_ = mask .* _bk + bkComposed = onW .+ bk_ + nError = bkComposed .* modelError + nError = reshape(nError, (1,1,:,1)) + + # _,_,i3,_ = size(wOut) + # for i in 1:i3 + # # nError a.k.a. learning signal use dopamine concept, + # # this neuron receive summed error signal (modelError) + # onW = @view(wOut[:, startCol:stopCol, i, 1]) + # _bk = @view(bk[:, startCol:stopCol, i, 1]) + # mask = (iszero).(onW) + # bk_ = mask .* _bk + # bkComposed = onW .+ bk_ + + # nError = bkComposed .* modelError + # nError = reshape(nError, (1,1,:,1)) + + # # compute wRecChange of all neurons wrt to iᵗʰ output neuron + # wRecChange .+= (eta .* nError .* eRec) + # end + + # compute wRecChange of all neurons wrt to iᵗʰ output neuron + wRecChange .+= (eta .* nError .* eRec) + + # frequency regulator + wRecChange .+= 0.001 .* ((firingTargetFrequency - (firingCounter./timeStep)) ./ timeStep) .* + eta .* eRec + # if sum(timeStep) == 785 + # epsilonRec_cpu = epsilonRec |> cpu + # println("modelError $modelError $(size(modelError))", modelError) + # println("") + # println("wOutSum $(size(wOutSum))") + # wchange = (eta .* nError .* eRec) |> cpu + # println("wchange 5 1 ", wchange[:,:,5,1]) + # println("") + # println("epsilonRec 5 1 ", epsilonRec_cpu[:,:,5,1]) + # println("") + # error("DEBUG lifComputeParamsChange!") + # end + # error("DEBUG lifComputeParamsChange!") + # reset epsilonRec + epsilonRec .= 0 +end + +function alifComputeParamsChange!( timeStep::CuArray, + phi::CuArray, + epsilonRec::CuArray, + eta::CuArray, + eRec::CuArray, + wRec::CuArray, + exInType::CuArray, + wRecChange::CuArray, + wOut::CuArray, + firingCounter::CuArray, + firingTargetFrequency::CuArray, + arrayProjection4d::CuArray, + nError::CuArray, + modelError::CuArray, + outputError::CuArray, + inputSize::CuArray, + bk::CuArray, + label, + + epsilonRecA::CuArray, + beta::CuArray, + ) + + eRec .= phi .* (epsilonRec .- (beta .* epsilonRecA)) # use eq. 25 + + # 2D wRec matrix contain input, lif, alif neurons. I need only lif neurons + startIndex = prod(inputSize) +1 + stopIndex = startIndex + size(wRec, 3) -1 + startCol = CartesianIndices(wRec)[startIndex][2] + stopCol = CartesianIndices(wRec)[stopIndex][2] + + # some RSNN neuron that has direct connection to output neuron need to get Bjk + # from output neuron that represent correct answer, the rest of RSNN get random Bjk + onW = @view(wOut[:, startCol:stopCol, sum(label), 1]) + _bk = @view(bk[:, startCol:stopCol, 1]) + mask = iszero.(onW) + bk_ = mask .* _bk + bkComposed = onW .+ bk_ + nError = bkComposed .* modelError + nError = reshape(nError, (1,1,:,1)) + + wRecChange .+= (eta .* nError .* eRec) + + # frequency regulator + wRecChange .+= 0.001 .* ((firingTargetFrequency - (firingCounter./timeStep)) ./ timeStep) .* + eta .* eRec + + # reset epsilonRec + epsilonRec .= 0 + epsilonRecA .= 0 + + # error("DEBUG -> alifComputeParamsChange! $(Dates.now())") +end + +function onComputeParamsChange!(phi::CuArray, + epsilonRec::CuArray, + eta::CuArray, + eRec::CuArray, + wOutChange::CuArray, + arrayProjection4d::CuArray, + nError::CuArray, + outputError::CuArray # outputError is output neuron's error + ) + + eRec .= phi .* epsilonRec + nError .= reshape(outputError, (1, 1, :, size(outputError, 2))) .* arrayProjection4d + wOutChange .+= (-eta .* nError .* eRec) + + # reset epsilonRec + epsilonRec .= 0 + + # error("DEBUG -> onComputeParamsChange! $(Dates.now())") +end + +function lifComputeParamsChange!( phi::AbstractArray, + epsilonRec::AbstractArray, + eta::AbstractArray, + wRec::AbstractArray, + wRecChange::AbstractArray, + wOut::AbstractArray, + modelError::AbstractArray) + d1, d2, d3, d4 = size(epsilonRec) + # Bₖⱼ in paper, sum() to get each neuron's total wOut weight + wOutSum = reshape(sum(wOut, dims=3), (d1, :, d4)) + + for j in 1:d4, i in 1:d3 # compute along neurons axis of every batch + # how much error of this neuron 1-spike causing each output neuron's error + + view(wRecChange, :, :, i, j) .+= (-1 * view(eta, :, :, i, j)[1]) .* + # eRec + ( + (view(phi, :, :, i, j)[1] .* view(epsilonRec, :, :, i, j)) .* + # nError a.k.a. learning signal + ( + view(modelError, :, j)[1] * # dopamine concept, this neuron receive summed error signal + # RSNN neuron's total wOut weight (neuron synaptic subscription .* wOutSum) + view(wOutSum, :, :, j)[i] + ) + ) + end +end + +function alifComputeParamsChange!( phi::AbstractArray, + epsilonRec::AbstractArray, + epsilonRecA::AbstractArray, + eta::AbstractArray, + wRec::AbstractArray, + wRecChange::AbstractArray, + beta::AbstractArray, + wOut::AbstractArray, + modelError::AbstractArray) + d1, d2, d3, d4 = size(epsilonRec) + + # Bₖⱼ in paper, sum() to get each neuron's total wOut weight + wOutSum = reshape(sum(wOut, dims=3), (d1, :, d4)) + + for j in 1:d4, i in 1:d3 # compute along neurons axis of every batch + # how much error of this neuron 1-spike causing each output neuron's error + + view(wRecChange, :, :, i, j) .+= (-1 * view(eta, :, :, i, j)[1]) .* + # eRec + ( + # eRec_v + (view(phi, :, :, i, j)[1] .* view(epsilonRec, :, :, i, j)) .+ + # eRec_a + ((view(phi, :, :, i, j)[1] * view(beta, :, :, i, j)[1]) .* + view(epsilonRecA, :, :, i, j)) + ) .* + # nError a.k.a. learning signal + ( + view(modelError, :, j)[1] * + # RSNN neuron's total wOut weight (neuron synaptic subscription .* wOutSum) + view(wOutSum, :, :, j)[i] + # sum(GeneralUtils.isNotEqual.(view(wRec, :, :, i, j), 0) .* + # view(wOutSum, :, :, j)) + ) + end +end + +function onComputeParamsChange!(phi::AbstractArray, + epsilonRec::AbstractArray, + eta::AbstractArray, + wOutChange::AbstractArray, + outputError::AbstractArray) + d1, d2, d3, d4 = size(epsilonRec) + + for j in 1:d4, i in 1:d3 # compute along neurons axis of every batch + # how much error of this neuron 1-spike causing each output neuron's error + + view(wOutChange, :, :, i, j) .+= (-1 * view(eta, :, :, i, j)[1]) .* + # eRec + ( + (view(phi, :, :, i, j)[1] .* view(epsilonRec, :, :, i, j)) .* + # nError a.k.a. learning signal, output neuron receives error of its own answer - correct answer. + view(outputError, :, j)[i] + ) + end +end + +function learn!(kfn::kfn_1, progress, device=cpu) + # lif learn + kfn.lif_wRec, kfn.lif_neuronInactivityCounter, kfn.lif_synapseReconnectDelay = + lifLearn(kfn.lif_wRec, + kfn.lif_wRecChange, + kfn.lif_exInType, + kfn.lif_arrayProjection4d, + kfn.lif_neuronInactivityCounter, + kfn.lif_synapseReconnectDelay, + kfn.lif_synapseConnectionNumber, + kfn.lif_synapticActivityCounter, + kfn.lif_eta, + kfn.lif_vt, + kfn.zitCumulative, + progress, + device) + + # alif learn + kfn.alif_wRec, kfn.alif_neuronInactivityCounter, kfn.alif_synapseReconnectDelay = + alifLearn(kfn.alif_wRec, + kfn.alif_wRecChange, + kfn.alif_exInType, + kfn.alif_arrayProjection4d, + kfn.alif_neuronInactivityCounter, + kfn.alif_synapseReconnectDelay, + kfn.alif_synapseConnectionNumber, + kfn.alif_synapticActivityCounter, + kfn.alif_eta, + kfn.alif_vt, + kfn.zitCumulative, + progress, + device) + + # on learn + onLearn!(kfn.on_wOut, + kfn.on_wOutChange, + kfn.on_arrayProjection4d) + + # wrap up learning session + if kfn.learningStage == [3] + kfn.learningStage = [0] + end + # error("DEBUG -> kfn learn! $(Dates.now())") +end + +# function lifLearn(wRec, +# exInType, +# wRecChange, +# arrayProjection4d, +# neuronInactivityCounter, +# synapseReconnectDelay, +# synapseConnectionNumber, +# synapticWChangeCounter, #TODO +# eta, +# zitCumulative, +# device) + +# # merge learning weight with average learning weight of all batch +# wch = sum(wRecChange, dims=4) ./ (size(wRec, 4)) .* arrayProjection4d +# wRec .= (exInType .* wRec) .+ wch + +# arrayProjection4d_cpu = arrayProjection4d |> cpu +# wRec_cpu = wRec |> cpu +# wRec_cpu = wRec_cpu[:,:,:,1] # since every batch has the same neuron wRec, (row, col, n) +# eta_cpu = eta |> cpu +# eta_cpu = eta_cpu[:,:,:,1] +# neuronInactivityCounter_cpu = neuronInactivityCounter |> cpu +# neuronInactivityCounter_cpu = neuronInactivityCounter_cpu[:,:,:,1] # (row, col, n) +# synapseReconnectDelay_cpu = synapseReconnectDelay |> cpu +# synapseReconnectDelay_cpu = synapseReconnectDelay_cpu[:,:,:,1] +# zitCumulative_cpu = zitCumulative |> cpu +# zitCumulative_cpu = zitCumulative_cpu[:,:,1] # (row, col) + +# # -W if less than 10% of repeat avg, +W otherwise +# _, _, i3 = size(wRec_cpu) +# for i in 1:i3 +# x = 0.1 * (sum(synapseReconnectDelay[:,:,i]) / length(synapseReconnectDelay[:,:,i])) +# mask = GeneralUtils.replaceLessThan.(wRec_cpu[:,:,i], x, -1, 1) +# wRec_cpu[:,:,i] .+= mask .* eta_cpu[:,:,i] .* wRec_cpu[:,:,i] +# end + +# # weak / negative synaptic connection will get randomed in neuroplasticity() +# wRec_cpu = GeneralUtils.replaceBetween.(wRec_cpu, 0.0, 0.01, -1.0) # mark with -1.0 + +# # neuroplasticity, work on CPU side +# wRec_cpu = neuroplasticity(synapseConnectionNumber, +# zitCumulative_cpu, +# wRec_cpu, +# neuronInactivityCounter_cpu, +# synapseReconnectDelay_cpu) + +# wRec_cpu = wRec_cpu .* arrayProjection4d_cpu +# wRec = wRec_cpu |> device + +# neuronInactivityCounter_cpu = neuronInactivityCounter_cpu .* arrayProjection4d_cpu +# neuronInactivityCounter = neuronInactivityCounter_cpu |> device + +# synapseReconnectDelay_cpu = synapseReconnectDelay_cpu .* arrayProjection4d_cpu +# synapseReconnectDelay = synapseReconnectDelay_cpu |> device + +# return wRec, neuronInactivityCounter, synapseReconnectDelay +# end + +function lifLearn(wRec, + wRecChange, + exInType, + arrayProjection4d, + neuronInactivityCounter, + synapseReconnectDelay, + synapseConnectionNumber, + synapticActivityCounter, + eta, + vt, + zitCumulative, + progress, + device) + + + # transfer data to cpu + arrayProjection4d_cpu = arrayProjection4d |> cpu + wRec_cpu = wRec |> cpu + wRecChange_cpu = wRecChange |> cpu + wRecChange_cpu = wRecChange_cpu[:,:,:,1] + eta_cpu = eta |> cpu + eta_cpu = eta_cpu[:,:,:,1] + exInType_cpu = exInType |> cpu + exInType_cpu = exInType_cpu[:,:,:,1] + neuronInactivityCounter_cpu = neuronInactivityCounter |> cpu + neuronInactivityCounter_cpu = neuronInactivityCounter_cpu[:,:,:,1] # (row, col, n) + synapseReconnectDelay_cpu = synapseReconnectDelay |> cpu + synapseReconnectDelay_cpu = synapseReconnectDelay_cpu[:,:,:,1] + synapticActivityCounter_cpu = synapticActivityCounter |> cpu + synapticActivityCounter_cpu = synapticActivityCounter_cpu[:,:,:,1] + zitCumulative_cpu = zitCumulative |> cpu + zitCumulative_cpu = zitCumulative_cpu[:,:,1] + + # neuroplasticity, work on CPU side + wRec_cpu, neuronInactivityCounter_cpu, synapseReconnectDelay_cpu = + neuroplasticity(synapseConnectionNumber, + zitCumulative_cpu, + wRec_cpu, + exInType_cpu, + wRecChange_cpu, + vt, + eta, + neuronInactivityCounter_cpu, + synapseReconnectDelay_cpu, + synapticActivityCounter_cpu, + progress,) + + + + + + + + + # # merge learning weight with average learning weight of all batch + # wch = sum(wRecChange, dims=4) ./ (size(wRec, 4)) .* arrayProjection4d + # wRec .= (exInType .* wRec) .+ wch + + # # (row, col) + + # # -W if less than 10% of repeat avg, +W otherwise + # _, _, i3 = size(wRec_cpu) + # for i in 1:i3 + # x = 0.1 * (sum(synapseReconnectDelay[:,:,i]) / length(synapseReconnectDelay[:,:,i])) + # mask = GeneralUtils.replaceLessThan.(wRec_cpu[:,:,i], x, -1, 1) + # wRec_cpu[:,:,i] .+= mask .* eta_cpu[:,:,i] .* wRec_cpu[:,:,i] + # end + + # # weak / negative synaptic connection will get randomed in neuroplasticity() + # wRec_cpu = GeneralUtils.replaceBetween.(wRec_cpu, 0.0, 0.01, -1.0) # mark with -1.0 + + # neuroplasticity, work on CPU side + # wRec_cpu = neuroplasticity(synapseConnectionNumber, + # zitCumulative_cpu, + # wRec_cpu, + # wRecChange_cpu, + # vt, + # neuronInactivityCounter_cpu, + # synapseReconnectDelay_cpu) + + # transfer data backto gpu + wRec_cpu = wRec_cpu .* arrayProjection4d_cpu + wRec = wRec_cpu |> device + neuronInactivityCounter_cpu = neuronInactivityCounter_cpu .* arrayProjection4d_cpu + neuronInactivityCounter = neuronInactivityCounter_cpu |> device + synapseReconnectDelay_cpu = synapseReconnectDelay_cpu .* arrayProjection4d_cpu + synapseReconnectDelay = synapseReconnectDelay_cpu |> device + + return wRec, neuronInactivityCounter, synapseReconnectDelay +end + +function alifLearn(wRec, + wRecChange, + exInType, + arrayProjection4d, + neuronInactivityCounter, + synapseReconnectDelay, + synapseConnectionNumber, + synapticActivityCounter, + eta, + vt, + zitCumulative, + progress, + device) + + # merge learning weight with average learning weight of all batch + wch = sum(wRecChange, dims=4) ./ (size(wRec, 4)) .* arrayProjection4d + wRec .= (exInType .* wRec) .+ wch + + arrayProjection4d_cpu = arrayProjection4d |> cpu + wRec_cpu = wRec |> cpu + wRec_cpu = wRec_cpu[:,:,:,1] # since every batch has the same neuron wRec, (row, col, n) + eta_cpu = eta |> cpu + eta_cpu = eta_cpu[:,:,:,1] + neuronInactivityCounter_cpu = neuronInactivityCounter |> cpu + neuronInactivityCounter_cpu = neuronInactivityCounter_cpu[:,:,:,1] # (row, col, n) + synapseReconnectDelay_cpu = synapseReconnectDelay |> cpu + synapseReconnectDelay_cpu = synapseReconnectDelay_cpu[:,:,:,1] + zitCumulative_cpu = zitCumulative |> cpu + zitCumulative_cpu = zitCumulative_cpu[:,:,1] # (row, col) + + # -W if less than 10% of repeat avg, +W otherwise + _, _, i3 = size(wRec_cpu) + for i in 1:i3 + x = 0.1 * (sum(synapseReconnectDelay[:,:,i]) / length(synapseReconnectDelay[:,:,i])) + mask = GeneralUtils.replaceLessThan.(wRec_cpu[:,:,i], x, -1, 1) + wRec_cpu[:,:,i] .+= mask .* eta_cpu[:,:,i] .* wRec_cpu[:,:,i] + end + + # weak / negative synaptic connection will get randomed in neuroplasticity() + wRec_cpu = GeneralUtils.replaceBetween.(wRec_cpu, 0.0, 0.01, -1.0) # mark with -1.0 + + # neuroplasticity, work on CPU side + wRec_cpu = neuroplasticity(synapseConnectionNumber, + zitCumulative_cpu, + wRec_cpu, + neuronInactivityCounter_cpu, + synapseReconnectDelay_cpu) + + wRec_cpu = wRec_cpu .* arrayProjection4d_cpu + wRec = wRec_cpu |> device + + neuronInactivityCounter_cpu = neuronInactivityCounter_cpu .* arrayProjection4d_cpu + neuronInactivityCounter = neuronInactivityCounter_cpu |> device + + synapseReconnectDelay_cpu = synapseReconnectDelay_cpu .* arrayProjection4d_cpu + synapseReconnectDelay = synapseReconnectDelay_cpu |> device + + # error("DEBUG -> alifLearn! $(Dates.now())") + return wRec, neuronInactivityCounter, synapseReconnectDelay +end + +function onLearn!(wOut, + wOutChange, + arrayProjection4d) + # merge learning weight with average learning weight + wOut .+= (sum(wOutChange, dims=4) ./ (size(wOut, 4))) .* arrayProjection4d + + # adaptive wOut to help convergence using c_decay + wOut .-= 0.001 .* wOut +end + +alltrue(args...) = false ∈ [args...] ? false : true +isbetween(x, lowerlimit, upperlimit) = lowerlimit < x < upperlimit ? true : false + +#WORKING 1) implement 90% +w, 10% -w 2) rewrite this function +function neuroplasticity(synapseConnectionNumber, + zitCumulative, # (row, col) + wRec, # (row, col, n) + exInType, + wRecChange, + vt, + eta, + neuronInactivityCounter, + synapseReconnectDelay, + synapticActivityCounter, + progress,) # (row, col, n) + i1,i2,i3 = size(wRec) + println("eta $(size(eta))") + println("wRec $(size(wRec))") + error("DEBUG -> neuroplasticity $(Dates.now())") + + if progress == 2 # no need to learn + # skip neuroplasticity + #TODO I may need to do something with neuronInactivityCounter and other variables + wRecChange .= 0 + elseif progress == 1 # progress increase + # ready to reconnect synapse must not have wRecChange + mask = (!isequal).(wRec, 0) + wRecChange .*= mask + + # merge learning weight with average learning weight of all batch + wRec .= abs.((exInType .* wRec) .+ wRecChange) # abs because wRec doesn't carry sign + + # seperate active synapse out of inactive in this signal + mask_inactiveSynapse = isequal.(synapticActivityCounter, 0) + mask_activeSynapse = (!isequal).(synapticActivityCounter, 0) + + # adjust weight based on vt progress and repeatition (90% +w, 10% -w) depend on epsilonRec + avgActivity = sum(synapticActivityCounter) / length(synapticActivityCounter) + lowerlimit = 0.1 * avgActivity + + # +w, synapse with more than 10% of avg activity get increase weight by eta + mask_more = (!isless).(synapticActivityCounter, lowerlimit) + mask_2 = alltrue.(mask_activeSynapse, mask_more) + mask_2 .*= 1 .+ eta # minor activity synapse weight will be reduced by eta + wRec .*= mask_2 + + # -w, synapse with less than 10% of avg activity get reduced weight by eta + mask_less = isless.(synapticActivityCounter, lowerlimit) # 1st criteria + + mask_3 = alltrue.(mask_activeSynapse, mask_less) + mask_3 .*= 1 .- eta # minor activity synapse weight will be reduced by eta + wRec .*= mask_3 + + # -w all non-fire connection except mature connection + mask_notmature = isless.(wRec, 0.1) # 2nd criteria, not mature synapse has weight < 0.1 + mask_1 = alltrue.(mask_inactiveSynapse, mask_notmature) + mask_1 .*= 1 .- eta + wRec .*= mask_1 + + #WORKING prune weak connection + # mark weak / negative synaptic connection so they will get randomed in neuroplasticity() + mask_weak = isbetween.(wRec, 0.0, 0.01) + mask_notweak = (!isbetween).(wRec, 0.0, 0.01) + wRec .*= mask_notweak # all marked weak synapse weight need to be 0.0 + r = rand((1:1000), size(wRec)) # synapse random wait time to reconnect + r .*= mask_weak + synapticActivityCounter .*= mask_notweak # all marked weak synapse is set 0 + synapticActivityCounter .+= r # set pruned synapse to random wait time + + #TODO rewire synapse connection + + + elseif progress == 0 # no progress, no weight update, only rewire + + # -w all non-fire connection except mature connection + + # prune weak connection + + # rewire synapse connection + elseif progress == -1 # setback + # adjust weight based on vt progress and repeatition (90% +w, 10% -w) depend on epsilonRec + + # -w all non-fire connection except mature connection + + # prune weak connection + + # rewire synapse connection + else + error("undefined condition line $(@__LINE__)") + end + + # error("DEBUG -> neuroplasticity $(Dates.now())") + + # merge learning weight with average learning weight of all batch + wRec .= abs.((exInType .* wRec) .+ wRecChange) # abs because wRec doesn't carry sign + + + # adjust weight based on vt progress and repeatition (90% +w, 10% -w) depend on epsilonRec + + + + # -w all non-fire connection except mature connection + + + + # prune weak connection + + + + + # rewire synapse connection + + + + # for each neuron, find total number of synaptic conn that should draw + # new connection to firing and non-firing neurons pool + subToFireNeuron_toBe = Int(floor(0.7 * synapseConnectionNumber)) + + # for each neuron, count how many synapse already subscribed to firing-neurons + zw = zitCumulative .* wRec + subToFireNeuron_current = sum(GeneralUtils.isBetween.(zw, 0.0, 100.0), dims=(1,2)) # (1, 1, n) + zitMask = (!iszero).(zitCumulative) # zitMask of firing neurons = 1, non-firing = 0 + projection = ones(i1,i2,i3) + zitMask = zitMask .* projection # (row, col, n) + totalNewConn = sum(isequal.(wRec, -1.0), dims=(1,2)) # count new conn mark (-1.0), (1, 1, n) + println("neuroplasticity, from $(synapseConnectionNumber*size(totalNewConn, 3)) conn, $(sum(totalNewConn)) are replaced") + + # clear -1.0 marker + GeneralUtils.replaceElements!(wRec, -1.0, synapseReconnectDelay, -0.99) + GeneralUtils.replaceElements!(wRec, -1.0, 0.0) # -1.0 marker is no longer required + + for i in 1:i3 + if neuronInactivityCounter[1:1:i][1] < -10000 # neuron die i.e. reset all weight + println("neuron die") + neuronInactivityCounter[:,:,i] .= 0 # reset + w = random_wRec(i1,i2,1,synapseConnectionNumber) + wRec[:,:,i] .= w + + a = similar(w) .= -0.99 # synapseConnectionNumber of this neuron + mask = (!iszero).(w) + GeneralUtils.replaceElements!(mask, 1, a, 0) + synapseReconnectDelay[:,:,i] = a + else + remaining = 0 + if subToFireNeuron_current[1,1,i] < subToFireNeuron_toBe + toAddConn = subToFireNeuron_toBe - subToFireNeuron_current[1,1,i] + totalNewConn[1,1,i] = totalNewConn[1,1,i] - toAddConn + # add new conn to firing neurons pool + remaining = addNewSynapticConn!(zitMask[:,:,i], 1, + @view(wRec[:,:,i]), + @view(synapseReconnectDelay[:,:,i]), + toAddConn) + totalNewConn[1,1,i] += remaining + end + + # add new conn to non-firing neurons pool + remaining = addNewSynapticConn!(zitMask[:,:,i], 0, + @view(wRec[:,:,i]), + @view(synapseReconnectDelay[:,:,i]), + totalNewConn[1,1,i]) + if remaining > 0 # final get-all round if somehow non-firing pool has not enough slot + remaining = addNewSynapticConn!(zitMask[:,:,i], 1, + @view(wRec[:,:,i]), + @view(synapseReconnectDelay[:,:,i]), + remaining) + end + end + end + + # error("DEBUG -> neuroplasticity $(Dates.now())") + return wRec +end + +# function neuroplasticity(synapseConnectionNumber, +# zitCumulative, # (row, col) +# wRec, # (row, col, n) +# neuronInactivityCounter, +# synapseReconnectDelay) # (row, col, n) + +# i1,i2,i3 = size(wRec) + +# # for each neuron, find total number of synaptic conn that should draw +# # new connection to firing and non-firing neurons pool +# subToFireNeuron_toBe = Int(floor(0.7 * synapseConnectionNumber)) + +# # for each neuron, count how many synap already subscribed to firing-neurons +# zw = zitCumulative .* wRec +# subToFireNeuron_current = sum(GeneralUtils.isBetween.(zw, 0.0, 100.0), dims=(1,2)) # (1, 1, n) +# zitMask = (!iszero).(zitCumulative) # zitMask of firing neurons = 1, non-firing = 0 +# projection = ones(i1,i2,i3) +# zitMask = zitMask .* projection # (row, col, n) +# totalNewConn = sum(isequal.(wRec, -1.0), dims=(1,2)) # count new conn mark (-1.0), (1, 1, n) +# println("neuroplasticity, from $(synapseConnectionNumber*size(totalNewConn, 3)) conn, $(sum(totalNewConn)) are replaced") + +# # clear -1.0 marker +# GeneralUtils.replaceElements!(wRec, -1.0, synapseReconnectDelay, -0.99) +# GeneralUtils.replaceElements!(wRec, -1.0, 0.0) # -1.0 marker is no longer required + +# for i in 1:i3 +# if neuronInactivityCounter[1:1:i][1] < -10000 # neuron die i.e. reset all weight +# println("neuron die") +# neuronInactivityCounter[:,:,i] .= 0 # reset +# w = random_wRec(i1,i2,1,synapseConnectionNumber) +# wRec[:,:,i] .= w + +# a = similar(w) .= -0.99 # synapseConnectionNumber of this neuron +# mask = (!iszero).(w) +# GeneralUtils.replaceElements!(mask, 1, a, 0) +# synapseReconnectDelay[:,:,i] = a +# else +# remaining = 0 +# if subToFireNeuron_current[1,1,i] < subToFireNeuron_toBe +# toAddConn = subToFireNeuron_toBe - subToFireNeuron_current[1,1,i] +# totalNewConn[1,1,i] = totalNewConn[1,1,i] - toAddConn +# # add new conn to firing neurons pool +# remaining = addNewSynapticConn!(zitMask[:,:,i], 1, +# @view(wRec[:,:,i]), +# @view(synapseReconnectDelay[:,:,i]), +# toAddConn) +# totalNewConn[1,1,i] += remaining +# end + +# # add new conn to non-firing neurons pool +# remaining = addNewSynapticConn!(zitMask[:,:,i], 0, +# @view(wRec[:,:,i]), +# @view(synapseReconnectDelay[:,:,i]), +# totalNewConn[1,1,i]) +# if remaining > 0 # final get-all round if somehow non-firing pool has not enough slot +# remaining = addNewSynapticConn!(zitMask[:,:,i], 1, +# @view(wRec[:,:,i]), +# @view(synapseReconnectDelay[:,:,i]), +# remaining) +# end +# end +# end + +# # error("DEBUG -> neuroplasticity $(Dates.now())") +# return wRec +# end + +# learningLiquidity(x) = -0.0001x + 1 # -10000 to +10000; f(x) = -5e-05x+0.5 + +function learningLiquidity(x) + if x > 10000 + y = 0.0 + elseif x < -10000 + y = 1.0 + else + y = -5e-05x+0.5 # range -10000 to +10000 + end + return y +end + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +end # module \ No newline at end of file diff --git a/src/learn.jl b/src/learn.jl index c032035..d637100 100644 --- a/src/learn.jl +++ b/src/learn.jl @@ -591,6 +591,10 @@ function neuroplasticity(synapseConnectionNumber, #TODO I may need to do something with neuronInactivityCounter and other variables wRecChange .= 0 elseif progress == 1 # progress increase + # ready to reconnect synapse must not have wRecChange + mask = (!isequal).(wRec, 0) + wRecChange .*= mask + # merge learning weight with average learning weight of all batch wRec .= abs.((exInType .* wRec) .+ wRecChange) # abs because wRec doesn't carry sign @@ -609,7 +613,7 @@ function neuroplasticity(synapseConnectionNumber, wRec .*= mask_2 # -w, synapse with less than 10% of avg activity get reduced weight by eta - mask_less = isless.(synapticActivityCounter, lowerlimit) # 1st criteria + mask_less = isbetween.(synapticActivityCounter, 0.0, lowerlimit) # 1st criteria mask_3 = alltrue.(mask_activeSynapse, mask_less) mask_3 .*= 1 .- eta # minor activity synapse weight will be reduced by eta @@ -621,12 +625,56 @@ function neuroplasticity(synapseConnectionNumber, mask_1 .*= 1 .- eta wRec .*= mask_1 - # prune weak connection - # mark weak / negative synaptic connection so they will get randomed in neuroplasticity() - mask = isbetween.(wRec, 0.0, 0.01) - wRec = GeneralUtils.replaceBetween.(wRec, 0.0, 0.01, -1.0) # mark with -1.0 + # prune synapse + mask_weak = isbetween.(wRec, 0.0, 0.01) + mask_notweak = (!isbetween).(wRec, 0.0, 0.01) + wRec .*= mask_notweak # all marked weak synapse weight need to be 0.0 i.e. pruned + r = rand((1:1000), size(wRec)) .* mask_weak # synapse random wait time to reconnect + synapticActivityCounter .*= mask_notweak # all marked weak synapse activity are reset + synapticActivityCounter .+= (mask_weak .* -1.0) + synapseReconnectDelay .= (synapseReconnectDelay .* mask_notweak) .+ r # set pruned synapse to random wait time #WORKING rewire synapse connection + synapseReconnectDelay mark timeStep while also counting delay == BUG + + for i in 1:i3 # neuron-by-neuron + if neuronInactivityCounter[1:1:i][1] < -10000 # neuron die i.e. reset all weight + println("neuron die") + neuronInactivityCounter[:,:,i] .= 0 # reset + w = random_wRec(i1,i2,1,synapseConnectionNumber) + wRec[:,:,i] .= w + + #WORKING + a = similar(w) .= -0.99 # synapseConnectionNumber of this neuron + mask = (!iszero).(w) + GeneralUtils.replaceElements!(mask, 1, a, 0) + synapseReconnectDelay[:,:,i] = a + else + remaining = 0 + if subToFireNeuron_current[1,1,i] < subToFireNeuron_toBe + toAddConn = subToFireNeuron_toBe - subToFireNeuron_current[1,1,i] + totalNewConn[1,1,i] = totalNewConn[1,1,i] - toAddConn + # add new conn to firing neurons pool + remaining = addNewSynapticConn!(zitMask[:,:,i], 1, + @view(wRec[:,:,i]), + @view(synapseReconnectDelay[:,:,i]), + toAddConn) + totalNewConn[1,1,i] += remaining + end + + # add new conn to non-firing neurons pool + remaining = addNewSynapticConn!(zitMask[:,:,i], 0, + @view(wRec[:,:,i]), + @view(synapseReconnectDelay[:,:,i]), + totalNewConn[1,1,i]) + if remaining > 0 # final get-all round if somehow non-firing pool has not enough slot + remaining = addNewSynapticConn!(zitMask[:,:,i], 1, + @view(wRec[:,:,i]), + @view(synapseReconnectDelay[:,:,i]), + remaining) + end + end + end elseif progress == 0 # no progress, no weight update, only rewire diff --git a/src/type.jl b/src/type.jl index 5ba355f..911c516 100644 --- a/src/type.jl +++ b/src/type.jl @@ -232,17 +232,9 @@ function kfn_1(params::Dict; device=cpu) # count subscribed synapse activity, just like epsilonRec but without decay. # use to adjust weight based on how often neural pathway is used - kfn.lif_synapseReconnectDelay = Array(similar(kfn.lif_wRec) .= -0.99) # -0.99 for non-sub conn - mask = Array((!iszero).(kfn.lif_wRec)) - # initial value subscribed conn - for i in eachindex(mask) - if mask[i] == 1 - kfn.lif_synapseReconnectDelay[i] = rand(1:1000) - end - end - kfn.lif_synapseReconnectDelay = kfn.lif_synapseReconnectDelay |> device + kfn.lif_synapseReconnectDelay = (similar(kfn.lif_wRec) .= -1.0) # -1.0 for non-sub conn - kfn.lif_synapticActivityCounter = Array(similar(kfn.lif_wRec) .= -0.99) # -0.99 for non-sub conn + kfn.lif_synapticActivityCounter = Array(similar(kfn.lif_wRec) .= -1.0) # -1.0 for non-sub conn mask = Array((!iszero).(kfn.lif_wRec)) # initial value subscribed conn GeneralUtils.replaceElements!(mask, 1, kfn.lif_synapticActivityCounter, 0.0) @@ -291,17 +283,9 @@ function kfn_1(params::Dict; device=cpu) kfn.alif_firingCounter = (similar(kfn.alif_wRec) .= 0) kfn.alif_firingTargetFrequency = (similar(kfn.alif_wRec) .= 0.1) kfn.alif_neuronInactivityCounter = (similar(kfn.alif_wRec) .= 0) - kfn.alif_synapseReconnectDelay = Array(similar(kfn.alif_wRec) .= -0.99) # -9 for non-sub conn - mask = Array((!iszero).(kfn.alif_wRec)) - # initial value subscribed conn - for i in eachindex(mask) - if mask[i] == 1 - kfn.alif_synapseReconnectDelay[i] = rand(1:1000) - end - end - kfn.alif_synapseReconnectDelay = kfn.alif_synapseReconnectDelay |> device + kfn.alif_synapseReconnectDelay = (similar(kfn.alif_wRec) .= -1.0) # -1.0 for non-sub conn - kfn.alif_synapticActivityCounter = Array(similar(kfn.alif_wRec) .= -0.99) # -0.99 for non-sub conn + kfn.alif_synapticActivityCounter = Array(similar(kfn.alif_wRec) .= -1.0) # -1.0 for non-sub conn mask = Array((!iszero).(kfn.alif_wRec)) # initial value subscribed conn GeneralUtils.replaceElements!(mask, 1, kfn.alif_synapticActivityCounter, 0.0) @@ -398,7 +382,7 @@ function random_wRec(row, col, n, synapseConnectionNumber) for slice in eachslice(w, dims=3) pool = shuffle!([1:row*col...])[1:synapseConnectionNumber] for i in pool - slice[i] = rand(0.01:0.01:0.1) # assign weight to synaptic connection. /10 to start small, + slice[i] = rand(0.01:0.01:0.5) # assign weight to synaptic connection. /10 to start small, # otherwise RSNN's vt Usually stay negative (-) end end