module learn export learn!, compute_paramsChange! using Statistics, Random, LinearAlgebra, JSON3, Flux, CUDA, Dates using GeneralUtils using ..type, ..snnUtil #------------------------------------------------------------------------------------------------100 function compute_paramsChange!(kfn::kfn_1, modelError::CuArray, outputError::CuArray, label) lifComputeParamsChange!(kfn.timeStep, kfn.lif_phi, kfn.lif_epsilonRec, kfn.lif_eta, kfn.lif_eRec, kfn.lif_wRec, kfn.lif_exInType, kfn.lif_wRecChange, kfn.on_wOut, kfn.lif_firingCounter, kfn.lif_firingTargetFrequency, kfn.lif_arrayProjection4d, kfn.lif_error, modelError, outputError, kfn.inputSize, kfn.bk, label, ) alifComputeParamsChange!(kfn.timeStep, kfn.alif_phi, kfn.alif_epsilonRec, kfn.alif_eta, kfn.alif_eRec, kfn.alif_wRec, kfn.alif_exInType, kfn.alif_wRecChange, kfn.on_wOut, kfn.alif_firingCounter, kfn.alif_firingTargetFrequency, kfn.alif_arrayProjection4d, kfn.alif_error, modelError, outputError, kfn.inputSize, kfn.bk, label, kfn.alif_epsilonRecA, kfn.alif_beta, ) onComputeParamsChange!(kfn.on_phi, kfn.on_epsilonRec, kfn.on_eta, kfn.on_eRec, kfn.on_wOutChange, kfn.on_arrayProjection4d, kfn.on_error, kfn.on_synapticActivityCounter, outputError, ) # error("DEBUG -> kfn compute_paramsChange! $(Dates.now())") end function lifComputeParamsChange!( timeStep::CuArray, phi::CuArray, epsilonRec::CuArray, eta::CuArray, eRec::CuArray, wRec::CuArray, exInType::CuArray, wRecChange::CuArray, wOut::CuArray, firingCounter::CuArray, firingTargetFrequency::CuArray, arrayProjection4d::CuArray, nError::CuArray, modelError::CuArray, outputError::CuArray, inputSize::CuArray, bk::CuArray, label, ) eRec .= phi .* epsilonRec # 2D wRec matrix contain input, lif, alif neurons. I need only lif neurons startIndex = prod(inputSize) +1 stopIndex = startIndex + size(wRec, 3) -1 startCol = CartesianIndices(wRec)[startIndex][2] stopCol = CartesianIndices(wRec)[stopIndex][2] # some RSNN neuron that has direct connection to output neuron need to get Bjk # from output neuron that represent correct answer, the rest of RSNN get random Bjk onW = @view(wOut[:, startCol:stopCol, sum(label+1), 1]) # label+1 because julia is 1-based index _bk = @view(bk[:, startCol:stopCol, 1]) mask = iszero.(onW) bk_ = mask .* _bk bkComposed = onW .+ bk_ nError = bkComposed .* modelError nError = reshape(nError, (1,1,:,1)) # compute wRecChange of all neurons wrt to iᵗʰ output neuron wRecChange .+= (eta .* nError .* eRec) # frequency regulator targetFiringCount = firingTargetFrequency .* timeStep freqError = (firingCounter .- targetFiringCount) ./ timeStep freqWRecChange = -1 .* freqError .* eta .* eRec wRecChange .+= freqWRecChange # reset epsilonRec epsilonRec .= 0 end function alifComputeParamsChange!( timeStep::CuArray, phi::CuArray, epsilonRec::CuArray, eta::CuArray, eRec::CuArray, wRec::CuArray, exInType::CuArray, wRecChange::CuArray, wOut::CuArray, firingCounter::CuArray, firingTargetFrequency::CuArray, arrayProjection4d::CuArray, nError::CuArray, modelError::CuArray, outputError::CuArray, inputSize::CuArray, bk::CuArray, label, epsilonRecA::CuArray, beta::CuArray, ) eRec .= phi .* (epsilonRec .- (beta .* epsilonRecA)) # use eq. 25 # 2D wRec matrix contain input, lif, alif neurons. I need only lif neurons startIndex = prod(inputSize) +1 stopIndex = startIndex + size(wRec, 3) -1 startCol = CartesianIndices(wRec)[startIndex][2] stopCol = CartesianIndices(wRec)[stopIndex][2] # some RSNN neuron that has direct connection to output neuron need to get Bjk # from output neuron that represent correct answer, the rest of RSNN get random Bjk onW = @view(wOut[:, startCol:stopCol, sum(label+1), 1]) # label+1 because julia is 1-based index _bk = @view(bk[:, startCol:stopCol, 1]) mask = iszero.(onW) bk_ = mask .* _bk bkComposed = onW .+ bk_ nError = bkComposed .* modelError nError = reshape(nError, (1,1,:,1)) wRecChange .+= (eta .* nError .* eRec) # frequency regulator targetFiringCount = firingTargetFrequency .* timeStep freqError = (firingCounter .- targetFiringCount) ./ timeStep freqWRecChange = -1 .* freqError .* eta .* eRec wRecChange .+= freqWRecChange # wRecChange .+= 0.01 .* ((firingTargetFrequency - (firingCounter./timeStep)) ./ timeStep) .* # eta .* eRec # reset epsilonRec epsilonRec .= 0 epsilonRecA .= 0 # error("DEBUG -> alifComputeParamsChange! $(Dates.now())") end function onComputeParamsChange!(phi::CuArray, epsilonRec::CuArray, eta::CuArray, eRec::CuArray, wOutChange::CuArray, arrayProjection4d::CuArray, nError::CuArray, synapticActivityCounter, outputError::CuArray # outputError is output neuron's error ) eRec .= phi .* epsilonRec nError .= reshape(outputError, (1, 1, :, size(outputError, 2))) .* arrayProjection4d wOutChange .+= (eta .* nError .* eRec) # reset epsilonRec epsilonRec .= 0 # error("DEBUG -> onComputeParamsChange! $(Dates.now())") end function lifComputeParamsChange!( phi::AbstractArray, epsilonRec::AbstractArray, eta::AbstractArray, wRec::AbstractArray, wRecChange::AbstractArray, wOut::AbstractArray, modelError::AbstractArray) d1, d2, d3, d4 = size(epsilonRec) # Bₖⱼ in paper, sum() to get each neuron's total wOut weight wOutSum = reshape(sum(wOut, dims=3), (d1, :, d4)) for j in 1:d4, i in 1:d3 # compute along neurons axis of every batch # how much error of this neuron 1-spike causing each output neuron's error view(wRecChange, :, :, i, j) .+= (-1 * view(eta, :, :, i, j)[1]) .* # eRec ( (view(phi, :, :, i, j)[1] .* view(epsilonRec, :, :, i, j)) .* # nError a.k.a. learning signal ( view(modelError, :, j)[1] * # dopamine concept, this neuron receive summed error signal # RSNN neuron's total wOut weight (neuron synaptic subscription .* wOutSum) view(wOutSum, :, :, j)[i] ) ) end end function alifComputeParamsChange!( phi::AbstractArray, epsilonRec::AbstractArray, epsilonRecA::AbstractArray, eta::AbstractArray, wRec::AbstractArray, wRecChange::AbstractArray, beta::AbstractArray, wOut::AbstractArray, modelError::AbstractArray) d1, d2, d3, d4 = size(epsilonRec) # Bₖⱼ in paper, sum() to get each neuron's total wOut weight wOutSum = reshape(sum(wOut, dims=3), (d1, :, d4)) for j in 1:d4, i in 1:d3 # compute along neurons axis of every batch # how much error of this neuron 1-spike causing each output neuron's error view(wRecChange, :, :, i, j) .+= (-1 * view(eta, :, :, i, j)[1]) .* # eRec ( # eRec_v (view(phi, :, :, i, j)[1] .* view(epsilonRec, :, :, i, j)) .+ # eRec_a ((view(phi, :, :, i, j)[1] * view(beta, :, :, i, j)[1]) .* view(epsilonRecA, :, :, i, j)) ) .* # nError a.k.a. learning signal ( view(modelError, :, j)[1] * # RSNN neuron's total wOut weight (neuron synaptic subscription .* wOutSum) view(wOutSum, :, :, j)[i] # sum(GeneralUtils.isNotEqual.(view(wRec, :, :, i, j), 0) .* # view(wOutSum, :, :, j)) ) end end function onComputeParamsChange!(phi::AbstractArray, epsilonRec::AbstractArray, eta::AbstractArray, wOutChange::AbstractArray, outputError::AbstractArray) d1, d2, d3, d4 = size(epsilonRec) for j in 1:d4, i in 1:d3 # compute along neurons axis of every batch # how much error of this neuron 1-spike causing each output neuron's error view(wOutChange, :, :, i, j) .+= (-1 * view(eta, :, :, i, j)[1]) .* # eRec ( (view(phi, :, :, i, j)[1] .* view(epsilonRec, :, :, i, j)) .* # nError a.k.a. learning signal, output neuron receives error of its own answer - correct answer. view(outputError, :, j)[i] ) end end function learn!(kfn::kfn_1, progress, device=cpu) if sum(kfn.timeStep) == 800 println("zitCumulative ", sum(kfn.zitCumulative[:,:,784:size(kfn.zitCumulative, 3)], dims=3)) println("on_synapticActivityCounter 0 ", kfn.on_synapticActivityCounter[:,:,1]) println("on_synapticActivityCounter 5 ", kfn.on_synapticActivityCounter[:,:,6]) println("wOut 0 ", sum(kfn.on_wOut[:,:,1,1], dims=3)) println("wOut 5 ", sum(kfn.on_wOut[:,:,1,1], dims=3)) end #WORKING compare output neuron 0 synapse activity when input are label 0 and 5, (!isequal).(wOut) # lif learn kfn.lif_wRec, kfn.lif_neuronInactivityCounter, kfn.lif_synapticActivityCounter, kfn.lif_synapseReconnectDelay = lifLearn(kfn.lif_wRec, kfn.lif_wRecChange, kfn.lif_exInType, kfn.lif_arrayProjection4d, kfn.lif_neuronInactivityCounter, kfn.lif_synapseReconnectDelay, kfn.lif_synapseConnectionNumber, kfn.lif_synapticActivityCounter, kfn.lif_eta, kfn.lif_vt, kfn.zitCumulative, progress, device) # alif learn kfn.alif_wRec, kfn.alif_neuronInactivityCounter, kfn.alif_synapticActivityCounter, kfn.alif_synapseReconnectDelay = alifLearn(kfn.alif_wRec, kfn.alif_wRecChange, kfn.alif_exInType, kfn.alif_arrayProjection4d, kfn.alif_neuronInactivityCounter, kfn.alif_synapseReconnectDelay, kfn.alif_synapseConnectionNumber, kfn.alif_synapticActivityCounter, kfn.alif_eta, kfn.alif_vt, kfn.zitCumulative, progress, device) # on learn onLearn!(kfn.on_wOut, kfn.on_wOutChange, kfn.on_eta, kfn.on_arrayProjection4d, progress,) # wrap up learning session if kfn.learningStage == [3] kfn.learningStage = [0] end # error("DEBUG -> kfn learn! $(Dates.now())") end function lifLearn(wRec, wRecChange, exInType, arrayProjection4d, neuronInactivityCounter, synapseReconnectDelay, synapseConnectionNumber, synapticActivityCounter, eta, vt, zitCumulative, progress, device) # transfer data to cpu arrayProjection4d_cpu = arrayProjection4d |> cpu wRec_cpu = wRec |> cpu wRecChange_cpu = wRecChange |> cpu eta_cpu = eta |> cpu exInType_cpu = exInType |> cpu neuronInactivityCounter_cpu = neuronInactivityCounter |> cpu synapseReconnectDelay_cpu = synapseReconnectDelay |> cpu synapticActivityCounter_cpu = synapticActivityCounter |> cpu zitCumulative_cpu = zitCumulative |> cpu # neuroplasticity, work on CPU side wRec_cpu, neuronInactivityCounter_cpu, synapticActivityCounter_cpu, synapseReconnectDelay_cpu = neuroplasticity(synapseConnectionNumber, zitCumulative_cpu, wRec_cpu, exInType_cpu, wRecChange_cpu, vt, eta_cpu, neuronInactivityCounter_cpu, synapseReconnectDelay_cpu, synapticActivityCounter_cpu, progress,) # transfer data backto gpu wRec = wRec_cpu |> device neuronInactivityCounter = neuronInactivityCounter_cpu |> device synapticActivityCounter = synapticActivityCounter_cpu |> device synapseReconnectDelay = synapseReconnectDelay_cpu |> device # error("DEBUG -> lifLearn! $(Dates.now())") return wRec, neuronInactivityCounter, synapticActivityCounter, synapseReconnectDelay end function alifLearn(wRec, wRecChange, exInType, arrayProjection4d, neuronInactivityCounter, synapseReconnectDelay, synapseConnectionNumber, synapticActivityCounter, eta, vt, zitCumulative, progress, device) # transfer data to cpu arrayProjection4d_cpu = arrayProjection4d |> cpu wRec_cpu = wRec |> cpu wRecChange_cpu = wRecChange |> cpu eta_cpu = eta |> cpu exInType_cpu = exInType |> cpu neuronInactivityCounter_cpu = neuronInactivityCounter |> cpu synapseReconnectDelay_cpu = synapseReconnectDelay |> cpu synapticActivityCounter_cpu = synapticActivityCounter |> cpu zitCumulative_cpu = zitCumulative |> cpu # neuroplasticity, work on CPU side wRec_cpu, neuronInactivityCounter_cpu, synapticActivityCounter_cpu, synapseReconnectDelay_cpu = neuroplasticity(synapseConnectionNumber, zitCumulative_cpu, wRec_cpu, exInType_cpu, wRecChange_cpu, vt, eta_cpu, neuronInactivityCounter_cpu, synapseReconnectDelay_cpu, synapticActivityCounter_cpu, progress,) # transfer data backto gpu wRec = wRec_cpu |> device neuronInactivityCounter = neuronInactivityCounter_cpu |> device synapticActivityCounter = synapticActivityCounter_cpu |> device synapseReconnectDelay = synapseReconnectDelay_cpu |> device # error("DEBUG -> alifLearn! $(Dates.now())") return wRec, neuronInactivityCounter, synapticActivityCounter, synapseReconnectDelay end # function onLearn!(wOut, # wOutChange, # arrayProjection4d) # # merge learning weight with average learning weight # wOut .+= (sum(wOutChange, dims=4) ./ (size(wOut, 4))) .* arrayProjection4d # # adaptive wOut to help convergence using c_decay # wOut .-= 0.001 .* wOut # end function onLearn!(wOut, wOutChange, eta, arrayProjection4d, progress,) if progress != 0 # merge learning weight with average learning weight wOut .+= (sum(wOutChange, dims=4) ./ (size(wOut, 4))) .* arrayProjection4d # adaptive wOut to help convergence using c_decay wOut .-= 0.1 .* eta .* wOut # wOut .-= 0.001 .* wOut else #TESTING skip wOutChange .= 0 end end function neuroplasticity(synapseConnectionNumber, zitCumulative, # (row, col) wRec, # (row, col, n) exInType, wRecChange, vt, eta, neuronInactivityCounter, synapseReconnectDelay, synapticActivityCounter, progress,) # (row, col, n) if progress == 2 # no need to learn for current neural pathway # skip neuroplasticity #TODO I may need to do something with neuronInactivityCounter and other variables wRecChange .= 0 # error("DEBUG -> neuroplasticity") elseif progress != 0 # progress increase # ready to reconnect synapse must not have wRecChange mask = (!isequal).(wRec, 0) wRecChange .*= mask # merge learning weight, all resulting negative wRec will get pruned mergeLearnWeight!(wRec, exInType, wRecChange, synapticActivityCounter, synapseReconnectDelay) # adjust wRec based on repeatition (90% +w, 10% -w) growRepeatedPath!(wRec, synapticActivityCounter, eta) # -w all non-fire connection except mature connection weakenNotMatureSynapse!(wRec, synapticActivityCounter, eta) # prune weak synapse pruneSynapse!(wRec, synapticActivityCounter, synapseReconnectDelay) # rewire synapse connection rewireSynapse!(wRec, neuronInactivityCounter, synapticActivityCounter, synapseReconnectDelay, synapseConnectionNumber, zitCumulative) # error("DEBUG -> neuroplasticity 1") elseif progress == 0 # no progress, no weight update, only rewire # #TESTING -w all non-fire connection except mature connection # weakenNotMatureSynapse!(wRec, synapticActivityCounter, eta) # prune weak synapse pruneSynapse!(wRec, synapticActivityCounter, synapseReconnectDelay) # rewire synapse connection rewireSynapse!(wRec, neuronInactivityCounter, synapticActivityCounter, synapseReconnectDelay, synapseConnectionNumber, zitCumulative) # error("DEBUG -> neuroplasticity") else error("undefined condition line $(@__LINE__)") end # error("DEBUG -> neuroplasticity $(Dates.now())") return wRec, neuronInactivityCounter, synapticActivityCounter, synapseReconnectDelay end # learningLiquidity(x) = -0.0001x + 1 # -10000 to +10000; f(x) = -5e-05x+0.5 function learningLiquidity(x) if x > 10000 y = 0.0 elseif x < -10000 y = 1.0 else y = -5e-05x+0.5 # range -10000 to +10000 end return y end end # module