From b7c87bd0fab6bb8ee6204ee9a7faa94c627eb9b7 Mon Sep 17 00:00:00 2001 From: ton Date: Sun, 23 Jul 2023 11:03:08 +0700 Subject: [PATCH] lif forward --- src/forward.jl | 141 ++++++++++++++++++++++++++++--------------------- src/snnUtil.jl | 1 - src/type.jl | 56 +++++++++++--------- 3 files changed, 114 insertions(+), 84 deletions(-) diff --git a/src/forward.jl b/src/forward.jl index 095c5b4..f8ccb96 100644 --- a/src/forward.jl +++ b/src/forward.jl @@ -17,91 +17,114 @@ function (kfn::kfn_1)(input::AbstractArray) end println(">>> input ", size(input)) - - # pass input_data into input neuron. - GeneralUtils.cartesianAssign!(kfn.z_i_t, input) - kfn.lif_z_i_t = GeneralUtils.matMul_3Dto4D_batchwise(kfn.z_i_t, - ones(size(kfn.z_i_t)[1], size(kfn.z_i_t)[2], size(kfn.lif_w)[3], size(kfn.z_i_t)[3])) - - println(">>> z_i_t ", size(kfn.z_i_t)) - println(">>> lif_z_i_t ", size(kfn.lif_z_i_t)) - println(">>> lif_recSignal ", size(kfn.lif_recSignal)) - println(">>> lif_w ", size(kfn.lif_w)) - println(">>> lif_refractoryActive ", size(kfn.lif_refractoryCounter)) + println(">>> zit ", size(kfn.zit)) + # println(">>> lif_zit ", size(kfn.lif_zit)) + # println(">>> lif_recSignal ", size(kfn.lif_recSignal)) + println(">>> lif_wRec ", size(kfn.lif_wRec)) + println(">>> lif_refractoryCounter ", size(kfn.lif_refractoryCounter)) println(">>> lif_alpha ", size(kfn.lif_alpha)) println(">>> lif_vt0 ", size(kfn.lif_vt0)) println(">>> lif_vt0 sum ", sum(kfn.lif_vt0)) - # check active/inactive neurons - refractoryStatus!(kfn.lif_refractoryCounter, kfn.lif_refractoryActive, kfn.lif_refractoryInactive) - refractoryStatus!(kfn.alif_refractoryCounter, kfn.alif_refractoryActive, kfn.alif_refractoryInactive) - + # pass input_data into input neuron. + s1, s2, s3 = size(input) + GeneralUtils.cartesianAssign!(kfn.zit, reshape(input, (s1, s2, 1, s3))) + #WORKING LIF forward active neurons - # a = kfn.lif_refractoryActive .* kfn.lif_w - # lifForward.(kfn.lif_refractoryCounter, kfn.z_i_t0, kfn.z_i_t1, + lifForward( kfn.zit, + kfn.lif_zit, + kfn.lif_wRec, + kfn.lif_vt0, + kfn.lif_vt1, + kfn.lif_vth, + kfn.lif_vRest, + kfn.lif_zt1, + kfn.lif_alpha, + kfn.lif_phi, + kfn.lif_epsilonRec, + kfn.lif_refractoryCounter, + kfn.lif_refractoryDuration,) + + + + + error("debug end kfn forward") + + + # kfn.lif_zit = GeneralUtils.matMul_3Dto4D_batchwise(kfn.zit, + # ones(size(kfn.zit)[1], size(kfn.zit)[2], size(kfn.lif_wRec)[3], size(kfn.zit)[3])) + + + + # check active/inactive neurons + # refractoryStatus!(kfn.lif_refractoryCounter, kfn.lif_refractoryActive, kfn.lif_refractoryInactive) + # refractoryStatus!(kfn.alif_refractoryCounter, kfn.alif_refractoryActive, kfn.alif_refractoryInactive) + + + # a = kfn.lif_refractoryActive .* kfn.lif_wRec + # lifForward.(kfn.lif_refractoryCounter, kfn.zit0, kfn.zit1, # kfn.lif_vt0, kfn.lif_vt1, kfn.lif_alpha, kfn.lif_recSignal) # kfn.lif_recSignal .= GeneralUtils.sumAlongDim3( - # GeneralUtils.matMul_3Dto4D_batchwise(kfn.z_i_t1, kfn.lif_refractoryActive .* kfn.lif_w)) + # GeneralUtils.matMul_3Dto4D_batchwise(kfn.zit1, kfn.lif_refractoryActive .* kfn.lif_wRec)) # kfn.lif_vt1 = (kfn.lif_alpha .* kfn.lif_vt0) .+ kfn.lif_recSignal - # GeneralUtils.batchMatEleMul(kfn.z_i_t1, kfn.alif_w, resultStorage=kfn.alif_recSignal) + # GeneralUtils.batchMatEleMul(kfn.zit1, kfn.alif_wRec, resultStorage=kfn.alif_recSignal) - error("debug end kfn forward") + end -function lifForward(lif_refractoryCounter, z_i_t0, z_i_t1, lif_w, lif_vt0, lif_vt1, lif_alpha, - lif_recSignal) - error("debug end LIF forward") - - - - # if n.refractoryCounter != 0 - # n.refractoryCounter -= 1 +function lifForward(zit, + lif_zit, + lif_wRec, + lif_vt0, + lif_vt1, + lif_vth, + lif_vRest, + lif_zt1, + lif_alpha, + lif_phi, + lif_epsilonRec, + lif_refractoryCounter, + lif_refractoryDuration,) + _, _, d3, d4 = size(lif_wRec) + lif_zit .= zit .* ones(size(lif_wRec)...) # project zit into lif_zit - # # neuron is in refractory state, skip all calculation - # n.z_t1 = false # used by timestep_forward() in kfn. Set to zero because neuron spike - # # last only 1 timestep follow by a period of refractory. - # n.recSignal = n.recSignal * 0.0 - # # decay of v_t1 - # n.v_t1 = n.alpha * n.v_t + for j in 1:d4, i in 1:d3 # compute along neurons axis of every batch + if view(lif_refractoryCounter, :, :, i, j)[1] > 0 # refractory period is active + view(lif_refractoryCounter, :, :, i, j)[1] -= 1 + view(lif_zt1, :, :, i, j)[1] = 0 + view(lif_vt1, :, :, i, j)[1] = view(lif_alpha, :, :, i, j)[1] * view(lif_vt0, :, :, i, j)[1] + view(lif_phi, :, :, i, j)[1] = 0.0 + view(lif_epsilonRec, :, :, i, j) .= view(lif_alpha, :, :, i, j)[1] .* + view(lif_epsilonRec, :, :, i, j) + else # refractory period is inactive + view(lif_vt1, :, :, i, j)[1] = + (view(lif_alpha, :, :, i, j)[1] * view(lif_vt0,:, :, i, j)[1]) + + sum(view(lif_zit, :, :, i, j) .* view(lif_wRec, :, :, i, j)) + if view(lif_vt1, :, :, i, j)[1] > view(lif_vth, :, :, i, j)[1] + view(lif_zt1, :, :, i, j)[1] = 1 + view(lif_refractoryCounter, :, :, i, j)[1] = view(lif_refractoryDuration, :, :, i, j)[1] + view(lif_firingCounter, :, :, i, j)[1] += 1 + view(lif_vt1, :, :, i, j)[1] = view(lif_vRest, :, :, i, j)[1] + else + view(lif_zt1, :, :, i, j)[1] = 0 + end + end + end - # n.phi = 0.0 - # n.decayedEpsilonRec = n.alpha * n.epsilonRec - # n.epsilonRec = n.decayedEpsilonRec - # else - # n.recSignal = sum(n.wRec .* n.z_i_t) # signal from other neuron that this neuron subscribed - - # # computeAlpha!(n) - # n.alpha_v_t = n.alpha * n.v_t - # n.v_t1 = n.alpha_v_t + n.recSignal - # # n.v_t1 = no_negative!(n.v_t1) - - # if n.v_t1 > n.v_th - # n.z_t1 = true - # n.refractoryCounter = n.refractoryDuration - # n.firingCounter += 1 - # n.v_t1 = n.vRest - # else - # n.z_t1 = false - # end - - # # there is a difference from alif formula - # n.phi = (n.gammaPd / n.v_th) * max(0, 1 - (n.v_t1 - n.v_th) / n.v_th) - # n.decayedEpsilonRec = n.alpha * n.epsilonRec - # n.epsilonRec = n.decayedEpsilonRec + n.z_i_t - # end + error("debug end -> LIF forward") end diff --git a/src/snnUtil.jl b/src/snnUtil.jl index 7edf5c8..168ba9e 100644 --- a/src/snnUtil.jl +++ b/src/snnUtil.jl @@ -72,7 +72,6 @@ end - end # module \ No newline at end of file diff --git a/src/type.jl b/src/type.jl index f1a7312..43f633e 100644 --- a/src/type.jl +++ b/src/type.jl @@ -21,40 +21,44 @@ Base.@kwdef mutable struct kfn_1 <: knowledgeFn timeStep::AbstractArray = [0] learningStage::AbstractArray = [0] # 0 inference, 1 start, 2 during, 3 end learning - z_i_t::Union{AbstractArray, Nothing} = nothing # 3D activation matrix + zit::Union{AbstractArray, Nothing} = nothing # 3D activation matrix # ---------------------------------------------------------------------------- # # LIF # # ---------------------------------------------------------------------------- # - # a projection of kfn.z_i_t into lif dimension for broadcasting later) - lif_z_i_t::Union{AbstractArray, Nothing} = nothing + # a projection of kfn.zit into lif dimension for broadcasting later) + lif_zit::Union{AbstractArray, Nothing} = nothing - lif_w::Union{AbstractArray, Nothing} = nothing - lif_recSignal::Union{AbstractArray, Nothing} = nothing + lif_wRec::Union{AbstractArray, Nothing} = nothing + # lif_recSignal::Union{AbstractArray, Nothing} = nothing lif_vt0::Union{AbstractArray, Nothing} = nothing lif_vt1::Union{AbstractArray, Nothing} = nothing lif_vth::Union{AbstractArray, Nothing} = nothing + lif_vRest::Union{AbstractArray, Nothing} = nothing lif_zt0::Union{AbstractArray, Nothing} = nothing lif_zt1::Union{AbstractArray, Nothing} = nothing lif_refractoryCounter::Union{AbstractArray, Nothing} = nothing - lif_refractoryActive::Union{AbstractArray, Nothing} = nothing - lif_refractoryInactive::Union{AbstractArray, Nothing} = nothing + lif_refractoryDuration::Union{AbstractArray, Nothing} = nothing + # lif_refractoryActive::Union{AbstractArray, Nothing} = nothing + # lif_refractoryInactive::Union{AbstractArray, Nothing} = nothing lif_alpha::Union{AbstractArray, Nothing} = nothing lif_delta::AbstractFloat = 1.0 lif_tau_m::AbstractFloat = 20.0 + lif_phi::Union{AbstractArray, Nothing} = nothing + lif_epsilonRec::Union{AbstractArray, Nothing} = nothing + + lif_firingCounter::Union{AbstractArray, Nothing} = nothing # ---------------------------------------------------------------------------- # # ALIF # # ---------------------------------------------------------------------------- # - alif_w::Union{AbstractArray, Nothing} = nothing + alif_wRec::Union{AbstractArray, Nothing} = nothing alif_recSignal::Union{AbstractArray, Nothing} = nothing alif_zt0::Union{AbstractArray, Nothing} = nothing alif_zt1::Union{AbstractArray, Nothing} = nothing alif_refractoryCounter::Union{AbstractArray, Nothing} = nothing alif_refractoryActive::Union{AbstractArray, Nothing} = nothing alif_refractoryInactive::Union{AbstractArray, Nothing} = nothing - - end # outer constructor @@ -70,23 +74,27 @@ function kfn_1(params::Dict) col += kfn.params[:computeNeuron][:alif][:numbers][2] # activation matrix - kfn.z_i_t = zeros(row, col, batch) + kfn.zit = zeros(row, col, 1, batch) # -------------------------------- LIF config -------------------------------- # # In 3D LIF matrix, z-axis represent each neuron while each 2D slice represent that neuron's # synaptic subscription to other neurons (via activation matrix) z = kfn.params[:computeNeuron][:lif][:numbers][1] * kfn.params[:computeNeuron][:lif][:numbers][2] - - kfn.lif_recSignal = zeros(1, 1, z, batch) + kfn.lif_zit = zeros(row, col, z, batch) + # kfn.lif_recSignal = zeros(1, 1, z, batch) kfn.lif_vt0 = zeros(1, 1, z, batch) kfn.lif_vt1 = zeros(1, 1, z, batch) kfn.lif_vth = ones(1, 1, z, batch) + kfn.lif_vRest = zeros(1, 1, z, batch) kfn.lif_zt0 = zeros(1, 1, z, batch) kfn.lif_zt1 = zeros(1, 1, z, batch) kfn.lif_refractoryCounter = zeros(1, 1, z, batch) - kfn.lif_refractoryActive = zeros(1, 1, z, batch) - kfn.lif_refractoryInactive = zeros(1, 1, z, batch) + kfn.lif_refractoryDuration = ones(1, 1, z, batch) .* 3 + # kfn.lif_refractoryActive = zeros(1, 1, z, batch) + # kfn.lif_refractoryInactive = zeros(1, 1, z, batch) kfn.lif_alpha = ones(1, 1, z, batch) .* (exp(-kfn.lif_delta / kfn.lif_tau_m)) + kfn.lif_phi = zeros(1, 1, z, batch) + kfn.lif_epsilonRec = zeros(row, col, z, batch) # subscription w = zeros(row, col, z) @@ -98,14 +106,13 @@ function kfn_1(params::Dict) slice[i] = randn()/10 # assign weight to synaptic connection end end - #WORKING project 3D w into 4D kfn.lif_w - kfn.lif_w = reshape(w, (row, col, z, 1)) .* ones(row, col, z, batch) - println(">>> lif_w ", size(kfn.lif_w)) - error("end WORKING") + # project 3D w into 4D kfn.lif_wRec + kfn.lif_wRec = reshape(w, (row, col, z, 1)) .* ones(row, col, z, batch) - # ALIF + kfn.lif_firingCounter = zeros(1, 1, z, batch) + + # -------------------------------- ALIF config ------------------------------- # z = kfn.params[:computeNeuron][:alif][:numbers][1] * kfn.params[:computeNeuron][:alif][:numbers][2] - kfn.alif_w = zeros(row, col, z) # matrix z-axis represent each neurons kfn.alif_recSignal = zeros(1, 1, z, batch) kfn.alif_zt0 = zeros(1, 1, z, batch) kfn.alif_zt1 = zeros(1, 1, z, batch) @@ -114,16 +121,17 @@ function kfn_1(params::Dict) kfn.alif_refractoryInactive = zeros(1, 1, z, batch) # subscription - row, col, _ = size(kfn.alif_w) # row*col is synaptic subscribe weight for each neuron in z-axis + w = zeros(row, col, z) synapticConnectionPercent = kfn.params[:computeNeuron][:alif][:params][:synapticConnectionPercent] synapticConnection = Int(floor(row*col * synapticConnectionPercent/100)) - for slice in eachslice(kfn.alif_w, dims=3) + for slice in eachslice(w, dims=3) pool = shuffle!([1:row*col...])[1:synapticConnection] for i in pool slice[i] = randn()/10 end end - + # project 3D w into 4D kfn.lif_wRec + kfn.alif_wRec = reshape(w, (row, col, z, 1)) .* ones(row, col, z, batch)