module learn export learn!, compute_paramsChange! using Statistics, Random, LinearAlgebra, JSON3, Flux, CUDA, Dates using GeneralUtils using ..type, ..snnUtil #------------------------------------------------------------------------------------------------100 function compute_paramsChange!(kfn::kfn_1, modelError, outputError) modelError = reshape(modelError, (1,1,1,:)) # (1,1,1,batch) lifComputeParamsChange!(kfn.lif_phi, kfn.lif_epsilonRec, kfn.lif_eta, kfn.lif_eRec, kfn.lif_wRec, kfn.lif_wRecChange, kfn.on_wOut, kfn.lif_arrayProjection4d, kfn.lif_error, modelError, kfn.inputSize, ) alifComputeParamsChange!(kfn.alif_phi, kfn.alif_epsilonRec, kfn.alif_eta, kfn.alif_eRec, kfn.alif_wRec, kfn.alif_wRecChange, kfn.on_wOut, kfn.alif_arrayProjection4d, kfn.alif_error, modelError, kfn.alif_epsilonRecA, kfn.alif_beta, ) onComputeParamsChange!(kfn.on_phi, kfn.on_epsilonRec, kfn.on_eta, kfn.on_eRec, kfn.on_wOut, kfn.on_wOutChange, kfn.on_arrayProjection4d, kfn.on_error, outputError, ) # error("DEBUG -> kfn compute_paramsChange! $(Dates.now())") end function lifComputeParamsChange!( phi::CuArray, epsilonRec::CuArray, eta::CuArray, eRec::CuArray, wRec::CuArray, wRecChange::CuArray, wOut::CuArray, arrayProjection4d::CuArray, nError::CuArray, modelError::CuArray, inputSize::CuArray, ) # Bₖⱼ in paper, sum() to get each neuron's total wOut weight, # use absolute because only magnitude is needed wOutSum_all = reshape( abs.(sum(wOut, dims=3)), (1,1,:, size(wOut, 4)) ) # (1,1,allNeuron,batch) # get only each lif neuron's wOut, leaving out other neuron's wOut startIndex = prod(inputSize) +1 stopIndex = startIndex + size(wRec, 3) -1 wOutSum = @view(wOutSum_all[1,1, startIndex:stopIndex, :]) wOutSum = reshape(wOutSum, (1, 1, size(wOutSum, 1), size(wOutSum, 2))) # (1,1,n,batch) # nError a.k.a. learning signal use dopamine concept, # this neuron receive summed error signal (modelError) nError .= (modelError .* wOutSum) .* arrayProjection4d eRec .= phi .* epsilonRec wRecChange .+= (-eta .* nError .* eRec) # reset epsilonRec epsilonRec .= 0 end function alifComputeParamsChange!( phi::CuArray, epsilonRec::CuArray, eta::CuArray, eRec::CuArray, wRec::CuArray, wRecChange::CuArray, wOut::CuArray, arrayProjection4d::CuArray, nError::CuArray, modelError::CuArray, epsilonRecA::CuArray, beta::CuArray ) # Bₖⱼ in paper, sum() to get each neuron's total wOut weight, # use absolute because only magnitude is needed wOutSum_all = reshape( abs.(sum(wOut, dims=3)), (1,1,:, size(wOut, 4)) ) # (1,1,allNeuron,batch) # get only each lif neuron's wOut, leaving out other neuron's wOut wOutSum = @view(wOutSum_all[1,1, end-size(wRec, 3)+1:end, :]) wOutSum = reshape(wOutSum, (1, 1, size(wOutSum, 1), size(wOutSum, 2))) # (1,1,n,batch) # nError a.k.a. learning signal use dopamine concept, # this neuron receive summed error signal (modelError) nError .= (modelError .* wOutSum) .* arrayProjection4d eRec .= phi .* (epsilonRec .- (beta .* epsilonRecA)) # use eq. 25 wRecChange .+= (-eta .* nError .* eRec) # reset epsilonRec epsilonRec .= 0 epsilonRecA .= 0 # error("DEBUG -> alifComputeParamsChange! $(Dates.now())") end function onComputeParamsChange!(phi::CuArray, epsilonRec::CuArray, eta::CuArray, eRec::CuArray, wOut::CuArray, wOutChange::CuArray, arrayProjection4d::CuArray, nError::CuArray, outputError::CuArray # outputError is output neuron's error ) eRec .= phi .* epsilonRec nError .= reshape(outputError, (1, 1, :, size(outputError, 2))) .* arrayProjection4d wOutChange .+= (-eta .* nError .* eRec) # reset epsilonRec epsilonRec .= 0 # error("DEBUG -> onComputeParamsChange! $(Dates.now())") end function lifComputeParamsChange!( phi::AbstractArray, epsilonRec::AbstractArray, eta::AbstractArray, wRec::AbstractArray, wRecChange::AbstractArray, wOut::AbstractArray, modelError::AbstractArray) d1, d2, d3, d4 = size(epsilonRec) # Bₖⱼ in paper, sum() to get each neuron's total wOut weight wOutSum = reshape(sum(wOut, dims=3), (d1, :, d4)) for j in 1:d4, i in 1:d3 # compute along neurons axis of every batch # how much error of this neuron 1-spike causing each output neuron's error view(wRecChange, :, :, i, j) .+= (-1 * view(eta, :, :, i, j)[1]) .* # eRec ( (view(phi, :, :, i, j)[1] .* view(epsilonRec, :, :, i, j)) .* # nError a.k.a. learning signal ( view(modelError, :, j)[1] * # dopamine concept, this neuron receive summed error signal # RSNN neuron's total wOut weight (neuron synaptic subscription .* wOutSum) view(wOutSum, :, :, j)[i] ) ) end end function alifComputeParamsChange!( phi::AbstractArray, epsilonRec::AbstractArray, epsilonRecA::AbstractArray, eta::AbstractArray, wRec::AbstractArray, wRecChange::AbstractArray, beta::AbstractArray, wOut::AbstractArray, modelError::AbstractArray) d1, d2, d3, d4 = size(epsilonRec) # Bₖⱼ in paper, sum() to get each neuron's total wOut weight wOutSum = reshape(sum(wOut, dims=3), (d1, :, d4)) for j in 1:d4, i in 1:d3 # compute along neurons axis of every batch # how much error of this neuron 1-spike causing each output neuron's error view(wRecChange, :, :, i, j) .+= (-1 * view(eta, :, :, i, j)[1]) .* # eRec ( # eRec_v (view(phi, :, :, i, j)[1] .* view(epsilonRec, :, :, i, j)) .+ # eRec_a ((view(phi, :, :, i, j)[1] * view(beta, :, :, i, j)[1]) .* view(epsilonRecA, :, :, i, j)) ) .* # nError a.k.a. learning signal ( view(modelError, :, j)[1] * # RSNN neuron's total wOut weight (neuron synaptic subscription .* wOutSum) view(wOutSum, :, :, j)[i] # sum(GeneralUtils.isNotEqual.(view(wRec, :, :, i, j), 0) .* # view(wOutSum, :, :, j)) ) end end function onComputeParamsChange!(phi::AbstractArray, epsilonRec::AbstractArray, eta::AbstractArray, wOutChange::AbstractArray, outputError::AbstractArray) d1, d2, d3, d4 = size(epsilonRec) for j in 1:d4, i in 1:d3 # compute along neurons axis of every batch # how much error of this neuron 1-spike causing each output neuron's error view(wOutChange, :, :, i, j) .+= (-1 * view(eta, :, :, i, j)[1]) .* # eRec ( (view(phi, :, :, i, j)[1] .* view(epsilonRec, :, :, i, j)) .* # nError a.k.a. learning signal, output neuron receives error of its own answer - correct answer. view(outputError, :, j)[i] ) end end function learn!(kfn::kfn_1) # lif learn lifLearn!(kfn.lif_wRec, kfn.lif_wRecChange, kfn.lif_arrayProjection4d) # alif learn alifLearn!(kfn.alif_wRec, kfn.alif_wRecChange, kfn.alif_arrayProjection4d) # on learn onLearn!(kfn.on_wOut, kfn.on_wOutChange, kfn.on_arrayProjection4d) # wrap up learning session if kfn.learningStage == [3] kfn.learningStage = [0] end # error("DEBUG -> kfn learn! $(Dates.now())") end function lifLearn!(wRec, wRecChange, arrayProjection4d) # merge learning weight with average learning weight wRec .+= (sum(wRecChange, dims=4) ./ (size(wRec, 4))) .* arrayProjection4d #TODO synaptic strength #TODO neuroplasticity # error("DEBUG -> lifLearn! $(Dates.now())") end function alifLearn!(wRec, wRecChange, arrayProjection4d) # merge learning weight with average learning weight wRec .+= (sum(wRecChange) ./ (size(wRec, 4))) .* arrayProjection4d #TODO synaptic strength #TODO neuroplasticity end function onLearn!(wOut, wOutChange, arrayProjection4d) # merge learning weight with average learning weight wOut .+= (sum(wOutChange) ./ (size(wOut, 4))) .* arrayProjection4d # adaptive wOut to help convergence using c_decay wOut .-= 0.001 .* wOut #TODO synaptic strength #TODO neuroplasticity end #TODO voltage regulator #TODO frequency regulator end # module