This commit is contained in:
ton
2023-08-10 10:06:21 +07:00
parent 65bb97baf3
commit a80e9f2621
3 changed files with 512 additions and 441 deletions

View File

@@ -9,6 +9,7 @@ using ..type, ..snnUtil
#------------------------------------------------------------------------------------------------100
function compute_paramsChange!(kfn::kfn_1, modelError, outputError)
modelError = reshape(modelError, (1,1,1,:)) # (1,1,1,batch)
lifComputeParamsChange!(kfn.lif_phi,
kfn.lif_epsilonRec,
kfn.lif_eta,
@@ -18,7 +19,10 @@ function compute_paramsChange!(kfn::kfn_1, modelError, outputError)
kfn.on_wOut,
kfn.lif_arrayProjection4d,
kfn.lif_error,
modelError)
modelError,
kfn.inputSize,
)
alifComputeParamsChange!(kfn.alif_phi,
kfn.alif_epsilonRec,
@@ -30,7 +34,10 @@ function compute_paramsChange!(kfn::kfn_1, modelError, outputError)
kfn.alif_arrayProjection4d,
kfn.alif_error,
modelError,
kfn.alif_beta)
kfn.alif_epsilonRecA,
kfn.alif_beta,
)
onComputeParamsChange!(kfn.on_phi,
kfn.on_epsilonRec,
@@ -38,7 +45,10 @@ function compute_paramsChange!(kfn::kfn_1, modelError, outputError)
kfn.on_eRec,
kfn.on_wOut,
kfn.on_wOutChange,
outputError)
kfn.on_arrayProjection4d,
kfn.on_error,
outputError,
)
# error("DEBUG -> kfn compute_paramsChange! $(Dates.now())")
end
@@ -51,18 +61,28 @@ function lifComputeParamsChange!( phi::CuArray,
wOut::CuArray,
arrayProjection4d::CuArray,
nError::CuArray,
modelError::CuArray)
# Bₖⱼ in paper, sum() to get each neuron's total wOut weight
wOutSum = sum(wOut, dims=3) .* arrayProjection4d
modelError::CuArray,
inputSize::CuArray,
)
# Bₖⱼ in paper, sum() to get each neuron's total wOut weight,
# use absolute because only magnitude is needed
wOutSum_all = reshape( abs.(sum(wOut, dims=3)), (1,1,:, size(wOut, 4)) ) # (1,1,allNeuron,batch)
# get only each lif neuron's wOut, leaving out other neuron's wOut
startIndex = prod(inputSize) +1
stopIndex = startIndex + size(wRec, 3) -1
wOutSum = @view(wOutSum_all[1,1, startIndex:stopIndex, :])
wOutSum = reshape(wOutSum, (1, 1, size(wOutSum, 1), size(wOutSum, 2))) # (1,1,n,batch)
# nError a.k.a. learning signal use dopamine concept,
# this neuron receive summed error signal (modelError)
nError .= (modelError .* arrayProjection4d) .* wOutSum
nError .= (modelError .* wOutSum) .* arrayProjection4d
eRec .= phi .* epsilonRec
# GeneralUtils.isNotEqual(wRec, 0) is a subscribe filter use to filter out non-subscribed wRecChange
wRecChange .+= ((-1 .* eta) .* nError .* eRec .* sign.(wRec)) .* GeneralUtils.isNotEqual.(wRec, 0)
# error("DEBUG -> lifComputeParamsChange! $(Dates.now())")
wRecChange .+= ((-1 .* eta) .* nError .* eRec)
# reset epsilonRec
epsilonRec .= 0
end
function alifComputeParamsChange!( phi::CuArray,
@@ -75,18 +95,29 @@ function alifComputeParamsChange!( phi::CuArray,
arrayProjection4d::CuArray,
nError::CuArray,
modelError::CuArray,
beta::CuArray)
# Bₖⱼ in paper, sum() to get each neuron's total wOut weight
wOutSum = sum(wOut, dims=3) .* arrayProjection4d
epsilonRecA::CuArray,
beta::CuArray
)
# Bₖⱼ in paper, sum() to get each neuron's total wOut weight,
# use absolute because only magnitude is needed
wOutSum_all = reshape( abs.(sum(wOut, dims=3)), (1,1,:, size(wOut, 4)) ) # (1,1,allNeuron,batch)
# get only each lif neuron's wOut, leaving out other neuron's wOut
wOutSum = @view(wOutSum_all[1,1, end-size(wRec, 3)+1:end, :])
wOutSum = reshape(wOutSum, (1, 1, size(wOutSum, 1), size(wOutSum, 2))) # (1,1,n,batch)
# nError a.k.a. learning signal use dopamine concept,
# this neuron receive summed error signal (modelError)
nError .= (modelError .* arrayProjection4d) .* wOutSum
eRec .= (phi .* epsilonRec) .+ (phi .* epsilonRec .* beta)
nError .= (modelError .* wOutSum) .* arrayProjection4d
eRec .= phi .* (epsilonRec .- (beta .* epsilonRecA)) # use eq. 25
wRecChange .+= ((-1 .* eta) .* nError .* eRec)
# reset epsilonRec
epsilonRec .= 0
epsilonRecA .= 0
# GeneralUtils.isNotEqual(wRec, 0) is a subscribe filter use to filter out non-subscribed wRecChange
wRecChange .+= ((-1 .* eta) .* nError .* eRec .* sign.(wRec)) .* GeneralUtils.isNotEqual.(wRec, 0)
# error("DEBUG -> alifComputeParamsChange! $(Dates.now())")
end
@@ -96,15 +127,17 @@ function onComputeParamsChange!(phi::CuArray,
eRec::CuArray,
wOut::CuArray,
wOutChange::CuArray,
arrayProjection4d::CuArray,
nError::CuArray,
outputError::CuArray # outputError is output neuron's error
)
# nError a.k.a. learning signal use dopamine concept,
# this neuron receive summed error signal (modelError)
eRec .= (phi .* epsilonRec) .* reshape(outputError, (1, 1, :, size(epsilonRec, 4)))
# GeneralUtils.isNotEqual(wRec, 0) is a subscribe filter use to filter out non-subscribed wRecChange
wOutChange .+= ((-1 .* eta) .* eRec .* sign.(wOut)) .* GeneralUtils.isNotEqual.(wOut, 0)
eRec .= phi .* epsilonRec
nError .= reshape(outputError, (1, 1, :, size(outputError, 2))) .* arrayProjection4d
wOutChange .+= ((-1 .* eta) .* nError .* eRec)
# reset epsilonRec
epsilonRec .= 0
# error("DEBUG -> onComputeParamsChange! $(Dates.now())")
end
@@ -224,20 +257,20 @@ end
function lifLearn!(wRec,
wRecChange,
arrayProjection4d)
# merge learning weight with average learning weight
wRec .+= (sum(wRecChange) ./ (size(wRec, 4))) .* arrayProjection4d
wRec .+= (sum(wRecChange, dims=4) ./ (size(wRec, 4))) .* arrayProjection4d
#TODO synaptic strength
#TODO neuroplasticity
# error("DEBUG -> lifLearn! $(Dates.now())")
end
function alifLearn!(wRec,
wRecChange,
arrayProjection4d)
# merge learning weight
# merge learning weight with average learning weight
wRec .+= (sum(wRecChange) ./ (size(wRec, 4))) .* arrayProjection4d
#TODO synaptic strength
@@ -249,7 +282,7 @@ end
function onLearn!(wOut,
wOutChange,
arrayProjection4d)
# merge learning weight
# merge learning weight with average learning weight
wOut .+= (sum(wOutChange) ./ (size(wOut, 4))) .* arrayProjection4d
# adaptive wOut to help convergence using c_decay