Files
IronpenGPU/src/learn.jl
2023-08-10 13:32:32 +07:00

347 lines
11 KiB
Julia

module learn
export learn!, compute_paramsChange!
using Statistics, Random, LinearAlgebra, JSON3, Flux, CUDA, Dates
using GeneralUtils
using ..type, ..snnUtil
#------------------------------------------------------------------------------------------------100
function compute_paramsChange!(kfn::kfn_1, modelError, outputError)
modelError = reshape(modelError, (1,1,1,:)) # (1,1,1,batch)
lifComputeParamsChange!(kfn.lif_phi,
kfn.lif_epsilonRec,
kfn.lif_eta,
kfn.lif_eRec,
kfn.lif_wRec,
kfn.lif_wRecChange,
kfn.on_wOut,
kfn.lif_arrayProjection4d,
kfn.lif_error,
modelError,
kfn.inputSize,
)
alifComputeParamsChange!(kfn.alif_phi,
kfn.alif_epsilonRec,
kfn.alif_eta,
kfn.alif_eRec,
kfn.alif_wRec,
kfn.alif_wRecChange,
kfn.on_wOut,
kfn.alif_arrayProjection4d,
kfn.alif_error,
modelError,
kfn.alif_epsilonRecA,
kfn.alif_beta,
)
onComputeParamsChange!(kfn.on_phi,
kfn.on_epsilonRec,
kfn.on_eta,
kfn.on_eRec,
kfn.on_wOut,
kfn.on_wOutChange,
kfn.on_arrayProjection4d,
kfn.on_error,
outputError,
)
# error("DEBUG -> kfn compute_paramsChange! $(Dates.now())")
end
function lifComputeParamsChange!( phi::CuArray,
epsilonRec::CuArray,
eta::CuArray,
eRec::CuArray,
wRec::CuArray,
wRecChange::CuArray,
wOut::CuArray,
arrayProjection4d::CuArray,
nError::CuArray,
modelError::CuArray,
inputSize::CuArray,
)
# Bₖⱼ in paper, sum() to get each neuron's total wOut weight,
# use absolute because only magnitude is needed
wOutSum_all = reshape( abs.(sum(wOut, dims=3)), (1,1,:, size(wOut, 4)) ) # (1,1,allNeuron,batch)
# get only each lif neuron's wOut, leaving out other neuron's wOut
startIndex = prod(inputSize) +1
stopIndex = startIndex + size(wRec, 3) -1
wOutSum = @view(wOutSum_all[1,1, startIndex:stopIndex, :])
wOutSum = reshape(wOutSum, (1, 1, size(wOutSum, 1), size(wOutSum, 2))) # (1,1,n,batch)
# nError a.k.a. learning signal use dopamine concept,
# this neuron receive summed error signal (modelError)
nError .= (modelError .* wOutSum) .* arrayProjection4d
eRec .= phi .* epsilonRec
wRecChange .+= (-eta .* nError .* eRec)
# reset epsilonRec
epsilonRec .= 0
end
function alifComputeParamsChange!( phi::CuArray,
epsilonRec::CuArray,
eta::CuArray,
eRec::CuArray,
wRec::CuArray,
wRecChange::CuArray,
wOut::CuArray,
arrayProjection4d::CuArray,
nError::CuArray,
modelError::CuArray,
epsilonRecA::CuArray,
beta::CuArray
)
# Bₖⱼ in paper, sum() to get each neuron's total wOut weight,
# use absolute because only magnitude is needed
wOutSum_all = reshape( abs.(sum(wOut, dims=3)), (1,1,:, size(wOut, 4)) ) # (1,1,allNeuron,batch)
# get only each lif neuron's wOut, leaving out other neuron's wOut
wOutSum = @view(wOutSum_all[1,1, end-size(wRec, 3)+1:end, :])
wOutSum = reshape(wOutSum, (1, 1, size(wOutSum, 1), size(wOutSum, 2))) # (1,1,n,batch)
# nError a.k.a. learning signal use dopamine concept,
# this neuron receive summed error signal (modelError)
nError .= (modelError .* wOutSum) .* arrayProjection4d
eRec .= phi .* (epsilonRec .- (beta .* epsilonRecA)) # use eq. 25
wRecChange .+= (-eta .* nError .* eRec)
# reset epsilonRec
epsilonRec .= 0
epsilonRecA .= 0
# error("DEBUG -> alifComputeParamsChange! $(Dates.now())")
end
function onComputeParamsChange!(phi::CuArray,
epsilonRec::CuArray,
eta::CuArray,
eRec::CuArray,
wOut::CuArray,
wOutChange::CuArray,
arrayProjection4d::CuArray,
nError::CuArray,
outputError::CuArray # outputError is output neuron's error
)
eRec .= phi .* epsilonRec
nError .= reshape(outputError, (1, 1, :, size(outputError, 2))) .* arrayProjection4d
wOutChange .+= (-eta .* nError .* eRec)
# reset epsilonRec
epsilonRec .= 0
# error("DEBUG -> onComputeParamsChange! $(Dates.now())")
end
function lifComputeParamsChange!( phi::AbstractArray,
epsilonRec::AbstractArray,
eta::AbstractArray,
wRec::AbstractArray,
wRecChange::AbstractArray,
wOut::AbstractArray,
modelError::AbstractArray)
d1, d2, d3, d4 = size(epsilonRec)
# Bₖⱼ in paper, sum() to get each neuron's total wOut weight
wOutSum = reshape(sum(wOut, dims=3), (d1, :, d4))
for j in 1:d4, i in 1:d3 # compute along neurons axis of every batch
# how much error of this neuron 1-spike causing each output neuron's error
view(wRecChange, :, :, i, j) .+= (-1 * view(eta, :, :, i, j)[1]) .*
# eRec
(
(view(phi, :, :, i, j)[1] .* view(epsilonRec, :, :, i, j)) .*
# nError a.k.a. learning signal
(
view(modelError, :, j)[1] * # dopamine concept, this neuron receive summed error signal
# RSNN neuron's total wOut weight (neuron synaptic subscription .* wOutSum)
view(wOutSum, :, :, j)[i]
)
)
end
end
function alifComputeParamsChange!( phi::AbstractArray,
epsilonRec::AbstractArray,
epsilonRecA::AbstractArray,
eta::AbstractArray,
wRec::AbstractArray,
wRecChange::AbstractArray,
beta::AbstractArray,
wOut::AbstractArray,
modelError::AbstractArray)
d1, d2, d3, d4 = size(epsilonRec)
# Bₖⱼ in paper, sum() to get each neuron's total wOut weight
wOutSum = reshape(sum(wOut, dims=3), (d1, :, d4))
for j in 1:d4, i in 1:d3 # compute along neurons axis of every batch
# how much error of this neuron 1-spike causing each output neuron's error
view(wRecChange, :, :, i, j) .+= (-1 * view(eta, :, :, i, j)[1]) .*
# eRec
(
# eRec_v
(view(phi, :, :, i, j)[1] .* view(epsilonRec, :, :, i, j)) .+
# eRec_a
((view(phi, :, :, i, j)[1] * view(beta, :, :, i, j)[1]) .*
view(epsilonRecA, :, :, i, j))
) .*
# nError a.k.a. learning signal
(
view(modelError, :, j)[1] *
# RSNN neuron's total wOut weight (neuron synaptic subscription .* wOutSum)
view(wOutSum, :, :, j)[i]
# sum(GeneralUtils.isNotEqual.(view(wRec, :, :, i, j), 0) .*
# view(wOutSum, :, :, j))
)
end
end
function onComputeParamsChange!(phi::AbstractArray,
epsilonRec::AbstractArray,
eta::AbstractArray,
wOutChange::AbstractArray,
outputError::AbstractArray)
d1, d2, d3, d4 = size(epsilonRec)
for j in 1:d4, i in 1:d3 # compute along neurons axis of every batch
# how much error of this neuron 1-spike causing each output neuron's error
view(wOutChange, :, :, i, j) .+= (-1 * view(eta, :, :, i, j)[1]) .*
# eRec
(
(view(phi, :, :, i, j)[1] .* view(epsilonRec, :, :, i, j)) .*
# nError a.k.a. learning signal, output neuron receives error of its own answer - correct answer.
view(outputError, :, j)[i]
)
end
end
function learn!(kfn::kfn_1)
# lif learn
lifLearn!(kfn.lif_wRec,
kfn.lif_wRecChange,
kfn.lif_arrayProjection4d)
# alif learn
alifLearn!(kfn.alif_wRec,
kfn.alif_wRecChange,
kfn.alif_arrayProjection4d)
# on learn
onLearn!(kfn.on_wOut,
kfn.on_wOutChange,
kfn.on_arrayProjection4d)
# wrap up learning session
if kfn.learningStage == [3]
kfn.learningStage = [0]
end
# error("DEBUG -> kfn learn! $(Dates.now())")
end
function lifLearn!(wRec,
wRecChange,
arrayProjection4d)
# merge learning weight with average learning weight
wRec .+= (sum(wRecChange, dims=4) ./ (size(wRec, 4))) .* arrayProjection4d
#TODO synaptic strength
#TODO neuroplasticity
# error("DEBUG -> lifLearn! $(Dates.now())")
end
function alifLearn!(wRec,
wRecChange,
arrayProjection4d)
# merge learning weight with average learning weight
wRec .+= (sum(wRecChange, dims=4) ./ (size(wRec, 4))) .* arrayProjection4d
#TODO synaptic strength
#TODO neuroplasticity
end
function onLearn!(wOut,
wOutChange,
arrayProjection4d)
# merge learning weight with average learning weight
wOut .+= (sum(wOutChange, dims=4) ./ (size(wOut, 4))) .* arrayProjection4d
# adaptive wOut to help convergence using c_decay
wOut .-= 0.001 .* wOut
#TODO synaptic strength
#TODO neuroplasticity
end
#TODO voltage regulator
#TODO frequency regulator
end # module