diff --git a/previousVersion/0.0.3/main.jl b/previousVersion/0.0.3/main.jl index bcf50aa..031797b 100644 --- a/previousVersion/0.0.3/main.jl +++ b/previousVersion/0.0.3/main.jl @@ -593,7 +593,7 @@ function validate(model, dataset, labelDict) thinkingPeriod = 16 # 1000-784 = 216 predict = [0] |> device - n = length(trainData) + n = length(dataset) println("n $n") p = Progress(n, dt=1.0) # minimum update interval: 1 second for (imgBatch, labels) in dataset diff --git a/src/forward.jl b/src/forward.jl index 09873e2..12d6719 100644 --- a/src/forward.jl +++ b/src/forward.jl @@ -235,8 +235,7 @@ function lifForward( zit, phi[i1,i2,i3,i4] = 0 # compute epsilonRec - epsilonRec[i1,i2,i3,i4] = (alpha[i1,i2,i3,i4] * epsilonRec[i1,i2,i3,i4]) + - (zit[i1,i2,i3,i4] * subscription[i1,i2,i3,i4]) + epsilonRec[i1,i2,i3,i4] = (alpha[i1,i2,i3,i4] * epsilonRec[i1,i2,i3,i4]) else # refractory period is inactive recSignal[i1,i2,i3,i4] = wRec[i1,i2,i3,i4] * zit[i1,i2,i3,i4] @@ -386,13 +385,10 @@ function alifForward( zit, a[i1,i2,i3,i4] = rho[i1,i2,i3,i4] * a[i1,i2,i3,i4] # compute epsilonRec - epsilonRec[i1,i2,i3,i4] = (alpha[i1,i2,i3,i4] * epsilonRec[i1,i2,i3,i4]) + - (zit[i1,i2,i3,i4] * subscription[i1,i2,i3,i4]) + epsilonRec[i1,i2,i3,i4] = (alpha[i1,i2,i3,i4] * epsilonRec[i1,i2,i3,i4]) # compute epsilonRecA - epsilonRecA[i1,i2,i3,i4] = (phi[i1,i2,i3,i4] * epsilonRec[i1,i2,i3,i4]) + - ((rho[i1,i2,i3,i4] - (phi[i1,i2,i3,i4] * beta[i1,i2,i3,i4])) * - epsilonRecA[i1,i2,i3,i4]) + epsilonRecA[i1,i2,i3,i4] = (phi[i1,i2,i3,i4] * epsilonRec[i1,i2,i3,i4]) # compute avth avth[i1,i2,i3,i4] = vth[i1,i2,i3,i4] + (beta[i1,i2,i3,i4] * a[i1,i2,i3,i4]) @@ -534,8 +530,7 @@ function onForward( zit, phi[i1,i2,i3,i4] = 0 # compute epsilonRec - epsilonRec[i1,i2,i3,i4] = (alpha[i1,i2,i3,i4] * epsilonRec[i1,i2,i3,i4]) + - (zit[i1,i2,i3,i4] * subscription[i1,i2,i3,i4]) + epsilonRec[i1,i2,i3,i4] = (alpha[i1,i2,i3,i4] * epsilonRec[i1,i2,i3,i4]) else # refractory period is inactive recSignal[i1,i2,i3,i4] = zit[i1,i2,i3,i4] * wOut[i1,i2,i3,i4] diff --git a/src/learn.jl b/src/learn.jl index 168e7da..d9197da 100644 --- a/src/learn.jl +++ b/src/learn.jl @@ -9,7 +9,10 @@ using ..type, ..snnUtil #------------------------------------------------------------------------------------------------100 function compute_paramsChange!(kfn::kfn_1, modelError, outputError) - modelError = reshape(modelError, (1,1,1,:)) # (1,1,1,batch) + # modelError = reshape(modelError, (1,1,1,:)) # (1,1,1,batch) + modelError = reshape(modelError, (1,1,:, size(modelError, 2))) + modelError = sum(modelError, dims=3) + lifComputeParamsChange!(kfn.timeStep, kfn.lif_phi, kfn.lif_epsilonRec, @@ -92,6 +95,23 @@ function lifComputeParamsChange!( timeStep::CuArray, #TODO frequency regulator wRecChange .+= 0.0001 .* ((firingTargetFrequency - (firingCounter./timeStep)) .* timeStep) .* eta .* eRec + + # if sum(timeStep) == 785 + # epsilonRec_cpu = epsilonRec |> cpu + # println("modelError ", modelError) + # println("") + # wchange = (-eta .* nError .* eRec) |> cpu + # println("wchange 5 1 ", wchange[:,:,5,1]) + # println("") + # println("wchange 5 2 ", wchange[:,:,5,2]) + # println("") + # println("epsilonRec 5 1 ", epsilonRec_cpu[:,:,5,1]) + # println("") + # println("epsilonRec 5 2 ", epsilonRec_cpu[:,:,5,2]) + # println("") + # error("DEBUG lifComputeParamsChange!") + # end + # reset epsilonRec epsilonRec .= 0 end