diff --git a/src/forward.jl b/src/forward.jl index 3e97098..d8869ea 100644 --- a/src/forward.jl +++ b/src/forward.jl @@ -17,6 +17,8 @@ function (kfn::kfn_1)(input::AbstractArray) # what to do at the start of learning round if view(kfn.learningStage, 1)[1] == 1 + kfn.timeStep .= 1 + # reset learning params kfn.zitCumulative = kfn.zitCumulative[:,:,1,:] diff --git a/src/learn.jl b/src/learn.jl index 9a8ddf5..05c0eab 100644 --- a/src/learn.jl +++ b/src/learn.jl @@ -101,23 +101,6 @@ function lifComputeParamsChange!( timeStep::CuArray, bkComposed = onW .+ bk_ nError = bkComposed .* modelError nError = reshape(nError, (1,1,:,1)) - - # _,_,i3,_ = size(wOut) - # for i in 1:i3 - # # nError a.k.a. learning signal use dopamine concept, - # # this neuron receive summed error signal (modelError) - # onW = @view(wOut[:, startCol:stopCol, i, 1]) - # _bk = @view(bk[:, startCol:stopCol, i, 1]) - # mask = (iszero).(onW) - # bk_ = mask .* _bk - # bkComposed = onW .+ bk_ - - # nError = bkComposed .* modelError - # nError = reshape(nError, (1,1,:,1)) - - # # compute wRecChange of all neurons wrt to iᵗʰ output neuron - # wRecChange .+= (eta .* nError .* eRec) - # end # compute wRecChange of all neurons wrt to iᵗʰ output neuron wRecChange .+= (eta .* nError .* eRec) @@ -125,19 +108,6 @@ function lifComputeParamsChange!( timeStep::CuArray, # frequency regulator wRecChange .+= 0.001 .* ((firingTargetFrequency - (firingCounter./timeStep)) ./ timeStep) .* eta .* eRec - # if sum(timeStep) == 785 - # epsilonRec_cpu = epsilonRec |> cpu - # println("modelError $modelError $(size(modelError))", modelError) - # println("") - # println("wOutSum $(size(wOutSum))") - # wchange = (eta .* nError .* eRec) |> cpu - # println("wchange 5 1 ", wchange[:,:,5,1]) - # println("") - # println("epsilonRec 5 1 ", epsilonRec_cpu[:,:,5,1]) - # println("") - # error("DEBUG lifComputeParamsChange!") - # end - # error("DEBUG lifComputeParamsChange!") # reset epsilonRec epsilonRec .= 0 end @@ -303,7 +273,7 @@ end function learn!(kfn::kfn_1, progress, device=cpu) # lif learn - kfn.lif_wRec, kfn.lif_neuronInactivityCounter, kfn.lif_synapseReconnectDelay = + kfn.lif_wRec, kfn.lif_neuronInactivityCounter, kfn.lif_synapticActivityCounter, kfn.lif_synapseReconnectDelay = lifLearn(kfn.lif_wRec, kfn.lif_wRecChange, kfn.lif_exInType, @@ -319,7 +289,7 @@ function learn!(kfn::kfn_1, progress, device=cpu) device) # alif learn - kfn.alif_wRec, kfn.alif_neuronInactivityCounter, kfn.alif_synapseReconnectDelay = + kfn.alif_wRec, kfn.alif_neuronInactivityCounter, kfn.alif_synapticActivityCounter, kfn.alif_synapseReconnectDelay = alifLearn(kfn.alif_wRec, kfn.alif_wRecChange, kfn.alif_exInType, @@ -346,64 +316,6 @@ function learn!(kfn::kfn_1, progress, device=cpu) # error("DEBUG -> kfn learn! $(Dates.now())") end -# function lifLearn(wRec, -# exInType, -# wRecChange, -# arrayProjection4d, -# neuronInactivityCounter, -# synapseReconnectDelay, -# synapseConnectionNumber, -# synapticWChangeCounter, #TODO -# eta, -# zitCumulative, -# device) - -# # merge learning weight with average learning weight of all batch -# wch = sum(wRecChange, dims=4) ./ (size(wRec, 4)) .* arrayProjection4d -# wRec .= (exInType .* wRec) .+ wch - -# arrayProjection4d_cpu = arrayProjection4d |> cpu -# wRec_cpu = wRec |> cpu -# wRec_cpu = wRec_cpu[:,:,:,1] # since every batch has the same neuron wRec, (row, col, n) -# eta_cpu = eta |> cpu -# eta_cpu = eta_cpu[:,:,:,1] -# neuronInactivityCounter_cpu = neuronInactivityCounter |> cpu -# neuronInactivityCounter_cpu = neuronInactivityCounter_cpu[:,:,:,1] # (row, col, n) -# synapseReconnectDelay_cpu = synapseReconnectDelay |> cpu -# synapseReconnectDelay_cpu = synapseReconnectDelay_cpu[:,:,:,1] -# zitCumulative_cpu = zitCumulative |> cpu -# zitCumulative_cpu = zitCumulative_cpu[:,:,1] # (row, col) - -# # -W if less than 10% of repeat avg, +W otherwise -# _, _, i3 = size(wRec_cpu) -# for i in 1:i3 -# x = 0.1 * (sum(synapseReconnectDelay[:,:,i]) / length(synapseReconnectDelay[:,:,i])) -# mask = GeneralUtils.replaceLessThan.(wRec_cpu[:,:,i], x, -1, 1) -# wRec_cpu[:,:,i] .+= mask .* eta_cpu[:,:,i] .* wRec_cpu[:,:,i] -# end - -# # weak / negative synaptic connection will get randomed in neuroplasticity() -# wRec_cpu = GeneralUtils.replaceBetween.(wRec_cpu, 0.0, 0.01, -1.0) # mark with -1.0 - -# # neuroplasticity, work on CPU side -# wRec_cpu = neuroplasticity(synapseConnectionNumber, -# zitCumulative_cpu, -# wRec_cpu, -# neuronInactivityCounter_cpu, -# synapseReconnectDelay_cpu) - -# wRec_cpu = wRec_cpu .* arrayProjection4d_cpu -# wRec = wRec_cpu |> device - -# neuronInactivityCounter_cpu = neuronInactivityCounter_cpu .* arrayProjection4d_cpu -# neuronInactivityCounter = neuronInactivityCounter_cpu |> device - -# synapseReconnectDelay_cpu = synapseReconnectDelay_cpu .* arrayProjection4d_cpu -# synapseReconnectDelay = synapseReconnectDelay_cpu |> device - -# return wRec, neuronInactivityCounter, synapseReconnectDelay -# end - function lifLearn(wRec, wRecChange, exInType, @@ -418,7 +330,6 @@ function lifLearn(wRec, progress, device) - # transfer data to cpu arrayProjection4d_cpu = arrayProjection4d |> cpu wRec_cpu = wRec |> cpu @@ -430,7 +341,7 @@ function lifLearn(wRec, synapticActivityCounter_cpu = synapticActivityCounter |> cpu zitCumulative_cpu = zitCumulative |> cpu # neuroplasticity, work on CPU side - wRec_cpu, neuronInactivityCounter_cpu, synapseReconnectDelay_cpu = + wRec_cpu, neuronInactivityCounter_cpu, synapticActivityCounter_cpu, synapseReconnectDelay_cpu = neuroplasticity(synapseConnectionNumber, zitCumulative_cpu, wRec_cpu, @@ -444,14 +355,12 @@ function lifLearn(wRec, progress,) # transfer data backto gpu - wRec_cpu = wRec_cpu .* arrayProjection4d_cpu wRec = wRec_cpu |> device - neuronInactivityCounter_cpu = neuronInactivityCounter_cpu .* arrayProjection4d_cpu neuronInactivityCounter = neuronInactivityCounter_cpu |> device - synapseReconnectDelay_cpu = synapseReconnectDelay_cpu .* arrayProjection4d_cpu + synapticActivityCounter = synapticActivityCounter_cpu |> device synapseReconnectDelay = synapseReconnectDelay_cpu |> device - error("DEBUG -> lifLearn! $(Dates.now())") - return wRec, neuronInactivityCounter, synapseReconnectDelay + # error("DEBUG -> lifLearn! $(Dates.now())") + return wRec, neuronInactivityCounter, synapticActivityCounter, synapseReconnectDelay end function alifLearn(wRec, @@ -468,46 +377,37 @@ function alifLearn(wRec, progress, device) - # merge learning weight with average learning weight of all batch - wch = sum(wRecChange, dims=4) ./ (size(wRec, 4)) .* arrayProjection4d - wRec .= (exInType .* wRec) .+ wch - + # transfer data to cpu arrayProjection4d_cpu = arrayProjection4d |> cpu wRec_cpu = wRec |> cpu + wRecChange_cpu = wRecChange |> cpu eta_cpu = eta |> cpu + exInType_cpu = exInType |> cpu neuronInactivityCounter_cpu = neuronInactivityCounter |> cpu synapseReconnectDelay_cpu = synapseReconnectDelay |> cpu + synapticActivityCounter_cpu = synapticActivityCounter |> cpu zitCumulative_cpu = zitCumulative |> cpu - - # -W if less than 10% of repeat avg, +W otherwise - _, _, i3 = size(wRec_cpu) - for i in 1:i3 - x = 0.1 * (sum(synapseReconnectDelay[:,:,i]) / length(synapseReconnectDelay[:,:,i])) - mask = GeneralUtils.replaceLessThan.(wRec_cpu[:,:,i], x, -1, 1) - wRec_cpu[:,:,i] .+= mask .* eta_cpu[:,:,i] .* wRec_cpu[:,:,i] - end - - # weak / negative synaptic connection will get randomed in neuroplasticity() - wRec_cpu = GeneralUtils.replaceBetween.(wRec_cpu, 0.0, 0.01, -1.0) # mark with -1.0 - # neuroplasticity, work on CPU side - wRec_cpu = neuroplasticity(synapseConnectionNumber, - zitCumulative_cpu, - wRec_cpu, - neuronInactivityCounter_cpu, - synapseReconnectDelay_cpu) + wRec_cpu, neuronInactivityCounter_cpu, synapticActivityCounter_cpu, synapseReconnectDelay_cpu = + neuroplasticity(synapseConnectionNumber, + zitCumulative_cpu, + wRec_cpu, + exInType_cpu, + wRecChange_cpu, + vt, + eta_cpu, + neuronInactivityCounter_cpu, + synapseReconnectDelay_cpu, + synapticActivityCounter_cpu, + progress,) - wRec_cpu = wRec_cpu .* arrayProjection4d_cpu + # transfer data backto gpu wRec = wRec_cpu |> device - - neuronInactivityCounter_cpu = neuronInactivityCounter_cpu .* arrayProjection4d_cpu neuronInactivityCounter = neuronInactivityCounter_cpu |> device - - synapseReconnectDelay_cpu = synapseReconnectDelay_cpu .* arrayProjection4d_cpu + synapticActivityCounter = synapticActivityCounter_cpu |> device synapseReconnectDelay = synapseReconnectDelay_cpu |> device - - # error("DEBUG -> alifLearn! $(Dates.now())") - return wRec, neuronInactivityCounter, synapseReconnectDelay + # error("DEBUG -> alifLearn! $(Dates.now())") + return wRec, neuronInactivityCounter, synapticActivityCounter, synapseReconnectDelay end function onLearn!(wOut, @@ -522,7 +422,7 @@ end GeneralUtils.allTrue(args...) = false ∈ [args...] ? false : true -#WORKING 2) rewrite this function +#WORKING function neuroplasticity(synapseConnectionNumber, zitCumulative, # (row, col) wRec, # (row, col, n) @@ -534,12 +434,6 @@ function neuroplasticity(synapseConnectionNumber, synapseReconnectDelay, synapticActivityCounter, progress,) # (row, col, n) - i1,i2,i3 = size(wRec) - println("eta $(size(eta))") - println("wRec 1 $(size(wRec)) ", wRec[:,:,1,1]) - println("zitCumulative $(size(zitCumulative))") - - println("progress $progress") if progress == 2 # no need to learn # skip neuroplasticity @@ -584,7 +478,8 @@ function neuroplasticity(synapseConnectionNumber, end # error("DEBUG -> neuroplasticity $(Dates.now())") - return wRec + return wRec, neuronInactivityCounter, + synapticActivityCounter, synapseReconnectDelay end # learningLiquidity(x) = -0.0001x + 1 # -10000 to +10000; f(x) = -5e-05x+0.5 diff --git a/src/snnUtil.jl b/src/snnUtil.jl index b99503f..97939b1 100644 --- a/src/snnUtil.jl +++ b/src/snnUtil.jl @@ -120,21 +120,16 @@ function growRepeatedPath!(wRec, synapticActivityCounter, eta) end function weakenNotMatureSynapse!(wRec, synapticActivityCounter, eta) # TODO not fully tested, there is no connection YET where there is 0 synapse activity but wRec is not 0 (subscribed) - # println("wRec ", wRec[:,:,1,1]) mask_inactiveSynapse = isequal.(synapticActivityCounter, 0) mask_notmature = GeneralUtils.isBetween.(wRec, 0.0, 0.1) # 2nd criteria, not mature synapse has weight < 0.1 - # println("mask_notmature ", mask_notmature[:,:,1,1]) mask_1 = GeneralUtils.allTrue.(mask_inactiveSynapse, mask_notmature) - # println("mask_1 ", mask_1[:,:,1,1]) mask_2 = mask_1 .* (1 .- eta) GeneralUtils.replaceElements!(mask_2, 0, 1) # replace 0 with 1 so mask * Wrec will not get 0 weight wRec .*= mask_2 - # println("wRec 2 ", wRec[:,:,1,1]) end function pruneSynapse!(wRec, synapticActivityCounter, synapseReconnectDelay) mask_weak = GeneralUtils.isBetween.(wRec, 0.0, 0.01) - println("weak synapse ", sum(mask_weak)) mask_notweak = (!GeneralUtils.isBetween).(wRec, 0.0, 0.01) wRec .*= mask_notweak # all marked weak synapse weight need to be 0.0 i.e. pruned # all weak synapse activity are reset @@ -165,22 +160,23 @@ function rewireSynapse!(wRec::AbstractArray, neuronInactivityCounter::AbstractAr else for ind in eachindex(synapseReconnectDelay[:,:,n,i4]) timemark = synapseReconnectDelay[:,:,n,i4][ind] - # println("timemark ", timemark) + if timemark > 0 #TODO not fully tested. mark timeStep available + timemark = Int(timemark) # get neuron pool at 10 timeStep earlier earlier = timemark - 10 > 0 ? timemark - 10 : timemark - pool = sum(zitCumulative[:,:,earlier:timemark], dims=3) + pool = sum(zitCumulative[:,:,earlier:timemark], dims=3) #BUG BoundsError: attempt to access 10×25×801 Array{Float32, 3} at index [1:10, 1:25, 1340.0f0:1.0f0:1350.0f0] if sum(pool) != 0 indices = findall(x -> x != 0, pool) - pick = rand(indices) + pick = rand(indices) # cartesian indice wRec[pick] = rand(0.01:0.01:0.05) synapticActivityCounter[pick] = 0 synapseReconnectDelay[pick] = -0.1 - error("DEBUG -> rewireSynapse!") + # error("DEBUG -> rewireSynapse!") else # if neurons not firing at all, try again next time - synapticActivityCounter[pick] = 0 - synapseReconnectDelay[:,:,n,i4] = rand(1:1000) - error("DEBUG -> rewireSynapse!") + synapticActivityCounter[:,:,n,i4][ind] = 0 + synapseReconnectDelay[:,:,n,i4][ind] = rand(1:1000) * -1 + # error("DEBUG -> rewireSynapse!") end end end