This commit is contained in:
ton
2023-09-29 08:27:26 +07:00
parent 2b15510669
commit 364cfb4ea8
18 changed files with 7364 additions and 267 deletions

View File

@@ -675,241 +675,6 @@ function onForward( zit,
return nothing
end
# function lifForward(kfn_zit::Array{T},
# zit::Array{T},
# wRec::Array{T},
# vt0::Array{T},
# vt1::Array{T},
# vth::Array{T},
# vRest::Array{T},
# zt1::Array{T},
# alpha::Array{T},
# phi::Array{T},
# epsilonRec::Array{T},
# refractoryCounter::Array{T},
# refractoryDuration::Array{T},
# gammaPd::Array{T},
# firingCounter::Array{T},
# arrayProjection4d::Array{T},
# recSignal::Array{T},
# decayed_vt0::Array{T},
# decayed_epsilonRec::Array{T},
# vt1_diff_vth::Array{T},
# vt1_diff_vth_div_vth::Array{T},
# gammaPd_div_vth::Array{T},
# phiActivation::Array{T},
# ) where T<:Number
# # project 3D kfn zit into 4D lif zit
# i1, i2, i3, i4 = size(alif_wRec)
# lif_zit .= reshape(kfn_zit, (i1, i2, 1, i4)) .* lif_arrayProjection4d
# for j in 1:size(wRec, 4), i in 1:size(wRec, 3) # compute along neurons axis of every batch
# if sum(@view(refractoryCounter[:,:,i,j])) > 0 # refractory period is active
# @. @views refractoryCounter[:,:,i,j] -= 1
# @. @views zt1[:,:,i,j] = 0
# @. @views vt1[:,:,i,j] = alpha[:,:,i,j] * vt0[:,:,i,j]
# @. @views phi[:,:,i,j] = 0
# # compute epsilonRec
# @. @views decayed_epsilonRec[:,:,i,j] = alpha[:,:,i,j] * epsilonRec[:,:,i,j]
# @. @views epsilonRec[:,:,i,j] = decayed_epsilonRec[:,:,i,j]
# else # refractory period is inactive
# @. @views recSignal[:,:,i,j] = zit[:,:,i,j] * wRec[:,:,i,j]
# @. @views decayed_vt0[:,:,i,j] = alpha[:,:,i,j] * vt0[:,:,i,j]
# @view(vt1[:,:,i,j]) .= @view(decayed_vt0[:,:,i,j]) .+ sum(@view(recSignal[:,:,i,j]))
# if sum(@view(vt1[:,:,i,j])) > sum(@view(vth[:,:,i,j]))
# @. @views zt1[:,:,i,j] = 1
# @. @views refractoryCounter[:,:,i,j] = refractoryDuration[:,:,i,j]
# @. @views firingCounter[:,:,i,j] += 1
# @. @views vt1[:,:,i,j] = vRest[:,:,i,j]
# else
# @. @views zt1[:,:,i,j] = 0
# end
# # compute phi, there is a difference from alif formula
# @. @views gammaPd_div_vth[:,:,i,j] = gammaPd[:,:,i,j] / vth[:,:,i,j]
# @. @views vt1_diff_vth[:,:,i,j] = vt1[:,:,i,j] - vth[:,:,i,j]
# @. @views vt1_diff_vth_div_vth[:,:,i,j] = vt1_diff_vth[:,:,i,j] / vth[:,:,i,j]
# @view(phiActivation[:,:,i,j]) .= max(0, 1 - sum(@view(vt1_diff_vth_div_vth[:,:,i,j])))
# @. @views phi[:,:,i,j] = gammaPd_div_vth[:,:,i,j] * phiActivation[:,:,i,j]
# # compute epsilonRec
# @. @views decayed_epsilonRec[:,:,i,j] = alpha[:,:,i,j] * epsilonRec[:,:,i,j]
# @. @views epsilonRec[:,:,i,j] = decayed_epsilonRec[:,:,i,j] + zit[:,:,i,j]
# end
# end
# end
# function alifForward(zit::Array{T},
# wRec::Array{T},
# vt0::Array{T},
# vt1::Array{T},
# vth::Array{T},
# vRest::Array{T},
# zt1::Array{T},
# alpha::Array{T},
# phi::Array{T},
# epsilonRec::Array{T},
# refractoryCounter::Array{T},
# refractoryDuration::Array{T},
# gammaPd::Array{T},
# firingCounter::Array{T},
# recSignal::Array{T},
# decayed_vt0::Array{T},
# decayed_epsilonRec::Array{T},
# vt1_diff_vth::Array{T},
# vt1_diff_vth_div_vth::Array{T},
# gammaPd_div_vth::Array{T},
# phiActivation::Array{T},
# epsilonRecA::Array{T},
# avth::Array{T},
# a::Array{T},
# beta::Array{T},
# rho::Array{T},
# phi_x_epsilonRec::Array{T},
# phi_x_beta::Array{T},
# rho_diff_phi_x_beta::Array{T},
# rho_div_phi_x_beta_x_epsilonRecA::Array{T},
# beta_x_a::Array{T},
# ) where T<:Number
# for j in 1:size(wRec, 4), i in 1:size(wRec, 3) # compute along neurons axis of every batch
# if sum(@view(refractoryCounter[:,:,i,j])) > 0 # refractory period is active
# @. @views refractoryCounter[:,:,i,j] -= 1
# @. @views zt1[:,:,i,j] = 0
# @. @views vt1[:,:,i,j] = alpha[:,:,i,j] * vt0[:,:,i,j]
# @. @views phi[:,:,i,j] = 0
# @. @views a[:,:,i,j] = rho[:,:,i,j] * a[:,:,i,j]
# # compute epsilonRec
# @. @views decayed_epsilonRec[:,:,i,j] = alpha[:,:,i,j] * epsilonRec[:,:,i,j]
# @. @views epsilonRec[:,:,i,j] = decayed_epsilonRec[:,:,i,j]
# # compute epsilonRecA
# @. @views phi_x_epsilonRec[:,:,i,j] = phi[:,:,i,j] * epsilonRec[:,:,i,j]
# @. @views phi_x_beta[:,:,i,j] = phi[:,:,i,j] * beta[:,:,i,j]
# @. @views rho_diff_phi_x_beta[:,:,i,j] = rho[:,:,i,j] - phi_x_beta[:,:,i,j]
# @. @views rho_div_phi_x_beta_x_epsilonRecA[:,:,i,j] = rho_diff_phi_x_beta[:,:,i,j] * epsilonRecA[:,:,i,j]
# @. @views epsilonRecA[:,:,i,j] = phi_x_epsilonRec[:,:,i,j] + rho_div_phi_x_beta_x_epsilonRecA[:,:,i,j]
# # compute avth
# @. @views beta_x_a[:,:,i,j] = beta[:,:,i,j] * a[:,:,i,j]
# @. @views avth[:,:,i,j] = vth[:,:,i,j] + beta_x_a[:,:,i,j]
# else # refractory period is inactive
# @. @views recSignal[:,:,i,j] = zit[:,:,i,j] * wRec[:,:,i,j]
# @. @views decayed_vt0[:,:,i,j] = alpha[:,:,i,j] * vt0[:,:,i,j]
# @view(vt1[:,:,i,j]) .= @view(decayed_vt0[:,:,i,j]) .+ sum(@view(recSignal[:,:,i,j]))
# # compute avth
# @. @views beta_x_a[:,:,i,j] = beta[:,:,i,j] * a[:,:,i,j]
# @. @views avth[:,:,i,j] = vth[:,:,i,j] + beta_x_a[:,:,i,j]
# if sum(@view(vt1[:,:,i,j])) > sum(@view(avth[:,:,i,j]))
# @. @views zt1[:,:,i,j] = 1
# @. @views refractoryCounter[:,:,i,j] = refractoryDuration[:,:,i,j]
# @. @views firingCounter[:,:,i,j] += 1
# @. @views vt1[:,:,i,j] = vRest[:,:,i,j]
# @. @views a[:,:,i,j] = rho[:,:,i,j] * a[:,:,i,j]
# @. @views a[:,:,i,j] = a[:,:,i,j] += 1
# else
# @. @views zt1[:,:,i,j] = 0
# @. @views a[:,:,i,j] = rho[:,:,i,j] * a[:,:,i,j]
# end
# # compute phi, there is a difference from alif formula
# @. @views gammaPd_div_vth[:,:,i,j] = gammaPd[:,:,i,j] / vth[:,:,i,j]
# @. @views vt1_diff_vth[:,:,i,j] = vt1[:,:,i,j] - vth[:,:,i,j]
# @. @views vt1_diff_vth_div_vth[:,:,i,j] = vt1_diff_vth[:,:,i,j] / vth[:,:,i,j]
# @view(phiActivation[:,:,i,j]) .= max(0, 1 - sum(@view(vt1_diff_vth_div_vth[:,:,i,j])))
# @. @views phi[:,:,i,j] = gammaPd_div_vth[:,:,i,j] * phiActivation[:,:,i,j]
# # compute epsilonRec
# @. @views decayed_epsilonRec[:,:,i,j] = alpha[:,:,i,j] * epsilonRec[:,:,i,j]
# @. @views epsilonRec[:,:,i,j] = decayed_epsilonRec[:,:,i,j] + zit[:,:,i,j]
# # compute epsilonRecA
# @. @views phi_x_epsilonRec[:,:,i,j] = phi[:,:,i,j] * epsilonRec[:,:,i,j]
# @. @views phi_x_beta[:,:,i,j] = phi[:,:,i,j] * beta[:,:,i,j]
# @. @views rho_diff_phi_x_beta[:,:,i,j] = rho[:,:,i,j] - phi_x_beta[:,:,i,j]
# @. @views rho_div_phi_x_beta_x_epsilonRecA[:,:,i,j] = rho_diff_phi_x_beta[:,:,i,j] * epsilonRecA[:,:,i,j]
# @. @views epsilonRecA[:,:,i,j] = phi_x_epsilonRec[:,:,i,j] + rho_div_phi_x_beta_x_epsilonRecA[:,:,i,j]
# end
# end
# end
# function onForward(kfn_zit::Array{T},
# zit::Array{T},
# wOut::Array{T},
# vt0::Array{T},
# vt1::Array{T},
# vth::Array{T},
# vRest::Array{T},
# zt1::Array{T},
# alpha::Array{T},
# phi::Array{T},
# epsilonRec::Array{T},
# refractoryCounter::Array{T},
# refractoryDuration::Array{T},
# gammaPd::Array{T},
# firingCounter::Array{T},
# arrayProjection4d::Array{T},
# recSignal::Array{T},
# decayed_vt0::Array{T},
# decayed_epsilonRec::Array{T},
# vt1_diff_vth::Array{T},
# vt1_diff_vth_div_vth::Array{T},
# gammaPd_div_vth::Array{T},
# phiActivation::Array{T},
# ) where T<:Number
# # project 3D kfn zit into 4D lif zit
# zit .= reshape(kfn_zit,
# (size(wOut, 1), size(wOut, 2), 1, size(wOut, 4))) .* arrayProjection4d
# for j in 1:size(wOut, 4), i in 1:size(wOut, 3) # compute along neurons axis of every batch
# if sum(@view(refractoryCounter[:,:,i,j])) > 0 # refractory period is active
# @. @views refractoryCounter[:,:,i,j] -= 1
# @. @views zt1[:,:,i,j] = 0
# @. @views vt1[:,:,i,j] = alpha[:,:,i,j] * vt0[:,:,i,j]
# @. @views phi[:,:,i,j] = 0
# # compute epsilonRec
# @. @views decayed_epsilonRec[:,:,i,j] = alpha[:,:,i,j] * epsilonRec[:,:,i,j]
# @. @views epsilonRec[:,:,i,j] = decayed_epsilonRec[:,:,i,j]
# else # refractory period is inactive
# @. @views recSignal[:,:,i,j] = zit[:,:,i,j] * wOut[:,:,i,j]
# @. @views decayed_vt0[:,:,i,j] = alpha[:,:,i,j] * vt0[:,:,i,j]
# @view(vt1[:,:,i,j]) .= @view(decayed_vt0[:,:,i,j]) .+ sum(@view(recSignal[:,:,i,j]))
# if sum(@view(vt1[:,:,i,j])) > sum(@view(vth[:,:,i,j]))
# @. @views zt1[:,:,i,j] = 1
# @. @views refractoryCounter[:,:,i,j] = refractoryDuration[:,:,i,j]
# @. @views firingCounter[:,:,i,j] += 1
# @. @views vt1[:,:,i,j] = vRest[:,:,i,j]
# else
# @. @views zt1[:,:,i,j] = 0
# end
# # compute phi, there is a difference from alif formula
# @. @views gammaPd_div_vth[:,:,i,j] = gammaPd[:,:,i,j] / vth[:,:,i,j]
# @. @views vt1_diff_vth[:,:,i,j] = vt1[:,:,i,j] - vth[:,:,i,j]
# @. @views vt1_diff_vth_div_vth[:,:,i,j] = vt1_diff_vth[:,:,i,j] / vth[:,:,i,j]
# @view(phiActivation[:,:,i,j]) .= max(0, 1 - sum(@view(vt1_diff_vth_div_vth[:,:,i,j])))
# @. @views phi[:,:,i,j] = gammaPd_div_vth[:,:,i,j] * phiActivation[:,:,i,j]
# # compute epsilonRec
# @. @views decayed_epsilonRec[:,:,i,j] = alpha[:,:,i,j] * epsilonRec[:,:,i,j]
# @. @views epsilonRec[:,:,i,j] = decayed_epsilonRec[:,:,i,j] + zit[:,:,i,j]
# end
# end
# end

View File

@@ -283,16 +283,15 @@ end
function learn!(kfn::kfn_1, progress, device=cpu)
if sum(kfn.timeStep) == 800
println("zitCumulative ", sum(kfn.zitCumulative[:,:,784:size(kfn.zitCumulative, 3)], dims=3))
# println("zitCumulative ", sum(kfn.zitCumulative[:,:,784:size(kfn.zitCumulative, 3)], dims=3))
println("synapse lif $(sum((!isequal).(kfn.lif_wRec, 0))) alif $(sum((!isequal).(kfn.alif_wRec, 0)))")
println("on_synapticActivityCounter 0 ", kfn.on_synapticActivityCounter[:,:,1])
println("on_synapticActivityCounter 5 ", kfn.on_synapticActivityCounter[:,:,6])
println("wOut 0 ", sum(kfn.on_wOut[:,:,1,1], dims=3))
println("wOut 5 ", sum(kfn.on_wOut[:,:,6,1], dims=3))
println("wOut 0 $(sum(kfn.on_wOut[:,:,1,1], dims=3)) total $(sum(sum(kfn.on_wOut[:,:,1,1], dims=3)))")
println("wOut 5 $(sum(kfn.on_wOut[:,:,6,1], dims=3)) total $(sum(sum(kfn.on_wOut[:,:,6,1], dims=3)))")
end
#WORKING compare output neuron 0 synapse activity when input are label 0 and 5, (!isequal).(wOut)
# lif learn
kfn.lif_wRec, kfn.lif_neuronInactivityCounter, kfn.lif_synapticActivityCounter, kfn.lif_synapseReconnectDelay =
lifLearn(kfn.lif_wRec,
@@ -451,11 +450,12 @@ function onLearn!(wOut,
if progress != 0
# adaptive wOut to help convergence using c_decay
wOut .-= 0.1 .* eta .* wOut # wOut .-= 0.001 .* wOut
# merge learning weight with average learning weight
wOut .+= (sum(wOutChange, dims=4) ./ (size(wOut, 4))) .* arrayProjection4d
# adaptive wOut to help convergence using c_decay
wOut .-= 0.1 .* eta .* wOut # wOut .-= 0.001 .* wOut
else
#TESTING skip
wOutChange .= 0
@@ -478,23 +478,32 @@ function neuroplasticity(synapseConnectionNumber,
# skip neuroplasticity
#TODO I may need to do something with neuronInactivityCounter and other variables
wRecChange .= 0
# # -w all non-fire connection except mature connection
# weakenNotMatureSynapse!(wRec, synapticActivityCounter, eta)
# # prune weak synapse
# pruneSynapse!(wRec, synapticActivityCounter, synapseReconnectDelay)
# error("DEBUG -> neuroplasticity")
elseif progress == 1 # some progress whether up or down
# ready to reconnect synapse must not have wRecChange
mask = (!isequal).(wRec, 0)
wRecChange .*= mask
# weakenAllActiveSynapse!(wRec, synapticActivityCounter, eta)
# merge learning weight, all resulting negative wRec will get pruned
mergeLearnWeight!(wRec, exInType, wRecChange, synapticActivityCounter, synapseReconnectDelay)
# adjust wRec based on repeatition (90% +w, 10% -w)
growRepeatedPath!(wRec, synapticActivityCounter, eta)
# # adjust wRec based on repeatition (90% +w, 10% -w)
# growRepeatedPath!(wRec, synapticActivityCounter, eta)
# -w all non-fire connection except mature connection
weakenNotMatureSynapse!(wRec, synapticActivityCounter, eta)
# # -w all non-fire connection except mature connection
# weakenNotMatureSynapse!(wRec, synapticActivityCounter, eta)
# prune weak synapse
pruneSynapse!(wRec, synapticActivityCounter, synapseReconnectDelay)
# # prune weak synapse
# pruneSynapse!(wRec, synapticActivityCounter, synapseReconnectDelay)
# rewire synapse connection
rewireSynapse!(wRec, neuronInactivityCounter, synapticActivityCounter,
@@ -503,8 +512,10 @@ function neuroplasticity(synapseConnectionNumber,
elseif progress == 0 # no progress, no weight update, only rewire
wRecChange .= 0
# prune weak synapse
pruneSynapse!(wRec, synapticActivityCounter, synapseReconnectDelay)
# weakenAllActiveSynapse!(wRec, synapticActivityCounter, eta)
# # prune weak synapse
# pruneSynapse!(wRec, synapticActivityCounter, synapseReconnectDelay)
# rewire synapse connection
rewireSynapse!(wRec, neuronInactivityCounter, synapticActivityCounter,
@@ -515,18 +526,20 @@ function neuroplasticity(synapseConnectionNumber,
# ready to reconnect synapse must not have wRecChange
mask = (!isequal).(wRec, 0)
wRecChange .*= mask
# weakenAllActiveSynapse!(wRec, synapticActivityCounter, eta)
# merge learning weight, all resulting negative wRec will get pruned
mergeLearnWeight!(wRec, exInType, wRecChange, synapticActivityCounter, synapseReconnectDelay)
# adjust wRec based on repeatition (90% +w, 10% -w)
growRepeatedPath!(wRec, synapticActivityCounter, eta)
# # adjust wRec based on repeatition (90% +w, 10% -w)
# growRepeatedPath!(wRec, synapticActivityCounter, eta)
# -w all non-fire connection except mature connection
weakenNotMatureSynapse!(wRec, synapticActivityCounter, eta)
# # -w all non-fire connection except mature connection
# weakenNotMatureSynapse!(wRec, synapticActivityCounter, eta)
# prune weak synapse
pruneSynapse!(wRec, synapticActivityCounter, synapseReconnectDelay)
# # prune weak synapse
# pruneSynapse!(wRec, synapticActivityCounter, synapseReconnectDelay)
# rewire synapse connection
rewireSynapse!(wRec, neuronInactivityCounter, synapticActivityCounter,

View File

@@ -1,12 +1,13 @@
module snnUtil
export refractoryStatus!, addNewSynapticConn!, mergeLearnWeight!, growRepeatedPath!,
weakenNotMatureSynapse!, pruneSynapse!, rewireSynapse!
weakenNotMatureSynapse!, pruneSynapse!, rewireSynapse!, weakenAllActiveSynapse!
using Random, GeneralUtils
using ..type
#------------------------------------------------------------------------------------------------100
synapseMaxWaittime = 100
function refractoryStatus!(refractoryCounter, refractoryActive, refractoryInactive)
d1, d2, d3, d4 = size(refractoryCounter)
@@ -82,7 +83,7 @@ function mergeLearnWeight!(wRec::AbstractArray, exInType, wRecChange::AbstractAr
# println("wRec 5 $(size(wRec)) ", wRec[:,:,1,1])
GeneralUtils.replaceElements!(flipsign, 1, synapticActivityCounter, 0)
# set pruned synapse to random wait time
waittime = rand((1:1000), size(wRec)) .* flipsign # synapse's random wait time to reconnect
waittime = rand((1:synapseMaxWaittime), size(wRec)) .* flipsign # synapse's random wait time to reconnect
# synapseReconnectDelay counting mode when value is negative hence .* -1
synapseReconnectDelay .= (synapseReconnectDelay .* nonflipsign) .+ (waittime .* -1)
# println("synapseReconnectDelay ", synapseReconnectDelay[:,:,1,1])
@@ -113,6 +114,12 @@ function growRepeatedPath!(wRec, synapticActivityCounter, eta)
# error("DEBUG -> growRepeatedPath!")
end
function weakenAllActiveSynapse!(wRec, synapticActivityCounter, eta) # TODO not fully tested, there is no connection YET where there is 0 synapse activity but wRec is not 0 (subscribed)
mask_activeSynapse = (!isequal).(synapticActivityCounter, 0)
mask_1 = mask_activeSynapse .* (1 .- (0.1 .* eta))
GeneralUtils.replaceElements!(mask_1, 0, 1) # replace 0 with 1 so mask * Wrec will not get 0 weight
wRec .*= mask_1
end
function weakenNotMatureSynapse!(wRec, synapticActivityCounter, eta) # TODO not fully tested, there is no connection YET where there is 0 synapse activity but wRec is not 0 (subscribed)
mask_inactiveSynapse = isequal.(synapticActivityCounter, 0)
@@ -130,7 +137,7 @@ function pruneSynapse!(wRec, synapticActivityCounter, synapseReconnectDelay)
# all weak synapse activity are reset
GeneralUtils.replaceElements!(mask_weak, 1, synapticActivityCounter, 0)
# set pruned synapse to random wait time
waittime = rand((1:1000), size(wRec)) .* mask_weak # synapse's random wait time to reconnect
waittime = rand((1:synapseMaxWaittime), size(wRec)) .* mask_weak # synapse's random wait time to reconnect
# synapseReconnectDelay counting mode when value is negative hence .* -1
synapseReconnectDelay .= (synapseReconnectDelay .* mask_notweak) .+ (waittime .* -1)
# error("DEBUG -> pruneSynapse!")
@@ -159,21 +166,21 @@ function rewireSynapse!(wRec::AbstractArray, neuronInactivityCounter::AbstractAr
if timemark > 0 #TODO not fully tested. mark timeStep available
timemark = Int(timemark)
# get neuron pool at 10 timeStep earlier
earlier = size(zitCumulative, 3) - 10 > 0 ? size(zitCumulative, 3) - 10 : size(zitCumulative, 3)
# get neuron pool within 100 timeStep earlier
earlier = size(zitCumulative, 3) - 100 > 0 ? size(zitCumulative, 3) - 100 : size(zitCumulative, 3)
current = size(zitCumulative, 3)
pool = sum(zitCumulative[:,:,earlier:current], dims=3)
if sum(pool) != 0
indices = findall(x -> x != 0, pool)
pick = rand(indices) # cartesian indice
wRec[pick] = rand(0.01:0.01:0.05)
wRec[pick] = rand(0.001:0.001:0.02)
synapticActivityCounter[pick] = 0
synapseReconnectDelay[pick] = -0.1
# error("DEBUG -> rewireSynapse!")
else # if neurons not firing at all, try again next time
synapticActivityCounter[:,:,n,i4][ind] = 0
synapseReconnectDelay[:,:,n,i4][ind] = rand(1:1000) * -1
synapseReconnectDelay[:,:,n,i4][ind] = rand(1:synapseMaxWaittime) * -1 # wait time
# error("DEBUG -> rewireSynapse!")
end
end

View File

@@ -7,7 +7,7 @@ export
# function
random_wRec
using Random, GeneralUtils
using Random, GeneralUtils, LinearAlgebra
#------------------------------------------------------------------------------------------------100
rng = MersenneTwister(1234)
@@ -372,7 +372,7 @@ function random_wRec(row, col, n, synapseConnectionNumber)
for slice in eachslice(w, dims=3)
pool = shuffle!([1:row*col...])[1:synapseConnectionNumber]
for i in pool
slice[i] = rand(0.01:0.01:0.05) # assign weight to synaptic connection. /10 to start small,
slice[i] = rand() # assign weight to synaptic connection. /10 to start small,
# otherwise RSNN's vt Usually stay negative (-)
end
end
@@ -382,7 +382,7 @@ function random_wRec(row, col, n, synapseConnectionNumber)
# avgWeight = sum(w)/length(w)
# w = w .* (0.01 / avgWeight) # adjust overall weight
return w #(row, col, n)
return normalize!(w) #(row, col, n)
end