ComputeParamsChange()
This commit is contained in:
568
src/forward.jl
568
src/forward.jl
@@ -22,7 +22,6 @@ function (kfn::kfn_1)(input::AbstractArray)
|
|||||||
end
|
end
|
||||||
|
|
||||||
# println(">>> input ", size(input))
|
# println(">>> input ", size(input))
|
||||||
# println(">>> zit ", size(kfn.zit))
|
|
||||||
# println(">>> lif_zit ", size(kfn.lif_zit))
|
# println(">>> lif_zit ", size(kfn.lif_zit))
|
||||||
# println(">>> lif_recSignal ", size(kfn.lif_recSignal))
|
# println(">>> lif_recSignal ", size(kfn.lif_recSignal))
|
||||||
# println(">>> lif_wRec ", size(kfn.lif_wRec))
|
# println(">>> lif_wRec ", size(kfn.lif_wRec))
|
||||||
@@ -31,17 +30,29 @@ function (kfn::kfn_1)(input::AbstractArray)
|
|||||||
# println(">>> lif_vt0 ", size(kfn.lif_vt0))
|
# println(">>> lif_vt0 ", size(kfn.lif_vt0))
|
||||||
# println(">>> lif_vt0 sum ", sum(kfn.lif_vt0))
|
# println(">>> lif_vt0 sum ", sum(kfn.lif_vt0))
|
||||||
|
|
||||||
# pass input_data into input neuron.
|
# update activation matrix with "lif_zt1" and "alif_zt1" by concatenating
|
||||||
GeneralUtils.cartesianAssign!(kfn.zit, input)
|
# (input, lif_zt1, alif_zt1) to form activation matrix
|
||||||
|
_zit = cat(reshape(input, (size(input, 1), size(input, 2), 1, size(input, 3))),
|
||||||
|
reshape(kfn.lif_zt, (size(input, 1), :, 1, size(input, 3))),
|
||||||
|
reshape(kfn.alif_zt, (size(input, 1), :, 1, size(input, 3))), dims=2)
|
||||||
|
kfn.zit .= reshape(_zit, (size(input, 1), :, size(input, 3)))
|
||||||
|
|
||||||
lifForward( kfn.zit,
|
# pass input_data into input neuron.
|
||||||
kfn.lif_zit,
|
# GeneralUtils.cartesianAssign!(kfn.zit, input)
|
||||||
|
|
||||||
|
# kfn.zit = kfn.zit |> device
|
||||||
|
# input = input |> device
|
||||||
|
|
||||||
|
# project 3D kfn zit into 4D lif zit
|
||||||
|
i1, i2, i3, i4 = size(kfn.lif_zit)
|
||||||
|
kfn.lif_zit .= reshape(kfn.zit, (i1, i2, 1, i4)) .* kfn.lif_arrayProjection4d
|
||||||
|
|
||||||
|
lifForward( kfn.lif_zit,
|
||||||
kfn.lif_wRec,
|
kfn.lif_wRec,
|
||||||
kfn.lif_vt0,
|
kfn.lif_vt,
|
||||||
kfn.lif_vt1,
|
|
||||||
kfn.lif_vth,
|
kfn.lif_vth,
|
||||||
kfn.lif_vRest,
|
kfn.lif_vRest,
|
||||||
kfn.lif_zt1,
|
kfn.lif_zt4d,
|
||||||
kfn.lif_alpha,
|
kfn.lif_alpha,
|
||||||
kfn.lif_phi,
|
kfn.lif_phi,
|
||||||
kfn.lif_epsilonRec,
|
kfn.lif_epsilonRec,
|
||||||
@@ -49,23 +60,18 @@ function (kfn::kfn_1)(input::AbstractArray)
|
|||||||
kfn.lif_refractoryDuration,
|
kfn.lif_refractoryDuration,
|
||||||
kfn.lif_gammaPd,
|
kfn.lif_gammaPd,
|
||||||
kfn.lif_firingCounter,
|
kfn.lif_firingCounter,
|
||||||
kfn.lif_arrayProjection3DTo4D,
|
kfn.lif_recSignal,)
|
||||||
kfn.lif_recSignal,
|
|
||||||
kfn.lif_decayed_vt0,
|
|
||||||
kfn.lif_decayed_epsilonRec,
|
|
||||||
kfn.lif_vt1_diff_vth,
|
|
||||||
kfn.lif_vt1_diff_vth_div_vth,
|
|
||||||
kfn.lif_gammaPd_div_vth,
|
|
||||||
kfn.lif_phiActivation)
|
|
||||||
|
|
||||||
alifForward( kfn.zit,
|
# project 3D kfn zit into 4D alif zit
|
||||||
kfn.alif_zit,
|
i1, i2, i3, i4 = size(kfn.alif_zit)
|
||||||
|
kfn.alif_zit .= reshape(kfn.zit, (i1, i2, 1, i4)) .* kfn.alif_arrayProjection4d
|
||||||
|
|
||||||
|
alifForward(kfn.alif_zit,
|
||||||
kfn.alif_wRec,
|
kfn.alif_wRec,
|
||||||
kfn.alif_vt0,
|
kfn.alif_vt,
|
||||||
kfn.alif_vt1,
|
|
||||||
kfn.alif_vth,
|
kfn.alif_vth,
|
||||||
kfn.alif_vRest,
|
kfn.alif_vRest,
|
||||||
kfn.alif_zt1,
|
kfn.alif_zt4d,
|
||||||
kfn.alif_alpha,
|
kfn.alif_alpha,
|
||||||
kfn.alif_phi,
|
kfn.alif_phi,
|
||||||
kfn.alif_epsilonRec,
|
kfn.alif_epsilonRec,
|
||||||
@@ -73,44 +79,35 @@ function (kfn::kfn_1)(input::AbstractArray)
|
|||||||
kfn.alif_refractoryDuration,
|
kfn.alif_refractoryDuration,
|
||||||
kfn.alif_gammaPd,
|
kfn.alif_gammaPd,
|
||||||
kfn.alif_firingCounter,
|
kfn.alif_firingCounter,
|
||||||
kfn.alif_arrayProjection3DTo4D,
|
|
||||||
kfn.alif_recSignal,
|
kfn.alif_recSignal,
|
||||||
kfn.alif_decayed_vt0,
|
|
||||||
kfn.alif_decayed_epsilonRec,
|
|
||||||
kfn.alif_vt1_diff_vth,
|
|
||||||
kfn.alif_vt1_diff_vth_div_vth,
|
|
||||||
kfn.alif_gammaPd_div_vth,
|
|
||||||
kfn.alif_phiActivation,
|
|
||||||
|
|
||||||
kfn.alif_epsilonRecA,
|
kfn.alif_epsilonRecA,
|
||||||
kfn.alif_avth,
|
|
||||||
kfn.alif_a,
|
kfn.alif_a,
|
||||||
|
kfn.alif_avth,
|
||||||
kfn.alif_beta,
|
kfn.alif_beta,
|
||||||
kfn.alif_rho,
|
kfn.alif_rho,)
|
||||||
kfn.alif_phi_x_epsilonRec,
|
|
||||||
kfn.alif_phi_x_beta,
|
|
||||||
kfn.alif_rho_diff_phi_x_beta,
|
|
||||||
kfn.alif_rho_div_phi_x_beta_x_epsilonRecA,
|
|
||||||
kfn.alif_beta_x_a)
|
|
||||||
# error("DEBUG -> kfn forward")
|
|
||||||
|
|
||||||
|
# reduce lif_zt4d and alif_zt4d into lif_zt, alif_zt (4d -> 1d)
|
||||||
|
kfn.lif_zt .= reduce(max, kfn.lif_zt4d, dims=(1,2))
|
||||||
|
kfn.alif_zt .= reduce(max, kfn.alif_zt4d, dims=(1,2))
|
||||||
|
|
||||||
|
# update activation matrix with "lif_zt1" and "alif_zt1" by concatenating
|
||||||
# update activation matrix by concatenate (input, lif_zt1, alif_zt1) to form activation matrix
|
# (input, lif_zt1, alif_zt1) to form activation matrix
|
||||||
_zit = cat(reshape(input, (size(input, 1), size(input, 2), 1, size(input, 3))),
|
_zit = cat(reshape(input, (size(input, 1), size(input, 2), 1, size(input, 3))),
|
||||||
reshape(kfn.lif_zt1, (size(input, 1), :, 1, size(input, 3))),
|
reshape(kfn.lif_zt, (size(input, 1), :, 1, size(input, 3))),
|
||||||
reshape(kfn.alif_zt1, (size(input, 1), :, 1, size(input, 3))), dims=2)
|
reshape(kfn.alif_zt, (size(input, 1), :, 1, size(input, 3))), dims=2)
|
||||||
kfn.zit .= reshape(_zit, (size(input, 1), :, size(input, 3)))
|
kfn.zit .= reshape(_zit, (size(input, 1), :, size(input, 3)))
|
||||||
|
|
||||||
|
# project 3D kfn zit into 4D on zit
|
||||||
|
i1, i2, i3, i4 = size(kfn.on_zit)
|
||||||
|
kfn.on_zit .= reshape(kfn.zit, (i1, i2, 1, i4)) .* kfn.on_arrayProjection4d
|
||||||
|
|
||||||
# read out
|
# read out
|
||||||
onForward( kfn.zit,
|
onForward( kfn.on_zit,
|
||||||
kfn.on_zit,
|
|
||||||
kfn.on_wOut,
|
kfn.on_wOut,
|
||||||
kfn.on_vt0,
|
kfn.on_vt,
|
||||||
kfn.on_vt1,
|
|
||||||
kfn.on_vth,
|
kfn.on_vth,
|
||||||
kfn.on_vRest,
|
kfn.on_vRest,
|
||||||
kfn.on_zt1,
|
kfn.on_zt4d,
|
||||||
kfn.on_alpha,
|
kfn.on_alpha,
|
||||||
kfn.on_phi,
|
kfn.on_phi,
|
||||||
kfn.on_epsilonRec,
|
kfn.on_epsilonRec,
|
||||||
@@ -118,16 +115,11 @@ function (kfn::kfn_1)(input::AbstractArray)
|
|||||||
kfn.on_refractoryDuration,
|
kfn.on_refractoryDuration,
|
||||||
kfn.on_gammaPd,
|
kfn.on_gammaPd,
|
||||||
kfn.on_firingCounter,
|
kfn.on_firingCounter,
|
||||||
kfn.on_arrayProjection3DTo4D,
|
kfn.on_recSignal,)
|
||||||
kfn.on_recSignal,
|
# error("DEBUG -> kfn forward")
|
||||||
kfn.on_decayed_vt0,
|
logit = reshape(kfn.on_zt, (size(input, 1), :))
|
||||||
kfn.on_decayed_epsilonRec,
|
|
||||||
kfn.on_vt1_diff_vth,
|
|
||||||
kfn.on_vt1_diff_vth_div_vth,
|
|
||||||
kfn.on_gammaPd_div_vth,
|
|
||||||
kfn.on_phiActivation)
|
|
||||||
|
|
||||||
return reshape(kfn.on_zt1, (size(input, 1), :)),
|
return logit,
|
||||||
kfn.zit
|
kfn.zit
|
||||||
end
|
end
|
||||||
|
|
||||||
@@ -147,7 +139,7 @@ function lifForward(kfn_zit::Array{T},
|
|||||||
refractoryDuration::Array{T},
|
refractoryDuration::Array{T},
|
||||||
gammaPd::Array{T},
|
gammaPd::Array{T},
|
||||||
firingCounter::Array{T},
|
firingCounter::Array{T},
|
||||||
arrayProjection3DTo4D::Array{T},
|
arrayProjection4d::Array{T},
|
||||||
recSignal::Array{T},
|
recSignal::Array{T},
|
||||||
decayed_vt0::Array{T},
|
decayed_vt0::Array{T},
|
||||||
decayed_epsilonRec::Array{T},
|
decayed_epsilonRec::Array{T},
|
||||||
@@ -158,8 +150,8 @@ function lifForward(kfn_zit::Array{T},
|
|||||||
) where T<:Number
|
) where T<:Number
|
||||||
|
|
||||||
# project 3D kfn zit into 4D lif zit
|
# project 3D kfn zit into 4D lif zit
|
||||||
zit .= reshape(kfn_zit,
|
i1, i2, i3, i4 = size(alif_wRec)
|
||||||
(size(wRec, 1), size(wRec, 2), 1, size(wRec, 4))) .* arrayProjection3DTo4D
|
lif_zit .= reshape(kfn_zit, (i1, i2, 1, i4)) .* lif_arrayProjection4d
|
||||||
|
|
||||||
for j in 1:size(wRec, 4), i in 1:size(wRec, 3) # compute along neurons axis of every batch
|
for j in 1:size(wRec, 4), i in 1:size(wRec, 3) # compute along neurons axis of every batch
|
||||||
if sum(@view(refractoryCounter[:,:,i,j])) > 0 # refractory period is active
|
if sum(@view(refractoryCounter[:,:,i,j])) > 0 # refractory period is active
|
||||||
@@ -199,8 +191,128 @@ function lifForward(kfn_zit::Array{T},
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
function alifForward(kfn_zit::Array{T},
|
# gpu launcher
|
||||||
zit::Array{T},
|
function lifForward( lif_zit::CuArray,
|
||||||
|
lif_wRec::CuArray,
|
||||||
|
lif_vt::CuArray,
|
||||||
|
lif_vth::CuArray,
|
||||||
|
lif_vRest::CuArray,
|
||||||
|
lif_zt::CuArray,
|
||||||
|
lif_alpha::CuArray,
|
||||||
|
lif_phi::CuArray,
|
||||||
|
lif_epsilonRec::CuArray,
|
||||||
|
lif_refractoryCounter::CuArray,
|
||||||
|
lif_refractoryDuration::CuArray,
|
||||||
|
lif_gammaPd::CuArray,
|
||||||
|
lif_firingCounter::CuArray,
|
||||||
|
lif_recSignal::CuArray,)
|
||||||
|
|
||||||
|
kernel = @cuda launch=false lifForward( lif_zit,
|
||||||
|
lif_wRec,
|
||||||
|
lif_vt,
|
||||||
|
lif_vth,
|
||||||
|
lif_vRest,
|
||||||
|
lif_zt,
|
||||||
|
lif_alpha,
|
||||||
|
lif_phi,
|
||||||
|
lif_epsilonRec,
|
||||||
|
lif_refractoryCounter,
|
||||||
|
lif_refractoryDuration,
|
||||||
|
lif_gammaPd,
|
||||||
|
lif_firingCounter,
|
||||||
|
lif_recSignal,
|
||||||
|
GeneralUtils.linear_to_cartesian)
|
||||||
|
config = launch_configuration(kernel.fun)
|
||||||
|
|
||||||
|
|
||||||
|
# threads to be launched. Since one can't launch exact thread number the kernel needs,
|
||||||
|
# one just launch threads more than this kernel needs then use a guard inside the kernel
|
||||||
|
# to prevent unused threads to access memory.
|
||||||
|
threads = min(1024, config.threads) # depend on gpu. Most NVIDIA gpu has 1024 threads per block
|
||||||
|
|
||||||
|
# total desired threads to launch to gpu. Usually 1 thread per 1 matrix element
|
||||||
|
totalThreads = length(lif_wRec)
|
||||||
|
|
||||||
|
blocks = cld(totalThreads, threads)
|
||||||
|
# println("launching gpu kernel")
|
||||||
|
CUDA.@sync begin
|
||||||
|
kernel( lif_zit,
|
||||||
|
lif_wRec,
|
||||||
|
lif_vt,
|
||||||
|
lif_vth,
|
||||||
|
lif_vRest,
|
||||||
|
lif_zt,
|
||||||
|
lif_alpha,
|
||||||
|
lif_phi,
|
||||||
|
lif_epsilonRec,
|
||||||
|
lif_refractoryCounter,
|
||||||
|
lif_refractoryDuration,
|
||||||
|
lif_gammaPd,
|
||||||
|
lif_firingCounter,
|
||||||
|
lif_recSignal,
|
||||||
|
GeneralUtils.linear_to_cartesian; threads, blocks)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
# gpu kernel
|
||||||
|
function lifForward( zit,
|
||||||
|
wRec,
|
||||||
|
vt,
|
||||||
|
vth,
|
||||||
|
vRest,
|
||||||
|
zt,
|
||||||
|
alpha,
|
||||||
|
phi,
|
||||||
|
epsilonRec,
|
||||||
|
refractoryCounter,
|
||||||
|
refractoryDuration,
|
||||||
|
gammaPd,
|
||||||
|
firingCounter,
|
||||||
|
recSignal,
|
||||||
|
linear_to_cartesian)
|
||||||
|
i = (blockIdx().x - 1) * blockDim().x + threadIdx().x # gpu threads index
|
||||||
|
|
||||||
|
if i <= length(wRec)
|
||||||
|
# cartesian index
|
||||||
|
i1, i2, i3, i4 = linear_to_cartesian(i, size(wRec))
|
||||||
|
# @cuprintln("gpu thread $i $i1 $i2 $i3 $i4")
|
||||||
|
|
||||||
|
refractoryCounter[i] -= 1
|
||||||
|
|
||||||
|
if refractoryCounter[i] > 0 # refractory period is active
|
||||||
|
refractoryCounter[i] -= 1
|
||||||
|
zt[i] = 0
|
||||||
|
vt[i] = alpha[i] * vt[i]
|
||||||
|
phi[i] = 0
|
||||||
|
|
||||||
|
# compute epsilonRec
|
||||||
|
epsilonRec[i] = (alpha[i] * epsilonRec[i]) + zit[i]
|
||||||
|
|
||||||
|
else # refractory period is inactive
|
||||||
|
recSignal[i] = zit[i] * wRec[i]
|
||||||
|
vt[i] = (alpha[i] * vt[i]) + sum(@view(recSignal[:,:,i3,i4]))
|
||||||
|
|
||||||
|
# fires if membrane potential exceed threshold
|
||||||
|
if vt[i] > vth[i]
|
||||||
|
zt[i] = 1
|
||||||
|
refractoryCounter[i] = refractoryDuration[i]
|
||||||
|
firingCounter[i] += 1
|
||||||
|
vt[i] = vRest[i]
|
||||||
|
else
|
||||||
|
zt[i] = 0
|
||||||
|
end
|
||||||
|
|
||||||
|
# compute phi, there is a difference from lif formula
|
||||||
|
phi[i] = (gammaPd[i] / vth[i]) * max(0, 1 - ((vt[i] - vth[i]) / vth[i]))
|
||||||
|
|
||||||
|
# compute epsilonRec
|
||||||
|
epsilonRec[i] = (alpha[i] * epsilonRec[i]) + zit[i]
|
||||||
|
end
|
||||||
|
end
|
||||||
|
return nothing
|
||||||
|
end
|
||||||
|
|
||||||
|
function alifForward(zit::Array{T},
|
||||||
wRec::Array{T},
|
wRec::Array{T},
|
||||||
vt0::Array{T},
|
vt0::Array{T},
|
||||||
vt1::Array{T},
|
vt1::Array{T},
|
||||||
@@ -214,7 +326,6 @@ function alifForward(kfn_zit::Array{T},
|
|||||||
refractoryDuration::Array{T},
|
refractoryDuration::Array{T},
|
||||||
gammaPd::Array{T},
|
gammaPd::Array{T},
|
||||||
firingCounter::Array{T},
|
firingCounter::Array{T},
|
||||||
arrayProjection3DTo4D::Array{T},
|
|
||||||
recSignal::Array{T},
|
recSignal::Array{T},
|
||||||
decayed_vt0::Array{T},
|
decayed_vt0::Array{T},
|
||||||
decayed_epsilonRec::Array{T},
|
decayed_epsilonRec::Array{T},
|
||||||
@@ -235,11 +346,6 @@ function alifForward(kfn_zit::Array{T},
|
|||||||
beta_x_a::Array{T},
|
beta_x_a::Array{T},
|
||||||
) where T<:Number
|
) where T<:Number
|
||||||
|
|
||||||
|
|
||||||
# project 3D kfn zit into 4D lif zit
|
|
||||||
zit .= reshape(kfn_zit,
|
|
||||||
(size(wRec, 1), size(wRec, 2), 1, size(wRec, 4))) .* arrayProjection3DTo4D
|
|
||||||
|
|
||||||
for j in 1:size(wRec, 4), i in 1:size(wRec, 3) # compute along neurons axis of every batch
|
for j in 1:size(wRec, 4), i in 1:size(wRec, 3) # compute along neurons axis of every batch
|
||||||
if sum(@view(refractoryCounter[:,:,i,j])) > 0 # refractory period is active
|
if sum(@view(refractoryCounter[:,:,i,j])) > 0 # refractory period is active
|
||||||
@. @views refractoryCounter[:,:,i,j] -= 1
|
@. @views refractoryCounter[:,:,i,j] -= 1
|
||||||
@@ -305,6 +411,164 @@ function alifForward(kfn_zit::Array{T},
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
# gpu launcher
|
||||||
|
function alifForward( alif_zit::CuArray,
|
||||||
|
alif_wRec::CuArray,
|
||||||
|
alif_vt::CuArray,
|
||||||
|
alif_vth::CuArray,
|
||||||
|
alif_vRest::CuArray,
|
||||||
|
alif_zt::CuArray,
|
||||||
|
alif_alpha::CuArray,
|
||||||
|
alif_phi::CuArray,
|
||||||
|
alif_epsilonRec::CuArray,
|
||||||
|
alif_refractoryCounter::CuArray,
|
||||||
|
alif_refractoryDuration::CuArray,
|
||||||
|
alif_gammaPd::CuArray,
|
||||||
|
alif_firingCounter::CuArray,
|
||||||
|
alif_recSignal::CuArray,
|
||||||
|
alif_epsilonRecA::CuArray,
|
||||||
|
alif_a::CuArray,
|
||||||
|
alif_avth::CuArray,
|
||||||
|
alif_beta::CuArray,
|
||||||
|
alif_rho::CuArray,
|
||||||
|
)
|
||||||
|
|
||||||
|
kernel = @cuda launch=false alifForward( alif_zit,
|
||||||
|
alif_wRec,
|
||||||
|
alif_vt,
|
||||||
|
alif_vth,
|
||||||
|
alif_vRest,
|
||||||
|
alif_zt,
|
||||||
|
alif_alpha,
|
||||||
|
alif_phi,
|
||||||
|
alif_epsilonRec,
|
||||||
|
alif_refractoryCounter,
|
||||||
|
alif_refractoryDuration,
|
||||||
|
alif_gammaPd,
|
||||||
|
alif_firingCounter,
|
||||||
|
alif_recSignal,
|
||||||
|
alif_epsilonRecA,
|
||||||
|
alif_a,
|
||||||
|
alif_avth,
|
||||||
|
alif_beta,
|
||||||
|
alif_rho,
|
||||||
|
GeneralUtils.linear_to_cartesian)
|
||||||
|
config = launch_configuration(kernel.fun)
|
||||||
|
|
||||||
|
# threads to be launched. Since one can't launch exact thread number the kernel needs,
|
||||||
|
# one just launch threads more than this kernel needs then use a guard inside the kernel
|
||||||
|
# to prevent unused threads to access memory.
|
||||||
|
threads = min(1024, config.threads) # depend on gpu. Most NVIDIA gpu has 1024 threads per block
|
||||||
|
|
||||||
|
# total desired threads to launch to gpu. Usually 1 thread per 1 matrix element
|
||||||
|
totalThreads = length(alif_wRec)
|
||||||
|
|
||||||
|
blocks = cld(totalThreads, threads)
|
||||||
|
# println("launching gpu kernel")
|
||||||
|
CUDA.@sync begin
|
||||||
|
kernel( alif_zit,
|
||||||
|
alif_wRec,
|
||||||
|
alif_vt,
|
||||||
|
alif_vth,
|
||||||
|
alif_vRest,
|
||||||
|
alif_zt,
|
||||||
|
alif_alpha,
|
||||||
|
alif_phi,
|
||||||
|
alif_epsilonRec,
|
||||||
|
alif_refractoryCounter,
|
||||||
|
alif_refractoryDuration,
|
||||||
|
alif_gammaPd,
|
||||||
|
alif_firingCounter,
|
||||||
|
alif_recSignal,
|
||||||
|
alif_epsilonRecA,
|
||||||
|
alif_a,
|
||||||
|
alif_avth,
|
||||||
|
alif_beta,
|
||||||
|
alif_rho,
|
||||||
|
GeneralUtils.linear_to_cartesian; threads, blocks)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
# gpu kernel
|
||||||
|
function alifForward( zit,
|
||||||
|
wRec,
|
||||||
|
vt,
|
||||||
|
vth,
|
||||||
|
vRest,
|
||||||
|
zt,
|
||||||
|
alpha,
|
||||||
|
phi,
|
||||||
|
epsilonRec,
|
||||||
|
refractoryCounter,
|
||||||
|
refractoryDuration,
|
||||||
|
gammaPd,
|
||||||
|
firingCounter,
|
||||||
|
recSignal,
|
||||||
|
epsilonRecA,
|
||||||
|
a,
|
||||||
|
avth,
|
||||||
|
beta,
|
||||||
|
rho,
|
||||||
|
linear_to_cartesian)
|
||||||
|
i = (blockIdx().x - 1) * blockDim().x + threadIdx().x # gpu threads index
|
||||||
|
|
||||||
|
if i <= length(wRec)
|
||||||
|
# cartesian index
|
||||||
|
i1, i2, i3, i4 = linear_to_cartesian(i, size(wRec))
|
||||||
|
# @cuprintln("gpu thread $i $i1 $i2 $i3 $i4")
|
||||||
|
|
||||||
|
refractoryCounter[i] -= 1
|
||||||
|
|
||||||
|
if refractoryCounter[i] > 0 # refractory period is active
|
||||||
|
refractoryCounter[i] -= 1
|
||||||
|
zt[i] = 0
|
||||||
|
vt[i] = alpha[i] * vt[i]
|
||||||
|
phi[i] = 0
|
||||||
|
a[i] = rho[i] * a[i]
|
||||||
|
|
||||||
|
# compute epsilonRec
|
||||||
|
epsilonRec[i] = (alpha[i] * epsilonRec[i]) + zit[i]
|
||||||
|
|
||||||
|
# compute epsilonRecA
|
||||||
|
epsilonRecA[i] = (phi[i] * epsilonRec[i]) +
|
||||||
|
((rho[i] - (phi[i] * beta[i])) * epsilonRecA[i])
|
||||||
|
|
||||||
|
# compute avth
|
||||||
|
avth[i] = vth[i] + (beta[i] * a[i])
|
||||||
|
|
||||||
|
else # refractory period is inactive
|
||||||
|
recSignal[i] = zit[i] * wRec[i]
|
||||||
|
vt[i] = (alpha[i] * vt[i]) + sum(@view(recSignal[:,:,i3,i4]))
|
||||||
|
|
||||||
|
# compute avth
|
||||||
|
avth[i] = vth[i] + (beta[i] * a[i])
|
||||||
|
|
||||||
|
# fires if membrane potential exceed threshold
|
||||||
|
if vt[i] > avth[i]
|
||||||
|
zt[i] = 1
|
||||||
|
refractoryCounter[i] = refractoryDuration[i]
|
||||||
|
firingCounter[i] += 1
|
||||||
|
vt[i] = vRest[i]
|
||||||
|
a[i] = (rho[i] * a[i]) + 1
|
||||||
|
else
|
||||||
|
zt[i] = 0
|
||||||
|
a[i] = (rho[i] * a[i])
|
||||||
|
end
|
||||||
|
|
||||||
|
# compute phi, there is a difference from alif formula
|
||||||
|
phi[i] = (gammaPd[i] / vth[i]) * max(0, 1 - ((vt[i] - vth[i]) / vth[i]))
|
||||||
|
|
||||||
|
# compute epsilonRec
|
||||||
|
epsilonRec[i] = (alpha[i] * epsilonRec[i]) + zit[i]
|
||||||
|
|
||||||
|
# compute epsilonRecA
|
||||||
|
epsilonRecA[i] = (phi[i] * epsilonRec[i]) +
|
||||||
|
((rho[i] - (phi[i] * beta[i])) * epsilonRecA[i])
|
||||||
|
end
|
||||||
|
end
|
||||||
|
return nothing
|
||||||
|
end
|
||||||
|
|
||||||
function onForward(kfn_zit::Array{T},
|
function onForward(kfn_zit::Array{T},
|
||||||
zit::Array{T},
|
zit::Array{T},
|
||||||
wOut::Array{T},
|
wOut::Array{T},
|
||||||
@@ -320,7 +584,7 @@ function onForward(kfn_zit::Array{T},
|
|||||||
refractoryDuration::Array{T},
|
refractoryDuration::Array{T},
|
||||||
gammaPd::Array{T},
|
gammaPd::Array{T},
|
||||||
firingCounter::Array{T},
|
firingCounter::Array{T},
|
||||||
arrayProjection3DTo4D::Array{T},
|
arrayProjection4d::Array{T},
|
||||||
recSignal::Array{T},
|
recSignal::Array{T},
|
||||||
decayed_vt0::Array{T},
|
decayed_vt0::Array{T},
|
||||||
decayed_epsilonRec::Array{T},
|
decayed_epsilonRec::Array{T},
|
||||||
@@ -332,7 +596,7 @@ function onForward(kfn_zit::Array{T},
|
|||||||
|
|
||||||
# project 3D kfn zit into 4D lif zit
|
# project 3D kfn zit into 4D lif zit
|
||||||
zit .= reshape(kfn_zit,
|
zit .= reshape(kfn_zit,
|
||||||
(size(wOut, 1), size(wOut, 2), 1, size(wOut, 4))) .* arrayProjection3DTo4D
|
(size(wOut, 1), size(wOut, 2), 1, size(wOut, 4))) .* arrayProjection4d
|
||||||
|
|
||||||
for j in 1:size(wOut, 4), i in 1:size(wOut, 3) # compute along neurons axis of every batch
|
for j in 1:size(wOut, 4), i in 1:size(wOut, 3) # compute along neurons axis of every batch
|
||||||
if sum(@view(refractoryCounter[:,:,i,j])) > 0 # refractory period is active
|
if sum(@view(refractoryCounter[:,:,i,j])) > 0 # refractory period is active
|
||||||
@@ -372,57 +636,125 @@ function onForward(kfn_zit::Array{T},
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
# function onForward(kfn_zit,
|
# gpu launcher
|
||||||
# zit,
|
function onForward( on_zit::CuArray,
|
||||||
# wOut,
|
on_wOut::CuArray,
|
||||||
# vt0,
|
on_vt::CuArray,
|
||||||
# vt1,
|
on_vth::CuArray,
|
||||||
# vth,
|
on_vRest::CuArray,
|
||||||
# vRest,
|
on_zt::CuArray,
|
||||||
# zt1,
|
on_alpha::CuArray,
|
||||||
# alpha,
|
on_phi::CuArray,
|
||||||
# phi,
|
on_epsilonRec::CuArray,
|
||||||
# epsilonRec,
|
on_refractoryCounter::CuArray,
|
||||||
# refractoryCounter,
|
on_refractoryDuration::CuArray,
|
||||||
# refractoryDuration,
|
on_gammaPd::CuArray,
|
||||||
# gammaPd,
|
on_firingCounter::CuArray,
|
||||||
# firingCounter)
|
on_recSignal::CuArray)
|
||||||
# d1, d2, d3, d4 = size(wOut)
|
|
||||||
# zit .= reshape(kfn_zit, (d1, d2, 1, d4)) .* ones(size(wOut)...) # project zit into zit
|
|
||||||
|
|
||||||
# for j in 1:d4, i in 1:d3 # compute along neurons axis of every batch
|
kernel = @cuda launch=false onForward( on_zit,
|
||||||
# if view(refractoryCounter, :, :, i, j)[1] > 0 # neuron is inactive (in refractory period)
|
on_wOut,
|
||||||
# view(refractoryCounter, :, :, i, j)[1] -= 1
|
on_vt,
|
||||||
# view(zt1, :, :, i, j)[1] = 0
|
on_vth,
|
||||||
# view(vt1, :, :, i, j)[1] =
|
on_vRest,
|
||||||
# view(alpha, :, :, i, j)[1] * view(vt0, :, :, i, j)[1]
|
on_zt,
|
||||||
# view(phi, :, :, i, j)[1] = 0.0
|
on_alpha,
|
||||||
# view(epsilonRec, :, :, i, j) .= view(alpha, :, :, i, j)[1] .*
|
on_phi,
|
||||||
# view(epsilonRec, :, :, i, j)
|
on_epsilonRec,
|
||||||
# else # neuron is active
|
on_refractoryCounter,
|
||||||
# view(vt1, :, :, i, j)[1] =
|
on_refractoryDuration,
|
||||||
# (view(alpha, :, :, i, j)[1] * view(vt0,:, :, i, j)[1]) +
|
on_gammaPd,
|
||||||
# sum(view(zit, :, :, i, j) .* view(wOut, :, :, i, j))
|
on_firingCounter,
|
||||||
# if view(vt1, :, :, i, j)[1] > view(vth, :, :, i, j)[1]
|
on_recSignal,
|
||||||
# view(zt1, :, :, i, j)[1] = 1
|
GeneralUtils.linear_to_cartesian)
|
||||||
# view(refractoryCounter, :, :, i, j)[1] =
|
config = launch_configuration(kernel.fun)
|
||||||
# view(refractoryDuration, :, :, i, j)[1]
|
|
||||||
# view(firingCounter, :, :, i, j)[1] += 1
|
# threads to be launched. Since one can't launch exact thread number the kernel needs,
|
||||||
# view(vt1, :, :, i, j)[1] = view(vRest, :, :, i, j)[1]
|
# one just launch threads more than this kernel needs then use a guard inside the kernel
|
||||||
# else
|
# to prevent unused threads to access memory.
|
||||||
# view(zt1, :, :, i, j)[1] = 0
|
threads = min(1024, config.threads) # depend on gpu. Most NVIDIA gpu has 1024 threads per block
|
||||||
# end
|
|
||||||
# # there is a difference from alif formula
|
# total desired threads to launch to gpu. Usually 1 thread per 1 matrix element
|
||||||
# view(phi, :, :, i, j)[1] =
|
totalThreads = length(on_wOut)
|
||||||
# (view(gammaPd, :, :, i, j)[1] / view(vth, :, :, i, j)[1]) *
|
|
||||||
# max(0, 1 - ((view(vt1, :, :, i, j)[1] - view(vth, :, :, i, j)[1]) /
|
blocks = cld(totalThreads, threads)
|
||||||
# view(vth, :, :, i, j)[1]))
|
# println("launching gpu kernel")
|
||||||
# view(epsilonRec, :, :, i, j) .=
|
CUDA.@sync begin
|
||||||
# (view(alpha, :, :, i, j)[1] .* view(epsilonRec, :, :, i, j)) +
|
kernel( on_zit,
|
||||||
# view(zit, :, :, i, j)
|
on_wOut,
|
||||||
# end
|
on_vt,
|
||||||
# end
|
on_vth,
|
||||||
# end
|
on_vRest,
|
||||||
|
on_zt,
|
||||||
|
on_alpha,
|
||||||
|
on_phi,
|
||||||
|
on_epsilonRec,
|
||||||
|
on_refractoryCounter,
|
||||||
|
on_refractoryDuration,
|
||||||
|
on_gammaPd,
|
||||||
|
on_firingCounter,
|
||||||
|
on_recSignal,
|
||||||
|
GeneralUtils.linear_to_cartesian; threads, blocks)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
# gpu kernel
|
||||||
|
function onForward( zit,
|
||||||
|
wOut,
|
||||||
|
vt,
|
||||||
|
vth,
|
||||||
|
vRest,
|
||||||
|
zt,
|
||||||
|
alpha,
|
||||||
|
phi,
|
||||||
|
epsilonRec,
|
||||||
|
refractoryCounter,
|
||||||
|
refractoryDuration,
|
||||||
|
gammaPd,
|
||||||
|
firingCounter,
|
||||||
|
recSignal,
|
||||||
|
linear_to_cartesian)
|
||||||
|
i = (blockIdx().x - 1) * blockDim().x + threadIdx().x # gpu threads index
|
||||||
|
|
||||||
|
if i <= length(wOut)
|
||||||
|
# cartesian index
|
||||||
|
i1, i2, i3, i4 = linear_to_cartesian(i, size(wOut))
|
||||||
|
# @cuprintln("gpu thread $i $i1 $i2 $i3 $i4")
|
||||||
|
|
||||||
|
refractoryCounter[i] -= 1
|
||||||
|
|
||||||
|
if refractoryCounter[i] > 0 # refractory period is active
|
||||||
|
refractoryCounter[i] -= 1
|
||||||
|
zt[i] = 0
|
||||||
|
vt[i] = alpha[i] * vt[i]
|
||||||
|
phi[i] = 0
|
||||||
|
|
||||||
|
# compute epsilonRec
|
||||||
|
epsilonRec[i] = (alpha[i] * epsilonRec[i]) + zit[i]
|
||||||
|
|
||||||
|
else # refractory period is inactive
|
||||||
|
recSignal[i] = zit[i] * wOut[i]
|
||||||
|
vt[i] = (alpha[i] * vt[i]) + sum(@view(recSignal[:,:,i3,i4]))
|
||||||
|
|
||||||
|
# fires if membrane potential exceed threshold
|
||||||
|
if vt[i] > vth[i]
|
||||||
|
zt[i] = 1
|
||||||
|
refractoryCounter[i] = refractoryDuration[i]
|
||||||
|
firingCounter[i] += 1
|
||||||
|
vt[i] = vRest[i]
|
||||||
|
else
|
||||||
|
zt[i] = 0
|
||||||
|
end
|
||||||
|
|
||||||
|
# compute phi, there is a difference from on formula
|
||||||
|
phi[i] = (gammaPd[i] / vth[i]) * max(0, 1 - ((vt[i] - vth[i]) / vth[i]))
|
||||||
|
|
||||||
|
# compute epsilonRec
|
||||||
|
epsilonRec[i] = (alpha[i] * epsilonRec[i]) + zit[i]
|
||||||
|
end
|
||||||
|
end
|
||||||
|
return nothing
|
||||||
|
end
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
133
src/learn.jl
133
src/learn.jl
@@ -2,7 +2,7 @@ module learn
|
|||||||
|
|
||||||
export learn!, compute_paramsChange!
|
export learn!, compute_paramsChange!
|
||||||
|
|
||||||
using Statistics, Random, LinearAlgebra, JSON3, Flux, Dates
|
using Statistics, Random, LinearAlgebra, JSON3, Flux, CUDA, Dates
|
||||||
using GeneralUtils
|
using GeneralUtils
|
||||||
using ..type, ..snnUtil
|
using ..type, ..snnUtil
|
||||||
|
|
||||||
@@ -11,43 +11,76 @@ using ..type, ..snnUtil
|
|||||||
function compute_paramsChange!(kfn::kfn_1, modelError, outputError)
|
function compute_paramsChange!(kfn::kfn_1, modelError, outputError)
|
||||||
|
|
||||||
|
|
||||||
|
# kfn.lif_error .= modelError .* kfn.lif_arrayProjection4d
|
||||||
|
# kfn.alif_error .= modelError .* kfn.alif_arrayProjection4d
|
||||||
|
# kfn.on_error .= reshape(outputError, (1, 1, :, size(kfn.on_arrayProjection4d, 4))) .* kfn.on_arrayProjection4d
|
||||||
|
|
||||||
|
|
||||||
lifComputeParamsChange!(kfn.lif_phi,
|
lifComputeParamsChange!(kfn.lif_phi,
|
||||||
kfn.lif_epsilonRec,
|
kfn.lif_epsilonRec,
|
||||||
kfn.lif_eta,
|
kfn.lif_eta,
|
||||||
|
kfn.lif_eRec,
|
||||||
kfn.lif_wRec,
|
kfn.lif_wRec,
|
||||||
kfn.lif_wRecChange,
|
kfn.lif_wRecChange,
|
||||||
kfn.on_wOut,
|
kfn.on_wOut,
|
||||||
|
kfn.lif_arrayProjection4d,
|
||||||
|
kfn.lif_error,
|
||||||
modelError)
|
modelError)
|
||||||
|
|
||||||
alifComputeParamsChange!(kfn.alif_phi,
|
alifComputeParamsChange!(kfn.alif_phi,
|
||||||
kfn.alif_epsilonRec,
|
kfn.alif_epsilonRec,
|
||||||
kfn.alif_epsilonRecA,
|
|
||||||
kfn.alif_eta,
|
kfn.alif_eta,
|
||||||
|
kfn.alif_eRec,
|
||||||
kfn.alif_wRec,
|
kfn.alif_wRec,
|
||||||
kfn.alif_wRecChange,
|
kfn.alif_wRecChange,
|
||||||
kfn.alif_beta,
|
|
||||||
kfn.on_wOut,
|
kfn.on_wOut,
|
||||||
modelError)
|
kfn.alif_arrayProjection4d,
|
||||||
|
kfn.alif_error,
|
||||||
|
modelError,
|
||||||
|
kfn.alif_beta)
|
||||||
|
|
||||||
onComputeParamsChange!(kfn.on_phi,
|
onComputeParamsChange!(kfn.on_phi,
|
||||||
kfn.on_epsilonRec,
|
kfn.on_epsilonRec,
|
||||||
kfn.on_eta,
|
kfn.on_eta,
|
||||||
|
kfn.on_eRec,
|
||||||
|
kfn.on_wOut,
|
||||||
kfn.on_wOutChange,
|
kfn.on_wOutChange,
|
||||||
outputError)
|
outputError)
|
||||||
|
error("DEBUG -> kfn compute_paramsChange! $(Dates.now())")
|
||||||
|
|
||||||
error("debug end -> kfn compute_paramsChange! $(Dates.now())")
|
|
||||||
end
|
end
|
||||||
|
|
||||||
function lifComputeParamsChange!( phi,
|
function lifComputeParamsChange!( phi::CuArray,
|
||||||
epsilonRec,
|
epsilonRec::CuArray,
|
||||||
eta,
|
eta::CuArray,
|
||||||
wRec,
|
eRec::CuArray,
|
||||||
wRecChange,
|
wRec::CuArray,
|
||||||
wOut,
|
wRecChange::CuArray,
|
||||||
modelError)
|
wOut::CuArray,
|
||||||
d1, d2, d3, d4 = size(epsilonRec)
|
arrayProjection4d::CuArray,
|
||||||
|
nError::CuArray,
|
||||||
|
modelError::CuArray)
|
||||||
|
|
||||||
|
wOutSum = sum(wOut, dims=3) .* arrayProjection4d
|
||||||
|
|
||||||
|
# nError a.k.a. learning signal use dopamine concept,
|
||||||
|
# this neuron receive summed error signal (modelError)
|
||||||
|
nError .= (modelError .* arrayProjection4d) .* wOutSum
|
||||||
|
eRec .= phi .* epsilonRec
|
||||||
|
|
||||||
|
# GeneralUtils.isNotEqual(wRec, 0) is a subscribe filter use to filter out non-subscribed wRecChange
|
||||||
|
wRecChange .+= ((-1 .* eta) .* nError .* eRec) .* GeneralUtils.isNotEqual.(wRec, 0)
|
||||||
|
# error("DEBUG -> lifComputeParamsChange! $(Dates.now())")
|
||||||
|
end
|
||||||
|
|
||||||
|
function lifComputeParamsChange!( phi::AbstractArray,
|
||||||
|
epsilonRec::AbstractArray,
|
||||||
|
eta::AbstractArray,
|
||||||
|
wRec::AbstractArray,
|
||||||
|
wRecChange::AbstractArray,
|
||||||
|
wOut::AbstractArray,
|
||||||
|
modelError::AbstractArray)
|
||||||
|
d1, d2, d3, d4 = size(epsilonRec)
|
||||||
|
error("DEBUG -> lifComputeParamsChange! $(Dates.now())")
|
||||||
# Bₖⱼ in paper, sum() to get each neuron's total wOut weight
|
# Bₖⱼ in paper, sum() to get each neuron's total wOut weight
|
||||||
wOutSum = reshape(sum(wOut, dims=3), (d1, :, d4))
|
wOutSum = reshape(sum(wOut, dims=3), (d1, :, d4))
|
||||||
|
|
||||||
@@ -68,7 +101,6 @@ function lifComputeParamsChange!( phi,
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
function alifComputeParamsChange!( phi,
|
function alifComputeParamsChange!( phi,
|
||||||
epsilonRec,
|
epsilonRec,
|
||||||
epsilonRecA,
|
epsilonRecA,
|
||||||
@@ -106,12 +138,35 @@ function alifComputeParamsChange!( phi,
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function alifComputeParamsChange!( phi::CuArray,
|
||||||
|
epsilonRec::CuArray,
|
||||||
|
eta::CuArray,
|
||||||
|
eRec::CuArray,
|
||||||
|
wRec::CuArray,
|
||||||
|
wRecChange::CuArray,
|
||||||
|
wOut::CuArray,
|
||||||
|
arrayProjection4d::CuArray,
|
||||||
|
nError::CuArray,
|
||||||
|
modelError::CuArray,
|
||||||
|
beta::CuArray)
|
||||||
|
|
||||||
function onComputeParamsChange!(phi,
|
wOutSum = sum(wOut, dims=3) .* arrayProjection4d
|
||||||
epsilonRec,
|
|
||||||
eta,
|
# nError a.k.a. learning signal use dopamine concept,
|
||||||
wOutChange,
|
# this neuron receive summed error signal (modelError)
|
||||||
outputError)
|
nError .= (modelError .* arrayProjection4d) .* wOutSum
|
||||||
|
eRec .= (phi .* epsilonRec) .+ (phi .* epsilonRec .* beta)
|
||||||
|
|
||||||
|
# GeneralUtils.isNotEqual(wRec, 0) is a subscribe filter use to filter out non-subscribed wRecChange
|
||||||
|
wRecChange .+= ((-1 .* eta) .* nError .* eRec) .* GeneralUtils.isNotEqual.(wRec, 0)
|
||||||
|
# error("DEBUG -> alifComputeParamsChange! $(Dates.now())")
|
||||||
|
end
|
||||||
|
|
||||||
|
function onComputeParamsChange!(phi::AbstractArray,
|
||||||
|
epsilonRec::AbstractArray,
|
||||||
|
eta::AbstractArray,
|
||||||
|
wOutChange::AbstractArray,
|
||||||
|
outputError::AbstractArray)
|
||||||
d1, d2, d3, d4 = size(epsilonRec)
|
d1, d2, d3, d4 = size(epsilonRec)
|
||||||
|
|
||||||
for j in 1:d4, i in 1:d3 # compute along neurons axis of every batch
|
for j in 1:d4, i in 1:d3 # compute along neurons axis of every batch
|
||||||
@@ -127,30 +182,24 @@ function onComputeParamsChange!(phi,
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
# function onComputeParamsChange!(wOut,
|
function onComputeParamsChange!(phi::CuArray,
|
||||||
# epsilonRec,
|
epsilonRec::CuArray,
|
||||||
# eta,
|
eta::CuArray,
|
||||||
# wOutChange,
|
eRec::CuArray,
|
||||||
# bChange,
|
wOut::CuArray,
|
||||||
# outputError)
|
wOutChange::CuArray,
|
||||||
# d1, d2, d3, d4 = size(epsilonRec)
|
outputError::CuArray # outputError is output neuron's error
|
||||||
# println(">>> epsilon ", size(epsilonRec))
|
)
|
||||||
# println(">>> outputError ", size(outputError))
|
|
||||||
|
|
||||||
|
# nError a.k.a. learning signal use dopamine concept,
|
||||||
|
# this neuron receive summed error signal (modelError)
|
||||||
|
eRec .= (phi .* epsilonRec) .* reshape(outputError, (1, 1, :, size(epsilonRec, 4)))
|
||||||
|
|
||||||
# # Bₖⱼ in paper, sum() to get each neuron's total wOut weight
|
# GeneralUtils.isNotEqual(wRec, 0) is a subscribe filter use to filter out non-subscribed wRecChange
|
||||||
|
wOutChange .+= ((-1 .* eta) .* eRec) .* GeneralUtils.isNotEqual.(wOut, 0)
|
||||||
# for j in 1:d4, i in 1:d3 # compute along neurons axis of every batch
|
|
||||||
# # how much error of this neuron 1-spike causing each output neuron's error
|
|
||||||
|
|
||||||
# view(wOutChange, :, :, i, j) .+=
|
|
||||||
# (-1 * view(eta, :, :, i, j)[1] * view(outputError, :, j)[i]) .*
|
|
||||||
# view(epsilonRec, :, :, i, j)
|
|
||||||
# end
|
|
||||||
# #TODO add b
|
|
||||||
# error(">>> DEBUG -> onComputeParamsChange!")
|
|
||||||
# end
|
|
||||||
|
|
||||||
|
# error("DEBUG -> onComputeParamsChange! $(Dates.now())")
|
||||||
|
end
|
||||||
|
|
||||||
function learn!(kfn::kfn_1)
|
function learn!(kfn::kfn_1)
|
||||||
#WORKING lif learn
|
#WORKING lif learn
|
||||||
|
|||||||
262
src/type.jl
262
src/type.jl
@@ -22,6 +22,8 @@ Base.@kwdef mutable struct kfn_1 <: knowledgeFn
|
|||||||
timeStep::Union{AbstractArray, Nothing} = nothing
|
timeStep::Union{AbstractArray, Nothing} = nothing
|
||||||
learningStage::Union{AbstractArray, Nothing} = nothing # 0 inference, 1 start, 2 during, 3 end learning
|
learningStage::Union{AbstractArray, Nothing} = nothing # 0 inference, 1 start, 2 during, 3 end learning
|
||||||
zit::Union{AbstractArray, Nothing} = nothing # 3D activation matrix
|
zit::Union{AbstractArray, Nothing} = nothing # 3D activation matrix
|
||||||
|
modelError::Union{AbstractArray, Nothing} = nothing # store RSNN error
|
||||||
|
outputError::Union{AbstractArray, Nothing} = nothing # store output neurons error
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------- #
|
# ---------------------------------------------------------------------------- #
|
||||||
# LIF Neurons #
|
# LIF Neurons #
|
||||||
@@ -31,12 +33,11 @@ Base.@kwdef mutable struct kfn_1 <: knowledgeFn
|
|||||||
|
|
||||||
# main variables according to papers
|
# main variables according to papers
|
||||||
lif_wRec::Union{AbstractArray, Nothing} = nothing
|
lif_wRec::Union{AbstractArray, Nothing} = nothing
|
||||||
lif_vt0::Union{AbstractArray, Nothing} = nothing
|
lif_vt::Union{AbstractArray, Nothing} = nothing
|
||||||
lif_vt1::Union{AbstractArray, Nothing} = nothing
|
|
||||||
lif_vth::Union{AbstractArray, Nothing} = nothing
|
lif_vth::Union{AbstractArray, Nothing} = nothing
|
||||||
lif_vRest::Union{AbstractArray, Nothing} = nothing
|
lif_vRest::Union{AbstractArray, Nothing} = nothing
|
||||||
lif_zt0::Union{AbstractArray, Nothing} = nothing
|
lif_zt::Union{AbstractArray, Nothing} = nothing
|
||||||
lif_zt1::Union{AbstractArray, Nothing} = nothing
|
lif_zt4d::Union{AbstractArray, Nothing} = nothing
|
||||||
lif_refractoryCounter::Union{AbstractArray, Nothing} = nothing
|
lif_refractoryCounter::Union{AbstractArray, Nothing} = nothing
|
||||||
lif_refractoryDuration::Union{AbstractArray, Nothing} = nothing
|
lif_refractoryDuration::Union{AbstractArray, Nothing} = nothing
|
||||||
lif_alpha::Union{AbstractArray, Nothing} = nothing
|
lif_alpha::Union{AbstractArray, Nothing} = nothing
|
||||||
@@ -48,18 +49,18 @@ Base.@kwdef mutable struct kfn_1 <: knowledgeFn
|
|||||||
lif_eta::Union{AbstractArray, Nothing} = nothing
|
lif_eta::Union{AbstractArray, Nothing} = nothing
|
||||||
lif_gammaPd::Union{AbstractArray, Nothing} = nothing
|
lif_gammaPd::Union{AbstractArray, Nothing} = nothing
|
||||||
lif_wRecChange::Union{AbstractArray, Nothing} = nothing
|
lif_wRecChange::Union{AbstractArray, Nothing} = nothing
|
||||||
|
lif_error::Union{AbstractArray, Nothing} = nothing
|
||||||
|
|
||||||
lif_firingCounter::Union{AbstractArray, Nothing} = nothing
|
lif_firingCounter::Union{AbstractArray, Nothing} = nothing
|
||||||
|
|
||||||
# pre-allocation array
|
# pre-allocation array
|
||||||
lif_arrayProjection3DTo4D::Union{AbstractArray, Nothing} = nothing # use to project 3d array to 4d
|
lif_arrayProjection4d::Union{AbstractArray, Nothing} = nothing # use to project 3d array to 4d
|
||||||
lif_recSignal::Union{AbstractArray, Nothing} = nothing
|
lif_recSignal::Union{AbstractArray, Nothing} = nothing
|
||||||
lif_decayed_vt0::Union{AbstractArray, Nothing} = nothing
|
# lif_decayed_epsilonRec::Union{AbstractArray, Nothing} = nothing
|
||||||
lif_decayed_epsilonRec::Union{AbstractArray, Nothing} = nothing
|
# lif_vt_diff_vth::Union{AbstractArray, Nothing} = nothing
|
||||||
lif_vt1_diff_vth::Union{AbstractArray, Nothing} = nothing
|
# lif_vt_diff_vth_div_vth::Union{AbstractArray, Nothing} = nothing
|
||||||
lif_vt1_diff_vth_div_vth::Union{AbstractArray, Nothing} = nothing
|
# lif_gammaPd_div_vth::Union{AbstractArray, Nothing} = nothing
|
||||||
lif_gammaPd_div_vth::Union{AbstractArray, Nothing} = nothing
|
# lif_phiActivation::Union{AbstractArray, Nothing} = nothing
|
||||||
lif_phiActivation::Union{AbstractArray, Nothing} = nothing
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------- #
|
# ---------------------------------------------------------------------------- #
|
||||||
# ALIF Neurons #
|
# ALIF Neurons #
|
||||||
@@ -67,12 +68,11 @@ Base.@kwdef mutable struct kfn_1 <: knowledgeFn
|
|||||||
alif_zit::Union{AbstractArray, Nothing} = nothing
|
alif_zit::Union{AbstractArray, Nothing} = nothing
|
||||||
|
|
||||||
alif_wRec::Union{AbstractArray, Nothing} = nothing
|
alif_wRec::Union{AbstractArray, Nothing} = nothing
|
||||||
alif_vt0::Union{AbstractArray, Nothing} = nothing
|
alif_vt::Union{AbstractArray, Nothing} = nothing
|
||||||
alif_vt1::Union{AbstractArray, Nothing} = nothing
|
|
||||||
alif_vth::Union{AbstractArray, Nothing} = nothing
|
alif_vth::Union{AbstractArray, Nothing} = nothing
|
||||||
alif_vRest::Union{AbstractArray, Nothing} = nothing
|
alif_vRest::Union{AbstractArray, Nothing} = nothing
|
||||||
alif_zt0::Union{AbstractArray, Nothing} = nothing
|
alif_zt::Union{AbstractArray, Nothing} = nothing
|
||||||
alif_zt1::Union{AbstractArray, Nothing} = nothing
|
alif_zt4d::Union{AbstractArray, Nothing} = nothing
|
||||||
alif_refractoryCounter::Union{AbstractArray, Nothing} = nothing
|
alif_refractoryCounter::Union{AbstractArray, Nothing} = nothing
|
||||||
alif_refractoryDuration::Union{AbstractArray, Nothing} = nothing
|
alif_refractoryDuration::Union{AbstractArray, Nothing} = nothing
|
||||||
alif_alpha::Union{AbstractArray, Nothing} = nothing
|
alif_alpha::Union{AbstractArray, Nothing} = nothing
|
||||||
@@ -84,18 +84,18 @@ Base.@kwdef mutable struct kfn_1 <: knowledgeFn
|
|||||||
alif_eta::Union{AbstractArray, Nothing} = nothing
|
alif_eta::Union{AbstractArray, Nothing} = nothing
|
||||||
alif_gammaPd::Union{AbstractArray, Nothing} = nothing
|
alif_gammaPd::Union{AbstractArray, Nothing} = nothing
|
||||||
alif_wRecChange::Union{AbstractArray, Nothing} = nothing
|
alif_wRecChange::Union{AbstractArray, Nothing} = nothing
|
||||||
|
alif_error::Union{AbstractArray, Nothing} = nothing
|
||||||
|
|
||||||
alif_firingCounter::Union{AbstractArray, Nothing} = nothing
|
alif_firingCounter::Union{AbstractArray, Nothing} = nothing
|
||||||
|
|
||||||
# pre-allocation array
|
# pre-allocation array
|
||||||
alif_arrayProjection3DTo4D::Union{AbstractArray, Nothing} = nothing # use to project 3d array to 4d
|
alif_arrayProjection4d::Union{AbstractArray, Nothing} = nothing # use to project 3d array to 4d
|
||||||
alif_recSignal::Union{AbstractArray, Nothing} = nothing
|
alif_recSignal::Union{AbstractArray, Nothing} = nothing
|
||||||
alif_decayed_vt0::Union{AbstractArray, Nothing} = nothing
|
# alif_decayed_epsilonRec::Union{AbstractArray, Nothing} = nothing
|
||||||
alif_decayed_epsilonRec::Union{AbstractArray, Nothing} = nothing
|
# alif_vt_diff_vth::Union{AbstractArray, Nothing} = nothing
|
||||||
alif_vt1_diff_vth::Union{AbstractArray, Nothing} = nothing
|
# alif_vt_diff_vth_div_vth::Union{AbstractArray, Nothing} = nothing
|
||||||
alif_vt1_diff_vth_div_vth::Union{AbstractArray, Nothing} = nothing
|
# alif_gammaPd_div_vth::Union{AbstractArray, Nothing} = nothing
|
||||||
alif_gammaPd_div_vth::Union{AbstractArray, Nothing} = nothing
|
# alif_phiActivation::Union{AbstractArray, Nothing} = nothing
|
||||||
alif_phiActivation::Union{AbstractArray, Nothing} = nothing
|
|
||||||
|
|
||||||
# alif specific variables
|
# alif specific variables
|
||||||
alif_epsilonRecA::Union{AbstractArray, Nothing} = nothing
|
alif_epsilonRecA::Union{AbstractArray, Nothing} = nothing
|
||||||
@@ -106,11 +106,11 @@ Base.@kwdef mutable struct kfn_1 <: knowledgeFn
|
|||||||
alif_tau_a::Union{AbstractFloat, Nothing} = nothing # τ_a, adaption time constant in millisecond
|
alif_tau_a::Union{AbstractFloat, Nothing} = nothing # τ_a, adaption time constant in millisecond
|
||||||
|
|
||||||
# alif specific pre-allocation array
|
# alif specific pre-allocation array
|
||||||
alif_phi_x_epsilonRec::Union{AbstractArray, Nothing} = nothing
|
# alif_phi_x_epsilonRec::Union{AbstractArray, Nothing} = nothing
|
||||||
alif_phi_x_beta::Union{AbstractArray, Nothing} = nothing
|
# alif_phi_x_beta::Union{AbstractArray, Nothing} = nothing
|
||||||
alif_rho_diff_phi_x_beta::Union{AbstractArray, Nothing} = nothing
|
# alif_rho_diff_phi_x_beta::Union{AbstractArray, Nothing} = nothing
|
||||||
alif_rho_div_phi_x_beta_x_epsilonRecA::Union{AbstractArray, Nothing} = nothing
|
# alif_rho_div_phi_x_beta_x_epsilonRecA::Union{AbstractArray, Nothing} = nothing
|
||||||
alif_beta_x_a::Union{AbstractArray, Nothing} = nothing
|
# alif_beta_x_a::Union{AbstractArray, Nothing} = nothing
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------- #
|
# ---------------------------------------------------------------------------- #
|
||||||
# Output Neurons #
|
# Output Neurons #
|
||||||
@@ -120,12 +120,11 @@ Base.@kwdef mutable struct kfn_1 <: knowledgeFn
|
|||||||
|
|
||||||
# main variables according to papers
|
# main variables according to papers
|
||||||
on_wOut::Union{AbstractArray, Nothing} = nothing # wOut is wRec, just use the name from paper
|
on_wOut::Union{AbstractArray, Nothing} = nothing # wOut is wRec, just use the name from paper
|
||||||
on_vt0::Union{AbstractArray, Nothing} = nothing
|
on_vt::Union{AbstractArray, Nothing} = nothing
|
||||||
on_vt1::Union{AbstractArray, Nothing} = nothing
|
|
||||||
on_vth::Union{AbstractArray, Nothing} = nothing
|
on_vth::Union{AbstractArray, Nothing} = nothing
|
||||||
on_vRest::Union{AbstractArray, Nothing} = nothing
|
on_vRest::Union{AbstractArray, Nothing} = nothing
|
||||||
on_zt0::Union{AbstractArray, Nothing} = nothing
|
on_zt::Union{AbstractArray, Nothing} = nothing
|
||||||
on_zt1::Union{AbstractArray, Nothing} = nothing
|
on_zt4d::Union{AbstractArray, Nothing} = nothing
|
||||||
on_refractoryCounter::Union{AbstractArray, Nothing} = nothing
|
on_refractoryCounter::Union{AbstractArray, Nothing} = nothing
|
||||||
on_refractoryDuration::Union{AbstractArray, Nothing} = nothing
|
on_refractoryDuration::Union{AbstractArray, Nothing} = nothing
|
||||||
on_alpha::Union{AbstractArray, Nothing} = nothing
|
on_alpha::Union{AbstractArray, Nothing} = nothing
|
||||||
@@ -137,18 +136,18 @@ Base.@kwdef mutable struct kfn_1 <: knowledgeFn
|
|||||||
on_eta::Union{AbstractArray, Nothing} = nothing
|
on_eta::Union{AbstractArray, Nothing} = nothing
|
||||||
on_gammaPd::Union{AbstractArray, Nothing} = nothing
|
on_gammaPd::Union{AbstractArray, Nothing} = nothing
|
||||||
on_wOutChange::Union{AbstractArray, Nothing} = nothing
|
on_wOutChange::Union{AbstractArray, Nothing} = nothing
|
||||||
|
on_error::Union{AbstractArray, Nothing} = nothing
|
||||||
|
|
||||||
on_firingCounter::Union{AbstractArray, Nothing} = nothing
|
on_firingCounter::Union{AbstractArray, Nothing} = nothing
|
||||||
|
|
||||||
# pre-allocation array
|
# pre-allocation array
|
||||||
on_arrayProjection3DTo4D::Union{AbstractArray, Nothing} = nothing # use to project 3d array to 4d
|
on_arrayProjection4d::Union{AbstractArray, Nothing} = nothing # use to project 3d array to 4d
|
||||||
on_recSignal::Union{AbstractArray, Nothing} = nothing
|
on_recSignal::Union{AbstractArray, Nothing} = nothing
|
||||||
on_decayed_vt0::Union{AbstractArray, Nothing} = nothing
|
# on_decayed_epsilonRec::Union{AbstractArray, Nothing} = nothing
|
||||||
on_decayed_epsilonRec::Union{AbstractArray, Nothing} = nothing
|
# on_vt_diff_vth::Union{AbstractArray, Nothing} = nothing
|
||||||
on_vt1_diff_vth::Union{AbstractArray, Nothing} = nothing
|
# on_vt_diff_vth_div_vth::Union{AbstractArray, Nothing} = nothing
|
||||||
on_vt1_diff_vth_div_vth::Union{AbstractArray, Nothing} = nothing
|
# on_gammaPd_div_vth::Union{AbstractArray, Nothing} = nothing
|
||||||
on_gammaPd_div_vth::Union{AbstractArray, Nothing} = nothing
|
# on_phiActivation::Union{AbstractArray, Nothing} = nothing
|
||||||
on_phiActivation::Union{AbstractArray, Nothing} = nothing
|
|
||||||
end
|
end
|
||||||
|
|
||||||
# outer constructor
|
# outer constructor
|
||||||
@@ -169,7 +168,8 @@ function kfn_1(params::Dict; device=cpu)
|
|||||||
col += kfn.params[:computeNeuron][:alif][:numbers][2]
|
col += kfn.params[:computeNeuron][:alif][:numbers][2]
|
||||||
|
|
||||||
# activation matrix
|
# activation matrix
|
||||||
kfn.zit = zeros(row, col, batch) |> device
|
kfn.zit = zeros(row, col, batch) |> device
|
||||||
|
kfn.modelError = zeros(1) |> device
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------- #
|
# ---------------------------------------------------------------------------- #
|
||||||
# LIF config #
|
# LIF config #
|
||||||
@@ -190,35 +190,34 @@ function kfn_1(params::Dict; device=cpu)
|
|||||||
end
|
end
|
||||||
# project 3D w into 4D kfn.lif_wRec (row, col, n, batch)
|
# project 3D w into 4D kfn.lif_wRec (row, col, n, batch)
|
||||||
kfn.lif_wRec = reshape(w, (row, col, n, 1)) .* ones(row, col, n, batch) |> device
|
kfn.lif_wRec = reshape(w, (row, col, n, 1)) .* ones(row, col, n, batch) |> device
|
||||||
kfn.lif_zit = similar(kfn.lif_wRec) .= 0 |> device
|
kfn.lif_zit = (similar(kfn.lif_wRec) .= 0) |> device
|
||||||
kfn.lif_vt0 = zeros(1, 1, n, batch) |> device
|
kfn.lif_vt = (similar(kfn.lif_wRec) .= 0) |> device
|
||||||
kfn.lif_vt1 = similar(kfn.lif_vt0) .= 0 |> device
|
kfn.lif_vth = (similar(kfn.lif_wRec) .= 1) |> device
|
||||||
kfn.lif_vth = similar(kfn.lif_vt0) .= 1 |> device
|
kfn.lif_vRest = (similar(kfn.lif_wRec) .= 0) |> device
|
||||||
kfn.lif_vRest = similar(kfn.lif_vt0) .= 0 |> device
|
kfn.lif_zt = zeros(1, 1, n, batch) |> device
|
||||||
kfn.lif_zt0 = similar(kfn.lif_vt0) .= 0 |> device
|
kfn.lif_zt4d = (similar(kfn.lif_wRec) .= 0) |> device
|
||||||
kfn.lif_zt1 = similar(kfn.lif_vt0) .= 0 |> device
|
kfn.lif_refractoryCounter = (similar(kfn.lif_wRec) .= 0) |> device
|
||||||
kfn.lif_refractoryCounter = similar(kfn.lif_vt0) .= 0 |> device
|
kfn.lif_refractoryDuration = (similar(kfn.lif_wRec) .= 3) |> device
|
||||||
kfn.lif_refractoryDuration = similar(kfn.lif_vt0) .= 3 |> device
|
|
||||||
kfn.lif_delta = 1.0
|
kfn.lif_delta = 1.0
|
||||||
kfn.lif_tau_m = 20.0
|
kfn.lif_tau_m = 20.0
|
||||||
kfn.lif_alpha = similar(kfn.lif_vt0) .= (exp(-kfn.lif_delta / kfn.lif_tau_m)) |> device
|
kfn.lif_alpha = (similar(kfn.lif_wRec) .= (exp(-kfn.lif_delta / kfn.lif_tau_m))) |> device
|
||||||
kfn.lif_phi = similar(kfn.lif_vt0) .= 0 |> device
|
kfn.lif_phi = (similar(kfn.lif_wRec) .= 0) |> device
|
||||||
kfn.lif_epsilonRec = similar(kfn.lif_wRec) .= 0 |> device
|
kfn.lif_epsilonRec = (similar(kfn.lif_wRec) .= 0) |> device
|
||||||
kfn.lif_eRec = similar(kfn.lif_wRec) .= 0 |> device
|
kfn.lif_eRec = (similar(kfn.lif_wRec) .= 0) |> device
|
||||||
kfn.lif_eta = similar(kfn.lif_vt0) .= 0 |> device
|
kfn.lif_eta = (similar(kfn.lif_wRec) .= 0) |> device
|
||||||
kfn.lif_gammaPd = similar(kfn.lif_vt0) .= 0.3 |> device
|
kfn.lif_gammaPd = (similar(kfn.lif_wRec) .= 0.3) |> device
|
||||||
kfn.lif_wRecChange = similar(kfn.lif_wRec) .= 0 |> device
|
kfn.lif_wRecChange = (similar(kfn.lif_wRec) .= 0) |> device
|
||||||
|
kfn.lif_error = (similar(kfn.lif_wRec) .= 0) |> device
|
||||||
|
|
||||||
kfn.lif_firingCounter = similar(kfn.lif_vt0) .= 0 |> device
|
kfn.lif_firingCounter = (similar(kfn.lif_wRec) .= 0) |> device
|
||||||
|
|
||||||
kfn.lif_arrayProjection3DTo4D = similar(kfn.lif_wRec) .= 1 |> device
|
kfn.lif_arrayProjection4d = (similar(kfn.lif_wRec) .= 1) |> device
|
||||||
kfn.lif_recSignal = similar(kfn.lif_wRec) .= 0 |> device
|
kfn.lif_recSignal = (similar(kfn.lif_wRec) .= 0) |> device
|
||||||
kfn.lif_decayed_vt0 = similar(kfn.lif_vt0) .= 0 |> device
|
# kfn.lif_decayed_epsilonRec = (similar(kfn.lif_wRec) .= 0) |> device
|
||||||
kfn.lif_decayed_epsilonRec = similar(kfn.lif_wRec) .= 0 |> device
|
# kfn.lif_vt_diff_vth = (similar(kfn.lif_wRec) .= 0) |> device
|
||||||
kfn.lif_vt1_diff_vth = similar(kfn.lif_vt0) .= 0 |> device
|
# kfn.lif_vt_diff_vth_div_vth = (similar(kfn.lif_wRec) .= 0) |> device
|
||||||
kfn.lif_vt1_diff_vth_div_vth = similar(kfn.lif_vt0) .= 0 |> device
|
# kfn.lif_gammaPd_div_vth = (similar(kfn.lif_wRec) .= 0) |> device
|
||||||
kfn.lif_gammaPd_div_vth = similar(kfn.lif_vt0) .= 0 |> device
|
# kfn.lif_phiActivation = (similar(kfn.lif_wRec) .= 0) |> device
|
||||||
kfn.lif_phiActivation = similar(kfn.lif_vt0) .= 0 |> device
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------- #
|
# ---------------------------------------------------------------------------- #
|
||||||
# ALIF config #
|
# ALIF config #
|
||||||
@@ -237,48 +236,47 @@ function kfn_1(params::Dict; device=cpu)
|
|||||||
end
|
end
|
||||||
# project 3D w into 4D kfn.alif_wRec
|
# project 3D w into 4D kfn.alif_wRec
|
||||||
kfn.alif_wRec = reshape(w, (row, col, n, 1)) .* ones(row, col, n, batch) |> device
|
kfn.alif_wRec = reshape(w, (row, col, n, 1)) .* ones(row, col, n, batch) |> device
|
||||||
kfn.alif_zit = similar(kfn.alif_wRec) .= 0 |> device
|
kfn.alif_zit = (similar(kfn.alif_wRec) .= 0) |> device
|
||||||
kfn.alif_vt0 = zeros(1, 1, n, batch) |> device
|
kfn.alif_vt = (similar(kfn.alif_wRec) .= 0) |> device
|
||||||
kfn.alif_vt1 = similar(kfn.alif_vt0) .= 0 |> device
|
kfn.alif_vth = (similar(kfn.alif_wRec) .= 1) |> device
|
||||||
kfn.alif_vth = similar(kfn.alif_vt0) .= 1 |> device
|
kfn.alif_vRest = (similar(kfn.alif_wRec) .= 0) |> device
|
||||||
kfn.alif_vRest = similar(kfn.alif_vt0) .= 0 |> device
|
kfn.alif_zt = zeros(1, 1, n, batch) |> device
|
||||||
kfn.alif_zt0 = similar(kfn.alif_vt0) .= 0 |> device
|
kfn.alif_zt4d = (similar(kfn.alif_wRec) .= 0) |> device
|
||||||
kfn.alif_zt1 = similar(kfn.alif_vt0) .= 0 |> device
|
kfn.alif_refractoryCounter = (similar(kfn.alif_wRec) .= 0) |> device
|
||||||
kfn.alif_refractoryCounter = similar(kfn.alif_vt0) .= 0 |> device
|
kfn.alif_refractoryDuration = (similar(kfn.alif_wRec) .= 3) |> device
|
||||||
kfn.alif_refractoryDuration = similar(kfn.alif_vt0) .= 3 |> device
|
|
||||||
kfn.alif_delta = 1.0
|
kfn.alif_delta = 1.0
|
||||||
kfn.alif_tau_m = 20.0
|
kfn.alif_tau_m = 20.0
|
||||||
kfn.alif_alpha = similar(kfn.alif_vt0) .= (exp(-kfn.alif_delta / kfn.alif_tau_m)) |> device
|
kfn.alif_alpha = (similar(kfn.alif_wRec) .= (exp(-kfn.alif_delta / kfn.alif_tau_m))) |> device
|
||||||
kfn.alif_phi = similar(kfn.alif_vt0) .= 0 |> device
|
kfn.alif_phi = (similar(kfn.alif_wRec) .= 0) |> device
|
||||||
kfn.alif_epsilonRec = similar(kfn.alif_wRec) .= 0 |> device
|
kfn.alif_epsilonRec = (similar(kfn.alif_wRec) .= 0) |> device
|
||||||
kfn.alif_eRec = similar(kfn.alif_wRec) .= 0 |> device
|
kfn.alif_eRec = (similar(kfn.alif_wRec) .= 0) |> device
|
||||||
kfn.alif_eta = similar(kfn.alif_vt0) .= 0 |> device
|
kfn.alif_eta = (similar(kfn.alif_wRec) .= 0) |> device
|
||||||
kfn.alif_gammaPd = similar(kfn.alif_vt0) .= 0.3 |> device
|
kfn.alif_gammaPd = (similar(kfn.alif_wRec) .= 0.3) |> device
|
||||||
kfn.alif_wRecChange = similar(kfn.alif_wRec) .= 0 |> device
|
kfn.alif_wRecChange = (similar(kfn.alif_wRec) .= 0) |> device
|
||||||
|
kfn.alif_error = (similar(kfn.alif_wRec) .= 0) |> device
|
||||||
|
|
||||||
kfn.alif_firingCounter = similar(kfn.alif_vt0) .= 0 |> device
|
kfn.alif_firingCounter = (similar(kfn.alif_wRec) .= 0) |> device
|
||||||
|
|
||||||
kfn.alif_arrayProjection3DTo4D = similar(kfn.alif_wRec) .= 1 |> device
|
kfn.alif_arrayProjection4d = (similar(kfn.alif_wRec) .= 1) |> device
|
||||||
kfn.alif_recSignal = similar(kfn.alif_wRec) .= 0 |> device
|
kfn.alif_recSignal = (similar(kfn.alif_wRec) .= 0) |> device
|
||||||
kfn.alif_decayed_vt0 = similar(kfn.alif_vt0) .= 0 |> device
|
# kfn.alif_decayed_epsilonRec = (similar(kfn.alif_wRec) .= 0) |> device
|
||||||
kfn.alif_decayed_epsilonRec = similar(kfn.alif_wRec) .= 0 |> device
|
# kfn.alif_vt_diff_vth = (similar(kfn.alif_wRec) .= 0) |> device
|
||||||
kfn.alif_vt1_diff_vth = similar(kfn.alif_vt0) .= 0 |> device
|
# kfn.alif_vt_diff_vth_div_vth = (similar(kfn.alif_wRec) .= 0) |> device
|
||||||
kfn.alif_vt1_diff_vth_div_vth = similar(kfn.alif_vt0) .= 0 |> device
|
# kfn.alif_gammaPd_div_vth = (similar(kfn.alif_wRec) .= 0) |> device
|
||||||
kfn.alif_gammaPd_div_vth = similar(kfn.alif_vt0) .= 0 |> device
|
# kfn.alif_phiActivation = (similar(kfn.alif_wRec) .= 0) |> device
|
||||||
kfn.alif_phiActivation = similar(kfn.alif_vt0) .= 0 |> device
|
|
||||||
|
|
||||||
# alif specific variables
|
# alif specific variables
|
||||||
kfn.alif_epsilonRecA = similar(kfn.alif_wRec) .= 0 |> device
|
kfn.alif_epsilonRecA = (similar(kfn.alif_wRec) .= 0) |> device
|
||||||
kfn.alif_avth = similar(kfn.alif_vt0) .= 0 |> device
|
kfn.alif_avth = (similar(kfn.alif_wRec) .= 0) |> device
|
||||||
kfn.alif_a = similar(kfn.alif_vt0) .= 0 |> device
|
kfn.alif_a = (similar(kfn.alif_wRec) .= 0) |> device
|
||||||
kfn.alif_beta = similar(kfn.alif_vt0) .= 0.07 |> device
|
kfn.alif_beta = (similar(kfn.alif_wRec) .= 0.07) |> device
|
||||||
kfn.alif_tau_a = 100.0
|
kfn.alif_tau_a = 100.0
|
||||||
kfn.alif_rho = similar(kfn.alif_vt0) .= (exp(-kfn.alif_delta / kfn.alif_tau_a)) |> device
|
kfn.alif_rho = (similar(kfn.alif_wRec) .= (exp(-kfn.alif_delta / kfn.alif_tau_a))) |> device
|
||||||
kfn.alif_phi_x_epsilonRec = similar(kfn.alif_wRec) .= 0 |> device
|
# kfn.alif_phi_x_epsilonRec = (similar(kfn.alif_wRec) .= 0) |> device
|
||||||
kfn.alif_phi_x_beta = similar(kfn.alif_vt0) .= 0 |> device
|
# kfn.alif_phi_x_beta = (similar(kfn.alif_wRec) .= 0) |> device
|
||||||
kfn.alif_rho_diff_phi_x_beta = similar(kfn.alif_vt0) .= 0 |> device
|
# kfn.alif_rho_diff_phi_x_beta = (similar(kfn.alif_wRec) .= 0) |> device
|
||||||
kfn.alif_rho_div_phi_x_beta_x_epsilonRecA = similar(kfn.alif_wRec) .= 0 |> device
|
# kfn.alif_rho_div_phi_x_beta_x_epsilonRecA = (similar(kfn.alif_wRec) .= 0) |> device
|
||||||
kfn.alif_beta_x_a = similar(kfn.alif_vt0) .= 0 |> device
|
# kfn.alif_beta_x_a = (similar(kfn.alif_wRec) .= 0) |> device
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------- #
|
# ---------------------------------------------------------------------------- #
|
||||||
# output config #
|
# output config #
|
||||||
@@ -297,43 +295,49 @@ function kfn_1(params::Dict; device=cpu)
|
|||||||
end
|
end
|
||||||
# project 3D w into 4D kfn.lif_wOut (row, col, n, batch)
|
# project 3D w into 4D kfn.lif_wOut (row, col, n, batch)
|
||||||
kfn.on_wOut = reshape(w, (row, col, n, 1)) .* ones(row, col, n, batch) |> device
|
kfn.on_wOut = reshape(w, (row, col, n, 1)) .* ones(row, col, n, batch) |> device
|
||||||
kfn.on_zit = similar(kfn.on_wOut) .= 0 |> device
|
kfn.on_zit = (similar(kfn.on_wOut) .= 0) |> device
|
||||||
kfn.on_vt0 = zeros(1, 1, n, batch) |> device
|
kfn.on_vt = (similar(kfn.on_wOut) .= 0) |> device
|
||||||
kfn.on_vt1 = similar(kfn.on_vt0) .= 0 |> device
|
kfn.on_vth = (similar(kfn.on_wOut) .= 1) |> device
|
||||||
kfn.on_vth = similar(kfn.on_vt0) .= 1 |> device
|
kfn.on_vRest = (similar(kfn.on_wOut) .= 0) |> device
|
||||||
kfn.on_vRest = similar(kfn.on_vt0) .= 0 |> device
|
kfn.on_zt = zeros(1, 1, n, batch) |> device
|
||||||
kfn.on_zt0 = similar(kfn.on_vt0) .= 0 |> device
|
kfn.on_zt4d = (similar(kfn.on_wOut) .= 0) |> device
|
||||||
kfn.on_zt1 = similar(kfn.on_vt0) .= 0 |> device
|
kfn.on_refractoryCounter = (similar(kfn.on_wOut) .= 0) |> device
|
||||||
kfn.on_refractoryCounter = similar(kfn.on_vt0) .= 0 |> device
|
kfn.on_refractoryDuration = (similar(kfn.on_wOut) .= 0) |> device
|
||||||
kfn.on_refractoryDuration = similar(kfn.on_vt0) .= 0 |> device
|
|
||||||
kfn.on_delta = 1.0
|
kfn.on_delta = 1.0
|
||||||
kfn.on_tau_m = 20.0
|
kfn.on_tau_m = 20.0
|
||||||
kfn.on_alpha = similar(kfn.on_vt0) .= (exp(-kfn.on_delta / kfn.on_tau_m)) |> device
|
kfn.on_alpha = (similar(kfn.on_wOut) .= (exp(-kfn.on_delta / kfn.on_tau_m))) |> device
|
||||||
kfn.on_phi = similar(kfn.on_vt0) .= 0 |> device
|
kfn.on_phi = (similar(kfn.on_wOut) .= 0) |> device
|
||||||
kfn.on_epsilonRec = similar(kfn.on_wOut) .= 0 |> device
|
kfn.on_epsilonRec = (similar(kfn.on_wOut) .= 0) |> device
|
||||||
kfn.on_eRec = similar(kfn.on_wOut) .= 0 |> device
|
kfn.on_eRec = (similar(kfn.on_wOut) .= 0) |> device
|
||||||
kfn.on_eta = similar(kfn.on_vt0) .= 0 |> device
|
kfn.on_eta = (similar(kfn.on_wOut) .= 0) |> device
|
||||||
kfn.on_gammaPd = similar(kfn.on_vt0) .= 0.3 |> device
|
kfn.on_gammaPd = (similar(kfn.on_wOut) .= 0.3) |> device
|
||||||
kfn.on_wOutChange = similar(kfn.on_wOut) .= 0 |> device
|
kfn.on_wOutChange = (similar(kfn.on_wOut) .= 0) |> device
|
||||||
|
kfn.on_error = (similar(kfn.on_wOut) .= 0) |> device
|
||||||
|
|
||||||
kfn.on_firingCounter = similar(kfn.on_vt0) .= 0 |> device
|
kfn.on_firingCounter = (similar(kfn.on_wOut) .= 0) |> device
|
||||||
|
|
||||||
kfn.on_arrayProjection3DTo4D = similar(kfn.on_wOut) .= 1 |> device
|
kfn.on_arrayProjection4d = (similar(kfn.on_wOut) .= 1) |> device
|
||||||
kfn.on_recSignal = similar(kfn.on_wOut) .= 0 |> device
|
kfn.on_recSignal = (similar(kfn.on_wOut) .= 0) |> device
|
||||||
kfn.on_decayed_vt0 = similar(kfn.on_vt0) .= 0 |> device
|
|
||||||
kfn.on_decayed_epsilonRec = similar(kfn.on_wOut) .= 0 |> device
|
|
||||||
kfn.on_vt1_diff_vth = similar(kfn.on_vt0) .= 0 |> device
|
|
||||||
kfn.on_vt1_diff_vth_div_vth = similar(kfn.on_vt0) .= 0 |> device
|
kfn.outputError = zeros(n, batch) |> device
|
||||||
kfn.on_gammaPd_div_vth = similar(kfn.on_vt0) .= 0 |> device
|
|
||||||
kfn.on_phiActivation = similar(kfn.on_vt0) .= 0 |> device
|
|
||||||
|
|
||||||
|
|
||||||
|
# kfn.on_decayed_epsilonRec = (similar(kfn.on_wOut) .= 0 |> device
|
||||||
|
# kfn.on_vt_diff_vth = (similar(kfn.on_wOut) .= 0 |> device
|
||||||
|
# kfn.on_vt_diff_vth_div_vth = (similar(kfn.on_wOut) .= 0 |> device
|
||||||
|
# kfn.on_gammaPd_div_vth = (similar(kfn.on_wOut) .= 0 |> device
|
||||||
|
# kfn.on_phiActivation = (similar(kfn.on_wOut) .= 0 |> device
|
||||||
|
|
||||||
# kfn.on_zit = zeros(row, col, n, batch) |> device
|
# kfn.on_zit = zeros(row, col, n, batch) |> device
|
||||||
# kfn.on_vt0 = zeros(1, 1, n, batch) |> device
|
# kfn.on_vt = zeros(1, 1, n, batch) |> device
|
||||||
# kfn.on_vt1 = zeros(1, 1, n, batch) |> device
|
|
||||||
# kfn.on_vth = ones(1, 1, n, batch) |> device
|
# kfn.on_vth = ones(1, 1, n, batch) |> device
|
||||||
# kfn.on_vRest = zeros(1, 1, n, batch) |> device
|
# kfn.on_vRest = zeros(1, 1, n, batch) |> device
|
||||||
# # kfn.on_zt0 = zeros(1, 1, n, batch) |> device
|
# # kfn.on_zt = zeros(1, 1, n, batch) |> device
|
||||||
# kfn.on_zt1 = zeros(1, 1, n, batch) |> device
|
# kfn.on_zt4d = zeros(1, 1, n, batch) |> device
|
||||||
# kfn.on_refractoryCounter = zeros(1, 1, n, batch) |> device
|
# kfn.on_refractoryCounter = zeros(1, 1, n, batch) |> device
|
||||||
# kfn.on_refractoryDuration = ones(1, 1, n, batch) .* 0 |> device
|
# kfn.on_refractoryDuration = ones(1, 1, n, batch) .* 0 |> device
|
||||||
# kfn.on_delta = 1.0
|
# kfn.on_delta = 1.0
|
||||||
@@ -350,7 +354,7 @@ function kfn_1(params::Dict; device=cpu)
|
|||||||
|
|
||||||
# kfn.on_firingCounter = zeros(1, 1, n, batch) |> device
|
# kfn.on_firingCounter = zeros(1, 1, n, batch) |> device
|
||||||
# kfn.on_arraySize = [row, col, n, batch] |> device
|
# kfn.on_arraySize = [row, col, n, batch] |> device
|
||||||
# kfn.on_arrayProjection3DTo4D = ones(row, col, n, batch) |> device
|
# kfn.on_arrayProjection4d = ones(row, col, n, batch) |> device
|
||||||
|
|
||||||
# # subscription
|
# # subscription
|
||||||
# w = zeros(row, col, n)
|
# w = zeros(row, col, n)
|
||||||
|
|||||||
Reference in New Issue
Block a user