From 0da983f4936f460dce395feb3ab1f2b9a9f1d2a1 Mon Sep 17 00:00:00 2001 From: ton Date: Thu, 27 Jul 2023 10:00:20 +0700 Subject: [PATCH] learn() --- src/forward.jl | 52 +++++++++++++++++ src/learn.jl | 153 ++++++++++++++++++++++++++++++++----------------- src/type.jl | 7 ++- 3 files changed, 159 insertions(+), 53 deletions(-) diff --git a/src/forward.jl b/src/forward.jl index 28c13d1..da2190c 100644 --- a/src/forward.jl +++ b/src/forward.jl @@ -16,6 +16,7 @@ function (kfn::kfn_1)(input::AbstractArray) #TODO time step forward if kfn.learningStage == [1] # reset learning params + kfn.learningStage = [2] end d1, d2, d3 = size(input) @@ -271,6 +272,57 @@ function onForward(kfn_zit, end end +# function onForward(kfn_zit, +# zit, +# wOut, +# vt0, +# vt1, +# vth, +# vRest, +# zt1, +# alpha, +# phi, +# epsilonRec, +# refractoryCounter, +# refractoryDuration, +# gammaPd, +# firingCounter) +# d1, d2, d3, d4 = size(wOut) +# zit .= reshape(kfn_zit, (d1, d2, 1, d4)) .* ones(size(wOut)...) # project zit into zit + +# for j in 1:d4, i in 1:d3 # compute along neurons axis of every batch +# if view(refractoryCounter, :, :, i, j)[1] > 0 # neuron is inactive (in refractory period) +# view(refractoryCounter, :, :, i, j)[1] -= 1 +# view(zt1, :, :, i, j)[1] = 0 +# view(vt1, :, :, i, j)[1] = +# view(alpha, :, :, i, j)[1] * view(vt0, :, :, i, j)[1] +# view(phi, :, :, i, j)[1] = 0.0 +# view(epsilonRec, :, :, i, j) .= view(alpha, :, :, i, j)[1] .* +# view(epsilonRec, :, :, i, j) +# else # neuron is active +# view(vt1, :, :, i, j)[1] = +# (view(alpha, :, :, i, j)[1] * view(vt0,:, :, i, j)[1]) + +# sum(view(zit, :, :, i, j) .* view(wOut, :, :, i, j)) +# if view(vt1, :, :, i, j)[1] > view(vth, :, :, i, j)[1] +# view(zt1, :, :, i, j)[1] = 1 +# view(refractoryCounter, :, :, i, j)[1] = +# view(refractoryDuration, :, :, i, j)[1] +# view(firingCounter, :, :, i, j)[1] += 1 +# view(vt1, :, :, i, j)[1] = view(vRest, :, :, i, j)[1] +# else +# view(zt1, :, :, i, j)[1] = 0 +# end +# # there is a difference from alif formula +# view(phi, :, :, i, j)[1] = +# (view(gammaPd, :, :, i, j)[1] / view(vth, :, :, i, j)[1]) * +# max(0, 1 - ((view(vt1, :, :, i, j)[1] - view(vth, :, :, i, j)[1]) / +# view(vth, :, :, i, j)[1])) +# view(epsilonRec, :, :, i, j) .= +# (view(alpha, :, :, i, j)[1] .* view(epsilonRec, :, :, i, j)) + +# view(zit, :, :, i, j) +# end +# end +# end diff --git a/src/learn.jl b/src/learn.jl index aa6f9f9..451eab6 100644 --- a/src/learn.jl +++ b/src/learn.jl @@ -9,7 +9,7 @@ using ..type, ..snnUtil #------------------------------------------------------------------------------------------------100 function compute_paramsChange!(kfn::kfn_1, modelError, outputError) - #WORKING + lifComputeParamsChange!(kfn.lif_phi, kfn.lif_epsilonRec, @@ -29,36 +29,14 @@ function compute_paramsChange!(kfn::kfn_1, modelError, outputError) kfn.on_wOut, modelError) - - + onComputeParamsChange!(kfn.on_phi, + kfn.on_epsilonRec, + kfn.on_eta, + kfn.on_wOutChange, + outputError) error("debug end -> kfn compute_paramsChange! $(Dates.now())") - - # Threads.@threads for n in kfn.neuronsArray - # # for n in kfn.neuronsArray - # if typeof(n) <: computeNeuron - # wOut = Int64[] - # for oN in kfn.outputNeuronsArray - # wIndex = findall(isequal.(oN.subscriptionList, n.id)) - # if length(wIndex) != 0 - # push!(wOut, wIndex[1]) - # end - # end - - # if length(wOut) != 0 - # compute_wRecChange!(n, wOut, modelError) - # # compute_alphaChange!(n, modelError) - # compute_firingRateError!(n, kfn.kfnParams[:neuronFiringRateTarget], - # kfn.kfnParams[:totalComputeNeuron]) - # end - # end - # end - - # for oN in kfn.outputNeuronsArray - # compute_wRecChange!(oN, outputError[oN.id]) - # # compute_alphaChaZnge!(oN, outputError[oN.id]) - # end end function lifComputeParamsChange!( phi, @@ -77,15 +55,16 @@ function lifComputeParamsChange!( phi, # how much error of this neuron 1-spike causing each output neuron's error view(wRecChange, :, :, i, j) .+= (-1 * view(eta, :, :, i, j)[1]) .* - # eRec - ( (view(phi, :, :, i, j)[1] .* view(epsilonRec, :, :, i, j)) .* - # nError a.k.a. learning signal - (view(modelError, :, j)[1] .* - # RSNN neuron's total wOut weight (neuron synaptic subscription .* wOutSum) - sum(GeneralUtils.isNotEqual.(view(wRec, :, :, i, j), 0) .* - view(wOutSum, :, :, j)) - ) - ) + # eRec + ( + (view(phi, :, :, i, j)[1] .* view(epsilonRec, :, :, i, j)) .* + # nError a.k.a. learning signal + ( + view(modelError, :, j)[1] * # dopamine concept, this neuron receive summed error signal + # RSNN neuron's total wOut weight (neuron synaptic subscription .* wOutSum) + view(wOutSum, :, :, j)[i] + ) + ) end end @@ -108,30 +87,100 @@ function alifComputeParamsChange!( phi, # how much error of this neuron 1-spike causing each output neuron's error view(wRecChange, :, :, i, j) .+= (-1 * view(eta, :, :, i, j)[1]) .* - # eRec - ( - # eRec_v - (view(phi, :, :, i, j)[1] .* view(epsilonRec, :, :, i, j)) .+ - # eRec_a - ((view(phi, :, :, i, j)[1] * view(beta, :, :, i, j)[1]) .* - view(epsilonRecA, :, :, i, j)) - ) .* - # nError a.k.a. learning signal - ( - view(modelError, :, j)[1] .* - # RSNN neuron's total wOut weight (neuron synaptic subscription .* wOutSum) - sum(GeneralUtils.isNotEqual.(view(wRec, :, :, i, j), 0) .* - view(wOutSum, :, :, j)) - ) + # eRec + ( + # eRec_v + (view(phi, :, :, i, j)[1] .* view(epsilonRec, :, :, i, j)) .+ + # eRec_a + ((view(phi, :, :, i, j)[1] * view(beta, :, :, i, j)[1]) .* + view(epsilonRecA, :, :, i, j)) + ) .* + # nError a.k.a. learning signal + ( + view(modelError, :, j)[1] * + # RSNN neuron's total wOut weight (neuron synaptic subscription .* wOutSum) + view(wOutSum, :, :, j)[i] + # sum(GeneralUtils.isNotEqual.(view(wRec, :, :, i, j), 0) .* + # view(wOutSum, :, :, j)) + ) end end +function onComputeParamsChange!(phi, + epsilonRec, + eta, + wOutChange, + outputError) + d1, d2, d3, d4 = size(epsilonRec) + + for j in 1:d4, i in 1:d3 # compute along neurons axis of every batch + # how much error of this neuron 1-spike causing each output neuron's error + + view(wOutChange, :, :, i, j) .+= (-1 * view(eta, :, :, i, j)[1]) .* + # eRec + ( + (view(phi, :, :, i, j)[1] .* view(epsilonRec, :, :, i, j)) .* + # nError a.k.a. learning signal, output neuron receives error of its own answer - correct answer. + view(outputError, :, j)[i] + ) + end +end + +# function onComputeParamsChange!(wOut, +# epsilonRec, +# eta, +# wOutChange, +# bChange, +# outputError) +# d1, d2, d3, d4 = size(epsilonRec) +# println(">>> epsilon ", size(epsilonRec)) +# println(">>> outputError ", size(outputError)) + + +# # Bₖⱼ in paper, sum() to get each neuron's total wOut weight + +# for j in 1:d4, i in 1:d3 # compute along neurons axis of every batch +# # how much error of this neuron 1-spike causing each output neuron's error + +# view(wOutChange, :, :, i, j) .+= +# (-1 * view(eta, :, :, i, j)[1] * view(outputError, :, j)[i]) .* +# view(epsilonRec, :, :, i, j) +# end +# #TODO add b +# error(">>> DEBUG -> onComputeParamsChange!") +# end +function learn!(kfn::kfn_1) + #WORKING lif learn + lifLearn!(kfn.lif_wRec, + kfn.lif_wRecChange) + + + #TODO alif learn + #TODO on learn + #TODO wOut decay + + # wrap up learning session + if kfn.learningStage == [3] + kfn.learningStage = [0] + end +end + +function lifLearn!(wRec, + wRecChange) + # merge learning weight + wRec .+= wRecChange + + #TODO synaptic strength + + #TODO neuroplasticity + +end diff --git a/src/type.jl b/src/type.jl index 622dc49..3133489 100644 --- a/src/type.jl +++ b/src/type.jl @@ -106,7 +106,10 @@ Base.@kwdef mutable struct kfn_1 <: knowledgeFn on_eRec::Union{AbstractArray, Nothing} = nothing on_eta::Union{AbstractArray, Nothing} = nothing on_gammaPd::Union{AbstractArray, Nothing} = nothing + on_wOutChange::Union{AbstractArray, Nothing} = nothing + on_b::Union{AbstractArray, Nothing} = nothing + on_bChange::Union{AbstractArray, Nothing} = nothing on_firingCounter::Union{AbstractArray, Nothing} = nothing end @@ -219,7 +222,7 @@ function kfn_1(params::Dict) kfn.on_zt0 = zeros(1, 1, n, batch) kfn.on_zt1 = zeros(1, 1, n, batch) kfn.on_refractoryCounter = zeros(1, 1, n, batch) - kfn.on_refractoryDuration = ones(1, 1, n, batch) .* 1 + kfn.on_refractoryDuration = ones(1, 1, n, batch) .* 0 kfn.on_alpha = ones(1, 1, n, batch) .* (exp(-kfn.on_delta / kfn.on_tau_m)) kfn.on_phi = zeros(1, 1, n, batch) kfn.on_epsilonRec = zeros(row, col, n, batch) @@ -227,6 +230,8 @@ function kfn_1(params::Dict) kfn.on_eta = zeros(1, 1, n, batch) kfn.on_gammaPd = zeros(1, 1, n, batch) .* 0.3 kfn.on_wOutChange = zeros(row, col, n, batch) + kfn.on_b = randn(1, 1, n, batch) + kfn.on_bChange = randn(1, 1, n, batch) # subscription w = zeros(row, col, n)