956 lines
41 KiB
Julia
956 lines
41 KiB
Julia
module forward
|
|
|
|
# export
|
|
|
|
using Flux, CUDA
|
|
using GeneralUtils
|
|
using ..type, ..snnUtil
|
|
|
|
#------------------------------------------------------------------------------------------------100
|
|
|
|
""" kfn forward
|
|
input (row, col, batch)
|
|
"""
|
|
function (kfn::kfn_1)(input::AbstractArray)
|
|
|
|
kfn.timeStep .+= 1
|
|
|
|
# what to do at the start of learning round
|
|
if view(kfn.learningStage, 1)[1] == 1
|
|
kfn.timeStep .= 1
|
|
|
|
# reset learning params
|
|
kfn.zitCumulative = (kfn.zitCumulative[:,:,1] .= 0)
|
|
|
|
kfn.lif_vt .= 0
|
|
kfn.lif_wRecChange .= 0
|
|
kfn.lif_epsilonRec .= 0
|
|
kfn.lif_firingCounter .= 0
|
|
kfn.lif_refractoryCounter .= 0
|
|
kfn.lif_zt .= 0
|
|
kfn.lif_synapticActivityCounter .= 0
|
|
|
|
kfn.alif_vt .= 0
|
|
kfn.alif_a .= 0
|
|
kfn.alif_epsilonRec .= 0
|
|
kfn.alif_epsilonRecA .= 0
|
|
kfn.alif_wRecChange .= 0
|
|
kfn.alif_firingCounter .= 0
|
|
kfn.alif_refractoryCounter .= 0
|
|
kfn.alif_zt .= 0
|
|
kfn.alif_synapticActivityCounter .= 0
|
|
|
|
kfn.on_vt .= 0
|
|
kfn.on_epsilonRec .= 0
|
|
kfn.on_wOutChange .= 0
|
|
kfn.on_refractoryCounter .= 0
|
|
kfn.on_synapticActivityCounter .= 0
|
|
|
|
kfn.learningStage = [2]
|
|
end
|
|
|
|
# update activation matrix with "lif_zt1" and "alif_zt1" by concatenating
|
|
# (input, lif_zt1, alif_zt1) to form activation matrix
|
|
_zit = cat(reshape(input, (size(input, 1), size(input, 2), 1, size(input, 3))),
|
|
reshape(kfn.lif_zt, (size(input, 1), :, 1, size(input, 3))),
|
|
reshape(kfn.alif_zt, (size(input, 1), :, 1, size(input, 3))), dims=2)
|
|
kfn.zit .= reshape(_zit, (size(input, 1), :, size(input, 3)))
|
|
|
|
@sync begin
|
|
@async begin
|
|
# project 3D kfn zit into 4D lif zit
|
|
i1, i2, i3, i4 = size(kfn.lif_zit)
|
|
kfn.lif_zit .= reshape(kfn.zit, (i1, i2, 1, i4)) .* kfn.lif_arrayProjection4d
|
|
kfn.lif_exInType .= kfn.exInType .* kfn.lif_arrayProjection4d
|
|
|
|
lifForward( kfn.lif_zit,
|
|
kfn.lif_wRec,
|
|
kfn.lif_vt,
|
|
kfn.lif_vth,
|
|
kfn.lif_vRest,
|
|
kfn.lif_zt4d,
|
|
kfn.lif_alpha,
|
|
kfn.lif_phi,
|
|
kfn.lif_epsilonRec,
|
|
kfn.lif_refractoryCounter,
|
|
kfn.lif_refractoryDuration,
|
|
kfn.lif_gammaPd,
|
|
kfn.lif_firingCounter,
|
|
kfn.lif_recSignal,
|
|
kfn.lif_exInType,
|
|
kfn.lif_wRecChange,
|
|
kfn.lif_neuronInactivityCounter,
|
|
kfn.lif_synapseReconnectDelay,
|
|
kfn.lif_synapticActivityCounter,
|
|
kfn.timeStep,
|
|
)
|
|
end
|
|
@async begin
|
|
# project 3D kfn zit into 4D alif zit
|
|
i1, i2, i3, i4 = size(kfn.alif_zit)
|
|
kfn.alif_zit .= reshape(kfn.zit, (i1, i2, 1, i4)) .* kfn.alif_arrayProjection4d
|
|
kfn.alif_exInType .= kfn.exInType .* kfn.alif_arrayProjection4d
|
|
|
|
alifForward(kfn.alif_zit,
|
|
kfn.alif_wRec,
|
|
kfn.alif_vt,
|
|
kfn.alif_vth,
|
|
kfn.alif_vRest,
|
|
kfn.alif_zt4d,
|
|
kfn.alif_alpha,
|
|
kfn.alif_phi,
|
|
kfn.alif_epsilonRec,
|
|
kfn.alif_refractoryCounter,
|
|
kfn.alif_refractoryDuration,
|
|
kfn.alif_gammaPd,
|
|
kfn.alif_firingCounter,
|
|
kfn.alif_recSignal,
|
|
kfn.alif_exInType,
|
|
kfn.alif_wRecChange,
|
|
kfn.alif_neuronInactivityCounter,
|
|
kfn.alif_synapseReconnectDelay,
|
|
kfn.alif_synapticActivityCounter,
|
|
kfn.timeStep,
|
|
|
|
kfn.alif_epsilonRecA,
|
|
kfn.alif_a,
|
|
kfn.alif_avth,
|
|
kfn.alif_beta,
|
|
kfn.alif_rho,
|
|
)
|
|
end
|
|
end
|
|
|
|
# reduce lif_zt4d and alif_zt4d into lif_zt, alif_zt (4d -> 1d)
|
|
kfn.lif_zt .= reduce(max, kfn.lif_zt4d, dims=(1,2))
|
|
kfn.alif_zt .= reduce(max, kfn.alif_zt4d, dims=(1,2))
|
|
|
|
# update activation matrix with "lif_zt1" and "alif_zt1" by concatenating
|
|
# (input, lif_zt1, alif_zt1) to form activation matrix
|
|
_zit = cat(reshape(input, (size(input, 1), size(input, 2), 1, size(input, 3))),
|
|
reshape(kfn.lif_zt, (size(input, 1), :, 1, size(input, 3))),
|
|
reshape(kfn.alif_zt, (size(input, 1), :, 1, size(input, 3))), dims=2)
|
|
kfn.zit .= reshape(_zit, (size(input, 1), :, size(input, 3)))
|
|
kfn.zitCumulative = sum(kfn.zitCumulative) == 0 ? kfn.zit : cat(kfn.zitCumulative, kfn.zit, dims=3)
|
|
# kfn.zitCumulative = cat(kfn.zitCumulative, kfn.zit, dims=3)
|
|
# kfn.zitCumulative .+= kfn.zit
|
|
|
|
# project 3D kfn zit into 4D on zit
|
|
i1, i2, i3, i4 = size(kfn.on_zit)
|
|
kfn.on_zit .= reshape(kfn.zit, (i1, i2, 1, i4)) .* kfn.on_arrayProjection4d
|
|
|
|
# read out
|
|
onForward( kfn.on_zit,
|
|
kfn.on_wOut,
|
|
kfn.on_vt,
|
|
kfn.on_vth,
|
|
kfn.on_vRest,
|
|
kfn.on_zt4d,
|
|
kfn.on_alpha,
|
|
kfn.on_phi,
|
|
kfn.on_epsilonRec,
|
|
kfn.on_refractoryCounter,
|
|
kfn.on_refractoryDuration,
|
|
kfn.on_gammaPd,
|
|
kfn.on_firingCounter,
|
|
kfn.on_recSignal,
|
|
kfn.on_synapticActivityCounter,
|
|
)
|
|
# get on_zt4d to on_zt
|
|
kfn.on_zt .= reduce(max, kfn.on_zt4d, dims=(1,2))
|
|
logit = reshape(kfn.on_zt, (size(input, 1), :)) # (outputNeurons, batch)
|
|
|
|
return logit,
|
|
kfn.zit
|
|
end
|
|
|
|
# gpu launcher
|
|
function lifForward( zit::CuArray,
|
|
wRec::CuArray,
|
|
vt::CuArray,
|
|
vth::CuArray,
|
|
vRest::CuArray,
|
|
zt::CuArray,
|
|
alpha::CuArray,
|
|
phi::CuArray,
|
|
epsilonRec::CuArray,
|
|
refractoryCounter::CuArray,
|
|
refractoryDuration::CuArray,
|
|
gammaPd::CuArray,
|
|
firingCounter::CuArray,
|
|
recSignal::CuArray,
|
|
exInType::CuArray,
|
|
wRecChange::CuArray,
|
|
neuronInactivityCounter::CuArray,
|
|
synapseReconnectDelay::CuArray,
|
|
synapticActivityCounter::CuArray,
|
|
timeStep::CuArray,
|
|
)
|
|
|
|
kernel = @cuda launch=false lifForward( zit,
|
|
wRec,
|
|
vt,
|
|
vth,
|
|
vRest,
|
|
zt,
|
|
alpha,
|
|
phi,
|
|
epsilonRec,
|
|
refractoryCounter,
|
|
refractoryDuration,
|
|
gammaPd,
|
|
firingCounter,
|
|
recSignal,
|
|
exInType,
|
|
wRecChange,
|
|
neuronInactivityCounter,
|
|
synapseReconnectDelay,
|
|
synapticActivityCounter,
|
|
timeStep,
|
|
|
|
GeneralUtils.linear_to_cartesian,
|
|
)
|
|
config = launch_configuration(kernel.fun)
|
|
|
|
|
|
# threads to be launched. Since one can't launch exact thread number the kernel needs,
|
|
# one just launch threads more than this kernel needs then use a guard inside the kernel
|
|
# to prevent unused threads to access memory.
|
|
threads = min(1024, config.threads) # depend on gpu. Most NVIDIA gpu has 1024 threads per block
|
|
|
|
# total desired threads to launch to gpu. Usually 1 thread per 1 matrix element
|
|
totalThreads = length(wRec)
|
|
|
|
blocks = cld(totalThreads, threads)
|
|
# println("launching gpu kernel")
|
|
CUDA.@sync begin
|
|
kernel( zit,
|
|
wRec,
|
|
vt,
|
|
vth,
|
|
vRest,
|
|
zt,
|
|
alpha,
|
|
phi,
|
|
epsilonRec,
|
|
refractoryCounter,
|
|
refractoryDuration,
|
|
gammaPd,
|
|
firingCounter,
|
|
recSignal,
|
|
exInType,
|
|
wRecChange,
|
|
neuronInactivityCounter,
|
|
synapseReconnectDelay,
|
|
synapticActivityCounter,
|
|
timeStep,
|
|
GeneralUtils.linear_to_cartesian; threads, blocks)
|
|
end
|
|
end
|
|
|
|
# gpu kernel
|
|
function lifForward( zit,
|
|
wRec,
|
|
vt,
|
|
vth,
|
|
vRest,
|
|
zt,
|
|
alpha,
|
|
phi,
|
|
epsilonRec,
|
|
refractoryCounter,
|
|
refractoryDuration,
|
|
gammaPd,
|
|
firingCounter,
|
|
recSignal,
|
|
exInType,
|
|
wRecChange,
|
|
neuronInactivityCounter,
|
|
synapseReconnectDelay,
|
|
synapticActivityCounter,
|
|
timeStep,
|
|
linear_to_cartesian,
|
|
)
|
|
i = (blockIdx().x - 1) * blockDim().x + threadIdx().x # gpu threads index
|
|
|
|
if i <= length(wRec)
|
|
# cartesian index
|
|
i1, i2, i3, i4 = linear_to_cartesian(i, size(wRec))
|
|
# @cuprintln("gpu thread $i $i1 $i2 $i3 $i4")
|
|
|
|
if refractoryCounter[i1,i2,i3,i4] > 0 # refractory period is active
|
|
refractoryCounter[i1,i2,i3,i4] -= 1
|
|
recSignal[i1,i2,i3,i4] = 0
|
|
zt[i1,i2,i3,i4] = 0
|
|
vt[i1,i2,i3,i4] = alpha[i1,i2,i3,i4] * vt[i1,i2,i3,i4]
|
|
phi[i1,i2,i3,i4] = 0
|
|
|
|
# compute epsilonRec
|
|
epsilonRec[i1,i2,i3,i4] = (alpha[i1,i2,i3,i4] * epsilonRec[i1,i2,i3,i4])
|
|
|
|
else # refractory period is inactive
|
|
recSignal[i1,i2,i3,i4] = wRec[i1,i2,i3,i4] * zit[i1,i2,i3,i4] *
|
|
exInType[i1,i2,i3,i4]
|
|
vt[i1,i2,i3,i4] = (alpha[i1,i2,i3,i4] * vt[i1,i2,i3,i4]) +
|
|
sum(@view(recSignal[:,:,i3,i4]))
|
|
|
|
# fires if membrane potential exceed threshold
|
|
if vt[i1,i2,i3,i4] > vth[i1,i2,i3,i4]
|
|
zt[i1,i2,i3,i4] = 1
|
|
refractoryCounter[i1,i2,i3,i4] = refractoryDuration[i1,i2,i3,i4]
|
|
firingCounter[i1,i2,i3,i4] += 1
|
|
vt[i1,i2,i3,i4] = vRest[i1,i2,i3,i4]
|
|
|
|
# reset counter if neuron fires
|
|
neuronInactivityCounter[i1,i2,i3,i4] = 0
|
|
else
|
|
zt[i1,i2,i3,i4] = 0
|
|
neuronInactivityCounter[i1,i2,i3,i4] -= 1
|
|
end
|
|
|
|
# compute phi, there is a difference from lif formula
|
|
phi[i1,i2,i3,i4] = (gammaPd[i1,i2,i3,i4] / vth[i1,i2,i3,i4]) *
|
|
max(0, 1 - ((vt[i1,i2,i3,i4] - vth[i1,i2,i3,i4]) / vth[i1,i2,i3,i4]))
|
|
|
|
# compute epsilonRec
|
|
epsilonRec[i1,i2,i3,i4] = (alpha[i1,i2,i3,i4] * epsilonRec[i1,i2,i3,i4]) +
|
|
(zit[i1,i2,i3,i4] * !iszero(wRec[i1,i2,i3,i4]))
|
|
# !iszero indicates synaptic subscription
|
|
|
|
synapticActivityCounter[i1,i2,i3,i4] += zit[i1,i2,i3,i4] * !iszero(wRec[i1,i2,i3,i4])
|
|
|
|
# voltage regulator
|
|
wRecChange[i1,i2,i3,i4] = -0.01*0.0001 * (vt[i1,i2,i3,i4] - vth[i1,i2,i3,i4]) *
|
|
zit[i1,i2,i3,i4]
|
|
|
|
# negative value is counting mode, -0.1 < -0.1 won't work on GPU
|
|
if synapseReconnectDelay[i1,i2,i3,i4] < -0.2
|
|
synapseReconnectDelay[i1,i2,i3,i4] += 1
|
|
if synapseReconnectDelay[i1,i2,i3,i4] == 0
|
|
# mark timestep
|
|
synapseReconnectDelay[i1,i2,i3,i4] = sum(timeStep)
|
|
end
|
|
end
|
|
end
|
|
end
|
|
return nothing
|
|
end
|
|
|
|
# gpu launcher
|
|
function alifForward( zit::CuArray,
|
|
wRec::CuArray,
|
|
vt::CuArray,
|
|
vth::CuArray,
|
|
vRest::CuArray,
|
|
zt::CuArray,
|
|
alpha::CuArray,
|
|
phi::CuArray,
|
|
epsilonRec::CuArray,
|
|
refractoryCounter::CuArray,
|
|
refractoryDuration::CuArray,
|
|
gammaPd::CuArray,
|
|
firingCounter::CuArray,
|
|
recSignal::CuArray,
|
|
exInType::CuArray,
|
|
wRecChange::CuArray,
|
|
neuronInactivityCounter::CuArray,
|
|
synapseReconnectDelay::CuArray,
|
|
synapticActivityCounter::CuArray,
|
|
timeStep::CuArray,
|
|
|
|
epsilonRecA::CuArray,
|
|
a::CuArray,
|
|
avth::CuArray,
|
|
beta::CuArray,
|
|
rho::CuArray,
|
|
)
|
|
|
|
kernel = @cuda launch=false alifForward( zit,
|
|
wRec,
|
|
vt,
|
|
vth,
|
|
vRest,
|
|
zt,
|
|
alpha,
|
|
phi,
|
|
epsilonRec,
|
|
refractoryCounter,
|
|
refractoryDuration,
|
|
gammaPd,
|
|
firingCounter,
|
|
recSignal,
|
|
exInType,
|
|
wRecChange,
|
|
neuronInactivityCounter,
|
|
synapseReconnectDelay,
|
|
synapticActivityCounter,
|
|
timeStep,
|
|
|
|
epsilonRecA,
|
|
a,
|
|
avth,
|
|
beta,
|
|
rho,
|
|
GeneralUtils.linear_to_cartesian,
|
|
)
|
|
config = launch_configuration(kernel.fun)
|
|
|
|
# threads to be launched. Since one can't launch exact thread number the kernel needs,
|
|
# one just launch threads more than this kernel needs then use a guard inside the kernel
|
|
# to prevent unused threads to access memory.
|
|
threads = min(1024, config.threads) # depend on gpu. Most NVIDIA gpu has 1024 threads per block
|
|
|
|
# total desired threads to launch to gpu. Usually 1 thread per 1 matrix element
|
|
totalThreads = length(wRec)
|
|
|
|
blocks = cld(totalThreads, threads)
|
|
# println("launching gpu kernel")
|
|
CUDA.@sync begin
|
|
kernel( zit,
|
|
wRec,
|
|
vt,
|
|
vth,
|
|
vRest,
|
|
zt,
|
|
alpha,
|
|
phi,
|
|
epsilonRec,
|
|
refractoryCounter,
|
|
refractoryDuration,
|
|
gammaPd,
|
|
firingCounter,
|
|
recSignal,
|
|
exInType,
|
|
wRecChange,
|
|
neuronInactivityCounter,
|
|
synapseReconnectDelay,
|
|
synapticActivityCounter,
|
|
timeStep,
|
|
|
|
epsilonRecA,
|
|
a,
|
|
avth,
|
|
beta,
|
|
rho,
|
|
GeneralUtils.linear_to_cartesian; threads, blocks)
|
|
end
|
|
end
|
|
|
|
# gpu kernel
|
|
function alifForward( zit,
|
|
wRec,
|
|
vt,
|
|
vth,
|
|
vRest,
|
|
zt,
|
|
alpha,
|
|
phi,
|
|
epsilonRec,
|
|
refractoryCounter,
|
|
refractoryDuration,
|
|
gammaPd,
|
|
firingCounter,
|
|
recSignal,
|
|
exInType,
|
|
wRecChange,
|
|
neuronInactivityCounter,
|
|
synapseReconnectDelay,
|
|
synapticActivityCounter,
|
|
timeStep,
|
|
|
|
epsilonRecA,
|
|
a,
|
|
avth,
|
|
beta,
|
|
rho,
|
|
linear_to_cartesian,
|
|
)
|
|
i = (blockIdx().x - 1) * blockDim().x + threadIdx().x # gpu threads index
|
|
|
|
if i <= length(wRec)
|
|
# cartesian index
|
|
i1, i2, i3, i4 = linear_to_cartesian(i, size(wRec))
|
|
# @cuprintln("gpu thread $i $i1 $i2 $i3 $i4")
|
|
|
|
if refractoryCounter[i1,i2,i3,i4] > 0 # refractory period is active
|
|
refractoryCounter[i1,i2,i3,i4] -= 1
|
|
recSignal[i1,i2,i3,i4] = 0
|
|
zt[i1,i2,i3,i4] = 0
|
|
vt[i1,i2,i3,i4] = alpha[i1,i2,i3,i4] * vt[i1,i2,i3,i4]
|
|
phi[i1,i2,i3,i4] = 0
|
|
a[i1,i2,i3,i4] = rho[i1,i2,i3,i4] * a[i1,i2,i3,i4]
|
|
|
|
# compute epsilonRec
|
|
epsilonRec[i1,i2,i3,i4] = (alpha[i1,i2,i3,i4] * epsilonRec[i1,i2,i3,i4])
|
|
|
|
# compute epsilonRecA use eq.26
|
|
epsilonRecA[i1,i2,i3,i4] = (rho[i1,i2,i3,i4] *
|
|
(phi[i1,i2,i3,i4] * epsilonRec[i1,i2,i3,i4]))
|
|
|
|
# compute avth
|
|
avth[i1,i2,i3,i4] = vth[i1,i2,i3,i4] + (beta[i1,i2,i3,i4] * a[i1,i2,i3,i4])
|
|
|
|
else # refractory period is inactive
|
|
recSignal[i1,i2,i3,i4] = wRec[i1,i2,i3,i4] * zit[i1,i2,i3,i4] *
|
|
exInType[i1,i2,i3,i4]
|
|
vt[i1,i2,i3,i4] = (alpha[i1,i2,i3,i4] * vt[i1,i2,i3,i4]) +
|
|
sum(@view(recSignal[:,:,i3,i4]))
|
|
|
|
# compute avth
|
|
avth[i1,i2,i3,i4] = vth[i1,i2,i3,i4] + (beta[i1,i2,i3,i4] * a[i1,i2,i3,i4])
|
|
|
|
# fires if membrane potential exceed threshold
|
|
if vt[i1,i2,i3,i4] > avth[i1,i2,i3,i4]
|
|
zt[i1,i2,i3,i4] = 1
|
|
refractoryCounter[i1,i2,i3,i4] = refractoryDuration[i1,i2,i3,i4]
|
|
firingCounter[i1,i2,i3,i4] += 1
|
|
vt[i1,i2,i3,i4] = vRest[i1,i2,i3,i4]
|
|
a[i1,i2,i3,i4] = (rho[i1,i2,i3,i4] * a[i1,i2,i3,i4]) + 1
|
|
neuronInactivityCounter[i1,i2,i3,i4] = 0
|
|
else
|
|
zt[i1,i2,i3,i4] = 0
|
|
a[i1,i2,i3,i4] = (rho[i1,i2,i3,i4] * a[i1,i2,i3,i4])
|
|
neuronInactivityCounter[i1,i2,i3,i4] -= 1
|
|
end
|
|
|
|
# compute phi, there is a difference from alif formula
|
|
phi[i1,i2,i3,i4] = (gammaPd[i1,i2,i3,i4] / vth[i1,i2,i3,i4]) *
|
|
max(0, 1 - ((vt[i1,i2,i3,i4] - vth[i1,i2,i3,i4]) / vth[i1,i2,i3,i4]))
|
|
|
|
# compute epsilonRec
|
|
epsilonRec[i1,i2,i3,i4] = (alpha[i1,i2,i3,i4] * epsilonRec[i1,i2,i3,i4]) +
|
|
(zit[i1,i2,i3,i4] * !iszero(wRec[i1,i2,i3,i4]))
|
|
# compute epsilonRecA use eq.26
|
|
epsilonRecA[i1,i2,i3,i4] = (rho[i1,i2,i3,i4] *
|
|
(phi[i1,i2,i3,i4] * epsilonRec[i1,i2,i3,i4])) +
|
|
(zit[i1,i2,i3,i4] * !iszero(wRec[i1,i2,i3,i4]))
|
|
|
|
synapticActivityCounter[i1,i2,i3,i4] += zit[i1,i2,i3,i4] * !iszero(wRec[i1,i2,i3,i4])
|
|
|
|
# voltage regulator
|
|
wRecChange[i1,i2,i3,i4] = -0.01*0.0001 * (vt[i1,i2,i3,i4] - avth[i1,i2,i3,i4]) *
|
|
zit[i1,i2,i3,i4]
|
|
|
|
# negative value is counting mode, -0.1 < -0.1 won't work on GPU
|
|
if synapseReconnectDelay[i1,i2,i3,i4] < -0.2
|
|
synapseReconnectDelay[i1,i2,i3,i4] += 1
|
|
if synapseReconnectDelay[i1,i2,i3,i4] == 0
|
|
# mark timestep
|
|
synapseReconnectDelay[i1,i2,i3,i4] = sum(timeStep)
|
|
end
|
|
end
|
|
end
|
|
end
|
|
return nothing
|
|
end
|
|
|
|
# gpu launcher
|
|
function onForward( zit::CuArray,
|
|
wOut::CuArray,
|
|
vt::CuArray,
|
|
vth::CuArray,
|
|
vRest::CuArray,
|
|
zt::CuArray,
|
|
alpha::CuArray,
|
|
phi::CuArray,
|
|
epsilonRec::CuArray,
|
|
refractoryCounter::CuArray,
|
|
refractoryDuration::CuArray,
|
|
gammaPd::CuArray,
|
|
firingCounter::CuArray,
|
|
recSignal::CuArray,
|
|
synapticActivityCounter::CuArray,
|
|
)
|
|
|
|
kernel = @cuda launch=false onForward( zit,
|
|
wOut,
|
|
vt,
|
|
vth,
|
|
vRest,
|
|
zt,
|
|
alpha,
|
|
phi,
|
|
epsilonRec,
|
|
refractoryCounter,
|
|
refractoryDuration,
|
|
gammaPd,
|
|
firingCounter,
|
|
recSignal,
|
|
synapticActivityCounter,
|
|
GeneralUtils.linear_to_cartesian,
|
|
)
|
|
config = launch_configuration(kernel.fun)
|
|
|
|
# threads to be launched. Since one can't launch exact thread number the kernel needs,
|
|
# one just launch threads more than this kernel needs then use a guard inside the kernel
|
|
# to prevent unused threads to access memory.
|
|
threads = min(1024, config.threads) # depend on gpu. Most NVIDIA gpu has 1024 threads per block
|
|
|
|
# total desired threads to launch to gpu. Usually 1 thread per 1 matrix element
|
|
totalThreads = length(wOut)
|
|
|
|
blocks = cld(totalThreads, threads)
|
|
# println("launching gpu kernel")
|
|
CUDA.@sync begin
|
|
kernel( zit,
|
|
wOut,
|
|
vt,
|
|
vth,
|
|
vRest,
|
|
zt,
|
|
alpha,
|
|
phi,
|
|
epsilonRec,
|
|
refractoryCounter,
|
|
refractoryDuration,
|
|
gammaPd,
|
|
firingCounter,
|
|
recSignal,
|
|
synapticActivityCounter,
|
|
GeneralUtils.linear_to_cartesian; threads, blocks)
|
|
end
|
|
end
|
|
|
|
# gpu kernel
|
|
function onForward( zit,
|
|
wOut,
|
|
vt,
|
|
vth,
|
|
vRest,
|
|
zt,
|
|
alpha,
|
|
phi,
|
|
epsilonRec,
|
|
refractoryCounter,
|
|
refractoryDuration,
|
|
gammaPd,
|
|
firingCounter,
|
|
recSignal,
|
|
synapticActivityCounter,
|
|
linear_to_cartesian,
|
|
)
|
|
i = (blockIdx().x - 1) * blockDim().x + threadIdx().x # gpu threads index
|
|
|
|
if i <= length(wOut)
|
|
# cartesian index
|
|
i1, i2, i3, i4 = linear_to_cartesian(i, size(wOut))
|
|
# @cuprintln("gpu thread $i $i1 $i2 $i3 $i4")
|
|
|
|
if refractoryCounter[i1,i2,i3,i4] > 0 # refractory period is active
|
|
refractoryCounter[i1,i2,i3,i4] -= 1
|
|
recSignal[i1,i2,i3,i4] = 0
|
|
zt[i1,i2,i3,i4] = 0
|
|
vt[i1,i2,i3,i4] = alpha[i1,i2,i3,i4] * vt[i1,i2,i3,i4]
|
|
phi[i1,i2,i3,i4] = 0
|
|
|
|
# compute epsilonRec
|
|
epsilonRec[i1,i2,i3,i4] = (alpha[i1,i2,i3,i4] * epsilonRec[i1,i2,i3,i4])
|
|
|
|
else # refractory period is inactive
|
|
recSignal[i1,i2,i3,i4] = zit[i1,i2,i3,i4] * wOut[i1,i2,i3,i4]
|
|
vt[i1,i2,i3,i4] = (alpha[i1,i2,i3,i4] * vt[i1,i2,i3,i4]) + sum(@view(recSignal[:,:,i3,i4]))
|
|
|
|
# fires if membrane potential exceed threshold
|
|
if vt[i1,i2,i3,i4] > vth[i1,i2,i3,i4]
|
|
zt[i1,i2,i3,i4] = 1
|
|
refractoryCounter[i1,i2,i3,i4] = refractoryDuration[i1,i2,i3,i4]
|
|
firingCounter[i1,i2,i3,i4] += 1
|
|
vt[i1,i2,i3,i4] = vRest[i1,i2,i3,i4]
|
|
else
|
|
zt[i1,i2,i3,i4] = 0
|
|
end
|
|
|
|
# compute phi, there is a difference from on formula
|
|
phi[i1,i2,i3,i4] = (gammaPd[i1,i2,i3,i4] / vth[i1,i2,i3,i4]) *
|
|
max(0, 1 - ((vt[i1,i2,i3,i4] - vth[i1,i2,i3,i4]) / vth[i1,i2,i3,i4]))
|
|
|
|
# compute epsilonRec
|
|
epsilonRec[i1,i2,i3,i4] = (alpha[i1,i2,i3,i4] * epsilonRec[i1,i2,i3,i4]) +
|
|
(zit[i1,i2,i3,i4] * !iszero(wOut[i1,i2,i3,i4]))
|
|
|
|
synapticActivityCounter[i1,i2,i3,i4] += zit[i1,i2,i3,i4] * !iszero(wOut[i1,i2,i3,i4])
|
|
end
|
|
end
|
|
return nothing
|
|
end
|
|
|
|
# function lifForward(kfn_zit::Array{T},
|
|
# zit::Array{T},
|
|
# wRec::Array{T},
|
|
# vt0::Array{T},
|
|
# vt1::Array{T},
|
|
# vth::Array{T},
|
|
# vRest::Array{T},
|
|
# zt1::Array{T},
|
|
# alpha::Array{T},
|
|
# phi::Array{T},
|
|
# epsilonRec::Array{T},
|
|
# refractoryCounter::Array{T},
|
|
# refractoryDuration::Array{T},
|
|
# gammaPd::Array{T},
|
|
# firingCounter::Array{T},
|
|
# arrayProjection4d::Array{T},
|
|
# recSignal::Array{T},
|
|
# decayed_vt0::Array{T},
|
|
# decayed_epsilonRec::Array{T},
|
|
# vt1_diff_vth::Array{T},
|
|
# vt1_diff_vth_div_vth::Array{T},
|
|
# gammaPd_div_vth::Array{T},
|
|
# phiActivation::Array{T},
|
|
# ) where T<:Number
|
|
|
|
# # project 3D kfn zit into 4D lif zit
|
|
# i1, i2, i3, i4 = size(alif_wRec)
|
|
# lif_zit .= reshape(kfn_zit, (i1, i2, 1, i4)) .* lif_arrayProjection4d
|
|
|
|
# for j in 1:size(wRec, 4), i in 1:size(wRec, 3) # compute along neurons axis of every batch
|
|
# if sum(@view(refractoryCounter[:,:,i,j])) > 0 # refractory period is active
|
|
# @. @views refractoryCounter[:,:,i,j] -= 1
|
|
# @. @views zt1[:,:,i,j] = 0
|
|
# @. @views vt1[:,:,i,j] = alpha[:,:,i,j] * vt0[:,:,i,j]
|
|
# @. @views phi[:,:,i,j] = 0
|
|
|
|
# # compute epsilonRec
|
|
# @. @views decayed_epsilonRec[:,:,i,j] = alpha[:,:,i,j] * epsilonRec[:,:,i,j]
|
|
# @. @views epsilonRec[:,:,i,j] = decayed_epsilonRec[:,:,i,j]
|
|
# else # refractory period is inactive
|
|
# @. @views recSignal[:,:,i,j] = zit[:,:,i,j] * wRec[:,:,i,j]
|
|
# @. @views decayed_vt0[:,:,i,j] = alpha[:,:,i,j] * vt0[:,:,i,j]
|
|
# @view(vt1[:,:,i,j]) .= @view(decayed_vt0[:,:,i,j]) .+ sum(@view(recSignal[:,:,i,j]))
|
|
|
|
# if sum(@view(vt1[:,:,i,j])) > sum(@view(vth[:,:,i,j]))
|
|
# @. @views zt1[:,:,i,j] = 1
|
|
# @. @views refractoryCounter[:,:,i,j] = refractoryDuration[:,:,i,j]
|
|
# @. @views firingCounter[:,:,i,j] += 1
|
|
# @. @views vt1[:,:,i,j] = vRest[:,:,i,j]
|
|
# else
|
|
# @. @views zt1[:,:,i,j] = 0
|
|
# end
|
|
|
|
# # compute phi, there is a difference from alif formula
|
|
# @. @views gammaPd_div_vth[:,:,i,j] = gammaPd[:,:,i,j] / vth[:,:,i,j]
|
|
# @. @views vt1_diff_vth[:,:,i,j] = vt1[:,:,i,j] - vth[:,:,i,j]
|
|
# @. @views vt1_diff_vth_div_vth[:,:,i,j] = vt1_diff_vth[:,:,i,j] / vth[:,:,i,j]
|
|
# @view(phiActivation[:,:,i,j]) .= max(0, 1 - sum(@view(vt1_diff_vth_div_vth[:,:,i,j])))
|
|
# @. @views phi[:,:,i,j] = gammaPd_div_vth[:,:,i,j] * phiActivation[:,:,i,j]
|
|
|
|
# # compute epsilonRec
|
|
# @. @views decayed_epsilonRec[:,:,i,j] = alpha[:,:,i,j] * epsilonRec[:,:,i,j]
|
|
# @. @views epsilonRec[:,:,i,j] = decayed_epsilonRec[:,:,i,j] + zit[:,:,i,j]
|
|
# end
|
|
# end
|
|
# end
|
|
|
|
# function alifForward(zit::Array{T},
|
|
# wRec::Array{T},
|
|
# vt0::Array{T},
|
|
# vt1::Array{T},
|
|
# vth::Array{T},
|
|
# vRest::Array{T},
|
|
# zt1::Array{T},
|
|
# alpha::Array{T},
|
|
# phi::Array{T},
|
|
# epsilonRec::Array{T},
|
|
# refractoryCounter::Array{T},
|
|
# refractoryDuration::Array{T},
|
|
# gammaPd::Array{T},
|
|
# firingCounter::Array{T},
|
|
# recSignal::Array{T},
|
|
# decayed_vt0::Array{T},
|
|
# decayed_epsilonRec::Array{T},
|
|
# vt1_diff_vth::Array{T},
|
|
# vt1_diff_vth_div_vth::Array{T},
|
|
# gammaPd_div_vth::Array{T},
|
|
# phiActivation::Array{T},
|
|
|
|
# epsilonRecA::Array{T},
|
|
# avth::Array{T},
|
|
# a::Array{T},
|
|
# beta::Array{T},
|
|
# rho::Array{T},
|
|
# phi_x_epsilonRec::Array{T},
|
|
# phi_x_beta::Array{T},
|
|
# rho_diff_phi_x_beta::Array{T},
|
|
# rho_div_phi_x_beta_x_epsilonRecA::Array{T},
|
|
# beta_x_a::Array{T},
|
|
# ) where T<:Number
|
|
|
|
# for j in 1:size(wRec, 4), i in 1:size(wRec, 3) # compute along neurons axis of every batch
|
|
# if sum(@view(refractoryCounter[:,:,i,j])) > 0 # refractory period is active
|
|
# @. @views refractoryCounter[:,:,i,j] -= 1
|
|
# @. @views zt1[:,:,i,j] = 0
|
|
# @. @views vt1[:,:,i,j] = alpha[:,:,i,j] * vt0[:,:,i,j]
|
|
# @. @views phi[:,:,i,j] = 0
|
|
# @. @views a[:,:,i,j] = rho[:,:,i,j] * a[:,:,i,j]
|
|
|
|
# # compute epsilonRec
|
|
# @. @views decayed_epsilonRec[:,:,i,j] = alpha[:,:,i,j] * epsilonRec[:,:,i,j]
|
|
# @. @views epsilonRec[:,:,i,j] = decayed_epsilonRec[:,:,i,j]
|
|
|
|
# # compute epsilonRecA
|
|
# @. @views phi_x_epsilonRec[:,:,i,j] = phi[:,:,i,j] * epsilonRec[:,:,i,j]
|
|
# @. @views phi_x_beta[:,:,i,j] = phi[:,:,i,j] * beta[:,:,i,j]
|
|
# @. @views rho_diff_phi_x_beta[:,:,i,j] = rho[:,:,i,j] - phi_x_beta[:,:,i,j]
|
|
# @. @views rho_div_phi_x_beta_x_epsilonRecA[:,:,i,j] = rho_diff_phi_x_beta[:,:,i,j] * epsilonRecA[:,:,i,j]
|
|
# @. @views epsilonRecA[:,:,i,j] = phi_x_epsilonRec[:,:,i,j] + rho_div_phi_x_beta_x_epsilonRecA[:,:,i,j]
|
|
|
|
# # compute avth
|
|
# @. @views beta_x_a[:,:,i,j] = beta[:,:,i,j] * a[:,:,i,j]
|
|
# @. @views avth[:,:,i,j] = vth[:,:,i,j] + beta_x_a[:,:,i,j]
|
|
|
|
# else # refractory period is inactive
|
|
# @. @views recSignal[:,:,i,j] = zit[:,:,i,j] * wRec[:,:,i,j]
|
|
# @. @views decayed_vt0[:,:,i,j] = alpha[:,:,i,j] * vt0[:,:,i,j]
|
|
# @view(vt1[:,:,i,j]) .= @view(decayed_vt0[:,:,i,j]) .+ sum(@view(recSignal[:,:,i,j]))
|
|
|
|
# # compute avth
|
|
# @. @views beta_x_a[:,:,i,j] = beta[:,:,i,j] * a[:,:,i,j]
|
|
# @. @views avth[:,:,i,j] = vth[:,:,i,j] + beta_x_a[:,:,i,j]
|
|
|
|
# if sum(@view(vt1[:,:,i,j])) > sum(@view(avth[:,:,i,j]))
|
|
# @. @views zt1[:,:,i,j] = 1
|
|
# @. @views refractoryCounter[:,:,i,j] = refractoryDuration[:,:,i,j]
|
|
# @. @views firingCounter[:,:,i,j] += 1
|
|
# @. @views vt1[:,:,i,j] = vRest[:,:,i,j]
|
|
# @. @views a[:,:,i,j] = rho[:,:,i,j] * a[:,:,i,j]
|
|
# @. @views a[:,:,i,j] = a[:,:,i,j] += 1
|
|
# else
|
|
# @. @views zt1[:,:,i,j] = 0
|
|
# @. @views a[:,:,i,j] = rho[:,:,i,j] * a[:,:,i,j]
|
|
# end
|
|
|
|
# # compute phi, there is a difference from alif formula
|
|
# @. @views gammaPd_div_vth[:,:,i,j] = gammaPd[:,:,i,j] / vth[:,:,i,j]
|
|
# @. @views vt1_diff_vth[:,:,i,j] = vt1[:,:,i,j] - vth[:,:,i,j]
|
|
# @. @views vt1_diff_vth_div_vth[:,:,i,j] = vt1_diff_vth[:,:,i,j] / vth[:,:,i,j]
|
|
# @view(phiActivation[:,:,i,j]) .= max(0, 1 - sum(@view(vt1_diff_vth_div_vth[:,:,i,j])))
|
|
# @. @views phi[:,:,i,j] = gammaPd_div_vth[:,:,i,j] * phiActivation[:,:,i,j]
|
|
|
|
# # compute epsilonRec
|
|
# @. @views decayed_epsilonRec[:,:,i,j] = alpha[:,:,i,j] * epsilonRec[:,:,i,j]
|
|
# @. @views epsilonRec[:,:,i,j] = decayed_epsilonRec[:,:,i,j] + zit[:,:,i,j]
|
|
|
|
# # compute epsilonRecA
|
|
# @. @views phi_x_epsilonRec[:,:,i,j] = phi[:,:,i,j] * epsilonRec[:,:,i,j]
|
|
# @. @views phi_x_beta[:,:,i,j] = phi[:,:,i,j] * beta[:,:,i,j]
|
|
# @. @views rho_diff_phi_x_beta[:,:,i,j] = rho[:,:,i,j] - phi_x_beta[:,:,i,j]
|
|
# @. @views rho_div_phi_x_beta_x_epsilonRecA[:,:,i,j] = rho_diff_phi_x_beta[:,:,i,j] * epsilonRecA[:,:,i,j]
|
|
# @. @views epsilonRecA[:,:,i,j] = phi_x_epsilonRec[:,:,i,j] + rho_div_phi_x_beta_x_epsilonRecA[:,:,i,j]
|
|
# end
|
|
# end
|
|
# end
|
|
|
|
# function onForward(kfn_zit::Array{T},
|
|
# zit::Array{T},
|
|
# wOut::Array{T},
|
|
# vt0::Array{T},
|
|
# vt1::Array{T},
|
|
# vth::Array{T},
|
|
# vRest::Array{T},
|
|
# zt1::Array{T},
|
|
# alpha::Array{T},
|
|
# phi::Array{T},
|
|
# epsilonRec::Array{T},
|
|
# refractoryCounter::Array{T},
|
|
# refractoryDuration::Array{T},
|
|
# gammaPd::Array{T},
|
|
# firingCounter::Array{T},
|
|
# arrayProjection4d::Array{T},
|
|
# recSignal::Array{T},
|
|
# decayed_vt0::Array{T},
|
|
# decayed_epsilonRec::Array{T},
|
|
# vt1_diff_vth::Array{T},
|
|
# vt1_diff_vth_div_vth::Array{T},
|
|
# gammaPd_div_vth::Array{T},
|
|
# phiActivation::Array{T},
|
|
# ) where T<:Number
|
|
|
|
# # project 3D kfn zit into 4D lif zit
|
|
# zit .= reshape(kfn_zit,
|
|
# (size(wOut, 1), size(wOut, 2), 1, size(wOut, 4))) .* arrayProjection4d
|
|
|
|
# for j in 1:size(wOut, 4), i in 1:size(wOut, 3) # compute along neurons axis of every batch
|
|
# if sum(@view(refractoryCounter[:,:,i,j])) > 0 # refractory period is active
|
|
# @. @views refractoryCounter[:,:,i,j] -= 1
|
|
# @. @views zt1[:,:,i,j] = 0
|
|
# @. @views vt1[:,:,i,j] = alpha[:,:,i,j] * vt0[:,:,i,j]
|
|
# @. @views phi[:,:,i,j] = 0
|
|
|
|
# # compute epsilonRec
|
|
# @. @views decayed_epsilonRec[:,:,i,j] = alpha[:,:,i,j] * epsilonRec[:,:,i,j]
|
|
# @. @views epsilonRec[:,:,i,j] = decayed_epsilonRec[:,:,i,j]
|
|
# else # refractory period is inactive
|
|
# @. @views recSignal[:,:,i,j] = zit[:,:,i,j] * wOut[:,:,i,j]
|
|
# @. @views decayed_vt0[:,:,i,j] = alpha[:,:,i,j] * vt0[:,:,i,j]
|
|
# @view(vt1[:,:,i,j]) .= @view(decayed_vt0[:,:,i,j]) .+ sum(@view(recSignal[:,:,i,j]))
|
|
|
|
# if sum(@view(vt1[:,:,i,j])) > sum(@view(vth[:,:,i,j]))
|
|
# @. @views zt1[:,:,i,j] = 1
|
|
# @. @views refractoryCounter[:,:,i,j] = refractoryDuration[:,:,i,j]
|
|
# @. @views firingCounter[:,:,i,j] += 1
|
|
# @. @views vt1[:,:,i,j] = vRest[:,:,i,j]
|
|
# else
|
|
# @. @views zt1[:,:,i,j] = 0
|
|
# end
|
|
|
|
# # compute phi, there is a difference from alif formula
|
|
# @. @views gammaPd_div_vth[:,:,i,j] = gammaPd[:,:,i,j] / vth[:,:,i,j]
|
|
# @. @views vt1_diff_vth[:,:,i,j] = vt1[:,:,i,j] - vth[:,:,i,j]
|
|
# @. @views vt1_diff_vth_div_vth[:,:,i,j] = vt1_diff_vth[:,:,i,j] / vth[:,:,i,j]
|
|
# @view(phiActivation[:,:,i,j]) .= max(0, 1 - sum(@view(vt1_diff_vth_div_vth[:,:,i,j])))
|
|
# @. @views phi[:,:,i,j] = gammaPd_div_vth[:,:,i,j] * phiActivation[:,:,i,j]
|
|
|
|
# # compute epsilonRec
|
|
# @. @views decayed_epsilonRec[:,:,i,j] = alpha[:,:,i,j] * epsilonRec[:,:,i,j]
|
|
# @. @views epsilonRec[:,:,i,j] = decayed_epsilonRec[:,:,i,j] + zit[:,:,i,j]
|
|
# end
|
|
# end
|
|
# end
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
end # module |