dev
This commit is contained in:
235
src/forward.jl
235
src/forward.jl
@@ -675,241 +675,6 @@ function onForward( zit,
|
||||
return nothing
|
||||
end
|
||||
|
||||
# function lifForward(kfn_zit::Array{T},
|
||||
# zit::Array{T},
|
||||
# wRec::Array{T},
|
||||
# vt0::Array{T},
|
||||
# vt1::Array{T},
|
||||
# vth::Array{T},
|
||||
# vRest::Array{T},
|
||||
# zt1::Array{T},
|
||||
# alpha::Array{T},
|
||||
# phi::Array{T},
|
||||
# epsilonRec::Array{T},
|
||||
# refractoryCounter::Array{T},
|
||||
# refractoryDuration::Array{T},
|
||||
# gammaPd::Array{T},
|
||||
# firingCounter::Array{T},
|
||||
# arrayProjection4d::Array{T},
|
||||
# recSignal::Array{T},
|
||||
# decayed_vt0::Array{T},
|
||||
# decayed_epsilonRec::Array{T},
|
||||
# vt1_diff_vth::Array{T},
|
||||
# vt1_diff_vth_div_vth::Array{T},
|
||||
# gammaPd_div_vth::Array{T},
|
||||
# phiActivation::Array{T},
|
||||
# ) where T<:Number
|
||||
|
||||
# # project 3D kfn zit into 4D lif zit
|
||||
# i1, i2, i3, i4 = size(alif_wRec)
|
||||
# lif_zit .= reshape(kfn_zit, (i1, i2, 1, i4)) .* lif_arrayProjection4d
|
||||
|
||||
# for j in 1:size(wRec, 4), i in 1:size(wRec, 3) # compute along neurons axis of every batch
|
||||
# if sum(@view(refractoryCounter[:,:,i,j])) > 0 # refractory period is active
|
||||
# @. @views refractoryCounter[:,:,i,j] -= 1
|
||||
# @. @views zt1[:,:,i,j] = 0
|
||||
# @. @views vt1[:,:,i,j] = alpha[:,:,i,j] * vt0[:,:,i,j]
|
||||
# @. @views phi[:,:,i,j] = 0
|
||||
|
||||
# # compute epsilonRec
|
||||
# @. @views decayed_epsilonRec[:,:,i,j] = alpha[:,:,i,j] * epsilonRec[:,:,i,j]
|
||||
# @. @views epsilonRec[:,:,i,j] = decayed_epsilonRec[:,:,i,j]
|
||||
# else # refractory period is inactive
|
||||
# @. @views recSignal[:,:,i,j] = zit[:,:,i,j] * wRec[:,:,i,j]
|
||||
# @. @views decayed_vt0[:,:,i,j] = alpha[:,:,i,j] * vt0[:,:,i,j]
|
||||
# @view(vt1[:,:,i,j]) .= @view(decayed_vt0[:,:,i,j]) .+ sum(@view(recSignal[:,:,i,j]))
|
||||
|
||||
# if sum(@view(vt1[:,:,i,j])) > sum(@view(vth[:,:,i,j]))
|
||||
# @. @views zt1[:,:,i,j] = 1
|
||||
# @. @views refractoryCounter[:,:,i,j] = refractoryDuration[:,:,i,j]
|
||||
# @. @views firingCounter[:,:,i,j] += 1
|
||||
# @. @views vt1[:,:,i,j] = vRest[:,:,i,j]
|
||||
# else
|
||||
# @. @views zt1[:,:,i,j] = 0
|
||||
# end
|
||||
|
||||
# # compute phi, there is a difference from alif formula
|
||||
# @. @views gammaPd_div_vth[:,:,i,j] = gammaPd[:,:,i,j] / vth[:,:,i,j]
|
||||
# @. @views vt1_diff_vth[:,:,i,j] = vt1[:,:,i,j] - vth[:,:,i,j]
|
||||
# @. @views vt1_diff_vth_div_vth[:,:,i,j] = vt1_diff_vth[:,:,i,j] / vth[:,:,i,j]
|
||||
# @view(phiActivation[:,:,i,j]) .= max(0, 1 - sum(@view(vt1_diff_vth_div_vth[:,:,i,j])))
|
||||
# @. @views phi[:,:,i,j] = gammaPd_div_vth[:,:,i,j] * phiActivation[:,:,i,j]
|
||||
|
||||
# # compute epsilonRec
|
||||
# @. @views decayed_epsilonRec[:,:,i,j] = alpha[:,:,i,j] * epsilonRec[:,:,i,j]
|
||||
# @. @views epsilonRec[:,:,i,j] = decayed_epsilonRec[:,:,i,j] + zit[:,:,i,j]
|
||||
# end
|
||||
# end
|
||||
# end
|
||||
|
||||
# function alifForward(zit::Array{T},
|
||||
# wRec::Array{T},
|
||||
# vt0::Array{T},
|
||||
# vt1::Array{T},
|
||||
# vth::Array{T},
|
||||
# vRest::Array{T},
|
||||
# zt1::Array{T},
|
||||
# alpha::Array{T},
|
||||
# phi::Array{T},
|
||||
# epsilonRec::Array{T},
|
||||
# refractoryCounter::Array{T},
|
||||
# refractoryDuration::Array{T},
|
||||
# gammaPd::Array{T},
|
||||
# firingCounter::Array{T},
|
||||
# recSignal::Array{T},
|
||||
# decayed_vt0::Array{T},
|
||||
# decayed_epsilonRec::Array{T},
|
||||
# vt1_diff_vth::Array{T},
|
||||
# vt1_diff_vth_div_vth::Array{T},
|
||||
# gammaPd_div_vth::Array{T},
|
||||
# phiActivation::Array{T},
|
||||
|
||||
# epsilonRecA::Array{T},
|
||||
# avth::Array{T},
|
||||
# a::Array{T},
|
||||
# beta::Array{T},
|
||||
# rho::Array{T},
|
||||
# phi_x_epsilonRec::Array{T},
|
||||
# phi_x_beta::Array{T},
|
||||
# rho_diff_phi_x_beta::Array{T},
|
||||
# rho_div_phi_x_beta_x_epsilonRecA::Array{T},
|
||||
# beta_x_a::Array{T},
|
||||
# ) where T<:Number
|
||||
|
||||
# for j in 1:size(wRec, 4), i in 1:size(wRec, 3) # compute along neurons axis of every batch
|
||||
# if sum(@view(refractoryCounter[:,:,i,j])) > 0 # refractory period is active
|
||||
# @. @views refractoryCounter[:,:,i,j] -= 1
|
||||
# @. @views zt1[:,:,i,j] = 0
|
||||
# @. @views vt1[:,:,i,j] = alpha[:,:,i,j] * vt0[:,:,i,j]
|
||||
# @. @views phi[:,:,i,j] = 0
|
||||
# @. @views a[:,:,i,j] = rho[:,:,i,j] * a[:,:,i,j]
|
||||
|
||||
# # compute epsilonRec
|
||||
# @. @views decayed_epsilonRec[:,:,i,j] = alpha[:,:,i,j] * epsilonRec[:,:,i,j]
|
||||
# @. @views epsilonRec[:,:,i,j] = decayed_epsilonRec[:,:,i,j]
|
||||
|
||||
# # compute epsilonRecA
|
||||
# @. @views phi_x_epsilonRec[:,:,i,j] = phi[:,:,i,j] * epsilonRec[:,:,i,j]
|
||||
# @. @views phi_x_beta[:,:,i,j] = phi[:,:,i,j] * beta[:,:,i,j]
|
||||
# @. @views rho_diff_phi_x_beta[:,:,i,j] = rho[:,:,i,j] - phi_x_beta[:,:,i,j]
|
||||
# @. @views rho_div_phi_x_beta_x_epsilonRecA[:,:,i,j] = rho_diff_phi_x_beta[:,:,i,j] * epsilonRecA[:,:,i,j]
|
||||
# @. @views epsilonRecA[:,:,i,j] = phi_x_epsilonRec[:,:,i,j] + rho_div_phi_x_beta_x_epsilonRecA[:,:,i,j]
|
||||
|
||||
# # compute avth
|
||||
# @. @views beta_x_a[:,:,i,j] = beta[:,:,i,j] * a[:,:,i,j]
|
||||
# @. @views avth[:,:,i,j] = vth[:,:,i,j] + beta_x_a[:,:,i,j]
|
||||
|
||||
# else # refractory period is inactive
|
||||
# @. @views recSignal[:,:,i,j] = zit[:,:,i,j] * wRec[:,:,i,j]
|
||||
# @. @views decayed_vt0[:,:,i,j] = alpha[:,:,i,j] * vt0[:,:,i,j]
|
||||
# @view(vt1[:,:,i,j]) .= @view(decayed_vt0[:,:,i,j]) .+ sum(@view(recSignal[:,:,i,j]))
|
||||
|
||||
# # compute avth
|
||||
# @. @views beta_x_a[:,:,i,j] = beta[:,:,i,j] * a[:,:,i,j]
|
||||
# @. @views avth[:,:,i,j] = vth[:,:,i,j] + beta_x_a[:,:,i,j]
|
||||
|
||||
# if sum(@view(vt1[:,:,i,j])) > sum(@view(avth[:,:,i,j]))
|
||||
# @. @views zt1[:,:,i,j] = 1
|
||||
# @. @views refractoryCounter[:,:,i,j] = refractoryDuration[:,:,i,j]
|
||||
# @. @views firingCounter[:,:,i,j] += 1
|
||||
# @. @views vt1[:,:,i,j] = vRest[:,:,i,j]
|
||||
# @. @views a[:,:,i,j] = rho[:,:,i,j] * a[:,:,i,j]
|
||||
# @. @views a[:,:,i,j] = a[:,:,i,j] += 1
|
||||
# else
|
||||
# @. @views zt1[:,:,i,j] = 0
|
||||
# @. @views a[:,:,i,j] = rho[:,:,i,j] * a[:,:,i,j]
|
||||
# end
|
||||
|
||||
# # compute phi, there is a difference from alif formula
|
||||
# @. @views gammaPd_div_vth[:,:,i,j] = gammaPd[:,:,i,j] / vth[:,:,i,j]
|
||||
# @. @views vt1_diff_vth[:,:,i,j] = vt1[:,:,i,j] - vth[:,:,i,j]
|
||||
# @. @views vt1_diff_vth_div_vth[:,:,i,j] = vt1_diff_vth[:,:,i,j] / vth[:,:,i,j]
|
||||
# @view(phiActivation[:,:,i,j]) .= max(0, 1 - sum(@view(vt1_diff_vth_div_vth[:,:,i,j])))
|
||||
# @. @views phi[:,:,i,j] = gammaPd_div_vth[:,:,i,j] * phiActivation[:,:,i,j]
|
||||
|
||||
# # compute epsilonRec
|
||||
# @. @views decayed_epsilonRec[:,:,i,j] = alpha[:,:,i,j] * epsilonRec[:,:,i,j]
|
||||
# @. @views epsilonRec[:,:,i,j] = decayed_epsilonRec[:,:,i,j] + zit[:,:,i,j]
|
||||
|
||||
# # compute epsilonRecA
|
||||
# @. @views phi_x_epsilonRec[:,:,i,j] = phi[:,:,i,j] * epsilonRec[:,:,i,j]
|
||||
# @. @views phi_x_beta[:,:,i,j] = phi[:,:,i,j] * beta[:,:,i,j]
|
||||
# @. @views rho_diff_phi_x_beta[:,:,i,j] = rho[:,:,i,j] - phi_x_beta[:,:,i,j]
|
||||
# @. @views rho_div_phi_x_beta_x_epsilonRecA[:,:,i,j] = rho_diff_phi_x_beta[:,:,i,j] * epsilonRecA[:,:,i,j]
|
||||
# @. @views epsilonRecA[:,:,i,j] = phi_x_epsilonRec[:,:,i,j] + rho_div_phi_x_beta_x_epsilonRecA[:,:,i,j]
|
||||
# end
|
||||
# end
|
||||
# end
|
||||
|
||||
# function onForward(kfn_zit::Array{T},
|
||||
# zit::Array{T},
|
||||
# wOut::Array{T},
|
||||
# vt0::Array{T},
|
||||
# vt1::Array{T},
|
||||
# vth::Array{T},
|
||||
# vRest::Array{T},
|
||||
# zt1::Array{T},
|
||||
# alpha::Array{T},
|
||||
# phi::Array{T},
|
||||
# epsilonRec::Array{T},
|
||||
# refractoryCounter::Array{T},
|
||||
# refractoryDuration::Array{T},
|
||||
# gammaPd::Array{T},
|
||||
# firingCounter::Array{T},
|
||||
# arrayProjection4d::Array{T},
|
||||
# recSignal::Array{T},
|
||||
# decayed_vt0::Array{T},
|
||||
# decayed_epsilonRec::Array{T},
|
||||
# vt1_diff_vth::Array{T},
|
||||
# vt1_diff_vth_div_vth::Array{T},
|
||||
# gammaPd_div_vth::Array{T},
|
||||
# phiActivation::Array{T},
|
||||
# ) where T<:Number
|
||||
|
||||
# # project 3D kfn zit into 4D lif zit
|
||||
# zit .= reshape(kfn_zit,
|
||||
# (size(wOut, 1), size(wOut, 2), 1, size(wOut, 4))) .* arrayProjection4d
|
||||
|
||||
# for j in 1:size(wOut, 4), i in 1:size(wOut, 3) # compute along neurons axis of every batch
|
||||
# if sum(@view(refractoryCounter[:,:,i,j])) > 0 # refractory period is active
|
||||
# @. @views refractoryCounter[:,:,i,j] -= 1
|
||||
# @. @views zt1[:,:,i,j] = 0
|
||||
# @. @views vt1[:,:,i,j] = alpha[:,:,i,j] * vt0[:,:,i,j]
|
||||
# @. @views phi[:,:,i,j] = 0
|
||||
|
||||
# # compute epsilonRec
|
||||
# @. @views decayed_epsilonRec[:,:,i,j] = alpha[:,:,i,j] * epsilonRec[:,:,i,j]
|
||||
# @. @views epsilonRec[:,:,i,j] = decayed_epsilonRec[:,:,i,j]
|
||||
# else # refractory period is inactive
|
||||
# @. @views recSignal[:,:,i,j] = zit[:,:,i,j] * wOut[:,:,i,j]
|
||||
# @. @views decayed_vt0[:,:,i,j] = alpha[:,:,i,j] * vt0[:,:,i,j]
|
||||
# @view(vt1[:,:,i,j]) .= @view(decayed_vt0[:,:,i,j]) .+ sum(@view(recSignal[:,:,i,j]))
|
||||
|
||||
# if sum(@view(vt1[:,:,i,j])) > sum(@view(vth[:,:,i,j]))
|
||||
# @. @views zt1[:,:,i,j] = 1
|
||||
# @. @views refractoryCounter[:,:,i,j] = refractoryDuration[:,:,i,j]
|
||||
# @. @views firingCounter[:,:,i,j] += 1
|
||||
# @. @views vt1[:,:,i,j] = vRest[:,:,i,j]
|
||||
# else
|
||||
# @. @views zt1[:,:,i,j] = 0
|
||||
# end
|
||||
|
||||
# # compute phi, there is a difference from alif formula
|
||||
# @. @views gammaPd_div_vth[:,:,i,j] = gammaPd[:,:,i,j] / vth[:,:,i,j]
|
||||
# @. @views vt1_diff_vth[:,:,i,j] = vt1[:,:,i,j] - vth[:,:,i,j]
|
||||
# @. @views vt1_diff_vth_div_vth[:,:,i,j] = vt1_diff_vth[:,:,i,j] / vth[:,:,i,j]
|
||||
# @view(phiActivation[:,:,i,j]) .= max(0, 1 - sum(@view(vt1_diff_vth_div_vth[:,:,i,j])))
|
||||
# @. @views phi[:,:,i,j] = gammaPd_div_vth[:,:,i,j] * phiActivation[:,:,i,j]
|
||||
|
||||
# # compute epsilonRec
|
||||
# @. @views decayed_epsilonRec[:,:,i,j] = alpha[:,:,i,j] * epsilonRec[:,:,i,j]
|
||||
# @. @views epsilonRec[:,:,i,j] = decayed_epsilonRec[:,:,i,j] + zit[:,:,i,j]
|
||||
# end
|
||||
# end
|
||||
# end
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
57
src/learn.jl
57
src/learn.jl
@@ -283,16 +283,15 @@ end
|
||||
|
||||
function learn!(kfn::kfn_1, progress, device=cpu)
|
||||
if sum(kfn.timeStep) == 800
|
||||
println("zitCumulative ", sum(kfn.zitCumulative[:,:,784:size(kfn.zitCumulative, 3)], dims=3))
|
||||
# println("zitCumulative ", sum(kfn.zitCumulative[:,:,784:size(kfn.zitCumulative, 3)], dims=3))
|
||||
println("synapse lif $(sum((!isequal).(kfn.lif_wRec, 0))) alif $(sum((!isequal).(kfn.alif_wRec, 0)))")
|
||||
println("on_synapticActivityCounter 0 ", kfn.on_synapticActivityCounter[:,:,1])
|
||||
println("on_synapticActivityCounter 5 ", kfn.on_synapticActivityCounter[:,:,6])
|
||||
println("wOut 0 ", sum(kfn.on_wOut[:,:,1,1], dims=3))
|
||||
println("wOut 5 ", sum(kfn.on_wOut[:,:,6,1], dims=3))
|
||||
|
||||
println("wOut 0 $(sum(kfn.on_wOut[:,:,1,1], dims=3)) total $(sum(sum(kfn.on_wOut[:,:,1,1], dims=3)))")
|
||||
println("wOut 5 $(sum(kfn.on_wOut[:,:,6,1], dims=3)) total $(sum(sum(kfn.on_wOut[:,:,6,1], dims=3)))")
|
||||
end
|
||||
|
||||
#WORKING compare output neuron 0 synapse activity when input are label 0 and 5, (!isequal).(wOut)
|
||||
|
||||
|
||||
# lif learn
|
||||
kfn.lif_wRec, kfn.lif_neuronInactivityCounter, kfn.lif_synapticActivityCounter, kfn.lif_synapseReconnectDelay =
|
||||
lifLearn(kfn.lif_wRec,
|
||||
@@ -451,11 +450,12 @@ function onLearn!(wOut,
|
||||
|
||||
|
||||
if progress != 0
|
||||
# adaptive wOut to help convergence using c_decay
|
||||
wOut .-= 0.1 .* eta .* wOut # wOut .-= 0.001 .* wOut
|
||||
|
||||
# merge learning weight with average learning weight
|
||||
wOut .+= (sum(wOutChange, dims=4) ./ (size(wOut, 4))) .* arrayProjection4d
|
||||
|
||||
# adaptive wOut to help convergence using c_decay
|
||||
wOut .-= 0.1 .* eta .* wOut # wOut .-= 0.001 .* wOut
|
||||
else
|
||||
#TESTING skip
|
||||
wOutChange .= 0
|
||||
@@ -478,23 +478,32 @@ function neuroplasticity(synapseConnectionNumber,
|
||||
# skip neuroplasticity
|
||||
#TODO I may need to do something with neuronInactivityCounter and other variables
|
||||
wRecChange .= 0
|
||||
|
||||
# # -w all non-fire connection except mature connection
|
||||
# weakenNotMatureSynapse!(wRec, synapticActivityCounter, eta)
|
||||
|
||||
# # prune weak synapse
|
||||
# pruneSynapse!(wRec, synapticActivityCounter, synapseReconnectDelay)
|
||||
|
||||
# error("DEBUG -> neuroplasticity")
|
||||
elseif progress == 1 # some progress whether up or down
|
||||
# ready to reconnect synapse must not have wRecChange
|
||||
mask = (!isequal).(wRec, 0)
|
||||
wRecChange .*= mask
|
||||
|
||||
# weakenAllActiveSynapse!(wRec, synapticActivityCounter, eta)
|
||||
|
||||
# merge learning weight, all resulting negative wRec will get pruned
|
||||
mergeLearnWeight!(wRec, exInType, wRecChange, synapticActivityCounter, synapseReconnectDelay)
|
||||
|
||||
# adjust wRec based on repeatition (90% +w, 10% -w)
|
||||
growRepeatedPath!(wRec, synapticActivityCounter, eta)
|
||||
# # adjust wRec based on repeatition (90% +w, 10% -w)
|
||||
# growRepeatedPath!(wRec, synapticActivityCounter, eta)
|
||||
|
||||
# -w all non-fire connection except mature connection
|
||||
weakenNotMatureSynapse!(wRec, synapticActivityCounter, eta)
|
||||
# # -w all non-fire connection except mature connection
|
||||
# weakenNotMatureSynapse!(wRec, synapticActivityCounter, eta)
|
||||
|
||||
# prune weak synapse
|
||||
pruneSynapse!(wRec, synapticActivityCounter, synapseReconnectDelay)
|
||||
# # prune weak synapse
|
||||
# pruneSynapse!(wRec, synapticActivityCounter, synapseReconnectDelay)
|
||||
|
||||
# rewire synapse connection
|
||||
rewireSynapse!(wRec, neuronInactivityCounter, synapticActivityCounter,
|
||||
@@ -503,8 +512,10 @@ function neuroplasticity(synapseConnectionNumber,
|
||||
elseif progress == 0 # no progress, no weight update, only rewire
|
||||
wRecChange .= 0
|
||||
|
||||
# prune weak synapse
|
||||
pruneSynapse!(wRec, synapticActivityCounter, synapseReconnectDelay)
|
||||
# weakenAllActiveSynapse!(wRec, synapticActivityCounter, eta)
|
||||
|
||||
# # prune weak synapse
|
||||
# pruneSynapse!(wRec, synapticActivityCounter, synapseReconnectDelay)
|
||||
|
||||
# rewire synapse connection
|
||||
rewireSynapse!(wRec, neuronInactivityCounter, synapticActivityCounter,
|
||||
@@ -515,18 +526,20 @@ function neuroplasticity(synapseConnectionNumber,
|
||||
# ready to reconnect synapse must not have wRecChange
|
||||
mask = (!isequal).(wRec, 0)
|
||||
wRecChange .*= mask
|
||||
|
||||
# weakenAllActiveSynapse!(wRec, synapticActivityCounter, eta)
|
||||
|
||||
# merge learning weight, all resulting negative wRec will get pruned
|
||||
mergeLearnWeight!(wRec, exInType, wRecChange, synapticActivityCounter, synapseReconnectDelay)
|
||||
|
||||
# adjust wRec based on repeatition (90% +w, 10% -w)
|
||||
growRepeatedPath!(wRec, synapticActivityCounter, eta)
|
||||
# # adjust wRec based on repeatition (90% +w, 10% -w)
|
||||
# growRepeatedPath!(wRec, synapticActivityCounter, eta)
|
||||
|
||||
# -w all non-fire connection except mature connection
|
||||
weakenNotMatureSynapse!(wRec, synapticActivityCounter, eta)
|
||||
# # -w all non-fire connection except mature connection
|
||||
# weakenNotMatureSynapse!(wRec, synapticActivityCounter, eta)
|
||||
|
||||
# prune weak synapse
|
||||
pruneSynapse!(wRec, synapticActivityCounter, synapseReconnectDelay)
|
||||
# # prune weak synapse
|
||||
# pruneSynapse!(wRec, synapticActivityCounter, synapseReconnectDelay)
|
||||
|
||||
# rewire synapse connection
|
||||
rewireSynapse!(wRec, neuronInactivityCounter, synapticActivityCounter,
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
module snnUtil
|
||||
|
||||
export refractoryStatus!, addNewSynapticConn!, mergeLearnWeight!, growRepeatedPath!,
|
||||
weakenNotMatureSynapse!, pruneSynapse!, rewireSynapse!
|
||||
weakenNotMatureSynapse!, pruneSynapse!, rewireSynapse!, weakenAllActiveSynapse!
|
||||
|
||||
using Random, GeneralUtils
|
||||
using ..type
|
||||
|
||||
#------------------------------------------------------------------------------------------------100
|
||||
synapseMaxWaittime = 100
|
||||
|
||||
function refractoryStatus!(refractoryCounter, refractoryActive, refractoryInactive)
|
||||
d1, d2, d3, d4 = size(refractoryCounter)
|
||||
@@ -82,7 +83,7 @@ function mergeLearnWeight!(wRec::AbstractArray, exInType, wRecChange::AbstractAr
|
||||
# println("wRec 5 $(size(wRec)) ", wRec[:,:,1,1])
|
||||
GeneralUtils.replaceElements!(flipsign, 1, synapticActivityCounter, 0)
|
||||
# set pruned synapse to random wait time
|
||||
waittime = rand((1:1000), size(wRec)) .* flipsign # synapse's random wait time to reconnect
|
||||
waittime = rand((1:synapseMaxWaittime), size(wRec)) .* flipsign # synapse's random wait time to reconnect
|
||||
# synapseReconnectDelay counting mode when value is negative hence .* -1
|
||||
synapseReconnectDelay .= (synapseReconnectDelay .* nonflipsign) .+ (waittime .* -1)
|
||||
# println("synapseReconnectDelay ", synapseReconnectDelay[:,:,1,1])
|
||||
@@ -113,6 +114,12 @@ function growRepeatedPath!(wRec, synapticActivityCounter, eta)
|
||||
# error("DEBUG -> growRepeatedPath!")
|
||||
end
|
||||
|
||||
function weakenAllActiveSynapse!(wRec, synapticActivityCounter, eta) # TODO not fully tested, there is no connection YET where there is 0 synapse activity but wRec is not 0 (subscribed)
|
||||
mask_activeSynapse = (!isequal).(synapticActivityCounter, 0)
|
||||
mask_1 = mask_activeSynapse .* (1 .- (0.1 .* eta))
|
||||
GeneralUtils.replaceElements!(mask_1, 0, 1) # replace 0 with 1 so mask * Wrec will not get 0 weight
|
||||
wRec .*= mask_1
|
||||
end
|
||||
|
||||
function weakenNotMatureSynapse!(wRec, synapticActivityCounter, eta) # TODO not fully tested, there is no connection YET where there is 0 synapse activity but wRec is not 0 (subscribed)
|
||||
mask_inactiveSynapse = isequal.(synapticActivityCounter, 0)
|
||||
@@ -130,7 +137,7 @@ function pruneSynapse!(wRec, synapticActivityCounter, synapseReconnectDelay)
|
||||
# all weak synapse activity are reset
|
||||
GeneralUtils.replaceElements!(mask_weak, 1, synapticActivityCounter, 0)
|
||||
# set pruned synapse to random wait time
|
||||
waittime = rand((1:1000), size(wRec)) .* mask_weak # synapse's random wait time to reconnect
|
||||
waittime = rand((1:synapseMaxWaittime), size(wRec)) .* mask_weak # synapse's random wait time to reconnect
|
||||
# synapseReconnectDelay counting mode when value is negative hence .* -1
|
||||
synapseReconnectDelay .= (synapseReconnectDelay .* mask_notweak) .+ (waittime .* -1)
|
||||
# error("DEBUG -> pruneSynapse!")
|
||||
@@ -159,21 +166,21 @@ function rewireSynapse!(wRec::AbstractArray, neuronInactivityCounter::AbstractAr
|
||||
|
||||
if timemark > 0 #TODO not fully tested. mark timeStep available
|
||||
timemark = Int(timemark)
|
||||
# get neuron pool at 10 timeStep earlier
|
||||
earlier = size(zitCumulative, 3) - 10 > 0 ? size(zitCumulative, 3) - 10 : size(zitCumulative, 3)
|
||||
# get neuron pool within 100 timeStep earlier
|
||||
earlier = size(zitCumulative, 3) - 100 > 0 ? size(zitCumulative, 3) - 100 : size(zitCumulative, 3)
|
||||
current = size(zitCumulative, 3)
|
||||
pool = sum(zitCumulative[:,:,earlier:current], dims=3)
|
||||
|
||||
if sum(pool) != 0
|
||||
indices = findall(x -> x != 0, pool)
|
||||
pick = rand(indices) # cartesian indice
|
||||
wRec[pick] = rand(0.01:0.01:0.05)
|
||||
wRec[pick] = rand(0.001:0.001:0.02)
|
||||
synapticActivityCounter[pick] = 0
|
||||
synapseReconnectDelay[pick] = -0.1
|
||||
# error("DEBUG -> rewireSynapse!")
|
||||
else # if neurons not firing at all, try again next time
|
||||
synapticActivityCounter[:,:,n,i4][ind] = 0
|
||||
synapseReconnectDelay[:,:,n,i4][ind] = rand(1:1000) * -1
|
||||
synapseReconnectDelay[:,:,n,i4][ind] = rand(1:synapseMaxWaittime) * -1 # wait time
|
||||
# error("DEBUG -> rewireSynapse!")
|
||||
end
|
||||
end
|
||||
|
||||
@@ -7,7 +7,7 @@ export
|
||||
# function
|
||||
random_wRec
|
||||
|
||||
using Random, GeneralUtils
|
||||
using Random, GeneralUtils, LinearAlgebra
|
||||
|
||||
#------------------------------------------------------------------------------------------------100
|
||||
rng = MersenneTwister(1234)
|
||||
@@ -372,7 +372,7 @@ function random_wRec(row, col, n, synapseConnectionNumber)
|
||||
for slice in eachslice(w, dims=3)
|
||||
pool = shuffle!([1:row*col...])[1:synapseConnectionNumber]
|
||||
for i in pool
|
||||
slice[i] = rand(0.01:0.01:0.05) # assign weight to synaptic connection. /10 to start small,
|
||||
slice[i] = rand() # assign weight to synaptic connection. /10 to start small,
|
||||
# otherwise RSNN's vt Usually stay negative (-)
|
||||
end
|
||||
end
|
||||
@@ -382,7 +382,7 @@ function random_wRec(row, col, n, synapseConnectionNumber)
|
||||
# avgWeight = sum(w)/length(w)
|
||||
# w = w .* (0.01 / avgWeight) # adjust overall weight
|
||||
|
||||
return w #(row, col, n)
|
||||
return normalize!(w) #(row, col, n)
|
||||
end
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user