dev
This commit is contained in:
58
src/learn.jl
58
src/learn.jl
@@ -591,6 +591,10 @@ function neuroplasticity(synapseConnectionNumber,
|
||||
#TODO I may need to do something with neuronInactivityCounter and other variables
|
||||
wRecChange .= 0
|
||||
elseif progress == 1 # progress increase
|
||||
# ready to reconnect synapse must not have wRecChange
|
||||
mask = (!isequal).(wRec, 0)
|
||||
wRecChange .*= mask
|
||||
|
||||
# merge learning weight with average learning weight of all batch
|
||||
wRec .= abs.((exInType .* wRec) .+ wRecChange) # abs because wRec doesn't carry sign
|
||||
|
||||
@@ -609,7 +613,7 @@ function neuroplasticity(synapseConnectionNumber,
|
||||
wRec .*= mask_2
|
||||
|
||||
# -w, synapse with less than 10% of avg activity get reduced weight by eta
|
||||
mask_less = isless.(synapticActivityCounter, lowerlimit) # 1st criteria
|
||||
mask_less = isbetween.(synapticActivityCounter, 0.0, lowerlimit) # 1st criteria
|
||||
|
||||
mask_3 = alltrue.(mask_activeSynapse, mask_less)
|
||||
mask_3 .*= 1 .- eta # minor activity synapse weight will be reduced by eta
|
||||
@@ -621,12 +625,56 @@ function neuroplasticity(synapseConnectionNumber,
|
||||
mask_1 .*= 1 .- eta
|
||||
wRec .*= mask_1
|
||||
|
||||
# prune weak connection
|
||||
# mark weak / negative synaptic connection so they will get randomed in neuroplasticity()
|
||||
mask = isbetween.(wRec, 0.0, 0.01)
|
||||
wRec = GeneralUtils.replaceBetween.(wRec, 0.0, 0.01, -1.0) # mark with -1.0
|
||||
# prune synapse
|
||||
mask_weak = isbetween.(wRec, 0.0, 0.01)
|
||||
mask_notweak = (!isbetween).(wRec, 0.0, 0.01)
|
||||
wRec .*= mask_notweak # all marked weak synapse weight need to be 0.0 i.e. pruned
|
||||
r = rand((1:1000), size(wRec)) .* mask_weak # synapse random wait time to reconnect
|
||||
synapticActivityCounter .*= mask_notweak # all marked weak synapse activity are reset
|
||||
synapticActivityCounter .+= (mask_weak .* -1.0)
|
||||
synapseReconnectDelay .= (synapseReconnectDelay .* mask_notweak) .+ r # set pruned synapse to random wait time
|
||||
|
||||
#WORKING rewire synapse connection
|
||||
synapseReconnectDelay mark timeStep while also counting delay == BUG
|
||||
|
||||
for i in 1:i3 # neuron-by-neuron
|
||||
if neuronInactivityCounter[1:1:i][1] < -10000 # neuron die i.e. reset all weight
|
||||
println("neuron die")
|
||||
neuronInactivityCounter[:,:,i] .= 0 # reset
|
||||
w = random_wRec(i1,i2,1,synapseConnectionNumber)
|
||||
wRec[:,:,i] .= w
|
||||
|
||||
#WORKING
|
||||
a = similar(w) .= -0.99 # synapseConnectionNumber of this neuron
|
||||
mask = (!iszero).(w)
|
||||
GeneralUtils.replaceElements!(mask, 1, a, 0)
|
||||
synapseReconnectDelay[:,:,i] = a
|
||||
else
|
||||
remaining = 0
|
||||
if subToFireNeuron_current[1,1,i] < subToFireNeuron_toBe
|
||||
toAddConn = subToFireNeuron_toBe - subToFireNeuron_current[1,1,i]
|
||||
totalNewConn[1,1,i] = totalNewConn[1,1,i] - toAddConn
|
||||
# add new conn to firing neurons pool
|
||||
remaining = addNewSynapticConn!(zitMask[:,:,i], 1,
|
||||
@view(wRec[:,:,i]),
|
||||
@view(synapseReconnectDelay[:,:,i]),
|
||||
toAddConn)
|
||||
totalNewConn[1,1,i] += remaining
|
||||
end
|
||||
|
||||
# add new conn to non-firing neurons pool
|
||||
remaining = addNewSynapticConn!(zitMask[:,:,i], 0,
|
||||
@view(wRec[:,:,i]),
|
||||
@view(synapseReconnectDelay[:,:,i]),
|
||||
totalNewConn[1,1,i])
|
||||
if remaining > 0 # final get-all round if somehow non-firing pool has not enough slot
|
||||
remaining = addNewSynapticConn!(zitMask[:,:,i], 1,
|
||||
@view(wRec[:,:,i]),
|
||||
@view(synapseReconnectDelay[:,:,i]),
|
||||
remaining)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
elseif progress == 0 # no progress, no weight update, only rewire
|
||||
|
||||
Reference in New Issue
Block a user