From d4bfe537b9bd17cacececac6c41a274b88b3b312 Mon Sep 17 00:00:00 2001 From: ton Date: Sat, 26 Aug 2023 15:43:33 +0700 Subject: [PATCH] fix device error --- src/forward.jl | 414 ++++++++++++++++++++++++------------------------- src/learn.jl | 4 +- 2 files changed, 209 insertions(+), 209 deletions(-) diff --git a/src/forward.jl b/src/forward.jl index eedae6c..a62c636 100644 --- a/src/forward.jl +++ b/src/forward.jl @@ -612,238 +612,238 @@ function onForward( zit, return nothing end -function lifForward(kfn_zit::Array{T}, - zit::Array{T}, - wRec::Array{T}, - vt0::Array{T}, - vt1::Array{T}, - vth::Array{T}, - vRest::Array{T}, - zt1::Array{T}, - alpha::Array{T}, - phi::Array{T}, - epsilonRec::Array{T}, - refractoryCounter::Array{T}, - refractoryDuration::Array{T}, - gammaPd::Array{T}, - firingCounter::Array{T}, - arrayProjection4d::Array{T}, - recSignal::Array{T}, - decayed_vt0::Array{T}, - decayed_epsilonRec::Array{T}, - vt1_diff_vth::Array{T}, - vt1_diff_vth_div_vth::Array{T}, - gammaPd_div_vth::Array{T}, - phiActivation::Array{T}, - ) where T<:Number +# function lifForward(kfn_zit::Array{T}, +# zit::Array{T}, +# wRec::Array{T}, +# vt0::Array{T}, +# vt1::Array{T}, +# vth::Array{T}, +# vRest::Array{T}, +# zt1::Array{T}, +# alpha::Array{T}, +# phi::Array{T}, +# epsilonRec::Array{T}, +# refractoryCounter::Array{T}, +# refractoryDuration::Array{T}, +# gammaPd::Array{T}, +# firingCounter::Array{T}, +# arrayProjection4d::Array{T}, +# recSignal::Array{T}, +# decayed_vt0::Array{T}, +# decayed_epsilonRec::Array{T}, +# vt1_diff_vth::Array{T}, +# vt1_diff_vth_div_vth::Array{T}, +# gammaPd_div_vth::Array{T}, +# phiActivation::Array{T}, +# ) where T<:Number - # project 3D kfn zit into 4D lif zit - i1, i2, i3, i4 = size(alif_wRec) - lif_zit .= reshape(kfn_zit, (i1, i2, 1, i4)) .* lif_arrayProjection4d +# # project 3D kfn zit into 4D lif zit +# i1, i2, i3, i4 = size(alif_wRec) +# lif_zit .= reshape(kfn_zit, (i1, i2, 1, i4)) .* lif_arrayProjection4d - for j in 1:size(wRec, 4), i in 1:size(wRec, 3) # compute along neurons axis of every batch - if sum(@view(refractoryCounter[:,:,i,j])) > 0 # refractory period is active - @. @views refractoryCounter[:,:,i,j] -= 1 - @. @views zt1[:,:,i,j] = 0 - @. @views vt1[:,:,i,j] = alpha[:,:,i,j] * vt0[:,:,i,j] - @. @views phi[:,:,i,j] = 0 +# for j in 1:size(wRec, 4), i in 1:size(wRec, 3) # compute along neurons axis of every batch +# if sum(@view(refractoryCounter[:,:,i,j])) > 0 # refractory period is active +# @. @views refractoryCounter[:,:,i,j] -= 1 +# @. @views zt1[:,:,i,j] = 0 +# @. @views vt1[:,:,i,j] = alpha[:,:,i,j] * vt0[:,:,i,j] +# @. @views phi[:,:,i,j] = 0 - # compute epsilonRec - @. @views decayed_epsilonRec[:,:,i,j] = alpha[:,:,i,j] * epsilonRec[:,:,i,j] - @. @views epsilonRec[:,:,i,j] = decayed_epsilonRec[:,:,i,j] - else # refractory period is inactive - @. @views recSignal[:,:,i,j] = zit[:,:,i,j] * wRec[:,:,i,j] - @. @views decayed_vt0[:,:,i,j] = alpha[:,:,i,j] * vt0[:,:,i,j] - @view(vt1[:,:,i,j]) .= @view(decayed_vt0[:,:,i,j]) .+ sum(@view(recSignal[:,:,i,j])) +# # compute epsilonRec +# @. @views decayed_epsilonRec[:,:,i,j] = alpha[:,:,i,j] * epsilonRec[:,:,i,j] +# @. @views epsilonRec[:,:,i,j] = decayed_epsilonRec[:,:,i,j] +# else # refractory period is inactive +# @. @views recSignal[:,:,i,j] = zit[:,:,i,j] * wRec[:,:,i,j] +# @. @views decayed_vt0[:,:,i,j] = alpha[:,:,i,j] * vt0[:,:,i,j] +# @view(vt1[:,:,i,j]) .= @view(decayed_vt0[:,:,i,j]) .+ sum(@view(recSignal[:,:,i,j])) - if sum(@view(vt1[:,:,i,j])) > sum(@view(vth[:,:,i,j])) - @. @views zt1[:,:,i,j] = 1 - @. @views refractoryCounter[:,:,i,j] = refractoryDuration[:,:,i,j] - @. @views firingCounter[:,:,i,j] += 1 - @. @views vt1[:,:,i,j] = vRest[:,:,i,j] - else - @. @views zt1[:,:,i,j] = 0 - end +# if sum(@view(vt1[:,:,i,j])) > sum(@view(vth[:,:,i,j])) +# @. @views zt1[:,:,i,j] = 1 +# @. @views refractoryCounter[:,:,i,j] = refractoryDuration[:,:,i,j] +# @. @views firingCounter[:,:,i,j] += 1 +# @. @views vt1[:,:,i,j] = vRest[:,:,i,j] +# else +# @. @views zt1[:,:,i,j] = 0 +# end - # compute phi, there is a difference from alif formula - @. @views gammaPd_div_vth[:,:,i,j] = gammaPd[:,:,i,j] / vth[:,:,i,j] - @. @views vt1_diff_vth[:,:,i,j] = vt1[:,:,i,j] - vth[:,:,i,j] - @. @views vt1_diff_vth_div_vth[:,:,i,j] = vt1_diff_vth[:,:,i,j] / vth[:,:,i,j] - @view(phiActivation[:,:,i,j]) .= max(0, 1 - sum(@view(vt1_diff_vth_div_vth[:,:,i,j]))) - @. @views phi[:,:,i,j] = gammaPd_div_vth[:,:,i,j] * phiActivation[:,:,i,j] +# # compute phi, there is a difference from alif formula +# @. @views gammaPd_div_vth[:,:,i,j] = gammaPd[:,:,i,j] / vth[:,:,i,j] +# @. @views vt1_diff_vth[:,:,i,j] = vt1[:,:,i,j] - vth[:,:,i,j] +# @. @views vt1_diff_vth_div_vth[:,:,i,j] = vt1_diff_vth[:,:,i,j] / vth[:,:,i,j] +# @view(phiActivation[:,:,i,j]) .= max(0, 1 - sum(@view(vt1_diff_vth_div_vth[:,:,i,j]))) +# @. @views phi[:,:,i,j] = gammaPd_div_vth[:,:,i,j] * phiActivation[:,:,i,j] - # compute epsilonRec - @. @views decayed_epsilonRec[:,:,i,j] = alpha[:,:,i,j] * epsilonRec[:,:,i,j] - @. @views epsilonRec[:,:,i,j] = decayed_epsilonRec[:,:,i,j] + zit[:,:,i,j] - end - end -end +# # compute epsilonRec +# @. @views decayed_epsilonRec[:,:,i,j] = alpha[:,:,i,j] * epsilonRec[:,:,i,j] +# @. @views epsilonRec[:,:,i,j] = decayed_epsilonRec[:,:,i,j] + zit[:,:,i,j] +# end +# end +# end -function alifForward(zit::Array{T}, - wRec::Array{T}, - vt0::Array{T}, - vt1::Array{T}, - vth::Array{T}, - vRest::Array{T}, - zt1::Array{T}, - alpha::Array{T}, - phi::Array{T}, - epsilonRec::Array{T}, - refractoryCounter::Array{T}, - refractoryDuration::Array{T}, - gammaPd::Array{T}, - firingCounter::Array{T}, - recSignal::Array{T}, - decayed_vt0::Array{T}, - decayed_epsilonRec::Array{T}, - vt1_diff_vth::Array{T}, - vt1_diff_vth_div_vth::Array{T}, - gammaPd_div_vth::Array{T}, - phiActivation::Array{T}, +# function alifForward(zit::Array{T}, +# wRec::Array{T}, +# vt0::Array{T}, +# vt1::Array{T}, +# vth::Array{T}, +# vRest::Array{T}, +# zt1::Array{T}, +# alpha::Array{T}, +# phi::Array{T}, +# epsilonRec::Array{T}, +# refractoryCounter::Array{T}, +# refractoryDuration::Array{T}, +# gammaPd::Array{T}, +# firingCounter::Array{T}, +# recSignal::Array{T}, +# decayed_vt0::Array{T}, +# decayed_epsilonRec::Array{T}, +# vt1_diff_vth::Array{T}, +# vt1_diff_vth_div_vth::Array{T}, +# gammaPd_div_vth::Array{T}, +# phiActivation::Array{T}, - epsilonRecA::Array{T}, - avth::Array{T}, - a::Array{T}, - beta::Array{T}, - rho::Array{T}, - phi_x_epsilonRec::Array{T}, - phi_x_beta::Array{T}, - rho_diff_phi_x_beta::Array{T}, - rho_div_phi_x_beta_x_epsilonRecA::Array{T}, - beta_x_a::Array{T}, - ) where T<:Number +# epsilonRecA::Array{T}, +# avth::Array{T}, +# a::Array{T}, +# beta::Array{T}, +# rho::Array{T}, +# phi_x_epsilonRec::Array{T}, +# phi_x_beta::Array{T}, +# rho_diff_phi_x_beta::Array{T}, +# rho_div_phi_x_beta_x_epsilonRecA::Array{T}, +# beta_x_a::Array{T}, +# ) where T<:Number - for j in 1:size(wRec, 4), i in 1:size(wRec, 3) # compute along neurons axis of every batch - if sum(@view(refractoryCounter[:,:,i,j])) > 0 # refractory period is active - @. @views refractoryCounter[:,:,i,j] -= 1 - @. @views zt1[:,:,i,j] = 0 - @. @views vt1[:,:,i,j] = alpha[:,:,i,j] * vt0[:,:,i,j] - @. @views phi[:,:,i,j] = 0 - @. @views a[:,:,i,j] = rho[:,:,i,j] * a[:,:,i,j] +# for j in 1:size(wRec, 4), i in 1:size(wRec, 3) # compute along neurons axis of every batch +# if sum(@view(refractoryCounter[:,:,i,j])) > 0 # refractory period is active +# @. @views refractoryCounter[:,:,i,j] -= 1 +# @. @views zt1[:,:,i,j] = 0 +# @. @views vt1[:,:,i,j] = alpha[:,:,i,j] * vt0[:,:,i,j] +# @. @views phi[:,:,i,j] = 0 +# @. @views a[:,:,i,j] = rho[:,:,i,j] * a[:,:,i,j] - # compute epsilonRec - @. @views decayed_epsilonRec[:,:,i,j] = alpha[:,:,i,j] * epsilonRec[:,:,i,j] - @. @views epsilonRec[:,:,i,j] = decayed_epsilonRec[:,:,i,j] +# # compute epsilonRec +# @. @views decayed_epsilonRec[:,:,i,j] = alpha[:,:,i,j] * epsilonRec[:,:,i,j] +# @. @views epsilonRec[:,:,i,j] = decayed_epsilonRec[:,:,i,j] - # compute epsilonRecA - @. @views phi_x_epsilonRec[:,:,i,j] = phi[:,:,i,j] * epsilonRec[:,:,i,j] - @. @views phi_x_beta[:,:,i,j] = phi[:,:,i,j] * beta[:,:,i,j] - @. @views rho_diff_phi_x_beta[:,:,i,j] = rho[:,:,i,j] - phi_x_beta[:,:,i,j] - @. @views rho_div_phi_x_beta_x_epsilonRecA[:,:,i,j] = rho_diff_phi_x_beta[:,:,i,j] * epsilonRecA[:,:,i,j] - @. @views epsilonRecA[:,:,i,j] = phi_x_epsilonRec[:,:,i,j] + rho_div_phi_x_beta_x_epsilonRecA[:,:,i,j] +# # compute epsilonRecA +# @. @views phi_x_epsilonRec[:,:,i,j] = phi[:,:,i,j] * epsilonRec[:,:,i,j] +# @. @views phi_x_beta[:,:,i,j] = phi[:,:,i,j] * beta[:,:,i,j] +# @. @views rho_diff_phi_x_beta[:,:,i,j] = rho[:,:,i,j] - phi_x_beta[:,:,i,j] +# @. @views rho_div_phi_x_beta_x_epsilonRecA[:,:,i,j] = rho_diff_phi_x_beta[:,:,i,j] * epsilonRecA[:,:,i,j] +# @. @views epsilonRecA[:,:,i,j] = phi_x_epsilonRec[:,:,i,j] + rho_div_phi_x_beta_x_epsilonRecA[:,:,i,j] - # compute avth - @. @views beta_x_a[:,:,i,j] = beta[:,:,i,j] * a[:,:,i,j] - @. @views avth[:,:,i,j] = vth[:,:,i,j] + beta_x_a[:,:,i,j] +# # compute avth +# @. @views beta_x_a[:,:,i,j] = beta[:,:,i,j] * a[:,:,i,j] +# @. @views avth[:,:,i,j] = vth[:,:,i,j] + beta_x_a[:,:,i,j] - else # refractory period is inactive - @. @views recSignal[:,:,i,j] = zit[:,:,i,j] * wRec[:,:,i,j] - @. @views decayed_vt0[:,:,i,j] = alpha[:,:,i,j] * vt0[:,:,i,j] - @view(vt1[:,:,i,j]) .= @view(decayed_vt0[:,:,i,j]) .+ sum(@view(recSignal[:,:,i,j])) +# else # refractory period is inactive +# @. @views recSignal[:,:,i,j] = zit[:,:,i,j] * wRec[:,:,i,j] +# @. @views decayed_vt0[:,:,i,j] = alpha[:,:,i,j] * vt0[:,:,i,j] +# @view(vt1[:,:,i,j]) .= @view(decayed_vt0[:,:,i,j]) .+ sum(@view(recSignal[:,:,i,j])) - # compute avth - @. @views beta_x_a[:,:,i,j] = beta[:,:,i,j] * a[:,:,i,j] - @. @views avth[:,:,i,j] = vth[:,:,i,j] + beta_x_a[:,:,i,j] +# # compute avth +# @. @views beta_x_a[:,:,i,j] = beta[:,:,i,j] * a[:,:,i,j] +# @. @views avth[:,:,i,j] = vth[:,:,i,j] + beta_x_a[:,:,i,j] - if sum(@view(vt1[:,:,i,j])) > sum(@view(avth[:,:,i,j])) - @. @views zt1[:,:,i,j] = 1 - @. @views refractoryCounter[:,:,i,j] = refractoryDuration[:,:,i,j] - @. @views firingCounter[:,:,i,j] += 1 - @. @views vt1[:,:,i,j] = vRest[:,:,i,j] - @. @views a[:,:,i,j] = rho[:,:,i,j] * a[:,:,i,j] - @. @views a[:,:,i,j] = a[:,:,i,j] += 1 - else - @. @views zt1[:,:,i,j] = 0 - @. @views a[:,:,i,j] = rho[:,:,i,j] * a[:,:,i,j] - end +# if sum(@view(vt1[:,:,i,j])) > sum(@view(avth[:,:,i,j])) +# @. @views zt1[:,:,i,j] = 1 +# @. @views refractoryCounter[:,:,i,j] = refractoryDuration[:,:,i,j] +# @. @views firingCounter[:,:,i,j] += 1 +# @. @views vt1[:,:,i,j] = vRest[:,:,i,j] +# @. @views a[:,:,i,j] = rho[:,:,i,j] * a[:,:,i,j] +# @. @views a[:,:,i,j] = a[:,:,i,j] += 1 +# else +# @. @views zt1[:,:,i,j] = 0 +# @. @views a[:,:,i,j] = rho[:,:,i,j] * a[:,:,i,j] +# end - # compute phi, there is a difference from alif formula - @. @views gammaPd_div_vth[:,:,i,j] = gammaPd[:,:,i,j] / vth[:,:,i,j] - @. @views vt1_diff_vth[:,:,i,j] = vt1[:,:,i,j] - vth[:,:,i,j] - @. @views vt1_diff_vth_div_vth[:,:,i,j] = vt1_diff_vth[:,:,i,j] / vth[:,:,i,j] - @view(phiActivation[:,:,i,j]) .= max(0, 1 - sum(@view(vt1_diff_vth_div_vth[:,:,i,j]))) - @. @views phi[:,:,i,j] = gammaPd_div_vth[:,:,i,j] * phiActivation[:,:,i,j] +# # compute phi, there is a difference from alif formula +# @. @views gammaPd_div_vth[:,:,i,j] = gammaPd[:,:,i,j] / vth[:,:,i,j] +# @. @views vt1_diff_vth[:,:,i,j] = vt1[:,:,i,j] - vth[:,:,i,j] +# @. @views vt1_diff_vth_div_vth[:,:,i,j] = vt1_diff_vth[:,:,i,j] / vth[:,:,i,j] +# @view(phiActivation[:,:,i,j]) .= max(0, 1 - sum(@view(vt1_diff_vth_div_vth[:,:,i,j]))) +# @. @views phi[:,:,i,j] = gammaPd_div_vth[:,:,i,j] * phiActivation[:,:,i,j] - # compute epsilonRec - @. @views decayed_epsilonRec[:,:,i,j] = alpha[:,:,i,j] * epsilonRec[:,:,i,j] - @. @views epsilonRec[:,:,i,j] = decayed_epsilonRec[:,:,i,j] + zit[:,:,i,j] +# # compute epsilonRec +# @. @views decayed_epsilonRec[:,:,i,j] = alpha[:,:,i,j] * epsilonRec[:,:,i,j] +# @. @views epsilonRec[:,:,i,j] = decayed_epsilonRec[:,:,i,j] + zit[:,:,i,j] - # compute epsilonRecA - @. @views phi_x_epsilonRec[:,:,i,j] = phi[:,:,i,j] * epsilonRec[:,:,i,j] - @. @views phi_x_beta[:,:,i,j] = phi[:,:,i,j] * beta[:,:,i,j] - @. @views rho_diff_phi_x_beta[:,:,i,j] = rho[:,:,i,j] - phi_x_beta[:,:,i,j] - @. @views rho_div_phi_x_beta_x_epsilonRecA[:,:,i,j] = rho_diff_phi_x_beta[:,:,i,j] * epsilonRecA[:,:,i,j] - @. @views epsilonRecA[:,:,i,j] = phi_x_epsilonRec[:,:,i,j] + rho_div_phi_x_beta_x_epsilonRecA[:,:,i,j] - end - end -end +# # compute epsilonRecA +# @. @views phi_x_epsilonRec[:,:,i,j] = phi[:,:,i,j] * epsilonRec[:,:,i,j] +# @. @views phi_x_beta[:,:,i,j] = phi[:,:,i,j] * beta[:,:,i,j] +# @. @views rho_diff_phi_x_beta[:,:,i,j] = rho[:,:,i,j] - phi_x_beta[:,:,i,j] +# @. @views rho_div_phi_x_beta_x_epsilonRecA[:,:,i,j] = rho_diff_phi_x_beta[:,:,i,j] * epsilonRecA[:,:,i,j] +# @. @views epsilonRecA[:,:,i,j] = phi_x_epsilonRec[:,:,i,j] + rho_div_phi_x_beta_x_epsilonRecA[:,:,i,j] +# end +# end +# end -function onForward(kfn_zit::Array{T}, - zit::Array{T}, - wOut::Array{T}, - vt0::Array{T}, - vt1::Array{T}, - vth::Array{T}, - vRest::Array{T}, - zt1::Array{T}, - alpha::Array{T}, - phi::Array{T}, - epsilonRec::Array{T}, - refractoryCounter::Array{T}, - refractoryDuration::Array{T}, - gammaPd::Array{T}, - firingCounter::Array{T}, - arrayProjection4d::Array{T}, - recSignal::Array{T}, - decayed_vt0::Array{T}, - decayed_epsilonRec::Array{T}, - vt1_diff_vth::Array{T}, - vt1_diff_vth_div_vth::Array{T}, - gammaPd_div_vth::Array{T}, - phiActivation::Array{T}, - ) where T<:Number +# function onForward(kfn_zit::Array{T}, +# zit::Array{T}, +# wOut::Array{T}, +# vt0::Array{T}, +# vt1::Array{T}, +# vth::Array{T}, +# vRest::Array{T}, +# zt1::Array{T}, +# alpha::Array{T}, +# phi::Array{T}, +# epsilonRec::Array{T}, +# refractoryCounter::Array{T}, +# refractoryDuration::Array{T}, +# gammaPd::Array{T}, +# firingCounter::Array{T}, +# arrayProjection4d::Array{T}, +# recSignal::Array{T}, +# decayed_vt0::Array{T}, +# decayed_epsilonRec::Array{T}, +# vt1_diff_vth::Array{T}, +# vt1_diff_vth_div_vth::Array{T}, +# gammaPd_div_vth::Array{T}, +# phiActivation::Array{T}, +# ) where T<:Number - # project 3D kfn zit into 4D lif zit - zit .= reshape(kfn_zit, - (size(wOut, 1), size(wOut, 2), 1, size(wOut, 4))) .* arrayProjection4d +# # project 3D kfn zit into 4D lif zit +# zit .= reshape(kfn_zit, +# (size(wOut, 1), size(wOut, 2), 1, size(wOut, 4))) .* arrayProjection4d - for j in 1:size(wOut, 4), i in 1:size(wOut, 3) # compute along neurons axis of every batch - if sum(@view(refractoryCounter[:,:,i,j])) > 0 # refractory period is active - @. @views refractoryCounter[:,:,i,j] -= 1 - @. @views zt1[:,:,i,j] = 0 - @. @views vt1[:,:,i,j] = alpha[:,:,i,j] * vt0[:,:,i,j] - @. @views phi[:,:,i,j] = 0 +# for j in 1:size(wOut, 4), i in 1:size(wOut, 3) # compute along neurons axis of every batch +# if sum(@view(refractoryCounter[:,:,i,j])) > 0 # refractory period is active +# @. @views refractoryCounter[:,:,i,j] -= 1 +# @. @views zt1[:,:,i,j] = 0 +# @. @views vt1[:,:,i,j] = alpha[:,:,i,j] * vt0[:,:,i,j] +# @. @views phi[:,:,i,j] = 0 - # compute epsilonRec - @. @views decayed_epsilonRec[:,:,i,j] = alpha[:,:,i,j] * epsilonRec[:,:,i,j] - @. @views epsilonRec[:,:,i,j] = decayed_epsilonRec[:,:,i,j] - else # refractory period is inactive - @. @views recSignal[:,:,i,j] = zit[:,:,i,j] * wOut[:,:,i,j] - @. @views decayed_vt0[:,:,i,j] = alpha[:,:,i,j] * vt0[:,:,i,j] - @view(vt1[:,:,i,j]) .= @view(decayed_vt0[:,:,i,j]) .+ sum(@view(recSignal[:,:,i,j])) +# # compute epsilonRec +# @. @views decayed_epsilonRec[:,:,i,j] = alpha[:,:,i,j] * epsilonRec[:,:,i,j] +# @. @views epsilonRec[:,:,i,j] = decayed_epsilonRec[:,:,i,j] +# else # refractory period is inactive +# @. @views recSignal[:,:,i,j] = zit[:,:,i,j] * wOut[:,:,i,j] +# @. @views decayed_vt0[:,:,i,j] = alpha[:,:,i,j] * vt0[:,:,i,j] +# @view(vt1[:,:,i,j]) .= @view(decayed_vt0[:,:,i,j]) .+ sum(@view(recSignal[:,:,i,j])) - if sum(@view(vt1[:,:,i,j])) > sum(@view(vth[:,:,i,j])) - @. @views zt1[:,:,i,j] = 1 - @. @views refractoryCounter[:,:,i,j] = refractoryDuration[:,:,i,j] - @. @views firingCounter[:,:,i,j] += 1 - @. @views vt1[:,:,i,j] = vRest[:,:,i,j] - else - @. @views zt1[:,:,i,j] = 0 - end +# if sum(@view(vt1[:,:,i,j])) > sum(@view(vth[:,:,i,j])) +# @. @views zt1[:,:,i,j] = 1 +# @. @views refractoryCounter[:,:,i,j] = refractoryDuration[:,:,i,j] +# @. @views firingCounter[:,:,i,j] += 1 +# @. @views vt1[:,:,i,j] = vRest[:,:,i,j] +# else +# @. @views zt1[:,:,i,j] = 0 +# end - # compute phi, there is a difference from alif formula - @. @views gammaPd_div_vth[:,:,i,j] = gammaPd[:,:,i,j] / vth[:,:,i,j] - @. @views vt1_diff_vth[:,:,i,j] = vt1[:,:,i,j] - vth[:,:,i,j] - @. @views vt1_diff_vth_div_vth[:,:,i,j] = vt1_diff_vth[:,:,i,j] / vth[:,:,i,j] - @view(phiActivation[:,:,i,j]) .= max(0, 1 - sum(@view(vt1_diff_vth_div_vth[:,:,i,j]))) - @. @views phi[:,:,i,j] = gammaPd_div_vth[:,:,i,j] * phiActivation[:,:,i,j] +# # compute phi, there is a difference from alif formula +# @. @views gammaPd_div_vth[:,:,i,j] = gammaPd[:,:,i,j] / vth[:,:,i,j] +# @. @views vt1_diff_vth[:,:,i,j] = vt1[:,:,i,j] - vth[:,:,i,j] +# @. @views vt1_diff_vth_div_vth[:,:,i,j] = vt1_diff_vth[:,:,i,j] / vth[:,:,i,j] +# @view(phiActivation[:,:,i,j]) .= max(0, 1 - sum(@view(vt1_diff_vth_div_vth[:,:,i,j]))) +# @. @views phi[:,:,i,j] = gammaPd_div_vth[:,:,i,j] * phiActivation[:,:,i,j] - # compute epsilonRec - @. @views decayed_epsilonRec[:,:,i,j] = alpha[:,:,i,j] * epsilonRec[:,:,i,j] - @. @views epsilonRec[:,:,i,j] = decayed_epsilonRec[:,:,i,j] + zit[:,:,i,j] - end - end -end +# # compute epsilonRec +# @. @views decayed_epsilonRec[:,:,i,j] = alpha[:,:,i,j] * epsilonRec[:,:,i,j] +# @. @views epsilonRec[:,:,i,j] = decayed_epsilonRec[:,:,i,j] + zit[:,:,i,j] +# end +# end +# end diff --git a/src/learn.jl b/src/learn.jl index b59bee9..8ce4ffe 100644 --- a/src/learn.jl +++ b/src/learn.jl @@ -275,7 +275,7 @@ function learn!(kfn::kfn_1, device=cpu) kfn.lif_synapticConnectionNumber, kfn.zitCumulative, device) - + println("kfn.lif_wRec $(typeof(kfn.lif_wRec))") # alif learn kfn.alif_wRec, kfn.alif_neuronInactivityCounter, kfn.alif_synapticInactivityCounter = alifLearn(kfn.alif_wRec, @@ -286,7 +286,7 @@ function learn!(kfn::kfn_1, device=cpu) kfn.alif_synapticConnectionNumber, kfn.zitCumulative, device) - error("DEBUG -> kfn learn! $(Dates.now())") + # on learn onLearn!(kfn.on_wOut, kfn.on_wOutChange,