add zit_cumulative

This commit is contained in:
ton
2023-08-19 06:22:03 +07:00
parent 57373621b8
commit f656c62266
2 changed files with 5 additions and 0 deletions

View File

@@ -18,6 +18,8 @@ function (kfn::kfn_1)(input::AbstractArray)
# what to do at the start of learning round # what to do at the start of learning round
if view(kfn.learningStage, 1)[1] == 1 if view(kfn.learningStage, 1)[1] == 1
# reset learning params # reset learning params
kfn.zit_cumulative .= 0
kfn.lif_vt .= 0 kfn.lif_vt .= 0
kfn.lif_wRecChange .= 0 kfn.lif_wRecChange .= 0
kfn.lif_epsilonRec .= 0 kfn.lif_epsilonRec .= 0
@@ -109,6 +111,7 @@ function (kfn::kfn_1)(input::AbstractArray)
reshape(kfn.lif_zt, (size(input, 1), :, 1, size(input, 3))), reshape(kfn.lif_zt, (size(input, 1), :, 1, size(input, 3))),
reshape(kfn.alif_zt, (size(input, 1), :, 1, size(input, 3))), dims=2) reshape(kfn.alif_zt, (size(input, 1), :, 1, size(input, 3))), dims=2)
kfn.zit .= reshape(_zit, (size(input, 1), :, size(input, 3))) kfn.zit .= reshape(_zit, (size(input, 1), :, size(input, 3)))
kfn.zit_cumulative .+= kfn.zit
# project 3D kfn zit into 4D on zit # project 3D kfn zit into 4D on zit
i1, i2, i3, i4 = size(kfn.on_zit) i1, i2, i3, i4 = size(kfn.on_zit)

View File

@@ -23,6 +23,7 @@ Base.@kwdef mutable struct kfn_1 <: knowledgeFn
learningStage::Union{AbstractArray, Nothing} = nothing # 0 inference, 1 start, 2 during, 3 end learning learningStage::Union{AbstractArray, Nothing} = nothing # 0 inference, 1 start, 2 during, 3 end learning
inputSize::Union{AbstractArray, Nothing} = nothing inputSize::Union{AbstractArray, Nothing} = nothing
zit::Union{AbstractArray, Nothing} = nothing # 3D activation matrix zit::Union{AbstractArray, Nothing} = nothing # 3D activation matrix
zit_cumulative::Union{AbstractArray, Nothing} = nothing
modelError::Union{AbstractArray, Nothing} = nothing # store RSNN error modelError::Union{AbstractArray, Nothing} = nothing # store RSNN error
outputError::Union{AbstractArray, Nothing} = nothing # store output neurons error outputError::Union{AbstractArray, Nothing} = nothing # store output neurons error
@@ -175,6 +176,7 @@ function kfn_1(params::Dict; device=cpu)
# activation matrix # activation matrix
kfn.zit = zeros(row, col, batch) |> device kfn.zit = zeros(row, col, batch) |> device
kfn.zit_cumulative = (similar(kfn.zit) .= 0) |> device
kfn.modelError = zeros(1) |> device kfn.modelError = zeros(1) |> device
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #