From 1e08a4d750a7437dc56c70573a5f9f79e493c3fa Mon Sep 17 00:00:00 2001 From: ton Date: Mon, 24 Jul 2023 14:21:28 +0700 Subject: [PATCH] kfn forward() --- src/forward.jl | 328 +++++++++++++++++++++++++++++-------------------- src/type.jl | 80 ++++++++++-- 2 files changed, 264 insertions(+), 144 deletions(-) diff --git a/src/forward.jl b/src/forward.jl index 2e5d2bc..162f95b 100644 --- a/src/forward.jl +++ b/src/forward.jl @@ -7,7 +7,9 @@ using ..type, ..snnUtil #------------------------------------------------------------------------------------------------100 -# kfn forward +""" kfn forward + input (row, col, batch) +""" function (kfn::kfn_1)(input::AbstractArray) kfn.timeStep .+= 1 @@ -17,8 +19,9 @@ function (kfn::kfn_1)(input::AbstractArray) end println(">>> input ", size(input)) + d1, d2, d3 = size(input) println(">>> zit ", size(kfn.zit)) - # println(">>> lif_zit ", size(kfn.lif_zit)) + println(">>> lif_zit ", size(kfn.lif_zit)) # println(">>> lif_recSignal ", size(kfn.lif_recSignal)) println(">>> lif_wRec ", size(kfn.lif_wRec)) println(">>> lif_refractoryCounter ", size(kfn.lif_refractoryCounter)) @@ -27,10 +30,8 @@ function (kfn::kfn_1)(input::AbstractArray) println(">>> lif_vt0 sum ", sum(kfn.lif_vt0)) # pass input_data into input neuron. - s1, s2, s3 = size(input) - GeneralUtils.cartesianAssign!(kfn.zit, reshape(input, (s1, s2, 1, s3))) - - #WORKING LIF forward active neurons + GeneralUtils.cartesianAssign!(kfn.zit, input) + lifForward( kfn.zit, kfn.lif_zit, kfn.lif_wRec, @@ -47,156 +48,221 @@ function (kfn::kfn_1)(input::AbstractArray) kfn.lif_gammaPd) alifForward( kfn.zit, - kfn.alif_zit, - kfn.alif_wRec, - kfn.alif_vt0, - kfn.alif_vt1, - kfn.alif_vth, - kfn.alif_avth, - kfn.alif_vRest, - kfn.alif_zt1, - kfn.alif_alpha, - kfn.alif_phi, - kfn.alif_epsilonRec, - kfn.alif_epsilonRecA, - kfn.alif_refractoryCounter, - kfn.alif_refractoryDuration, - kfn.alif_a, - kfn.alif_beta, - kfn.alif_rho, - kfn.alif_gammaPd) + kfn.alif_zit, + kfn.alif_wRec, + kfn.alif_vt0, + kfn.alif_vt1, + kfn.alif_vth, + kfn.alif_avth, + kfn.alif_vRest, + kfn.alif_zt1, + kfn.alif_alpha, + kfn.alif_phi, + kfn.alif_epsilonRec, + kfn.alif_epsilonRecA, + kfn.alif_refractoryCounter, + kfn.alif_refractoryDuration, + kfn.alif_a, + kfn.alif_beta, + kfn.alif_rho, + kfn.alif_gammaPd) + + # update activation matrix by concatenate (input, lif_zt1, alif_zt1) to form activation matrix + _zit = cat(reshape(input, (d1, d2, 1, d3)), + reshape(kfn.lif_zt1, (d1, :, 1, d3)), + reshape(kfn.alif_zt1, (d1, :, 1, d3)), dims=2) + kfn.zit .= reshape(_zit, (d1, :, d3)) + #WORKING read out + onForward( kfn.zit, + kfn.on_zit, + kfn.on_wRec, + kfn.on_vt0, + kfn.on_vt1, + kfn.on_vth, + kfn.on_vRest, + kfn.on_zt1, + kfn.on_alpha, + kfn.on_phi, + kfn.on_epsilonRec, + kfn.on_refractoryCounter, + kfn.on_refractoryDuration, + kfn.on_gammaPd) - - error("debug end kfn forward") + return kfn.on_zt1 end - -function lifForward(zit, - lif_zit, - lif_wRec, - lif_vt0, - lif_vt1, - lif_vth, - lif_vRest, - lif_zt1, - lif_alpha, - lif_phi, - lif_epsilonRec, - lif_refractoryCounter, - lif_refractoryDuration, - lif_gammaPd) - _, _, d3, d4 = size(lif_wRec) - lif_zit .= zit .* ones(size(lif_wRec)...) # project zit into lif_zit +function lifForward(kfn_zit, + zit, + wRec, + vt0, + vt1, + vth, + vRest, + zt1, + alpha, + phi, + epsilonRec, + refractoryCounter, + refractoryDuration, + gammaPd) + d1, d2, d3, d4 = size(wRec) + zit .= reshape(kfn_zit, (d1, d2, 1, d4)) .* ones(size(wRec)...) # project zit into zit for j in 1:d4, i in 1:d3 # compute along neurons axis of every batch - if view(lif_refractoryCounter, :, :, i, j)[1] > 0 # refractory period is active - view(lif_refractoryCounter, :, :, i, j)[1] -= 1 - view(lif_zt1, :, :, i, j)[1] = 0 - view(lif_vt1, :, :, i, j)[1] = - view(lif_alpha, :, :, i, j)[1] * view(lif_vt0, :, :, i, j)[1] - view(lif_phi, :, :, i, j)[1] = 0.0 - view(lif_epsilonRec, :, :, i, j) .= view(lif_alpha, :, :, i, j)[1] .* - view(lif_epsilonRec, :, :, i, j) + if view(refractoryCounter, :, :, i, j)[1] > 0 # refractory period is active + view(refractoryCounter, :, :, i, j)[1] -= 1 + view(zt1, :, :, i, j)[1] = 0 + view(vt1, :, :, i, j)[1] = + view(alpha, :, :, i, j)[1] * view(vt0, :, :, i, j)[1] + view(phi, :, :, i, j)[1] = 0.0 + view(epsilonRec, :, :, i, j) .= view(alpha, :, :, i, j)[1] .* + view(epsilonRec, :, :, i, j) else # refractory period is inactive - view(lif_vt1, :, :, i, j)[1] = - (view(lif_alpha, :, :, i, j)[1] * view(lif_vt0,:, :, i, j)[1]) + - sum(view(lif_zit, :, :, i, j) .* view(lif_wRec, :, :, i, j)) - if view(lif_vt1, :, :, i, j)[1] > view(lif_vth, :, :, i, j)[1] - view(lif_zt1, :, :, i, j)[1] = 1 - view(lif_refractoryCounter, :, :, i, j)[1] = - view(lif_refractoryDuration, :, :, i, j)[1] - view(lif_firingCounter, :, :, i, j)[1] += 1 - view(lif_vt1, :, :, i, j)[1] = view(lif_vRest, :, :, i, j)[1] + view(vt1, :, :, i, j)[1] = + (view(alpha, :, :, i, j)[1] * view(vt0,:, :, i, j)[1]) + + sum(view(zit, :, :, i, j) .* view(wRec, :, :, i, j)) + if view(vt1, :, :, i, j)[1] > view(vth, :, :, i, j)[1] + view(zt1, :, :, i, j)[1] = 1 + view(refractoryCounter, :, :, i, j)[1] = + view(refractoryDuration, :, :, i, j)[1] + view(firingCounter, :, :, i, j)[1] += 1 + view(vt1, :, :, i, j)[1] = view(vRest, :, :, i, j)[1] else - view(lif_zt1, :, :, i, j)[1] = 0 + view(zt1, :, :, i, j)[1] = 0 end # there is a difference from alif formula - view(lif_phi, :, :, i, j)[1] = - (view(lif_gammaPd, :, :, i, j)[1] / view(lif_vth, :, :, i, j)[1]) * - max(0, 1 - ((view(lif_vt1, :, :, i, j)[1] - view(lif_vth, :, :, i, j)[1]) / - view(lif_vth, :, :, i, j)[1])) - view(lif_epsilonRec, :, :, i, j) .= - (view(lif_alpha, :, :, i, j)[1] .* view(lif_epsilonRec, :, :, i, j)) + - view(lif_zit, :, :, i, j) + view(phi, :, :, i, j)[1] = + (view(gammaPd, :, :, i, j)[1] / view(vth, :, :, i, j)[1]) * + max(0, 1 - ((view(vt1, :, :, i, j)[1] - view(vth, :, :, i, j)[1]) / + view(vth, :, :, i, j)[1])) + view(epsilonRec, :, :, i, j) .= + (view(alpha, :, :, i, j)[1] .* view(epsilonRec, :, :, i, j)) + + view(zit, :, :, i, j) end end end -function alifForward(zit, - alif_zit, - alif_wRec, - alif_vt0, - alif_vt1, - alif_vth, - alif_avth, - alif_vRest, - alif_zt1, - alif_alpha, - alif_phi, - alif_epsilonRec, - alif_epsilonRecA, - alif_refractoryCounter, - alif_refractoryDuration, - alif_a, - alif_beta, - alif_rho, - alif_gammaPd) - _, _, d3, d4 = size(alif_wRec) - alif_zit .= zit .* ones(size(alif_wRec)...) # project zit into alif_zit +function alifForward(kfn_zit, + zit, + wRec, + vt0, + vt1, + vth, + avth, + vRest, + zt1, + alpha, + phi, + epsilonRec, + epsilonRecA, + refractoryCounter, + refractoryDuration, + a, + beta, + rho, + gammaPd) + d1, d2, d3, d4 = size(wRec) + zit .= reshape(kfn_zit, (d1, d2, 1, d4)) .* ones(size(wRec)...) # project zit into zit for j in 1:d4, i in 1:d3 # compute along neurons axis of every batch - if view(alif_refractoryCounter, :, :, i, j)[1] > 0 # refractory period is active - view(alif_refractoryCounter, :, :, i, j)[1] -= 1 - view(alif_zt1, :, :, i, j)[1] = 0 - view(alif_vt1, :, :, i, j)[1] = view(alif_alpha, :, :, i, j)[1] * - view(alif_vt0, :, :, i, j)[1] - view(alif_phi, :, :, i, j)[1] = 0.0 - view(alif_epsilonRec, :, :, i, j) .= view(alif_alpha, :, :, i, j)[1] .* - view(alif_epsilonRec, :, :, i, j) - view(alif_a, :, :, i, j)[1] = - (view(alif_rho, :, :, i, j)[1] * view(alif_a, :, :, i, j)[1]) + 0 + if view(refractoryCounter, :, :, i, j)[1] > 0 # refractory period is active + view(refractoryCounter, :, :, i, j)[1] -= 1 + view(zt1, :, :, i, j)[1] = 0 + view(vt1, :, :, i, j)[1] = view(alpha, :, :, i, j)[1] * + view(vt0, :, :, i, j)[1] + view(phi, :, :, i, j)[1] = 0.0 + view(epsilonRec, :, :, i, j) .= view(alpha, :, :, i, j)[1] .* + view(epsilonRec, :, :, i, j) + view(a, :, :, i, j)[1] = + (view(rho, :, :, i, j)[1] * view(a, :, :, i, j)[1]) + 0 else # refractory period is inactive - view(alif_vt1, :, :, i, j)[1] = - (view(alif_alpha, :, :, i, j)[1] * view(alif_vt0,:, :, i, j)[1]) + - sum(view(alif_zit, :, :, i, j) .* view(alif_wRec, :, :, i, j)) - view(alif_avth, :, :, i, j)[1] = view(alif_vth, :, :, i, j)[1] + - (view(alif_beta, :, :, i, j)[1] * view(alif_a, :, :, i, j)[1]) - if view(alif_vt1, :, :, i, j)[1] > view(alif_avth, :, :, i, j)[1] - view(alif_zt1, :, :, i, j)[1] = 1 - view(alif_refractoryCounter, :, :, i, j)[1] = - view(alif_refractoryDuration, :, :, i, j)[1] - view(alif_firingCounter, :, :, i, j)[1] += 1 - view(alif_vt1, :, :, i, j)[1] = view(alif_vRest, :, :, i, j)[1] - view(alif_a, :, :, i, j)[1] = (view(alif_rho, :, :, i, j)[1] * - view(alif_a, :, :, i, j)[1]) + 1 + view(vt1, :, :, i, j)[1] = + (view(alpha, :, :, i, j)[1] * view(vt0,:, :, i, j)[1]) + + sum(view(zit, :, :, i, j) .* view(wRec, :, :, i, j)) + view(avth, :, :, i, j)[1] = view(vth, :, :, i, j)[1] + + (view(beta, :, :, i, j)[1] * view(a, :, :, i, j)[1]) + if view(vt1, :, :, i, j)[1] > view(avth, :, :, i, j)[1] + view(zt1, :, :, i, j)[1] = 1 + view(refractoryCounter, :, :, i, j)[1] = + view(refractoryDuration, :, :, i, j)[1] + view(firingCounter, :, :, i, j)[1] += 1 + view(vt1, :, :, i, j)[1] = view(vRest, :, :, i, j)[1] + view(a, :, :, i, j)[1] = (view(rho, :, :, i, j)[1] * + view(a, :, :, i, j)[1]) + 1 else - view(alif_zt1, :, :, i, j)[1] = 0 - view(alif_a, :, :, i, j)[1] = - (view(alif_rho, :, :, i, j)[1] * view(alif_a, :, :, i, j)[1]) + 0 + view(zt1, :, :, i, j)[1] = 0 + view(a, :, :, i, j)[1] = + (view(rho, :, :, i, j)[1] * view(a, :, :, i, j)[1]) + 0 end # there is a difference from alif formula - view(alif_phi, :, :, i, j)[1] = - (view(alif_gammaPd, :, :, i, j)[1] / view(alif_vth, :, :, i, j)[1]) * - max(0, 1 - ((view(alif_vt1, :, :, i, j)[1] - view(alif_avth, :, :, i, j)[1]) / - view(alif_vth, :, :, i, j)[1])) - view(alif_epsilonRec, :, :, i, j) .= - (view(alif_alpha, :, :, i, j) .* view(alif_epsilonRec, :, :, i, j)) + - view(alif_zit, :, :, i, j) - view(alif_epsilonRecA, :, :, i, j) .= - (view(alif_phi, :, :, i, j)[1] .* view(alif_epsilonRec, :, :, i, j)) + - ((view(alif_rho, :, :, i, j)[1] - - (view(alif_phi, :, :, i, j)[1] * view(alif_beta, :, :, i, j)[1])) .* - view(alif_epsilonRecA, :, :, i, j)) + view(phi, :, :, i, j)[1] = + (view(gammaPd, :, :, i, j)[1] / view(vth, :, :, i, j)[1]) * + max(0, 1 - ((view(vt1, :, :, i, j)[1] - view(avth, :, :, i, j)[1]) / + view(vth, :, :, i, j)[1])) + view(epsilonRec, :, :, i, j) .= + (view(alpha, :, :, i, j) .* view(epsilonRec, :, :, i, j)) + + view(zit, :, :, i, j) + view(epsilonRecA, :, :, i, j) .= + (view(phi, :, :, i, j)[1] .* view(epsilonRec, :, :, i, j)) + + ((view(rho, :, :, i, j)[1] - + (view(phi, :, :, i, j)[1] * view(beta, :, :, i, j)[1])) .* + view(epsilonRecA, :, :, i, j)) end end end - - - - +function onForward(kfn_zit, + zit, + wRec, + vt0, + vt1, + vth, + vRest, + zt1, + alpha, + phi, + epsilonRec, + refractoryCounter, + refractoryDuration, + gammaPd) + d1, d2, d3, d4 = size(wRec) + zit .= reshape(kfn_zit, (d1, d2, 1, d4)) .* ones(size(wRec)...) # project zit into zit + + for j in 1:d4, i in 1:d3 # compute along neurons axis of every batch + if view(refractoryCounter, :, :, i, j)[1] > 0 # refractory period is active + view(refractoryCounter, :, :, i, j)[1] -= 1 + view(zt1, :, :, i, j)[1] = 0 + view(vt1, :, :, i, j)[1] = + view(alpha, :, :, i, j)[1] * view(vt0, :, :, i, j)[1] + view(phi, :, :, i, j)[1] = 0.0 + view(epsilonRec, :, :, i, j) .= view(alpha, :, :, i, j)[1] .* + view(epsilonRec, :, :, i, j) + else # refractory period is inactive + view(vt1, :, :, i, j)[1] = + (view(alpha, :, :, i, j)[1] * view(vt0,:, :, i, j)[1]) + + sum(view(zit, :, :, i, j) .* view(wRec, :, :, i, j)) + if view(vt1, :, :, i, j)[1] > view(vth, :, :, i, j)[1] + view(zt1, :, :, i, j)[1] = 1 + view(refractoryCounter, :, :, i, j)[1] = + view(refractoryDuration, :, :, i, j)[1] + view(firingCounter, :, :, i, j)[1] += 1 + view(vt1, :, :, i, j)[1] = view(vRest, :, :, i, j)[1] + else + view(zt1, :, :, i, j)[1] = 0 + end + # there is a difference from alif formula + view(phi, :, :, i, j)[1] = + (view(gammaPd, :, :, i, j)[1] / view(vth, :, :, i, j)[1]) * + max(0, 1 - ((view(vt1, :, :, i, j)[1] - view(vth, :, :, i, j)[1]) / + view(vth, :, :, i, j)[1])) + view(epsilonRec, :, :, i, j) .= + (view(alpha, :, :, i, j)[1] .* view(epsilonRec, :, :, i, j)) + + view(zit, :, :, i, j) + end + end +end diff --git a/src/type.jl b/src/type.jl index 6c66c0c..07d4102 100644 --- a/src/type.jl +++ b/src/type.jl @@ -24,7 +24,7 @@ Base.@kwdef mutable struct kfn_1 <: knowledgeFn zit::Union{AbstractArray, Nothing} = nothing # 3D activation matrix # ---------------------------------------------------------------------------- # - # LIF # + # LIF Neurons # # ---------------------------------------------------------------------------- # # a projection of kfn.zit into lif dimension for broadcasting later) lif_zit::Union{AbstractArray, Nothing} = nothing @@ -49,7 +49,7 @@ Base.@kwdef mutable struct kfn_1 <: knowledgeFn lif_firingCounter::Union{AbstractArray, Nothing} = nothing # ---------------------------------------------------------------------------- # - # ALIF # + # ALIF Neurons # # ---------------------------------------------------------------------------- # alif_zit::Union{AbstractArray, Nothing} = nothing @@ -78,24 +78,51 @@ Base.@kwdef mutable struct kfn_1 <: knowledgeFn alif_beta::Union{AbstractArray, Nothing} = nothing # β, constant, value from paper alif_rho::Union{AbstractArray, Nothing} = nothing # ρ, threshold adaptation decay factor alif_tau_a::AbstractFloat = 100.0 # τ_a, adaption time constant in millisecond + + # ---------------------------------------------------------------------------- # + # Output Neurons # + # ---------------------------------------------------------------------------- # + # output neuron is based on LIF + on_zit::Union{AbstractArray, Nothing} = nothing + + on_wRec::Union{AbstractArray, Nothing} = nothing + on_vt0::Union{AbstractArray, Nothing} = nothing + on_vt1::Union{AbstractArray, Nothing} = nothing + on_vth::Union{AbstractArray, Nothing} = nothing + on_vRest::Union{AbstractArray, Nothing} = nothing + on_zt0::Union{AbstractArray, Nothing} = nothing + on_zt1::Union{AbstractArray, Nothing} = nothing + on_refractoryCounter::Union{AbstractArray, Nothing} = nothing + on_refractoryDuration::Union{AbstractArray, Nothing} = nothing + on_alpha::Union{AbstractArray, Nothing} = nothing + on_delta::AbstractFloat = 1.0 + on_tau_m::AbstractFloat = 20.0 + on_phi::Union{AbstractArray, Nothing} = nothing + on_epsilonRec::Union{AbstractArray, Nothing} = nothing + on_eta::Union{AbstractArray, Nothing} = nothing + on_gammaPd::Union{AbstractArray, Nothing} = nothing + + on_firingCounter::Union{AbstractArray, Nothing} = nothing end # outer constructor function kfn_1(params::Dict) kfn = kfn_1() kfn.params = params + # ---------------------------------------------------------------------------- # # initialize activation matrix # # ---------------------------------------------------------------------------- # # row*col is a 2D matrix represent all RSNN activation row, col, batch = kfn.params[:inputPort][:signal][:numbers] # z-axis represent signal batch number - row += kfn.params[:inputPort][:noise][:numbers][1] - col += kfn.params[:inputPort][:signal][:numbers][2] + # row += kfn.params[:inputPort][:noise][:numbers][1] + col += kfn.params[:inputPort][:noise][:numbers][2] col += kfn.params[:computeNeuron][:lif][:numbers][2] col += kfn.params[:computeNeuron][:alif][:numbers][2] # activation matrix - kfn.zit = zeros(row, col, 1, batch) + kfn.zit = zeros(row, col, batch) + # ---------------------------------------------------------------------------- # # LIF config # # ---------------------------------------------------------------------------- # @@ -129,8 +156,8 @@ function kfn_1(params::Dict) end # project 3D w into 4D kfn.lif_wRec kfn.lif_wRec = reshape(w, (row, col, z, 1)) .* ones(row, col, z, batch) - kfn.lif_firingCounter = zeros(1, 1, z, batch) + # ---------------------------------------------------------------------------- # # ALIF config # # ---------------------------------------------------------------------------- # @@ -166,21 +193,48 @@ function kfn_1(params::Dict) slice[i] = randn()/10 # assign weight to synaptic connection end end - # project 3D w into 4D kfn.lif_wRec + # project 3D w into 4D kfn.alif_wRec kfn.alif_wRec = reshape(w, (row, col, z, 1)) .* ones(row, col, z, batch) - kfn.alif_firingCounter = zeros(1, 1, z, batch) + # ---------------------------------------------------------------------------- # + # output config # + # ---------------------------------------------------------------------------- # + #WORKING + z = kfn.params[:outputPort][:numbers][1] * kfn.params[:outputPort][:numbers][2] + kfn.on_zit = zeros(row, col, z, batch) + kfn.on_vt0 = zeros(1, 1, z, batch) + kfn.on_vt1 = zeros(1, 1, z, batch) + kfn.on_vth = ones(1, 1, z, batch) + kfn.on_vRest = zeros(1, 1, z, batch) + kfn.on_zt0 = zeros(1, 1, z, batch) + kfn.on_zt1 = zeros(1, 1, z, batch) + kfn.on_refractoryCounter = zeros(1, 1, z, batch) + kfn.on_refractoryDuration = ones(1, 1, z, batch) .* 1 + kfn.on_alpha = ones(1, 1, z, batch) .* (exp(-kfn.on_delta / kfn.on_tau_m)) + kfn.on_phi = zeros(1, 1, z, batch) + kfn.on_epsilonRec = zeros(row, col, z, batch) + kfn.on_eta = zeros(1, 1, z, batch) + kfn.on_gammaPd = zeros(1, 1, z, batch) .* 0.3 + + # subscription + w = zeros(row, col, z) + synapticConnectionPercent = kfn.params[:outputPort][:params][:synapticConnectionPercent] + synapticConnection = Int(floor(row*col * synapticConnectionPercent/100)) + for slice in eachslice(w, dims=3) + pool = shuffle!([1:row*col...])[1:synapticConnection] + for i in pool + slice[i] = randn()/10 # assign weight to synaptic connection + end + end + # project 3D w into 4D kfn.on_wRec + kfn.on_wRec = reshape(w, (row, col, z, 1)) .* ones(row, col, z, batch) + kfn.on_firingCounter = zeros(1, 1, z, batch) - - - - - # error("debug end outer constructor") return kfn end