This commit is contained in:
ton
2023-09-17 16:49:50 +07:00
parent 213f35b1ff
commit 69b4862d48
3 changed files with 39 additions and 146 deletions

View File

@@ -17,6 +17,8 @@ function (kfn::kfn_1)(input::AbstractArray)
# what to do at the start of learning round # what to do at the start of learning round
if view(kfn.learningStage, 1)[1] == 1 if view(kfn.learningStage, 1)[1] == 1
kfn.timeStep .= 1
# reset learning params # reset learning params
kfn.zitCumulative = kfn.zitCumulative[:,:,1,:] kfn.zitCumulative = kfn.zitCumulative[:,:,1,:]

View File

@@ -102,42 +102,12 @@ function lifComputeParamsChange!( timeStep::CuArray,
nError = bkComposed .* modelError nError = bkComposed .* modelError
nError = reshape(nError, (1,1,:,1)) nError = reshape(nError, (1,1,:,1))
# _,_,i3,_ = size(wOut)
# for i in 1:i3
# # nError a.k.a. learning signal use dopamine concept,
# # this neuron receive summed error signal (modelError)
# onW = @view(wOut[:, startCol:stopCol, i, 1])
# _bk = @view(bk[:, startCol:stopCol, i, 1])
# mask = (iszero).(onW)
# bk_ = mask .* _bk
# bkComposed = onW .+ bk_
# nError = bkComposed .* modelError
# nError = reshape(nError, (1,1,:,1))
# # compute wRecChange of all neurons wrt to iᵗʰ output neuron
# wRecChange .+= (eta .* nError .* eRec)
# end
# compute wRecChange of all neurons wrt to iᵗʰ output neuron # compute wRecChange of all neurons wrt to iᵗʰ output neuron
wRecChange .+= (eta .* nError .* eRec) wRecChange .+= (eta .* nError .* eRec)
# frequency regulator # frequency regulator
wRecChange .+= 0.001 .* ((firingTargetFrequency - (firingCounter./timeStep)) ./ timeStep) .* wRecChange .+= 0.001 .* ((firingTargetFrequency - (firingCounter./timeStep)) ./ timeStep) .*
eta .* eRec eta .* eRec
# if sum(timeStep) == 785
# epsilonRec_cpu = epsilonRec |> cpu
# println("modelError $modelError $(size(modelError))", modelError)
# println("")
# println("wOutSum $(size(wOutSum))")
# wchange = (eta .* nError .* eRec) |> cpu
# println("wchange 5 1 ", wchange[:,:,5,1])
# println("")
# println("epsilonRec 5 1 ", epsilonRec_cpu[:,:,5,1])
# println("")
# error("DEBUG lifComputeParamsChange!")
# end
# error("DEBUG lifComputeParamsChange!")
# reset epsilonRec # reset epsilonRec
epsilonRec .= 0 epsilonRec .= 0
end end
@@ -303,7 +273,7 @@ end
function learn!(kfn::kfn_1, progress, device=cpu) function learn!(kfn::kfn_1, progress, device=cpu)
# lif learn # lif learn
kfn.lif_wRec, kfn.lif_neuronInactivityCounter, kfn.lif_synapseReconnectDelay = kfn.lif_wRec, kfn.lif_neuronInactivityCounter, kfn.lif_synapticActivityCounter, kfn.lif_synapseReconnectDelay =
lifLearn(kfn.lif_wRec, lifLearn(kfn.lif_wRec,
kfn.lif_wRecChange, kfn.lif_wRecChange,
kfn.lif_exInType, kfn.lif_exInType,
@@ -319,7 +289,7 @@ function learn!(kfn::kfn_1, progress, device=cpu)
device) device)
# alif learn # alif learn
kfn.alif_wRec, kfn.alif_neuronInactivityCounter, kfn.alif_synapseReconnectDelay = kfn.alif_wRec, kfn.alif_neuronInactivityCounter, kfn.alif_synapticActivityCounter, kfn.alif_synapseReconnectDelay =
alifLearn(kfn.alif_wRec, alifLearn(kfn.alif_wRec,
kfn.alif_wRecChange, kfn.alif_wRecChange,
kfn.alif_exInType, kfn.alif_exInType,
@@ -346,64 +316,6 @@ function learn!(kfn::kfn_1, progress, device=cpu)
# error("DEBUG -> kfn learn! $(Dates.now())") # error("DEBUG -> kfn learn! $(Dates.now())")
end end
# function lifLearn(wRec,
# exInType,
# wRecChange,
# arrayProjection4d,
# neuronInactivityCounter,
# synapseReconnectDelay,
# synapseConnectionNumber,
# synapticWChangeCounter, #TODO
# eta,
# zitCumulative,
# device)
# # merge learning weight with average learning weight of all batch
# wch = sum(wRecChange, dims=4) ./ (size(wRec, 4)) .* arrayProjection4d
# wRec .= (exInType .* wRec) .+ wch
# arrayProjection4d_cpu = arrayProjection4d |> cpu
# wRec_cpu = wRec |> cpu
# wRec_cpu = wRec_cpu[:,:,:,1] # since every batch has the same neuron wRec, (row, col, n)
# eta_cpu = eta |> cpu
# eta_cpu = eta_cpu[:,:,:,1]
# neuronInactivityCounter_cpu = neuronInactivityCounter |> cpu
# neuronInactivityCounter_cpu = neuronInactivityCounter_cpu[:,:,:,1] # (row, col, n)
# synapseReconnectDelay_cpu = synapseReconnectDelay |> cpu
# synapseReconnectDelay_cpu = synapseReconnectDelay_cpu[:,:,:,1]
# zitCumulative_cpu = zitCumulative |> cpu
# zitCumulative_cpu = zitCumulative_cpu[:,:,1] # (row, col)
# # -W if less than 10% of repeat avg, +W otherwise
# _, _, i3 = size(wRec_cpu)
# for i in 1:i3
# x = 0.1 * (sum(synapseReconnectDelay[:,:,i]) / length(synapseReconnectDelay[:,:,i]))
# mask = GeneralUtils.replaceLessThan.(wRec_cpu[:,:,i], x, -1, 1)
# wRec_cpu[:,:,i] .+= mask .* eta_cpu[:,:,i] .* wRec_cpu[:,:,i]
# end
# # weak / negative synaptic connection will get randomed in neuroplasticity()
# wRec_cpu = GeneralUtils.replaceBetween.(wRec_cpu, 0.0, 0.01, -1.0) # mark with -1.0
# # neuroplasticity, work on CPU side
# wRec_cpu = neuroplasticity(synapseConnectionNumber,
# zitCumulative_cpu,
# wRec_cpu,
# neuronInactivityCounter_cpu,
# synapseReconnectDelay_cpu)
# wRec_cpu = wRec_cpu .* arrayProjection4d_cpu
# wRec = wRec_cpu |> device
# neuronInactivityCounter_cpu = neuronInactivityCounter_cpu .* arrayProjection4d_cpu
# neuronInactivityCounter = neuronInactivityCounter_cpu |> device
# synapseReconnectDelay_cpu = synapseReconnectDelay_cpu .* arrayProjection4d_cpu
# synapseReconnectDelay = synapseReconnectDelay_cpu |> device
# return wRec, neuronInactivityCounter, synapseReconnectDelay
# end
function lifLearn(wRec, function lifLearn(wRec,
wRecChange, wRecChange,
exInType, exInType,
@@ -418,7 +330,6 @@ function lifLearn(wRec,
progress, progress,
device) device)
# transfer data to cpu # transfer data to cpu
arrayProjection4d_cpu = arrayProjection4d |> cpu arrayProjection4d_cpu = arrayProjection4d |> cpu
wRec_cpu = wRec |> cpu wRec_cpu = wRec |> cpu
@@ -430,7 +341,7 @@ function lifLearn(wRec,
synapticActivityCounter_cpu = synapticActivityCounter |> cpu synapticActivityCounter_cpu = synapticActivityCounter |> cpu
zitCumulative_cpu = zitCumulative |> cpu zitCumulative_cpu = zitCumulative |> cpu
# neuroplasticity, work on CPU side # neuroplasticity, work on CPU side
wRec_cpu, neuronInactivityCounter_cpu, synapseReconnectDelay_cpu = wRec_cpu, neuronInactivityCounter_cpu, synapticActivityCounter_cpu, synapseReconnectDelay_cpu =
neuroplasticity(synapseConnectionNumber, neuroplasticity(synapseConnectionNumber,
zitCumulative_cpu, zitCumulative_cpu,
wRec_cpu, wRec_cpu,
@@ -444,14 +355,12 @@ function lifLearn(wRec,
progress,) progress,)
# transfer data backto gpu # transfer data backto gpu
wRec_cpu = wRec_cpu .* arrayProjection4d_cpu
wRec = wRec_cpu |> device wRec = wRec_cpu |> device
neuronInactivityCounter_cpu = neuronInactivityCounter_cpu .* arrayProjection4d_cpu
neuronInactivityCounter = neuronInactivityCounter_cpu |> device neuronInactivityCounter = neuronInactivityCounter_cpu |> device
synapseReconnectDelay_cpu = synapseReconnectDelay_cpu .* arrayProjection4d_cpu synapticActivityCounter = synapticActivityCounter_cpu |> device
synapseReconnectDelay = synapseReconnectDelay_cpu |> device synapseReconnectDelay = synapseReconnectDelay_cpu |> device
error("DEBUG -> lifLearn! $(Dates.now())") # error("DEBUG -> lifLearn! $(Dates.now())")
return wRec, neuronInactivityCounter, synapseReconnectDelay return wRec, neuronInactivityCounter, synapticActivityCounter, synapseReconnectDelay
end end
function alifLearn(wRec, function alifLearn(wRec,
@@ -468,46 +377,37 @@ function alifLearn(wRec,
progress, progress,
device) device)
# merge learning weight with average learning weight of all batch # transfer data to cpu
wch = sum(wRecChange, dims=4) ./ (size(wRec, 4)) .* arrayProjection4d
wRec .= (exInType .* wRec) .+ wch
arrayProjection4d_cpu = arrayProjection4d |> cpu arrayProjection4d_cpu = arrayProjection4d |> cpu
wRec_cpu = wRec |> cpu wRec_cpu = wRec |> cpu
wRecChange_cpu = wRecChange |> cpu
eta_cpu = eta |> cpu eta_cpu = eta |> cpu
exInType_cpu = exInType |> cpu
neuronInactivityCounter_cpu = neuronInactivityCounter |> cpu neuronInactivityCounter_cpu = neuronInactivityCounter |> cpu
synapseReconnectDelay_cpu = synapseReconnectDelay |> cpu synapseReconnectDelay_cpu = synapseReconnectDelay |> cpu
synapticActivityCounter_cpu = synapticActivityCounter |> cpu
zitCumulative_cpu = zitCumulative |> cpu zitCumulative_cpu = zitCumulative |> cpu
# -W if less than 10% of repeat avg, +W otherwise
_, _, i3 = size(wRec_cpu)
for i in 1:i3
x = 0.1 * (sum(synapseReconnectDelay[:,:,i]) / length(synapseReconnectDelay[:,:,i]))
mask = GeneralUtils.replaceLessThan.(wRec_cpu[:,:,i], x, -1, 1)
wRec_cpu[:,:,i] .+= mask .* eta_cpu[:,:,i] .* wRec_cpu[:,:,i]
end
# weak / negative synaptic connection will get randomed in neuroplasticity()
wRec_cpu = GeneralUtils.replaceBetween.(wRec_cpu, 0.0, 0.01, -1.0) # mark with -1.0
# neuroplasticity, work on CPU side # neuroplasticity, work on CPU side
wRec_cpu = neuroplasticity(synapseConnectionNumber, wRec_cpu, neuronInactivityCounter_cpu, synapticActivityCounter_cpu, synapseReconnectDelay_cpu =
neuroplasticity(synapseConnectionNumber,
zitCumulative_cpu, zitCumulative_cpu,
wRec_cpu, wRec_cpu,
exInType_cpu,
wRecChange_cpu,
vt,
eta_cpu,
neuronInactivityCounter_cpu, neuronInactivityCounter_cpu,
synapseReconnectDelay_cpu) synapseReconnectDelay_cpu,
synapticActivityCounter_cpu,
progress,)
wRec_cpu = wRec_cpu .* arrayProjection4d_cpu # transfer data backto gpu
wRec = wRec_cpu |> device wRec = wRec_cpu |> device
neuronInactivityCounter_cpu = neuronInactivityCounter_cpu .* arrayProjection4d_cpu
neuronInactivityCounter = neuronInactivityCounter_cpu |> device neuronInactivityCounter = neuronInactivityCounter_cpu |> device
synapticActivityCounter = synapticActivityCounter_cpu |> device
synapseReconnectDelay_cpu = synapseReconnectDelay_cpu .* arrayProjection4d_cpu
synapseReconnectDelay = synapseReconnectDelay_cpu |> device synapseReconnectDelay = synapseReconnectDelay_cpu |> device
# error("DEBUG -> alifLearn! $(Dates.now())") # error("DEBUG -> alifLearn! $(Dates.now())")
return wRec, neuronInactivityCounter, synapseReconnectDelay return wRec, neuronInactivityCounter, synapticActivityCounter, synapseReconnectDelay
end end
function onLearn!(wOut, function onLearn!(wOut,
@@ -522,7 +422,7 @@ end
GeneralUtils.allTrue(args...) = false [args...] ? false : true GeneralUtils.allTrue(args...) = false [args...] ? false : true
#WORKING 2) rewrite this function #WORKING
function neuroplasticity(synapseConnectionNumber, function neuroplasticity(synapseConnectionNumber,
zitCumulative, # (row, col) zitCumulative, # (row, col)
wRec, # (row, col, n) wRec, # (row, col, n)
@@ -534,12 +434,6 @@ function neuroplasticity(synapseConnectionNumber,
synapseReconnectDelay, synapseReconnectDelay,
synapticActivityCounter, synapticActivityCounter,
progress,) # (row, col, n) progress,) # (row, col, n)
i1,i2,i3 = size(wRec)
println("eta $(size(eta))")
println("wRec 1 $(size(wRec)) ", wRec[:,:,1,1])
println("zitCumulative $(size(zitCumulative))")
println("progress $progress")
if progress == 2 # no need to learn if progress == 2 # no need to learn
# skip neuroplasticity # skip neuroplasticity
@@ -584,7 +478,8 @@ function neuroplasticity(synapseConnectionNumber,
end end
# error("DEBUG -> neuroplasticity $(Dates.now())") # error("DEBUG -> neuroplasticity $(Dates.now())")
return wRec return wRec, neuronInactivityCounter,
synapticActivityCounter, synapseReconnectDelay
end end
# learningLiquidity(x) = -0.0001x + 1 # -10000 to +10000; f(x) = -5e-05x+0.5 # learningLiquidity(x) = -0.0001x + 1 # -10000 to +10000; f(x) = -5e-05x+0.5

View File

@@ -120,21 +120,16 @@ function growRepeatedPath!(wRec, synapticActivityCounter, eta)
end end
function weakenNotMatureSynapse!(wRec, synapticActivityCounter, eta) # TODO not fully tested, there is no connection YET where there is 0 synapse activity but wRec is not 0 (subscribed) function weakenNotMatureSynapse!(wRec, synapticActivityCounter, eta) # TODO not fully tested, there is no connection YET where there is 0 synapse activity but wRec is not 0 (subscribed)
# println("wRec ", wRec[:,:,1,1])
mask_inactiveSynapse = isequal.(synapticActivityCounter, 0) mask_inactiveSynapse = isequal.(synapticActivityCounter, 0)
mask_notmature = GeneralUtils.isBetween.(wRec, 0.0, 0.1) # 2nd criteria, not mature synapse has weight < 0.1 mask_notmature = GeneralUtils.isBetween.(wRec, 0.0, 0.1) # 2nd criteria, not mature synapse has weight < 0.1
# println("mask_notmature ", mask_notmature[:,:,1,1])
mask_1 = GeneralUtils.allTrue.(mask_inactiveSynapse, mask_notmature) mask_1 = GeneralUtils.allTrue.(mask_inactiveSynapse, mask_notmature)
# println("mask_1 ", mask_1[:,:,1,1])
mask_2 = mask_1 .* (1 .- eta) mask_2 = mask_1 .* (1 .- eta)
GeneralUtils.replaceElements!(mask_2, 0, 1) # replace 0 with 1 so mask * Wrec will not get 0 weight GeneralUtils.replaceElements!(mask_2, 0, 1) # replace 0 with 1 so mask * Wrec will not get 0 weight
wRec .*= mask_2 wRec .*= mask_2
# println("wRec 2 ", wRec[:,:,1,1])
end end
function pruneSynapse!(wRec, synapticActivityCounter, synapseReconnectDelay) function pruneSynapse!(wRec, synapticActivityCounter, synapseReconnectDelay)
mask_weak = GeneralUtils.isBetween.(wRec, 0.0, 0.01) mask_weak = GeneralUtils.isBetween.(wRec, 0.0, 0.01)
println("weak synapse ", sum(mask_weak))
mask_notweak = (!GeneralUtils.isBetween).(wRec, 0.0, 0.01) mask_notweak = (!GeneralUtils.isBetween).(wRec, 0.0, 0.01)
wRec .*= mask_notweak # all marked weak synapse weight need to be 0.0 i.e. pruned wRec .*= mask_notweak # all marked weak synapse weight need to be 0.0 i.e. pruned
# all weak synapse activity are reset # all weak synapse activity are reset
@@ -165,22 +160,23 @@ function rewireSynapse!(wRec::AbstractArray, neuronInactivityCounter::AbstractAr
else else
for ind in eachindex(synapseReconnectDelay[:,:,n,i4]) for ind in eachindex(synapseReconnectDelay[:,:,n,i4])
timemark = synapseReconnectDelay[:,:,n,i4][ind] timemark = synapseReconnectDelay[:,:,n,i4][ind]
# println("timemark ", timemark)
if timemark > 0 #TODO not fully tested. mark timeStep available if timemark > 0 #TODO not fully tested. mark timeStep available
timemark = Int(timemark)
# get neuron pool at 10 timeStep earlier # get neuron pool at 10 timeStep earlier
earlier = timemark - 10 > 0 ? timemark - 10 : timemark earlier = timemark - 10 > 0 ? timemark - 10 : timemark
pool = sum(zitCumulative[:,:,earlier:timemark], dims=3) pool = sum(zitCumulative[:,:,earlier:timemark], dims=3) #BUG BoundsError: attempt to access 10×25×801 Array{Float32, 3} at index [1:10, 1:25, 1340.0f0:1.0f0:1350.0f0]
if sum(pool) != 0 if sum(pool) != 0
indices = findall(x -> x != 0, pool) indices = findall(x -> x != 0, pool)
pick = rand(indices) pick = rand(indices) # cartesian indice
wRec[pick] = rand(0.01:0.01:0.05) wRec[pick] = rand(0.01:0.01:0.05)
synapticActivityCounter[pick] = 0 synapticActivityCounter[pick] = 0
synapseReconnectDelay[pick] = -0.1 synapseReconnectDelay[pick] = -0.1
error("DEBUG -> rewireSynapse!") # error("DEBUG -> rewireSynapse!")
else # if neurons not firing at all, try again next time else # if neurons not firing at all, try again next time
synapticActivityCounter[pick] = 0 synapticActivityCounter[:,:,n,i4][ind] = 0
synapseReconnectDelay[:,:,n,i4] = rand(1:1000) synapseReconnectDelay[:,:,n,i4][ind] = rand(1:1000) * -1
error("DEBUG -> rewireSynapse!") # error("DEBUG -> rewireSynapse!")
end end
end end
end end