diff --git a/src/forward.jl b/src/forward.jl index 4790036..e11e36c 100644 --- a/src/forward.jl +++ b/src/forward.jl @@ -40,48 +40,53 @@ function (kfn::kfn_1)(input::AbstractArray) reshape(kfn.alif_zt, (size(input, 1), :, 1, size(input, 3))), dims=2) kfn.zit .= reshape(_zit, (size(input, 1), :, size(input, 3))) - # project 3D kfn zit into 4D lif zit - i1, i2, i3, i4 = size(kfn.lif_zit) - kfn.lif_zit .= reshape(kfn.zit, (i1, i2, 1, i4)) .* kfn.lif_arrayProjection4d - - lifForward( kfn.lif_zit, - kfn.lif_wRec, - kfn.lif_vt, - kfn.lif_vth, - kfn.lif_vRest, - kfn.lif_zt4d, - kfn.lif_alpha, - kfn.lif_phi, - kfn.lif_epsilonRec, - kfn.lif_refractoryCounter, - kfn.lif_refractoryDuration, - kfn.lif_gammaPd, - kfn.lif_firingCounter, - kfn.lif_recSignal,) - - # project 3D kfn zit into 4D alif zit - i1, i2, i3, i4 = size(kfn.alif_zit) - kfn.alif_zit .= reshape(kfn.zit, (i1, i2, 1, i4)) .* kfn.alif_arrayProjection4d - - alifForward(kfn.alif_zit, - kfn.alif_wRec, - kfn.alif_vt, - kfn.alif_vth, - kfn.alif_vRest, - kfn.alif_zt4d, - kfn.alif_alpha, - kfn.alif_phi, - kfn.alif_epsilonRec, - kfn.alif_refractoryCounter, - kfn.alif_refractoryDuration, - kfn.alif_gammaPd, - kfn.alif_firingCounter, - kfn.alif_recSignal, - kfn.alif_epsilonRecA, - kfn.alif_a, - kfn.alif_avth, - kfn.alif_beta, - kfn.alif_rho,) + @sync begin + @async begin + # project 3D kfn zit into 4D lif zit + i1, i2, i3, i4 = size(kfn.lif_zit) + kfn.lif_zit .= reshape(kfn.zit, (i1, i2, 1, i4)) .* kfn.lif_arrayProjection4d + + lifForward( kfn.lif_zit, + kfn.lif_wRec, + kfn.lif_vt, + kfn.lif_vth, + kfn.lif_vRest, + kfn.lif_zt4d, + kfn.lif_alpha, + kfn.lif_phi, + kfn.lif_epsilonRec, + kfn.lif_refractoryCounter, + kfn.lif_refractoryDuration, + kfn.lif_gammaPd, + kfn.lif_firingCounter, + kfn.lif_recSignal,) + end + @async begin + # project 3D kfn zit into 4D alif zit + i1, i2, i3, i4 = size(kfn.alif_zit) + kfn.alif_zit .= reshape(kfn.zit, (i1, i2, 1, i4)) .* kfn.alif_arrayProjection4d + + alifForward(kfn.alif_zit, + kfn.alif_wRec, + kfn.alif_vt, + kfn.alif_vth, + kfn.alif_vRest, + kfn.alif_zt4d, + kfn.alif_alpha, + kfn.alif_phi, + kfn.alif_epsilonRec, + kfn.alif_refractoryCounter, + kfn.alif_refractoryDuration, + kfn.alif_gammaPd, + kfn.alif_firingCounter, + kfn.alif_recSignal, + kfn.alif_epsilonRecA, + kfn.alif_a, + kfn.alif_avth, + kfn.alif_beta, + kfn.alif_rho,) + end + end # reduce lif_zt4d and alif_zt4d into lif_zt, alif_zt (4d -> 1d) kfn.lif_zt .= reduce(max, kfn.lif_zt4d, dims=(1,2))