diff --git a/src/forward.jl b/src/forward.jl index bf89c8b..bd08341 100644 --- a/src/forward.jl +++ b/src/forward.jl @@ -103,7 +103,7 @@ function (kfn::kfn_1)(input::AbstractArray) kfn.zit .= reshape(_zit, (size(input, 1), :, size(input, 3))) # read out - onForward( kfn.zit, + onForward( kfn.zit, kfn.on_zit, kfn.on_wOut, kfn.on_vt0, @@ -117,36 +117,45 @@ function (kfn::kfn_1)(input::AbstractArray) kfn.on_refractoryCounter, kfn.on_refractoryDuration, kfn.on_gammaPd, - kfn.on_firingCounter) + kfn.on_firingCounter, + kfn.on_arrayProjection3DTo4D, + kfn.on_recSignal, + kfn.on_decayed_vt0, + kfn.on_decayed_epsilonRec, + kfn.on_vt1_diff_vth, + kfn.on_vt1_diff_vth_div_vth, + kfn.on_gammaPd_div_vth, + kfn.on_phiActivation) return reshape(kfn.on_zt1, (size(input, 1), :)), kfn.zit end -function lifForward(kfn_zit, - zit, - wRec, - vt0, - vt1, - vth, - vRest, - zt1, - alpha, - phi, - epsilonRec, - refractoryCounter, - refractoryDuration, - gammaPd, - firingCounter, - arrayProjection3DTo4D, - recSignal, - decayed_vt0, - decayed_epsilonRec, - vt1_diff_vth, - vt1_diff_vth_div_vth, - gammaPd_div_vth, - phiActivation) +function lifForward(kfn_zit::Array{T}, + zit::Array{T}, + wRec::Array{T}, + vt0::Array{T}, + vt1::Array{T}, + vth::Array{T}, + vRest::Array{T}, + zt1::Array{T}, + alpha::Array{T}, + phi::Array{T}, + epsilonRec::Array{T}, + refractoryCounter::Array{T}, + refractoryDuration::Array{T}, + gammaPd::Array{T}, + firingCounter::Array{T}, + arrayProjection3DTo4D::Array{T}, + recSignal::Array{T}, + decayed_vt0::Array{T}, + decayed_epsilonRec::Array{T}, + vt1_diff_vth::Array{T}, + vt1_diff_vth_div_vth::Array{T}, + gammaPd_div_vth::Array{T}, + phiActivation::Array{T}, + ) where T<:Number # project 3D kfn zit into 4D lif zit zit .= reshape(kfn_zit, @@ -190,40 +199,41 @@ function lifForward(kfn_zit, end end -function alifForward(kfn_zit, - zit, - wRec, - vt0, - vt1, - vth, - vRest, - zt1, - alpha, - phi, - epsilonRec, - refractoryCounter, - refractoryDuration, - gammaPd, - firingCounter, - arrayProjection3DTo4D, - recSignal, - decayed_vt0, - decayed_epsilonRec, - vt1_diff_vth, - vt1_diff_vth_div_vth, - gammaPd_div_vth, - phiActivation, +function alifForward(kfn_zit::Array{T}, + zit::Array{T}, + wRec::Array{T}, + vt0::Array{T}, + vt1::Array{T}, + vth::Array{T}, + vRest::Array{T}, + zt1::Array{T}, + alpha::Array{T}, + phi::Array{T}, + epsilonRec::Array{T}, + refractoryCounter::Array{T}, + refractoryDuration::Array{T}, + gammaPd::Array{T}, + firingCounter::Array{T}, + arrayProjection3DTo4D::Array{T}, + recSignal::Array{T}, + decayed_vt0::Array{T}, + decayed_epsilonRec::Array{T}, + vt1_diff_vth::Array{T}, + vt1_diff_vth_div_vth::Array{T}, + gammaPd_div_vth::Array{T}, + phiActivation::Array{T}, - epsilonRecA, - avth, - a, - beta, - rho, - phi_x_epsilonRec, - phi_x_beta, - rho_diff_phi_x_beta, - rho_div_phi_x_beta_x_epsilonRecA, - beta_x_a) + epsilonRecA::Array{T}, + avth::Array{T}, + a::Array{T}, + beta::Array{T}, + rho::Array{T}, + phi_x_epsilonRec::Array{T}, + phi_x_beta::Array{T}, + rho_diff_phi_x_beta::Array{T}, + rho_div_phi_x_beta_x_epsilonRecA::Array{T}, + beta_x_a::Array{T}, + ) where T<:Number # project 3D kfn zit into 4D lif zit @@ -295,125 +305,69 @@ function alifForward(kfn_zit, end end -# function alifForward(kfn_zit, -# zit, -# wRec, -# vt0, -# vt1, -# vth, -# avth, -# vRest, -# zt1, -# alpha, -# phi, -# epsilonRec, -# epsilonRecA, -# refractoryCounter, -# refractoryDuration, -# a, -# beta, -# rho, -# gammaPd, -# firingCounter) -# d1, d2, d3, d4 = size(wRec) -# zit .= reshape(kfn_zit, (d1, d2, 1, d4)) .* ones(size(wRec)...) # project zit into zit +function onForward(kfn_zit::Array{T}, + zit::Array{T}, + wOut::Array{T}, + vt0::Array{T}, + vt1::Array{T}, + vth::Array{T}, + vRest::Array{T}, + zt1::Array{T}, + alpha::Array{T}, + phi::Array{T}, + epsilonRec::Array{T}, + refractoryCounter::Array{T}, + refractoryDuration::Array{T}, + gammaPd::Array{T}, + firingCounter::Array{T}, + arrayProjection3DTo4D::Array{T}, + recSignal::Array{T}, + decayed_vt0::Array{T}, + decayed_epsilonRec::Array{T}, + vt1_diff_vth::Array{T}, + vt1_diff_vth_div_vth::Array{T}, + gammaPd_div_vth::Array{T}, + phiActivation::Array{T}, + ) where T<:Number -# for j in 1:d4, i in 1:d3 # compute along neurons axis of every batch -# if view(refractoryCounter, :, :, i, j)[1] > 0 # refractory period is active -# view(refractoryCounter, :, :, i, j)[1] -= 1 -# view(zt1, :, :, i, j)[1] = 0 -# view(vt1, :, :, i, j)[1] = view(alpha, :, :, i, j)[1] * -# view(vt0, :, :, i, j)[1] -# view(phi, :, :, i, j)[1] = 0.0 -# view(epsilonRec, :, :, i, j) .= view(alpha, :, :, i, j)[1] .* -# view(epsilonRec, :, :, i, j) -# view(a, :, :, i, j)[1] = -# (view(rho, :, :, i, j)[1] * view(a, :, :, i, j)[1]) + 0 -# else # refractory period is inactive -# view(vt1, :, :, i, j)[1] = -# (view(alpha, :, :, i, j)[1] * view(vt0,:, :, i, j)[1]) + -# sum(view(zit, :, :, i, j) .* view(wRec, :, :, i, j)) -# view(avth, :, :, i, j)[1] = view(vth, :, :, i, j)[1] + -# (view(beta, :, :, i, j)[1] * view(a, :, :, i, j)[1]) -# if view(vt1, :, :, i, j)[1] > view(avth, :, :, i, j)[1] -# view(zt1, :, :, i, j)[1] = 1 -# view(refractoryCounter, :, :, i, j)[1] = -# view(refractoryDuration, :, :, i, j)[1] -# view(firingCounter, :, :, i, j)[1] += 1 -# view(vt1, :, :, i, j)[1] = view(vRest, :, :, i, j)[1] -# view(a, :, :, i, j)[1] = (view(rho, :, :, i, j)[1] * -# view(a, :, :, i, j)[1]) + 1 -# else -# view(zt1, :, :, i, j)[1] = 0 -# view(a, :, :, i, j)[1] = -# (view(rho, :, :, i, j)[1] * view(a, :, :, i, j)[1]) + 0 -# end - -# # there is a difference from alif formula -# view(phi, :, :, i, j)[1] = -# (view(gammaPd, :, :, i, j)[1] / view(vth, :, :, i, j)[1]) * -# max(0, 1 - ((view(vt1, :, :, i, j)[1] - view(avth, :, :, i, j)[1]) / -# view(vth, :, :, i, j)[1])) -# view(epsilonRec, :, :, i, j) .= -# (view(alpha, :, :, i, j) .* view(epsilonRec, :, :, i, j)) + -# view(zit, :, :, i, j) -# view(epsilonRecA, :, :, i, j) .= -# (view(phi, :, :, i, j)[1] .* view(epsilonRec, :, :, i, j)) + -# ((view(rho, :, :, i, j)[1] - -# (view(phi, :, :, i, j)[1] * view(beta, :, :, i, j)[1])) .* -# view(epsilonRecA, :, :, i, j)) -# end -# end -# end - -function onForward(kfn_zit, - zit, - wOut, - vt0, - vt1, - vth, - vRest, - zt1, - alpha, - phi, - epsilonRec, - refractoryCounter, - refractoryDuration, - gammaPd, - firingCounter) - d1, d2, d3, d4 = size(wOut) - zit .= reshape(kfn_zit, (d1, d2, 1, d4)) .* ones(size(wOut)...) # project zit into zit + # project 3D kfn zit into 4D lif zit + zit .= reshape(kfn_zit, + (size(wOut, 1), size(wOut, 2), 1, size(wOut, 4))) .* arrayProjection3DTo4D - for j in 1:d4, i in 1:d3 # compute along neurons axis of every batch - if view(refractoryCounter, :, :, i, j)[1] > 0 # neuron is inactive (in refractory period) - view(refractoryCounter, :, :, i, j)[1] -= 1 - view(zt1, :, :, i, j)[1] = 0 - view(vt1, :, :, i, j)[1] = - view(alpha, :, :, i, j)[1] * view(vt0, :, :, i, j)[1] - view(phi, :, :, i, j)[1] = 0.0 - view(epsilonRec, :, :, i, j) .= view(alpha, :, :, i, j)[1] .* - view(epsilonRec, :, :, i, j) - else # neuron is active - view(vt1, :, :, i, j)[1] = - (view(alpha, :, :, i, j)[1] * view(vt0,:, :, i, j)[1]) + - sum(view(zit, :, :, i, j) .* view(wOut, :, :, i, j)) - if view(vt1, :, :, i, j)[1] > view(vth, :, :, i, j)[1] - view(zt1, :, :, i, j)[1] = 1 - view(refractoryCounter, :, :, i, j)[1] = - view(refractoryDuration, :, :, i, j)[1] - view(firingCounter, :, :, i, j)[1] += 1 - view(vt1, :, :, i, j)[1] = view(vRest, :, :, i, j)[1] + for j in 1:size(wOut, 4), i in 1:size(wOut, 3) # compute along neurons axis of every batch + if sum(@view(refractoryCounter[:,:,i,j])) > 0 # refractory period is active + @. @views refractoryCounter[:,:,i,j] -= 1 + @. @views zt1[:,:,i,j] = 0 + @. @views vt1[:,:,i,j] = alpha[:,:,i,j] * vt0[:,:,i,j] + @. @views phi[:,:,i,j] = 0 + + # compute epsilonRec + @. @views decayed_epsilonRec[:,:,i,j] = alpha[:,:,i,j] * epsilonRec[:,:,i,j] + @. @views epsilonRec[:,:,i,j] = decayed_epsilonRec[:,:,i,j] + else # refractory period is inactive + @. @views recSignal[:,:,i,j] = zit[:,:,i,j] * wOut[:,:,i,j] + @. @views decayed_vt0[:,:,i,j] = alpha[:,:,i,j] * vt0[:,:,i,j] + @view(vt1[:,:,i,j]) .= @view(decayed_vt0[:,:,i,j]) .+ sum(@view(recSignal[:,:,i,j])) + + if sum(@view(vt1[:,:,i,j])) > sum(@view(vth[:,:,i,j])) + @. @views zt1[:,:,i,j] = 1 + @. @views refractoryCounter[:,:,i,j] = refractoryDuration[:,:,i,j] + @. @views firingCounter[:,:,i,j] += 1 + @. @views vt1[:,:,i,j] = vRest[:,:,i,j] else - view(zt1, :, :, i, j)[1] = 0 + @. @views zt1[:,:,i,j] = 0 end - # there is a difference from alif formula - view(phi, :, :, i, j)[1] = - (view(gammaPd, :, :, i, j)[1] / view(vth, :, :, i, j)[1]) * - max(0, 1 - ((view(vt1, :, :, i, j)[1] - view(vth, :, :, i, j)[1]) / - view(vth, :, :, i, j)[1])) - view(epsilonRec, :, :, i, j) .= - (view(alpha, :, :, i, j)[1] .* view(epsilonRec, :, :, i, j)) + - view(zit, :, :, i, j) + + # compute phi, there is a difference from alif formula + @. @views gammaPd_div_vth[:,:,i,j] = gammaPd[:,:,i,j] / vth[:,:,i,j] + @. @views vt1_diff_vth[:,:,i,j] = vt1[:,:,i,j] - vth[:,:,i,j] + @. @views vt1_diff_vth_div_vth[:,:,i,j] = vt1_diff_vth[:,:,i,j] / vth[:,:,i,j] + @view(phiActivation[:,:,i,j]) .= max(0, 1 - sum(@view(vt1_diff_vth_div_vth[:,:,i,j]))) + @. @views phi[:,:,i,j] = gammaPd_div_vth[:,:,i,j] * phiActivation[:,:,i,j] + + # compute epsilonRec + @. @views decayed_epsilonRec[:,:,i,j] = alpha[:,:,i,j] * epsilonRec[:,:,i,j] + @. @views epsilonRec[:,:,i,j] = decayed_epsilonRec[:,:,i,j] + zit[:,:,i,j] end end end diff --git a/src/type.jl b/src/type.jl index 161add4..0d14e07 100644 --- a/src/type.jl +++ b/src/type.jl @@ -116,9 +116,10 @@ Base.@kwdef mutable struct kfn_1 <: knowledgeFn # Output Neurons # # ---------------------------------------------------------------------------- # # output neuron is based on LIF - on_zit::Union{AbstractArray, Nothing} = nothing + on_zit::Union{AbstractArray, Nothing} = nothing - on_wOut::Union{AbstractArray, Nothing} = nothing # same as lif_wRec + # main variables according to papers + on_wOut::Union{AbstractArray, Nothing} = nothing # wOut is wRec, just use the name from paper on_vt0::Union{AbstractArray, Nothing} = nothing on_vt1::Union{AbstractArray, Nothing} = nothing on_vth::Union{AbstractArray, Nothing} = nothing @@ -135,14 +136,19 @@ Base.@kwdef mutable struct kfn_1 <: knowledgeFn on_eRec::Union{AbstractArray, Nothing} = nothing on_eta::Union{AbstractArray, Nothing} = nothing on_gammaPd::Union{AbstractArray, Nothing} = nothing - on_wOutChange::Union{AbstractArray, Nothing} = nothing - on_b::Union{AbstractArray, Nothing} = nothing - on_bChange::Union{AbstractArray, Nothing} = nothing on_firingCounter::Union{AbstractArray, Nothing} = nothing - on_arraySize::Union{AbstractArray, Nothing} = nothing + + # pre-allocation array on_arrayProjection3DTo4D::Union{AbstractArray, Nothing} = nothing # use to project 3d array to 4d + on_recSignal::Union{AbstractArray, Nothing} = nothing + on_decayed_vt0::Union{AbstractArray, Nothing} = nothing + on_decayed_epsilonRec::Union{AbstractArray, Nothing} = nothing + on_vt1_diff_vth::Union{AbstractArray, Nothing} = nothing + on_vt1_diff_vth_div_vth::Union{AbstractArray, Nothing} = nothing + on_gammaPd_div_vth::Union{AbstractArray, Nothing} = nothing + on_phiActivation::Union{AbstractArray, Nothing} = nothing end # outer constructor @@ -214,7 +220,6 @@ function kfn_1(params::Dict; device=cpu) kfn.lif_gammaPd_div_vth = similar(kfn.lif_vt0) .= 0 |> device kfn.lif_phiActivation = similar(kfn.lif_vt0) .= 0 |> device - # ---------------------------------------------------------------------------- # # ALIF config # # ---------------------------------------------------------------------------- # @@ -279,34 +284,10 @@ function kfn_1(params::Dict; device=cpu) # output config # # ---------------------------------------------------------------------------- # n = kfn.params[:outputPort][:numbers][1] * kfn.params[:outputPort][:numbers][2] - kfn.on_zit = zeros(row, col, n, batch) |> device - kfn.on_vt0 = zeros(1, 1, n, batch) |> device - kfn.on_vt1 = zeros(1, 1, n, batch) |> device - kfn.on_vth = ones(1, 1, n, batch) |> device - kfn.on_vRest = zeros(1, 1, n, batch) |> device - # kfn.on_zt0 = zeros(1, 1, n, batch) |> device - kfn.on_zt1 = zeros(1, 1, n, batch) |> device - kfn.on_refractoryCounter = zeros(1, 1, n, batch) |> device - kfn.on_refractoryDuration = ones(1, 1, n, batch) .* 0 |> device - kfn.on_delta = 1.0 - kfn.on_tau_m = 20.0 - kfn.on_alpha = ones(1, 1, n, batch) .* (exp(-kfn.on_delta / kfn.on_tau_m)) |> device - kfn.on_phi = zeros(1, 1, n, batch) |> device - kfn.on_epsilonRec = zeros(row, col, n, batch) |> device - # kfn.on_eRec = zeros(row, col, n, batch) - kfn.on_eta = zeros(1, 1, n, batch) |> device - kfn.on_gammaPd = zeros(1, 1, n, batch) .* 0.3 |> device - kfn.on_wOutChange = zeros(row, col, n, batch) |> device - # kfn.on_b = randn(1, 1, n, batch) |> device - # kfn.on_bChange = randn(1, 1, n, batch) |> device - - kfn.on_firingCounter = zeros(1, 1, n, batch) |> device - kfn.on_arraySize = [row, col, n, batch] |> device - kfn.on_arrayProjection3DTo4D = ones(row, col, n, batch) |> device # subscription w = zeros(row, col, n) - synapticConnectionPercent = kfn.params[:outputPort][:params][:synapticConnectionPercent] + synapticConnectionPercent = kfn.params[:computeNeuron][:lif][:params][:synapticConnectionPercent] synapticConnection = Int(floor(row*col * synapticConnectionPercent/100)) for slice in eachslice(w, dims=3) pool = shuffle!([1:row*col...])[1:synapticConnection] @@ -314,8 +295,75 @@ function kfn_1(params::Dict; device=cpu) slice[i] = randn()/10 # assign weight to synaptic connection end end - # project 3D w into 4D kfn.on_wOut - kfn.on_wOut = reshape(w, (row, col, n, 1)) .* ones(row, col, n, batch) |> device + # project 3D w into 4D kfn.lif_wOut (row, col, n, batch) + kfn.on_wOut = reshape(w, (row, col, n, 1)) .* ones(row, col, n, batch) |> device + kfn.on_zit = similar(kfn.on_wOut) .= 0 |> device + kfn.on_vt0 = zeros(1, 1, n, batch) |> device + kfn.on_vt1 = similar(kfn.on_vt0) .= 0 |> device + kfn.on_vth = similar(kfn.on_vt0) .= 1 |> device + kfn.on_vRest = similar(kfn.on_vt0) .= 0 |> device + kfn.on_zt0 = similar(kfn.on_vt0) .= 0 |> device + kfn.on_zt1 = similar(kfn.on_vt0) .= 0 |> device + kfn.on_refractoryCounter = similar(kfn.on_vt0) .= 0 |> device + kfn.on_refractoryDuration = similar(kfn.on_vt0) .= 0 |> device + kfn.on_delta = 1.0 + kfn.on_tau_m = 20.0 + kfn.on_alpha = similar(kfn.on_vt0) .= (exp(-kfn.on_delta / kfn.on_tau_m)) |> device + kfn.on_phi = similar(kfn.on_vt0) .= 0 |> device + kfn.on_epsilonRec = similar(kfn.on_wOut) .= 0 |> device + kfn.on_eRec = similar(kfn.on_wOut) .= 0 |> device + kfn.on_eta = similar(kfn.on_vt0) .= 0 |> device + kfn.on_gammaPd = similar(kfn.on_vt0) .= 0.3 |> device + kfn.on_wOutChange = similar(kfn.on_wOut) .= 0 |> device + + kfn.on_firingCounter = similar(kfn.on_vt0) .= 0 |> device + + kfn.on_arrayProjection3DTo4D = similar(kfn.on_wOut) .= 1 |> device + kfn.on_recSignal = similar(kfn.on_wOut) .= 0 |> device + kfn.on_decayed_vt0 = similar(kfn.on_vt0) .= 0 |> device + kfn.on_decayed_epsilonRec = similar(kfn.on_wOut) .= 0 |> device + kfn.on_vt1_diff_vth = similar(kfn.on_vt0) .= 0 |> device + kfn.on_vt1_diff_vth_div_vth = similar(kfn.on_vt0) .= 0 |> device + kfn.on_gammaPd_div_vth = similar(kfn.on_vt0) .= 0 |> device + kfn.on_phiActivation = similar(kfn.on_vt0) .= 0 |> device + + # kfn.on_zit = zeros(row, col, n, batch) |> device + # kfn.on_vt0 = zeros(1, 1, n, batch) |> device + # kfn.on_vt1 = zeros(1, 1, n, batch) |> device + # kfn.on_vth = ones(1, 1, n, batch) |> device + # kfn.on_vRest = zeros(1, 1, n, batch) |> device + # # kfn.on_zt0 = zeros(1, 1, n, batch) |> device + # kfn.on_zt1 = zeros(1, 1, n, batch) |> device + # kfn.on_refractoryCounter = zeros(1, 1, n, batch) |> device + # kfn.on_refractoryDuration = ones(1, 1, n, batch) .* 0 |> device + # kfn.on_delta = 1.0 + # kfn.on_tau_m = 20.0 + # kfn.on_alpha = ones(1, 1, n, batch) .* (exp(-kfn.on_delta / kfn.on_tau_m)) |> device + # kfn.on_phi = zeros(1, 1, n, batch) |> device + # kfn.on_epsilonRec = zeros(row, col, n, batch) |> device + # # kfn.on_eRec = zeros(row, col, n, batch) + # kfn.on_eta = zeros(1, 1, n, batch) |> device + # kfn.on_gammaPd = zeros(1, 1, n, batch) .* 0.3 |> device + # kfn.on_wOutChange = zeros(row, col, n, batch) |> device + # # kfn.on_b = randn(1, 1, n, batch) |> device + # # kfn.on_bChange = randn(1, 1, n, batch) |> device + + # kfn.on_firingCounter = zeros(1, 1, n, batch) |> device + # kfn.on_arraySize = [row, col, n, batch] |> device + # kfn.on_arrayProjection3DTo4D = ones(row, col, n, batch) |> device + + # # subscription + # w = zeros(row, col, n) + # synapticConnectionPercent = kfn.params[:outputPort][:params][:synapticConnectionPercent] + # synapticConnection = Int(floor(row*col * synapticConnectionPercent/100)) + # for slice in eachslice(w, dims=3) + # pool = shuffle!([1:row*col...])[1:synapticConnection] + # for i in pool + # slice[i] = randn()/10 # assign weight to synaptic connection + # end + # end + # # project 3D w into 4D kfn.on_wOut + # kfn.on_wOut = reshape(w, (row, col, n, 1)) .* ones(row, col, n, batch) |> device