From 65bb97baf3529c85a6700a38aa26d9113365c617 Mon Sep 17 00:00:00 2001 From: ton Date: Mon, 7 Aug 2023 17:06:58 +0700 Subject: [PATCH] bug fix replace [i] with [i1,i2,i3,i4] --- src/forward.jl | 124 ++++++++++++++++++++++++------------------------- 1 file changed, 62 insertions(+), 62 deletions(-) diff --git a/src/forward.jl b/src/forward.jl index e11e36c..20fb70b 100644 --- a/src/forward.jl +++ b/src/forward.jl @@ -118,14 +118,14 @@ function (kfn::kfn_1)(input::AbstractArray) kfn.on_gammaPd, kfn.on_firingCounter, kfn.on_recSignal,) - # error("DEBUG -> kfn forward") + logit = reshape(kfn.on_zt, (size(input, 1), :)) +# error("DEBUG -> kfn forward") return logit, kfn.zit end - function lifForward(kfn_zit::Array{T}, zit::Array{T}, wRec::Array{T}, @@ -279,36 +279,36 @@ function lifForward( zit, i1, i2, i3, i4 = linear_to_cartesian(i, size(wRec)) # @cuprintln("gpu thread $i $i1 $i2 $i3 $i4") - refractoryCounter[i] -= 1 + refractoryCounter[i1,i2,i3,i4] -= 1 - if refractoryCounter[i] > 0 # refractory period is active - refractoryCounter[i] -= 1 - zt[i] = 0 - vt[i] = alpha[i] * vt[i] - phi[i] = 0 + if refractoryCounter[i1,i2,i3,i4] > 0 # refractory period is active + refractoryCounter[i1,i2,i3,i4] -= 1 + zt[i1,i2,i3,i4] = 0 + vt[i1,i2,i3,i4] = alpha[i1,i2,i3,i4] * vt[i1,i2,i3,i4] + phi[i1,i2,i3,i4] = 0 # compute epsilonRec - epsilonRec[i] = (alpha[i] * epsilonRec[i]) + zit[i] + epsilonRec[i1,i2,i3,i4] = (alpha[i1,i2,i3,i4] * epsilonRec[i1,i2,i3,i4]) + zit[i1,i2,i3,i4] else # refractory period is inactive - recSignal[i] = zit[i] * wRec[i] - vt[i] = (alpha[i] * vt[i]) + sum(@view(recSignal[:,:,i3,i4])) + recSignal[i1,i2,i3,i4] = zit[i1,i2,i3,i4] * wRec[i1,i2,i3,i4] + vt[i1,i2,i3,i4] = (alpha[i1,i2,i3,i4] * vt[i1,i2,i3,i4]) + sum(@view(recSignal[:,:,i3,i4])) # fires if membrane potential exceed threshold - if vt[i] > vth[i] - zt[i] = 1 - refractoryCounter[i] = refractoryDuration[i] - firingCounter[i] += 1 - vt[i] = vRest[i] + if vt[i1,i2,i3,i4] > vth[i1,i2,i3,i4] + zt[i1,i2,i3,i4] = 1 + refractoryCounter[i1,i2,i3,i4] = refractoryDuration[i1,i2,i3,i4] + firingCounter[i1,i2,i3,i4] += 1 + vt[i1,i2,i3,i4] = vRest[i1,i2,i3,i4] else - zt[i] = 0 + zt[i1,i2,i3,i4] = 0 end # compute phi, there is a difference from lif formula - phi[i] = (gammaPd[i] / vth[i]) * max(0, 1 - ((vt[i] - vth[i]) / vth[i])) + phi[i1,i2,i3,i4] = (gammaPd[i1,i2,i3,i4] / vth[i1,i2,i3,i4]) * max(0, 1 - ((vt[i1,i2,i3,i4] - vth[i1,i2,i3,i4]) / vth[i1,i2,i3,i4])) # compute epsilonRec - epsilonRec[i] = (alpha[i] * epsilonRec[i]) + zit[i] + epsilonRec[i1,i2,i3,i4] = (alpha[i1,i2,i3,i4] * epsilonRec[i1,i2,i3,i4]) + zit[i1,i2,i3,i4] end end return nothing @@ -519,53 +519,53 @@ function alifForward( zit, i1, i2, i3, i4 = linear_to_cartesian(i, size(wRec)) # @cuprintln("gpu thread $i $i1 $i2 $i3 $i4") - refractoryCounter[i] -= 1 + refractoryCounter[i1,i2,i3,i4] -= 1 - if refractoryCounter[i] > 0 # refractory period is active - refractoryCounter[i] -= 1 - zt[i] = 0 - vt[i] = alpha[i] * vt[i] - phi[i] = 0 - a[i] = rho[i] * a[i] + if refractoryCounter[i1,i2,i3,i4] > 0 # refractory period is active + refractoryCounter[i1,i2,i3,i4] -= 1 + zt[i1,i2,i3,i4] = 0 + vt[i1,i2,i3,i4] = alpha[i1,i2,i3,i4] * vt[i1,i2,i3,i4] + phi[i1,i2,i3,i4] = 0 + a[i1,i2,i3,i4] = rho[i1,i2,i3,i4] * a[i1,i2,i3,i4] # compute epsilonRec - epsilonRec[i] = (alpha[i] * epsilonRec[i]) + zit[i] + epsilonRec[i1,i2,i3,i4] = (alpha[i1,i2,i3,i4] * epsilonRec[i1,i2,i3,i4]) + zit[i1,i2,i3,i4] # compute epsilonRecA - epsilonRecA[i] = (phi[i] * epsilonRec[i]) + - ((rho[i] - (phi[i] * beta[i])) * epsilonRecA[i]) + epsilonRecA[i1,i2,i3,i4] = (phi[i1,i2,i3,i4] * epsilonRec[i1,i2,i3,i4]) + + ((rho[i1,i2,i3,i4] - (phi[i1,i2,i3,i4] * beta[i1,i2,i3,i4])) * epsilonRecA[i1,i2,i3,i4]) # compute avth - avth[i] = vth[i] + (beta[i] * a[i]) + avth[i1,i2,i3,i4] = vth[i1,i2,i3,i4] + (beta[i1,i2,i3,i4] * a[i1,i2,i3,i4]) else # refractory period is inactive - recSignal[i] = zit[i] * wRec[i] - vt[i] = (alpha[i] * vt[i]) + sum(@view(recSignal[:,:,i3,i4])) + recSignal[i1,i2,i3,i4] = zit[i1,i2,i3,i4] * wRec[i1,i2,i3,i4] + vt[i1,i2,i3,i4] = (alpha[i1,i2,i3,i4] * vt[i1,i2,i3,i4]) + sum(@view(recSignal[:,:,i3,i4])) # compute avth - avth[i] = vth[i] + (beta[i] * a[i]) + avth[i1,i2,i3,i4] = vth[i1,i2,i3,i4] + (beta[i1,i2,i3,i4] * a[i1,i2,i3,i4]) # fires if membrane potential exceed threshold - if vt[i] > avth[i] - zt[i] = 1 - refractoryCounter[i] = refractoryDuration[i] - firingCounter[i] += 1 - vt[i] = vRest[i] - a[i] = (rho[i] * a[i]) + 1 + if vt[i1,i2,i3,i4] > avth[i1,i2,i3,i4] + zt[i1,i2,i3,i4] = 1 + refractoryCounter[i1,i2,i3,i4] = refractoryDuration[i1,i2,i3,i4] + firingCounter[i1,i2,i3,i4] += 1 + vt[i1,i2,i3,i4] = vRest[i1,i2,i3,i4] + a[i1,i2,i3,i4] = (rho[i1,i2,i3,i4] * a[i1,i2,i3,i4]) + 1 else - zt[i] = 0 - a[i] = (rho[i] * a[i]) + zt[i1,i2,i3,i4] = 0 + a[i1,i2,i3,i4] = (rho[i1,i2,i3,i4] * a[i1,i2,i3,i4]) end # compute phi, there is a difference from alif formula - phi[i] = (gammaPd[i] / vth[i]) * max(0, 1 - ((vt[i] - vth[i]) / vth[i])) + phi[i1,i2,i3,i4] = (gammaPd[i1,i2,i3,i4] / vth[i1,i2,i3,i4]) * max(0, 1 - ((vt[i1,i2,i3,i4] - vth[i1,i2,i3,i4]) / vth[i1,i2,i3,i4])) # compute epsilonRec - epsilonRec[i] = (alpha[i] * epsilonRec[i]) + zit[i] + epsilonRec[i1,i2,i3,i4] = (alpha[i1,i2,i3,i4] * epsilonRec[i1,i2,i3,i4]) + zit[i1,i2,i3,i4] # compute epsilonRecA - epsilonRecA[i] = (phi[i] * epsilonRec[i]) + - ((rho[i] - (phi[i] * beta[i])) * epsilonRecA[i]) + epsilonRecA[i1,i2,i3,i4] = (phi[i1,i2,i3,i4] * epsilonRec[i1,i2,i3,i4]) + + ((rho[i1,i2,i3,i4] - (phi[i1,i2,i3,i4] * beta[i1,i2,i3,i4])) * epsilonRecA[i1,i2,i3,i4]) end end return nothing @@ -723,36 +723,36 @@ function onForward( zit, i1, i2, i3, i4 = linear_to_cartesian(i, size(wOut)) # @cuprintln("gpu thread $i $i1 $i2 $i3 $i4") - refractoryCounter[i] -= 1 + refractoryCounter[i1,i2,i3,i4] -= 1 - if refractoryCounter[i] > 0 # refractory period is active - refractoryCounter[i] -= 1 - zt[i] = 0 - vt[i] = alpha[i] * vt[i] - phi[i] = 0 + if refractoryCounter[i1,i2,i3,i4] > 0 # refractory period is active + refractoryCounter[i1,i2,i3,i4] -= 1 + zt[i1,i2,i3,i4] = 0 + vt[i1,i2,i3,i4] = alpha[i1,i2,i3,i4] * vt[i1,i2,i3,i4] + phi[i1,i2,i3,i4] = 0 # compute epsilonRec - epsilonRec[i] = (alpha[i] * epsilonRec[i]) + zit[i] + epsilonRec[i1,i2,i3,i4] = (alpha[i1,i2,i3,i4] * epsilonRec[i1,i2,i3,i4]) + zit[i1,i2,i3,i4] else # refractory period is inactive - recSignal[i] = zit[i] * wOut[i] - vt[i] = (alpha[i] * vt[i]) + sum(@view(recSignal[:,:,i3,i4])) + recSignal[i1,i2,i3,i4] = zit[i1,i2,i3,i4] * wOut[i1,i2,i3,i4] + vt[i1,i2,i3,i4] = (alpha[i1,i2,i3,i4] * vt[i1,i2,i3,i4]) + sum(@view(recSignal[:,:,i3,i4])) # fires if membrane potential exceed threshold - if vt[i] > vth[i] - zt[i] = 1 - refractoryCounter[i] = refractoryDuration[i] - firingCounter[i] += 1 - vt[i] = vRest[i] + if vt[i1,i2,i3,i4] > vth[i1,i2,i3,i4] + zt[i1,i2,i3,i4] = 1 + refractoryCounter[i1,i2,i3,i4] = refractoryDuration[i1,i2,i3,i4] + firingCounter[i1,i2,i3,i4] += 1 + vt[i1,i2,i3,i4] = vRest[i1,i2,i3,i4] else - zt[i] = 0 + zt[i1,i2,i3,i4] = 0 end # compute phi, there is a difference from on formula - phi[i] = (gammaPd[i] / vth[i]) * max(0, 1 - ((vt[i] - vth[i]) / vth[i])) + phi[i1,i2,i3,i4] = (gammaPd[i1,i2,i3,i4] / vth[i1,i2,i3,i4]) * max(0, 1 - ((vt[i1,i2,i3,i4] - vth[i1,i2,i3,i4]) / vth[i1,i2,i3,i4])) # compute epsilonRec - epsilonRec[i] = (alpha[i] * epsilonRec[i]) + zit[i] + epsilonRec[i1,i2,i3,i4] = (alpha[i1,i2,i3,i4] * epsilonRec[i1,i2,i3,i4]) + zit[i1,i2,i3,i4] end end return nothing