From bdec057886f3db6404f57edf9402f3ed173810ac Mon Sep 17 00:00:00 2001 From: ton Date: Sun, 23 Jul 2023 15:39:48 +0700 Subject: [PATCH] alif forward --- src/forward.jl | 147 +++++++++++++++++++++++++++++++++++-------------- src/type.jl | 76 ++++++++++++++++++------- 2 files changed, 162 insertions(+), 61 deletions(-) diff --git a/src/forward.jl b/src/forward.jl index f8ccb96..2e5d2bc 100644 --- a/src/forward.jl +++ b/src/forward.jl @@ -43,44 +43,32 @@ function (kfn::kfn_1)(input::AbstractArray) kfn.lif_phi, kfn.lif_epsilonRec, kfn.lif_refractoryCounter, - kfn.lif_refractoryDuration,) + kfn.lif_refractoryDuration, + kfn.lif_gammaPd) + alifForward( kfn.zit, + kfn.alif_zit, + kfn.alif_wRec, + kfn.alif_vt0, + kfn.alif_vt1, + kfn.alif_vth, + kfn.alif_avth, + kfn.alif_vRest, + kfn.alif_zt1, + kfn.alif_alpha, + kfn.alif_phi, + kfn.alif_epsilonRec, + kfn.alif_epsilonRecA, + kfn.alif_refractoryCounter, + kfn.alif_refractoryDuration, + kfn.alif_a, + kfn.alif_beta, + kfn.alif_rho, + kfn.alif_gammaPd) error("debug end kfn forward") - - - # kfn.lif_zit = GeneralUtils.matMul_3Dto4D_batchwise(kfn.zit, - # ones(size(kfn.zit)[1], size(kfn.zit)[2], size(kfn.lif_wRec)[3], size(kfn.zit)[3])) - - - - # check active/inactive neurons - # refractoryStatus!(kfn.lif_refractoryCounter, kfn.lif_refractoryActive, kfn.lif_refractoryInactive) - # refractoryStatus!(kfn.alif_refractoryCounter, kfn.alif_refractoryActive, kfn.alif_refractoryInactive) - - - # a = kfn.lif_refractoryActive .* kfn.lif_wRec - # lifForward.(kfn.lif_refractoryCounter, kfn.zit0, kfn.zit1, - # kfn.lif_vt0, kfn.lif_vt1, kfn.lif_alpha, kfn.lif_recSignal) - - # kfn.lif_recSignal .= GeneralUtils.sumAlongDim3( - # GeneralUtils.matMul_3Dto4D_batchwise(kfn.zit1, kfn.lif_refractoryActive .* kfn.lif_wRec)) - # kfn.lif_vt1 = (kfn.lif_alpha .* kfn.lif_vt0) .+ kfn.lif_recSignal - - - - - - # GeneralUtils.batchMatEleMul(kfn.zit1, kfn.alif_wRec, resultStorage=kfn.alif_recSignal) - - - - - - - end @@ -96,39 +84,114 @@ function lifForward(zit, lif_phi, lif_epsilonRec, lif_refractoryCounter, - lif_refractoryDuration,) + lif_refractoryDuration, + lif_gammaPd) _, _, d3, d4 = size(lif_wRec) lif_zit .= zit .* ones(size(lif_wRec)...) # project zit into lif_zit - for j in 1:d4, i in 1:d3 # compute along neurons axis of every batch if view(lif_refractoryCounter, :, :, i, j)[1] > 0 # refractory period is active view(lif_refractoryCounter, :, :, i, j)[1] -= 1 view(lif_zt1, :, :, i, j)[1] = 0 - view(lif_vt1, :, :, i, j)[1] = view(lif_alpha, :, :, i, j)[1] * view(lif_vt0, :, :, i, j)[1] + view(lif_vt1, :, :, i, j)[1] = + view(lif_alpha, :, :, i, j)[1] * view(lif_vt0, :, :, i, j)[1] view(lif_phi, :, :, i, j)[1] = 0.0 view(lif_epsilonRec, :, :, i, j) .= view(lif_alpha, :, :, i, j)[1] .* view(lif_epsilonRec, :, :, i, j) else # refractory period is inactive view(lif_vt1, :, :, i, j)[1] = - (view(lif_alpha, :, :, i, j)[1] * view(lif_vt0,:, :, i, j)[1]) + - sum(view(lif_zit, :, :, i, j) .* view(lif_wRec, :, :, i, j)) + (view(lif_alpha, :, :, i, j)[1] * view(lif_vt0,:, :, i, j)[1]) + + sum(view(lif_zit, :, :, i, j) .* view(lif_wRec, :, :, i, j)) if view(lif_vt1, :, :, i, j)[1] > view(lif_vth, :, :, i, j)[1] view(lif_zt1, :, :, i, j)[1] = 1 - view(lif_refractoryCounter, :, :, i, j)[1] = view(lif_refractoryDuration, :, :, i, j)[1] + view(lif_refractoryCounter, :, :, i, j)[1] = + view(lif_refractoryDuration, :, :, i, j)[1] view(lif_firingCounter, :, :, i, j)[1] += 1 view(lif_vt1, :, :, i, j)[1] = view(lif_vRest, :, :, i, j)[1] else view(lif_zt1, :, :, i, j)[1] = 0 end + # there is a difference from alif formula + view(lif_phi, :, :, i, j)[1] = + (view(lif_gammaPd, :, :, i, j)[1] / view(lif_vth, :, :, i, j)[1]) * + max(0, 1 - ((view(lif_vt1, :, :, i, j)[1] - view(lif_vth, :, :, i, j)[1]) / + view(lif_vth, :, :, i, j)[1])) + view(lif_epsilonRec, :, :, i, j) .= + (view(lif_alpha, :, :, i, j)[1] .* view(lif_epsilonRec, :, :, i, j)) + + view(lif_zit, :, :, i, j) end end - - error("debug end -> LIF forward") end +function alifForward(zit, + alif_zit, + alif_wRec, + alif_vt0, + alif_vt1, + alif_vth, + alif_avth, + alif_vRest, + alif_zt1, + alif_alpha, + alif_phi, + alif_epsilonRec, + alif_epsilonRecA, + alif_refractoryCounter, + alif_refractoryDuration, + alif_a, + alif_beta, + alif_rho, + alif_gammaPd) + _, _, d3, d4 = size(alif_wRec) + alif_zit .= zit .* ones(size(alif_wRec)...) # project zit into alif_zit + + for j in 1:d4, i in 1:d3 # compute along neurons axis of every batch + if view(alif_refractoryCounter, :, :, i, j)[1] > 0 # refractory period is active + view(alif_refractoryCounter, :, :, i, j)[1] -= 1 + view(alif_zt1, :, :, i, j)[1] = 0 + view(alif_vt1, :, :, i, j)[1] = view(alif_alpha, :, :, i, j)[1] * + view(alif_vt0, :, :, i, j)[1] + view(alif_phi, :, :, i, j)[1] = 0.0 + view(alif_epsilonRec, :, :, i, j) .= view(alif_alpha, :, :, i, j)[1] .* + view(alif_epsilonRec, :, :, i, j) + view(alif_a, :, :, i, j)[1] = + (view(alif_rho, :, :, i, j)[1] * view(alif_a, :, :, i, j)[1]) + 0 + else # refractory period is inactive + view(alif_vt1, :, :, i, j)[1] = + (view(alif_alpha, :, :, i, j)[1] * view(alif_vt0,:, :, i, j)[1]) + + sum(view(alif_zit, :, :, i, j) .* view(alif_wRec, :, :, i, j)) + view(alif_avth, :, :, i, j)[1] = view(alif_vth, :, :, i, j)[1] + + (view(alif_beta, :, :, i, j)[1] * view(alif_a, :, :, i, j)[1]) + if view(alif_vt1, :, :, i, j)[1] > view(alif_avth, :, :, i, j)[1] + view(alif_zt1, :, :, i, j)[1] = 1 + view(alif_refractoryCounter, :, :, i, j)[1] = + view(alif_refractoryDuration, :, :, i, j)[1] + view(alif_firingCounter, :, :, i, j)[1] += 1 + view(alif_vt1, :, :, i, j)[1] = view(alif_vRest, :, :, i, j)[1] + view(alif_a, :, :, i, j)[1] = (view(alif_rho, :, :, i, j)[1] * + view(alif_a, :, :, i, j)[1]) + 1 + else + view(alif_zt1, :, :, i, j)[1] = 0 + view(alif_a, :, :, i, j)[1] = + (view(alif_rho, :, :, i, j)[1] * view(alif_a, :, :, i, j)[1]) + 0 + end - + # there is a difference from alif formula + view(alif_phi, :, :, i, j)[1] = + (view(alif_gammaPd, :, :, i, j)[1] / view(alif_vth, :, :, i, j)[1]) * + max(0, 1 - ((view(alif_vt1, :, :, i, j)[1] - view(alif_avth, :, :, i, j)[1]) / + view(alif_vth, :, :, i, j)[1])) + view(alif_epsilonRec, :, :, i, j) .= + (view(alif_alpha, :, :, i, j) .* view(alif_epsilonRec, :, :, i, j)) + + view(alif_zit, :, :, i, j) + view(alif_epsilonRecA, :, :, i, j) .= + (view(alif_phi, :, :, i, j)[1] .* view(alif_epsilonRec, :, :, i, j)) + + ((view(alif_rho, :, :, i, j)[1] - + (view(alif_phi, :, :, i, j)[1] * view(alif_beta, :, :, i, j)[1])) .* + view(alif_epsilonRecA, :, :, i, j)) + end + end +end diff --git a/src/type.jl b/src/type.jl index 43f633e..6c66c0c 100644 --- a/src/type.jl +++ b/src/type.jl @@ -30,7 +30,6 @@ Base.@kwdef mutable struct kfn_1 <: knowledgeFn lif_zit::Union{AbstractArray, Nothing} = nothing lif_wRec::Union{AbstractArray, Nothing} = nothing - # lif_recSignal::Union{AbstractArray, Nothing} = nothing lif_vt0::Union{AbstractArray, Nothing} = nothing lif_vt1::Union{AbstractArray, Nothing} = nothing lif_vth::Union{AbstractArray, Nothing} = nothing @@ -39,33 +38,55 @@ Base.@kwdef mutable struct kfn_1 <: knowledgeFn lif_zt1::Union{AbstractArray, Nothing} = nothing lif_refractoryCounter::Union{AbstractArray, Nothing} = nothing lif_refractoryDuration::Union{AbstractArray, Nothing} = nothing - # lif_refractoryActive::Union{AbstractArray, Nothing} = nothing - # lif_refractoryInactive::Union{AbstractArray, Nothing} = nothing lif_alpha::Union{AbstractArray, Nothing} = nothing lif_delta::AbstractFloat = 1.0 lif_tau_m::AbstractFloat = 20.0 lif_phi::Union{AbstractArray, Nothing} = nothing lif_epsilonRec::Union{AbstractArray, Nothing} = nothing + lif_eta::Union{AbstractArray, Nothing} = nothing + lif_gammaPd::Union{AbstractArray, Nothing} = nothing lif_firingCounter::Union{AbstractArray, Nothing} = nothing # ---------------------------------------------------------------------------- # # ALIF # # ---------------------------------------------------------------------------- # + alif_zit::Union{AbstractArray, Nothing} = nothing + alif_wRec::Union{AbstractArray, Nothing} = nothing - alif_recSignal::Union{AbstractArray, Nothing} = nothing + alif_vt0::Union{AbstractArray, Nothing} = nothing + alif_vt1::Union{AbstractArray, Nothing} = nothing + alif_vth::Union{AbstractArray, Nothing} = nothing + alif_avth::Union{AbstractArray, Nothing} = nothing + alif_vRest::Union{AbstractArray, Nothing} = nothing alif_zt0::Union{AbstractArray, Nothing} = nothing alif_zt1::Union{AbstractArray, Nothing} = nothing alif_refractoryCounter::Union{AbstractArray, Nothing} = nothing - alif_refractoryActive::Union{AbstractArray, Nothing} = nothing - alif_refractoryInactive::Union{AbstractArray, Nothing} = nothing + alif_refractoryDuration::Union{AbstractArray, Nothing} = nothing + alif_alpha::Union{AbstractArray, Nothing} = nothing + alif_delta::AbstractFloat = 1.0 + alif_tau_m::AbstractFloat = 20.0 + alif_phi::Union{AbstractArray, Nothing} = nothing + alif_epsilonRec::Union{AbstractArray, Nothing} = nothing + alif_epsilonRecA::Union{AbstractArray, Nothing} = nothing + alif_eta::Union{AbstractArray, Nothing} = nothing + alif_gammaPd::Union{AbstractArray, Nothing} = nothing + + alif_firingCounter::Union{AbstractArray, Nothing} = nothing + + alif_a::Union{AbstractArray, Nothing} = nothing # threshold adaptation + alif_beta::Union{AbstractArray, Nothing} = nothing # β, constant, value from paper + alif_rho::Union{AbstractArray, Nothing} = nothing # ρ, threshold adaptation decay factor + alif_tau_a::AbstractFloat = 100.0 # τ_a, adaption time constant in millisecond end # outer constructor function kfn_1(params::Dict) kfn = kfn_1() kfn.params = params - # ----------------------- initialize activation matrix ----------------------- # + # ---------------------------------------------------------------------------- # + # initialize activation matrix # + # ---------------------------------------------------------------------------- # # row*col is a 2D matrix represent all RSNN activation row, col, batch = kfn.params[:inputPort][:signal][:numbers] # z-axis represent signal batch number row += kfn.params[:inputPort][:noise][:numbers][1] @@ -74,14 +95,14 @@ function kfn_1(params::Dict) col += kfn.params[:computeNeuron][:alif][:numbers][2] # activation matrix - kfn.zit = zeros(row, col, 1, batch) - - # -------------------------------- LIF config -------------------------------- # + kfn.zit = zeros(row, col, 1, batch) + # ---------------------------------------------------------------------------- # + # LIF config # + # ---------------------------------------------------------------------------- # # In 3D LIF matrix, z-axis represent each neuron while each 2D slice represent that neuron's # synaptic subscription to other neurons (via activation matrix) z = kfn.params[:computeNeuron][:lif][:numbers][1] * kfn.params[:computeNeuron][:lif][:numbers][2] kfn.lif_zit = zeros(row, col, z, batch) - # kfn.lif_recSignal = zeros(1, 1, z, batch) kfn.lif_vt0 = zeros(1, 1, z, batch) kfn.lif_vt1 = zeros(1, 1, z, batch) kfn.lif_vth = ones(1, 1, z, batch) @@ -90,11 +111,11 @@ function kfn_1(params::Dict) kfn.lif_zt1 = zeros(1, 1, z, batch) kfn.lif_refractoryCounter = zeros(1, 1, z, batch) kfn.lif_refractoryDuration = ones(1, 1, z, batch) .* 3 - # kfn.lif_refractoryActive = zeros(1, 1, z, batch) - # kfn.lif_refractoryInactive = zeros(1, 1, z, batch) kfn.lif_alpha = ones(1, 1, z, batch) .* (exp(-kfn.lif_delta / kfn.lif_tau_m)) kfn.lif_phi = zeros(1, 1, z, batch) kfn.lif_epsilonRec = zeros(row, col, z, batch) + kfn.lif_eta = zeros(1, 1, z, batch) + kfn.lif_gammaPd = zeros(1, 1, z, batch) .* 0.3 # subscription w = zeros(row, col, z) @@ -110,15 +131,30 @@ function kfn_1(params::Dict) kfn.lif_wRec = reshape(w, (row, col, z, 1)) .* ones(row, col, z, batch) kfn.lif_firingCounter = zeros(1, 1, z, batch) - - # -------------------------------- ALIF config ------------------------------- # + # ---------------------------------------------------------------------------- # + # ALIF config # + # ---------------------------------------------------------------------------- # z = kfn.params[:computeNeuron][:alif][:numbers][1] * kfn.params[:computeNeuron][:alif][:numbers][2] - kfn.alif_recSignal = zeros(1, 1, z, batch) + kfn.alif_zit = zeros(row, col, z, batch) + kfn.alif_vt0 = zeros(1, 1, z, batch) + kfn.alif_vt1 = zeros(1, 1, z, batch) + kfn.alif_vth = ones(1, 1, z, batch) + kfn.alif_avth = ones(1, 1, z, batch) + kfn.alif_vRest = zeros(1, 1, z, batch) kfn.alif_zt0 = zeros(1, 1, z, batch) kfn.alif_zt1 = zeros(1, 1, z, batch) kfn.alif_refractoryCounter = zeros(1, 1, z, batch) - kfn.alif_refractoryActive = zeros(1, 1, z, batch) - kfn.alif_refractoryInactive = zeros(1, 1, z, batch) + kfn.alif_refractoryDuration = ones(1, 1, z, batch) .* 3 + kfn.alif_alpha = ones(1, 1, z, batch) .* (exp(-kfn.alif_delta / kfn.alif_tau_m)) + kfn.alif_phi = zeros(1, 1, z, batch) + kfn.alif_epsilonRec = zeros(row, col, z, batch) + kfn.alif_epsilonRecA = zeros(row, col, z, batch) + kfn.alif_eta = zeros(1, 1, z, batch) + kfn.alif_gammaPd = zeros(1, 1, z, batch) .* 0.3 + + kfn.alif_a = zeros(1, 1, z, batch) + kfn.alif_beta = zeros(1, 1, z, batch) .* 0.15 + kfn.alif_rho = zeros(1, 1, z, batch) .* (exp(-kfn.alif_delta / kfn.alif_tau_a)) # subscription w = zeros(row, col, z) @@ -127,12 +163,14 @@ function kfn_1(params::Dict) for slice in eachslice(w, dims=3) pool = shuffle!([1:row*col...])[1:synapticConnection] for i in pool - slice[i] = randn()/10 + slice[i] = randn()/10 # assign weight to synaptic connection end end # project 3D w into 4D kfn.lif_wRec kfn.alif_wRec = reshape(w, (row, col, z, 1)) .* ones(row, col, z, batch) + kfn.alif_firingCounter = zeros(1, 1, z, batch) +