rewrite neuroplasticity()

This commit is contained in:
ton
2023-09-17 13:57:54 +07:00
parent c8105a416e
commit 213f35b1ff
3 changed files with 13 additions and 143 deletions

View File

@@ -319,9 +319,8 @@ function lifForward( zit,
wRecChange[i1,i2,i3,i4] = -0.01*0.0001 * (vt[i1,i2,i3,i4] - vth[i1,i2,i3,i4]) * wRecChange[i1,i2,i3,i4] = -0.01*0.0001 * (vt[i1,i2,i3,i4] - vth[i1,i2,i3,i4]) *
zit[i1,i2,i3,i4] zit[i1,i2,i3,i4]
if synapseReconnectDelay[i1,i2,i3,i4] < 0 && # negative value is counting mode # negative value is counting mode, -0.1 < -0.1 won't work on GPU
synapseReconnectDelay[i1,i2,i3,i4] < -0.2 # -0.1 < -0.1 won't work on GPU if synapseReconnectDelay[i1,i2,i3,i4] < -0.2
synapseReconnectDelay[i1,i2,i3,i4] += 1 synapseReconnectDelay[i1,i2,i3,i4] += 1
if synapseReconnectDelay[i1,i2,i3,i4] == 0 if synapseReconnectDelay[i1,i2,i3,i4] == 0
# mark timestep # mark timestep
@@ -528,9 +527,8 @@ function alifForward( zit,
wRecChange[i1,i2,i3,i4] = -0.01*0.0001 * (vt[i1,i2,i3,i4] - avth[i1,i2,i3,i4]) * wRecChange[i1,i2,i3,i4] = -0.01*0.0001 * (vt[i1,i2,i3,i4] - avth[i1,i2,i3,i4]) *
zit[i1,i2,i3,i4] zit[i1,i2,i3,i4]
if synapseReconnectDelay[i1,i2,i3,i4] < 0 && # negative value is counting mode # negative value is counting mode, -0.1 < -0.1 won't work on GPU
synapseReconnectDelay[i1,i2,i3,i4] < -0.2 # -0.1 < -0.1 won't work on GPU if synapseReconnectDelay[i1,i2,i3,i4] < -0.2
synapseReconnectDelay[i1,i2,i3,i4] += 1 synapseReconnectDelay[i1,i2,i3,i4] += 1
if synapseReconnectDelay[i1,i2,i3,i4] == 0 if synapseReconnectDelay[i1,i2,i3,i4] == 0
# mark timestep # mark timestep

View File

@@ -545,7 +545,7 @@ function neuroplasticity(synapseConnectionNumber,
# skip neuroplasticity # skip neuroplasticity
#TODO I may need to do something with neuronInactivityCounter and other variables #TODO I may need to do something with neuronInactivityCounter and other variables
wRecChange .= 0 wRecChange .= 0
error("DEBUG -> neuroplasticity") # error("DEBUG -> neuroplasticity")
elseif progress != 0 # progress increase elseif progress != 0 # progress increase
# ready to reconnect synapse must not have wRecChange # ready to reconnect synapse must not have wRecChange
mask = (!isequal).(wRec, 0) mask = (!isequal).(wRec, 0)
@@ -566,8 +566,7 @@ function neuroplasticity(synapseConnectionNumber,
# rewire synapse connection # rewire synapse connection
rewireSynapse!(wRec, neuronInactivityCounter, synapticActivityCounter, rewireSynapse!(wRec, neuronInactivityCounter, synapticActivityCounter,
synapseReconnectDelay, zitCumulative) synapseReconnectDelay, zitCumulative)
error("DEBUG -> neuroplasticity") # error("DEBUG -> neuroplasticity 1")
elseif progress == 0 # no progress, no weight update, only rewire elseif progress == 0 # no progress, no weight update, only rewire
# -w all non-fire connection except mature connection # -w all non-fire connection except mature connection
weakenNotMatureSynapse!(wRec, synapticActivityCounter, eta) weakenNotMatureSynapse!(wRec, synapticActivityCounter, eta)
@@ -579,143 +578,15 @@ function neuroplasticity(synapseConnectionNumber,
rewireSynapse!(wRec, neuronInactivityCounter, synapticActivityCounter, rewireSynapse!(wRec, neuronInactivityCounter, synapticActivityCounter,
synapseReconnectDelay, zitCumulative) synapseReconnectDelay, zitCumulative)
error("DEBUG -> neuroplasticity $(Dates.now())") # error("DEBUG -> neuroplasticity")
else else
error("undefined condition line $(@__LINE__)") error("undefined condition line $(@__LINE__)")
end end
# error("DEBUG -> neuroplasticity $(Dates.now())")
# # for each neuron, find total number of synaptic conn that should draw
# # new connection to firing and non-firing neurons pool
# subToFireNeuron_toBe = Int(floor(0.7 * synapseConnectionNumber))
# # for each neuron, count how many synapse already subscribed to firing-neurons
# zw = zitCumulative .* wRec
# subToFireNeuron_current = sum(GeneralUtils.GeneralUtils.isBetween.(zw, 0.0, 100.0), dims=(1,2)) # (1, 1, n)
# zitMask = (!iszero).(zitCumulative) # zitMask of firing neurons = 1, non-firing = 0
# projection = ones(i1,i2,i3)
# zitMask = zitMask .* projection # (row, col, n)
# totalNewConn = sum(isequal.(wRec, -1.0), dims=(1,2)) # count new conn mark (-1.0), (1, 1, n)
# println("neuroplasticity, from $(synapseConnectionNumber*size(totalNewConn, 3)) conn, $(sum(totalNewConn)) are replaced")
# # clear -1.0 marker
# GeneralUtils.replaceElements!(wRec, -1.0, synapseReconnectDelay, -0.99)
# GeneralUtils.replaceElements!(wRec, -1.0, 0.0) # -1.0 marker is no longer required
# for i in 1:i3
# if neuronInactivityCounter[1:1:i][1] < -10000 # neuron die i.e. reset all weight
# println("neuron die")
# neuronInactivityCounter[:,:,i] .= 0 # reset
# w = random_wRec(i1,i2,1,synapseConnectionNumber)
# wRec[:,:,i] .= w
# a = similar(w) .= -0.99 # synapseConnectionNumber of this neuron
# mask = (!iszero).(w)
# GeneralUtils.replaceElements!(mask, 1, a, 0)
# synapseReconnectDelay[:,:,i] = a
# else
# remaining = 0
# if subToFireNeuron_current[1,1,i] < subToFireNeuron_toBe
# toAddConn = subToFireNeuron_toBe - subToFireNeuron_current[1,1,i]
# totalNewConn[1,1,i] = totalNewConn[1,1,i] - toAddConn
# # add new conn to firing neurons pool
# remaining = addNewSynapticConn!(zitMask[:,:,i], 1,
# @view(wRec[:,:,i]),
# @view(synapseReconnectDelay[:,:,i]),
# toAddConn)
# totalNewConn[1,1,i] += remaining
# end
# # add new conn to non-firing neurons pool
# remaining = addNewSynapticConn!(zitMask[:,:,i], 0,
# @view(wRec[:,:,i]),
# @view(synapseReconnectDelay[:,:,i]),
# totalNewConn[1,1,i])
# if remaining > 0 # final get-all round if somehow non-firing pool has not enough slot
# remaining = addNewSynapticConn!(zitMask[:,:,i], 1,
# @view(wRec[:,:,i]),
# @view(synapseReconnectDelay[:,:,i]),
# remaining)
# end
# end
# end
# error("DEBUG -> neuroplasticity $(Dates.now())") # error("DEBUG -> neuroplasticity $(Dates.now())")
return wRec return wRec
end end
# function neuroplasticity(synapseConnectionNumber,
# zitCumulative, # (row, col)
# wRec, # (row, col, n)
# neuronInactivityCounter,
# synapseReconnectDelay) # (row, col, n)
# i1,i2,i3 = size(wRec)
# # for each neuron, find total number of synaptic conn that should draw
# # new connection to firing and non-firing neurons pool
# subToFireNeuron_toBe = Int(floor(0.7 * synapseConnectionNumber))
# # for each neuron, count how many synap already subscribed to firing-neurons
# zw = zitCumulative .* wRec
# subToFireNeuron_current = sum(GeneralUtils.GeneralUtils.isBetween.(zw, 0.0, 100.0), dims=(1,2)) # (1, 1, n)
# zitMask = (!iszero).(zitCumulative) # zitMask of firing neurons = 1, non-firing = 0
# projection = ones(i1,i2,i3)
# zitMask = zitMask .* projection # (row, col, n)
# totalNewConn = sum(isequal.(wRec, -1.0), dims=(1,2)) # count new conn mark (-1.0), (1, 1, n)
# println("neuroplasticity, from $(synapseConnectionNumber*size(totalNewConn, 3)) conn, $(sum(totalNewConn)) are replaced")
# # clear -1.0 marker
# GeneralUtils.replaceElements!(wRec, -1.0, synapseReconnectDelay, -0.99)
# GeneralUtils.replaceElements!(wRec, -1.0, 0.0) # -1.0 marker is no longer required
# for i in 1:i3
# if neuronInactivityCounter[1:1:i][1] < -10000 # neuron die i.e. reset all weight
# println("neuron die")
# neuronInactivityCounter[:,:,i] .= 0 # reset
# w = random_wRec(i1,i2,1,synapseConnectionNumber)
# wRec[:,:,i] .= w
# a = similar(w) .= -0.99 # synapseConnectionNumber of this neuron
# mask = (!iszero).(w)
# GeneralUtils.replaceElements!(mask, 1, a, 0)
# synapseReconnectDelay[:,:,i] = a
# else
# remaining = 0
# if subToFireNeuron_current[1,1,i] < subToFireNeuron_toBe
# toAddConn = subToFireNeuron_toBe - subToFireNeuron_current[1,1,i]
# totalNewConn[1,1,i] = totalNewConn[1,1,i] - toAddConn
# # add new conn to firing neurons pool
# remaining = addNewSynapticConn!(zitMask[:,:,i], 1,
# @view(wRec[:,:,i]),
# @view(synapseReconnectDelay[:,:,i]),
# toAddConn)
# totalNewConn[1,1,i] += remaining
# end
# # add new conn to non-firing neurons pool
# remaining = addNewSynapticConn!(zitMask[:,:,i], 0,
# @view(wRec[:,:,i]),
# @view(synapseReconnectDelay[:,:,i]),
# totalNewConn[1,1,i])
# if remaining > 0 # final get-all round if somehow non-firing pool has not enough slot
# remaining = addNewSynapticConn!(zitMask[:,:,i], 1,
# @view(wRec[:,:,i]),
# @view(synapseReconnectDelay[:,:,i]),
# remaining)
# end
# end
# end
# # error("DEBUG -> neuroplasticity $(Dates.now())")
# return wRec
# end
# learningLiquidity(x) = -0.0001x + 1 # -10000 to +10000; f(x) = -5e-05x+0.5 # learningLiquidity(x) = -0.0001x + 1 # -10000 to +10000; f(x) = -5e-05x+0.5
function learningLiquidity(x) function learningLiquidity(x)

View File

@@ -88,7 +88,7 @@ function mergeLearnWeight!(wRec::AbstractArray, exInType, wRecChange::AbstractAr
# error("DEBUG -> mergeLearnWeight!") # error("DEBUG -> mergeLearnWeight!")
end end
function growRepeatedPath!(wRec, synapticActivityCounter, eta) #BUG wRec get all 0 function growRepeatedPath!(wRec, synapticActivityCounter, eta)
# seperate active synapse out of inactive in this signal # seperate active synapse out of inactive in this signal
mask_activeSynapse = (!isequal).(synapticActivityCounter, 0) mask_activeSynapse = (!isequal).(synapticActivityCounter, 0)
# println("synapticActivityCounter ", synapticActivityCounter[:,:,1,1]) # println("synapticActivityCounter ", synapticActivityCounter[:,:,1,1])
@@ -134,6 +134,7 @@ end
function pruneSynapse!(wRec, synapticActivityCounter, synapseReconnectDelay) function pruneSynapse!(wRec, synapticActivityCounter, synapseReconnectDelay)
mask_weak = GeneralUtils.isBetween.(wRec, 0.0, 0.01) mask_weak = GeneralUtils.isBetween.(wRec, 0.0, 0.01)
println("weak synapse ", sum(mask_weak))
mask_notweak = (!GeneralUtils.isBetween).(wRec, 0.0, 0.01) mask_notweak = (!GeneralUtils.isBetween).(wRec, 0.0, 0.01)
wRec .*= mask_notweak # all marked weak synapse weight need to be 0.0 i.e. pruned wRec .*= mask_notweak # all marked weak synapse weight need to be 0.0 i.e. pruned
# all weak synapse activity are reset # all weak synapse activity are reset
@@ -164,23 +165,23 @@ function rewireSynapse!(wRec::AbstractArray, neuronInactivityCounter::AbstractAr
else else
for ind in eachindex(synapseReconnectDelay[:,:,n,i4]) for ind in eachindex(synapseReconnectDelay[:,:,n,i4])
timemark = synapseReconnectDelay[:,:,n,i4][ind] timemark = synapseReconnectDelay[:,:,n,i4][ind]
if timemark > 0 # mark timeStep available # println("timemark ", timemark)
if timemark > 0 #TODO not fully tested. mark timeStep available
# get neuron pool at 10 timeStep earlier # get neuron pool at 10 timeStep earlier
earlier = timemark - 10 > 0 ? timemark - 10 : timemark earlier = timemark - 10 > 0 ? timemark - 10 : timemark
pool = sum(zitCumulative[:,:,earlier:timemark], dims=3) pool = sum(zitCumulative[:,:,earlier:timemark], dims=3)
println(">>> ", sum(pool))
if sum(pool) != 0 if sum(pool) != 0
indices = findall(x -> x != 0, pool) indices = findall(x -> x != 0, pool)
pick = rand(indices) pick = rand(indices)
wRec[pick] = rand(0.01:0.01:0.05) wRec[pick] = rand(0.01:0.01:0.05)
synapticActivityCounter[pick] = 0 synapticActivityCounter[pick] = 0
synapseReconnectDelay[pick] = -0.1 synapseReconnectDelay[pick] = -0.1
error("DEBUG -> rewireSynapse!")
else # if neurons not firing at all, try again next time else # if neurons not firing at all, try again next time
synapticActivityCounter[pick] = 0 synapticActivityCounter[pick] = 0
synapseReconnectDelay[:,:,n,i4] = rand(1:1000) synapseReconnectDelay[:,:,n,i4] = rand(1:1000)
error("DEBUG -> rewireSynapse!")
end end
error("DEBUG -> rewireSynapse!")
end end
end end
end end