diff --git a/src/forward.jl b/src/forward.jl index d5f9498..1684606 100644 --- a/src/forward.jl +++ b/src/forward.jl @@ -15,16 +15,19 @@ function (kfn::kfn_1)(input::AbstractArray) kfn.timeStep .+= 1 - #TODO time step forward + # what to do at the start of learning round if view(kfn.learningStage, 1)[1] == 1 # reset learning params kfn.lif_vt .= 0 kfn.lif_wRecChange .= 0 kfn.lif_epsilonRec .= 0 + kfn.lif_firingCounter .= 0 kfn.alif_vt .= 0 kfn.alif_epsilonRec .= 0 + kfn.alif_epsilonRecA .= 0 kfn.alif_wRecChange .= 0 + kfn.alif_firingCounter .= 0 kfn.on_vt .= 0 kfn.on_epsilonRec .= 0 diff --git a/src/learn.jl b/src/learn.jl index f147b2a..d5213a9 100644 --- a/src/learn.jl +++ b/src/learn.jl @@ -16,7 +16,9 @@ function compute_paramsChange!(kfn::kfn_1, modelError, outputError) kfn.lif_eRec, kfn.lif_wRec, kfn.lif_wRecChange, - kfn.on_wOut, + kfn.on_wOut, + kfn.lif_firingCounter, + kfn.lif_firingTargetFrequency, kfn.lif_arrayProjection4d, kfn.lif_error, modelError, @@ -30,7 +32,9 @@ function compute_paramsChange!(kfn::kfn_1, modelError, outputError) kfn.alif_eRec, kfn.alif_wRec, kfn.alif_wRecChange, - kfn.on_wOut, + kfn.on_wOut, + kfn.alif_firingCounter, + kfn.alif_firingTargetFrequency, kfn.alif_arrayProjection4d, kfn.alif_error, modelError, @@ -59,6 +63,8 @@ function lifComputeParamsChange!( phi::CuArray, wRec::CuArray, wRecChange::CuArray, wOut::CuArray, + firingCounter::CuArray, + firingTargetFrequency::CuArray, arrayProjection4d::CuArray, nError::CuArray, modelError::CuArray, @@ -81,6 +87,8 @@ function lifComputeParamsChange!( phi::CuArray, eRec .= phi .* epsilonRec wRecChange .+= (-eta .* nError .* eRec) + #TODO frequency regulator + # reset epsilonRec epsilonRec .= 0 end @@ -92,6 +100,8 @@ function alifComputeParamsChange!( phi::CuArray, wRec::CuArray, wRecChange::CuArray, wOut::CuArray, + firingCounter::CuArray, + firingTargetFrequency::CuArray, arrayProjection4d::CuArray, nError::CuArray, modelError::CuArray, @@ -114,6 +124,8 @@ function alifComputeParamsChange!( phi::CuArray, eRec .= phi .* (epsilonRec .- (beta .* epsilonRecA)) # use eq. 25 wRecChange .+= (-eta .* nError .* eRec) + #TODO frequency regulator + # reset epsilonRec epsilonRec .= 0 epsilonRecA .= 0 diff --git a/src/type.jl b/src/type.jl index ec1a905..2bdac00 100644 --- a/src/type.jl +++ b/src/type.jl @@ -54,6 +54,7 @@ Base.@kwdef mutable struct kfn_1 <: knowledgeFn lif_subscription::Union{AbstractArray, Nothing} = nothing lif_firingCounter::Union{AbstractArray, Nothing} = nothing + lif_firingTargetFrequency::Union{AbstractArray, Nothing} = nothing # pre-allocation array lif_arrayProjection4d::Union{AbstractArray, Nothing} = nothing # use to project 3d array to 4d @@ -90,6 +91,7 @@ Base.@kwdef mutable struct kfn_1 <: knowledgeFn alif_subscription::Union{AbstractArray, Nothing} = nothing alif_firingCounter::Union{AbstractArray, Nothing} = nothing + alif_firingTargetFrequency::Union{AbstractArray, Nothing} = nothing # pre-allocation array alif_arrayProjection4d::Union{AbstractArray, Nothing} = nothing # use to project 3d array to 4d @@ -215,6 +217,7 @@ function kfn_1(params::Dict; device=cpu) kfn.lif_subscription = (GeneralUtils.isNotEqual.(kfn.lif_wRec, 0)) |> device kfn.lif_firingCounter = (similar(kfn.lif_wRec) .= 0) |> device + kfn.lif_firingTargetFrequency = (similar(kfn.lif_wRec) .= 80) |> device kfn.lif_arrayProjection4d = (similar(kfn.lif_wRec) .= 1) |> device kfn.lif_recSignal = (similar(kfn.lif_wRec) .= 0) |> device @@ -262,6 +265,7 @@ function kfn_1(params::Dict; device=cpu) kfn.alif_subscription = (GeneralUtils.isNotEqual.(kfn.alif_wRec, 0)) |> device kfn.alif_firingCounter = (similar(kfn.alif_wRec) .= 0) |> device + kfn.alif_firingTargetFrequency = (similar(kfn.alif_wRec) .= 80) |> device kfn.alif_arrayProjection4d = (similar(kfn.alif_wRec) .= 1) |> device kfn.alif_recSignal = (similar(kfn.alif_wRec) .= 0) |> device