module learn export learn!, compute_paramsChange! using Statistics, Random, LinearAlgebra, JSON3, Flux, CUDA, Dates using GeneralUtils using ..type, ..snnUtil #------------------------------------------------------------------------------------------------100 function compute_paramsChange!(kfn::kfn_1, modelError, outputError) lifComputeParamsChange!(kfn.lif_phi, kfn.lif_epsilonRec, kfn.lif_eta, kfn.lif_eRec, kfn.lif_wRec, kfn.lif_wRecChange, kfn.on_wOut, kfn.lif_arrayProjection4d, kfn.lif_error, modelError) alifComputeParamsChange!(kfn.alif_phi, kfn.alif_epsilonRec, kfn.alif_eta, kfn.alif_eRec, kfn.alif_wRec, kfn.alif_wRecChange, kfn.on_wOut, kfn.alif_arrayProjection4d, kfn.alif_error, modelError, kfn.alif_beta) onComputeParamsChange!(kfn.on_phi, kfn.on_epsilonRec, kfn.on_eta, kfn.on_eRec, kfn.on_wOut, kfn.on_wOutChange, outputError) # error("DEBUG -> kfn compute_paramsChange! $(Dates.now())") end function lifComputeParamsChange!( phi::CuArray, epsilonRec::CuArray, eta::CuArray, eRec::CuArray, wRec::CuArray, wRecChange::CuArray, wOut::CuArray, arrayProjection4d::CuArray, nError::CuArray, modelError::CuArray) wOutSum = sum(wOut, dims=3) .* arrayProjection4d # nError a.k.a. learning signal use dopamine concept, # this neuron receive summed error signal (modelError) nError .= (modelError .* arrayProjection4d) .* wOutSum eRec .= phi .* epsilonRec # GeneralUtils.isNotEqual(wRec, 0) is a subscribe filter use to filter out non-subscribed wRecChange wRecChange .+= ((-1 .* eta) .* nError .* eRec) .* GeneralUtils.isNotEqual.(wRec, 0) # error("DEBUG -> lifComputeParamsChange! $(Dates.now())") end function alifComputeParamsChange!( phi::CuArray, epsilonRec::CuArray, eta::CuArray, eRec::CuArray, wRec::CuArray, wRecChange::CuArray, wOut::CuArray, arrayProjection4d::CuArray, nError::CuArray, modelError::CuArray, beta::CuArray) wOutSum = sum(wOut, dims=3) .* arrayProjection4d # nError a.k.a. learning signal use dopamine concept, # this neuron receive summed error signal (modelError) nError .= (modelError .* arrayProjection4d) .* wOutSum eRec .= (phi .* epsilonRec) .+ (phi .* epsilonRec .* beta) # GeneralUtils.isNotEqual(wRec, 0) is a subscribe filter use to filter out non-subscribed wRecChange wRecChange .+= ((-1 .* eta) .* nError .* eRec) .* GeneralUtils.isNotEqual.(wRec, 0) # error("DEBUG -> alifComputeParamsChange! $(Dates.now())") end function onComputeParamsChange!(phi::CuArray, epsilonRec::CuArray, eta::CuArray, eRec::CuArray, wOut::CuArray, wOutChange::CuArray, outputError::CuArray # outputError is output neuron's error ) # nError a.k.a. learning signal use dopamine concept, # this neuron receive summed error signal (modelError) eRec .= (phi .* epsilonRec) .* reshape(outputError, (1, 1, :, size(epsilonRec, 4))) # GeneralUtils.isNotEqual(wRec, 0) is a subscribe filter use to filter out non-subscribed wRecChange wOutChange .+= ((-1 .* eta) .* eRec) .* GeneralUtils.isNotEqual.(wOut, 0) # error("DEBUG -> onComputeParamsChange! $(Dates.now())") end function lifComputeParamsChange!( phi::AbstractArray, epsilonRec::AbstractArray, eta::AbstractArray, wRec::AbstractArray, wRecChange::AbstractArray, wOut::AbstractArray, modelError::AbstractArray) d1, d2, d3, d4 = size(epsilonRec) error("DEBUG -> lifComputeParamsChange! $(Dates.now())") # Bₖⱼ in paper, sum() to get each neuron's total wOut weight wOutSum = reshape(sum(wOut, dims=3), (d1, :, d4)) for j in 1:d4, i in 1:d3 # compute along neurons axis of every batch # how much error of this neuron 1-spike causing each output neuron's error view(wRecChange, :, :, i, j) .+= (-1 * view(eta, :, :, i, j)[1]) .* # eRec ( (view(phi, :, :, i, j)[1] .* view(epsilonRec, :, :, i, j)) .* # nError a.k.a. learning signal ( view(modelError, :, j)[1] * # dopamine concept, this neuron receive summed error signal # RSNN neuron's total wOut weight (neuron synaptic subscription .* wOutSum) view(wOutSum, :, :, j)[i] ) ) end end function alifComputeParamsChange!( phi::AbstractArray, epsilonRec::AbstractArray, epsilonRecA::AbstractArray, eta::AbstractArray, wRec::AbstractArray, wRecChange::AbstractArray, beta::AbstractArray, wOut::AbstractArray, modelError::AbstractArray) d1, d2, d3, d4 = size(epsilonRec) # Bₖⱼ in paper, sum() to get each neuron's total wOut weight wOutSum = reshape(sum(wOut, dims=3), (d1, :, d4)) for j in 1:d4, i in 1:d3 # compute along neurons axis of every batch # how much error of this neuron 1-spike causing each output neuron's error view(wRecChange, :, :, i, j) .+= (-1 * view(eta, :, :, i, j)[1]) .* # eRec ( # eRec_v (view(phi, :, :, i, j)[1] .* view(epsilonRec, :, :, i, j)) .+ # eRec_a ((view(phi, :, :, i, j)[1] * view(beta, :, :, i, j)[1]) .* view(epsilonRecA, :, :, i, j)) ) .* # nError a.k.a. learning signal ( view(modelError, :, j)[1] * # RSNN neuron's total wOut weight (neuron synaptic subscription .* wOutSum) view(wOutSum, :, :, j)[i] # sum(GeneralUtils.isNotEqual.(view(wRec, :, :, i, j), 0) .* # view(wOutSum, :, :, j)) ) end end function onComputeParamsChange!(phi::AbstractArray, epsilonRec::AbstractArray, eta::AbstractArray, wOutChange::AbstractArray, outputError::AbstractArray) d1, d2, d3, d4 = size(epsilonRec) for j in 1:d4, i in 1:d3 # compute along neurons axis of every batch # how much error of this neuron 1-spike causing each output neuron's error view(wOutChange, :, :, i, j) .+= (-1 * view(eta, :, :, i, j)[1]) .* # eRec ( (view(phi, :, :, i, j)[1] .* view(epsilonRec, :, :, i, j)) .* # nError a.k.a. learning signal, output neuron receives error of its own answer - correct answer. view(outputError, :, j)[i] ) end end function learn!(kfn::kfn_1) # lif learn lifLearn!(kfn.lif_wRec, kfn.lif_wRecChange, kfn.lif_arrayProjection4d) # alif learn alifLearn!(kfn.alif_wRec, kfn.alif_wRecChange, kfn.alif_arrayProjection4d) # on learn onLearn!(kfn.on_wOut, kfn.on_wOutChange, kfn.on_arrayProjection4d) # wOut decay kfn.on_wOut .*= 0.0001 # wrap up learning session if kfn.learningStage == [3] kfn.learningStage = [0] end # error("DEBUG -> kfn learn! $(Dates.now())") end function lifLearn!(wRec, wRecChange, arrayProjection4d) # merge learning weight with average learning weight wRec .+= (sum(wRecChange) ./ (size(wRec, 4))) .* arrayProjection4d #TODO synaptic strength #TODO neuroplasticity end function alifLearn!(wRec, wRecChange, arrayProjection4d) # merge learning weight wRec .+= (sum(wRecChange) ./ (size(wRec, 4))) .* arrayProjection4d #TODO synaptic strength #TODO neuroplasticity end function onLearn!(wOut, wOutChange, arrayProjection4d) # merge learning weight wOut .+= (sum(wOutChange) ./ (size(wOut, 4))) .* arrayProjection4d #TODO synaptic strength #TODO neuroplasticity end #TODO voltage regulator #TODO frequency regulator end # module