From 97dd3b65c4562cbd69aa83f5a4846c32b5790258 Mon Sep 17 00:00:00 2001 From: ton Date: Sat, 5 Aug 2023 19:17:28 +0700 Subject: [PATCH] minor fix --- src/learn.jl | 67 ++++++++++++++++++++++++++-------------------------- src/type.jl | 6 ++--- 2 files changed, 36 insertions(+), 37 deletions(-) diff --git a/src/learn.jl b/src/learn.jl index a5108c8..f54399d 100644 --- a/src/learn.jl +++ b/src/learn.jl @@ -46,30 +46,7 @@ function compute_paramsChange!(kfn::kfn_1, modelError, outputError) kfn.on_wOut, kfn.on_wOutChange, outputError) - error("DEBUG -> kfn compute_paramsChange! $(Dates.now())") -end - -function lifComputeParamsChange!( phi::CuArray, - epsilonRec::CuArray, - eta::CuArray, - eRec::CuArray, - wRec::CuArray, - wRecChange::CuArray, - wOut::CuArray, - arrayProjection4d::CuArray, - nError::CuArray, - modelError::CuArray) - - wOutSum = sum(wOut, dims=3) .* arrayProjection4d - - # nError a.k.a. learning signal use dopamine concept, - # this neuron receive summed error signal (modelError) - nError .= (modelError .* arrayProjection4d) .* wOutSum - eRec .= phi .* epsilonRec - - # GeneralUtils.isNotEqual(wRec, 0) is a subscribe filter use to filter out non-subscribed wRecChange - wRecChange .+= ((-1 .* eta) .* nError .* eRec) .* GeneralUtils.isNotEqual.(wRec, 0) - # error("DEBUG -> lifComputeParamsChange! $(Dates.now())") + # error("DEBUG -> kfn compute_paramsChange! $(Dates.now())") end function lifComputeParamsChange!( phi::AbstractArray, @@ -101,15 +78,36 @@ function lifComputeParamsChange!( phi::AbstractArray, end end -function alifComputeParamsChange!( phi, - epsilonRec, - epsilonRecA, - eta, - wRec, - wRecChange, - beta, - wOut, - modelError) +function lifComputeParamsChange!( phi::CuArray, + epsilonRec::CuArray, + eta::CuArray, + eRec::CuArray, + wRec::CuArray, + wRecChange::CuArray, + wOut::CuArray, + arrayProjection4d::CuArray, + nError::CuArray, + modelError::CuArray) + wOutSum = sum(wOut, dims=3) .* arrayProjection4d + + # nError a.k.a. learning signal use dopamine concept, + # this neuron receive summed error signal (modelError) + nError .= (modelError .* arrayProjection4d) .* wOutSum + eRec .= phi .* epsilonRec + # GeneralUtils.isNotEqual(wRec, 0) is a subscribe filter use to filter out non-subscribed wRecChange + wRecChange .+= ((-1 .* eta) .* nError .* eRec) .* GeneralUtils.isNotEqual.(wRec, 0) + # error("DEBUG -> lifComputeParamsChange! $(Dates.now())") +end + +function alifComputeParamsChange!( phi::AbstractArray, + epsilonRec::AbstractArray, + epsilonRecA::AbstractArray, + eta::AbstractArray, + wRec::AbstractArray, + wRecChange::AbstractArray, + beta::AbstractArray, + wOut::AbstractArray, + modelError::AbstractArray) d1, d2, d3, d4 = size(epsilonRec) # Bₖⱼ in paper, sum() to get each neuron's total wOut weight @@ -202,10 +200,11 @@ function onComputeParamsChange!(phi::CuArray, end function learn!(kfn::kfn_1) + println(">>> lif_wRecChange ", sum(kfn.lif_wRecChange[:, :, 1, 1])) #WORKING lif learn lifLearn!(kfn.lif_wRec, kfn.lif_wRecChange) - + error("DEBUG -> kfn learn! $(Dates.now())") #TODO alif learn diff --git a/src/type.jl b/src/type.jl index 049cc17..c56e313 100644 --- a/src/type.jl +++ b/src/type.jl @@ -204,7 +204,7 @@ function kfn_1(params::Dict; device=cpu) kfn.lif_phi = (similar(kfn.lif_wRec) .= 0) |> device kfn.lif_epsilonRec = (similar(kfn.lif_wRec) .= 0) |> device kfn.lif_eRec = (similar(kfn.lif_wRec) .= 0) |> device - kfn.lif_eta = (similar(kfn.lif_wRec) .= 0) |> device + kfn.lif_eta = (similar(kfn.lif_wRec) .= 0.001) |> device kfn.lif_gammaPd = (similar(kfn.lif_wRec) .= 0.3) |> device kfn.lif_wRecChange = (similar(kfn.lif_wRec) .= 0) |> device kfn.lif_error = (similar(kfn.lif_wRec) .= 0) |> device @@ -250,7 +250,7 @@ function kfn_1(params::Dict; device=cpu) kfn.alif_phi = (similar(kfn.alif_wRec) .= 0) |> device kfn.alif_epsilonRec = (similar(kfn.alif_wRec) .= 0) |> device kfn.alif_eRec = (similar(kfn.alif_wRec) .= 0) |> device - kfn.alif_eta = (similar(kfn.alif_wRec) .= 0) |> device + kfn.alif_eta = (similar(kfn.alif_wRec) .= 0.001) |> device kfn.alif_gammaPd = (similar(kfn.alif_wRec) .= 0.3) |> device kfn.alif_wRecChange = (similar(kfn.alif_wRec) .= 0) |> device kfn.alif_error = (similar(kfn.alif_wRec) .= 0) |> device @@ -309,7 +309,7 @@ function kfn_1(params::Dict; device=cpu) kfn.on_phi = (similar(kfn.on_wOut) .= 0) |> device kfn.on_epsilonRec = (similar(kfn.on_wOut) .= 0) |> device kfn.on_eRec = (similar(kfn.on_wOut) .= 0) |> device - kfn.on_eta = (similar(kfn.on_wOut) .= 0) |> device + kfn.on_eta = (similar(kfn.on_wOut) .= 0.001) |> device kfn.on_gammaPd = (similar(kfn.on_wOut) .= 0.3) |> device kfn.on_wOutChange = (similar(kfn.on_wOut) .= 0) |> device kfn.on_error = (similar(kfn.on_wOut) .= 0) |> device