bug fix
This commit is contained in:
@@ -129,7 +129,6 @@ function (n::lifNeuron)(kfn::knowledgeFn)
|
|||||||
n.v_t1 = n.alpha * n.v_t
|
n.v_t1 = n.alpha * n.v_t
|
||||||
else
|
else
|
||||||
n.recSignal = sum(n.wRec .* n.z_i_t) # signal from other neuron that this neuron subscribed
|
n.recSignal = sum(n.wRec .* n.z_i_t) # signal from other neuron that this neuron subscribed
|
||||||
|
|
||||||
n.alpha_v_t = n.alpha * n.v_t
|
n.alpha_v_t = n.alpha * n.v_t
|
||||||
n.v_t1 = n.alpha_v_t + n.recSignal
|
n.v_t1 = n.alpha_v_t + n.recSignal
|
||||||
n.v_t1 = no_negative!(n.v_t1)
|
n.v_t1 = no_negative!(n.v_t1)
|
||||||
@@ -171,7 +170,6 @@ function (n::alifNeuron)(kfn::knowledgeFn)
|
|||||||
n.v_t1 = n.alpha * n.v_t
|
n.v_t1 = n.alpha * n.v_t
|
||||||
n.phi = 0
|
n.phi = 0
|
||||||
else
|
else
|
||||||
n.z_t = isnothing(n.z_t) ? false : n.z_t
|
|
||||||
n.a = (n.rho * n.a) + ((1 - n.rho) * n.z_t)
|
n.a = (n.rho * n.a) + ((1 - n.rho) * n.z_t)
|
||||||
n.av_th = n.v_th + (n.beta * n.a)
|
n.av_th = n.v_th + (n.beta * n.a)
|
||||||
n.recSignal = sum(n.wRec .* n.z_i_t) # signal from other neuron that this neuron subscribed
|
n.recSignal = sum(n.wRec .* n.z_i_t) # signal from other neuron that this neuron subscribed
|
||||||
@@ -205,7 +203,7 @@ function (n::linearNeuron)(kfn::T) where T<:knowledgeFn
|
|||||||
n.timeStep = kfn.timeStep
|
n.timeStep = kfn.timeStep
|
||||||
|
|
||||||
# pulling other neuron's firing status at time t
|
# pulling other neuron's firing status at time t
|
||||||
n.z_i_t = getindex(kfn.firedNeurons_t0, n.subscriptionList)
|
n.z_i_t = getindex(kfn.firedNeurons_t1, n.subscriptionList)
|
||||||
|
|
||||||
if n.refractoryCounter != 0
|
if n.refractoryCounter != 0
|
||||||
n.refractoryCounter -= 1
|
n.refractoryCounter -= 1
|
||||||
@@ -217,13 +215,13 @@ function (n::linearNeuron)(kfn::T) where T<:knowledgeFn
|
|||||||
|
|
||||||
# decay of v_t1
|
# decay of v_t1
|
||||||
n.v_t1 = n.alpha * n.v_t
|
n.v_t1 = n.alpha * n.v_t
|
||||||
|
n.vError = n.v_t1
|
||||||
else
|
else
|
||||||
n.recSignal = sum(n.wRec .* n.z_i_t) # signal from other neuron that this neuron subscribed
|
n.recSignal = sum(n.wRec .* n.z_i_t) # signal from other neuron that this neuron subscribed
|
||||||
|
|
||||||
n.alpha_v_t = n.alpha * n.v_t
|
n.alpha_v_t = n.alpha * n.v_t
|
||||||
n.v_t1 = n.alpha_v_t + n.recSignal
|
n.v_t1 = n.alpha_v_t + n.recSignal
|
||||||
n.v_t1 = no_negative!(n.v_t1)
|
n.v_t1 = no_negative!(n.v_t1)
|
||||||
|
n.vError = n.v_t1
|
||||||
if n.v_t1 > n.v_th
|
if n.v_t1 > n.v_th
|
||||||
n.z_t1 = true
|
n.z_t1 = true
|
||||||
n.refractoryCounter = n.refractoryDuration
|
n.refractoryCounter = n.refractoryDuration
|
||||||
|
|||||||
26
src/learn.jl
26
src/learn.jl
@@ -27,7 +27,7 @@ function learn!(kfn::kfn_1, correctAnswer::AbstractVector)
|
|||||||
outs = [n.z_t1 for n in kfn.outputNeuronsArray]
|
outs = [n.z_t1 for n in kfn.outputNeuronsArray]
|
||||||
for (i, out) in enumerate(outs)
|
for (i, out) in enumerate(outs)
|
||||||
if out != correctAnswer[i] # need to adjust weight
|
if out != correctAnswer[i] # need to adjust weight
|
||||||
kfnError = (kfn.outputNeuronsArray[i].v_th - kfn.outputNeuronsArray[i].v_t1) *
|
kfnError = (kfn.outputNeuronsArray[i].v_th - kfn.outputNeuronsArray[i].vError) *
|
||||||
100 / kfn.outputNeuronsArray[i].v_th
|
100 / kfn.outputNeuronsArray[i].v_th
|
||||||
|
|
||||||
# Threads.@threads for n in kfn.neuronsArray
|
# Threads.@threads for n in kfn.neuronsArray
|
||||||
@@ -35,7 +35,7 @@ function learn!(kfn::kfn_1, correctAnswer::AbstractVector)
|
|||||||
learn!(n, kfnError)
|
learn!(n, kfnError)
|
||||||
end
|
end
|
||||||
|
|
||||||
learn!(kfn.outputNeuronsArray[i], kfn)
|
learn!(kfn.outputNeuronsArray[i], kfnError)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
@@ -43,18 +43,20 @@ function learn!(kfn::kfn_1, correctAnswer::AbstractVector)
|
|||||||
if kfn.learningStage == "end_learning"
|
if kfn.learningStage == "end_learning"
|
||||||
# Threads.@threads for n in kfn.neuronsArray
|
# Threads.@threads for n in kfn.neuronsArray
|
||||||
for n in kfn.neuronsArray
|
for n in kfn.neuronsArray
|
||||||
wSign_0 = sign.(n.wRec) # original sign
|
if typeof(n) <: computeNeuron
|
||||||
n.wRec += n.wRecChange # merge wRecChange into wRec
|
wSign_0 = sign.(n.wRec) # original sign
|
||||||
wSign_1 = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign
|
n.wRec += n.wRecChange # merge wRecChange into wRec
|
||||||
nonFlipedSign = isequal.(wSign_0, wSign_1) # 1 not fliped, 0 fliped
|
wSign_1 = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign
|
||||||
# normalize wRec peak to prevent input signal overwhelming neuron
|
nonFlipedSign = isequal.(wSign_0, wSign_1) # 1 not fliped, 0 fliped
|
||||||
normalizePeak!(n.wRec, 2)
|
# normalize wRec peak to prevent input signal overwhelming neuron
|
||||||
n.wRec .*= nonFlipedSign # set weight that fliped sign to 0 for random new connection
|
normalizePeak!(n.wRec, 2)
|
||||||
|
n.wRec .*= nonFlipedSign # set weight that fliped sign to 0 for random new connection
|
||||||
|
|
||||||
|
|
||||||
synapticConnStrength!(n)
|
synapticConnStrength!(n)
|
||||||
#TODO neuroplasticity
|
#TODO neuroplasticity
|
||||||
println("")
|
println("")
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
for n in kfn.outputNeuronsArray # merge wRecChange into wRec
|
for n in kfn.outputNeuronsArray # merge wRecChange into wRec
|
||||||
|
|||||||
@@ -123,7 +123,7 @@ function resetLearningParams!(n::alifNeuron)
|
|||||||
|
|
||||||
# reset refractory state at the start/end of episode. Otherwise once neuron goes into
|
# reset refractory state at the start/end of episode. Otherwise once neuron goes into
|
||||||
# refractory state, it will stay in refractory state forever
|
# refractory state, it will stay in refractory state forever
|
||||||
reset_refractoryCounter!(n)
|
# reset_refractoryCounter!(n)
|
||||||
end
|
end
|
||||||
|
|
||||||
# function reset_learning_no_wchange!(n::passthroughNeuron)
|
# function reset_learning_no_wchange!(n::passthroughNeuron)
|
||||||
@@ -136,12 +136,12 @@ end
|
|||||||
function resetLearningParams!(n::linearNeuron)
|
function resetLearningParams!(n::linearNeuron)
|
||||||
reset_epsilonRec!(n)
|
reset_epsilonRec!(n)
|
||||||
reset_wRecChange!(n)
|
reset_wRecChange!(n)
|
||||||
reset_v_t!(n)
|
# reset_v_t!(n)
|
||||||
reset_firing_counter!(n)
|
reset_firing_counter!(n)
|
||||||
|
|
||||||
# reset refractory state at the start/end of episode. Otherwise once neuron goes into
|
# reset refractory state at the start/end of episode. Otherwise once neuron goes into
|
||||||
# refractory state, it will stay in refractory state forever
|
# refractory state, it will stay in refractory state forever
|
||||||
reset_refractoryCounter!(n)
|
# reset_refractoryCounter!(n)
|
||||||
end
|
end
|
||||||
#------------------------------------------------------------------------------------------------100
|
#------------------------------------------------------------------------------------------------100
|
||||||
|
|
||||||
@@ -279,6 +279,7 @@ end
|
|||||||
function synapticConnStrength!(n::Union{computeNeuron, outputNeuron})
|
function synapticConnStrength!(n::Union{computeNeuron, outputNeuron})
|
||||||
for (i, connStrength) in enumerate(n.synapticStrength)
|
for (i, connStrength) in enumerate(n.synapticStrength)
|
||||||
# check whether connStrength increase or decrease based on usage from n.epsilonRec
|
# check whether connStrength increase or decrease based on usage from n.epsilonRec
|
||||||
|
#WORKING n.epsilonRec is all 0.0 why? may b it was reset?
|
||||||
updown = n.epsilonRec[i] == 0.0 ? "down" : "up"
|
updown = n.epsilonRec[i] == 0.0 ? "down" : "up"
|
||||||
updatedConnStrength = synapticConnStrength(connStrength, updown)
|
updatedConnStrength = synapticConnStrength(connStrength, updown)
|
||||||
updatedConnStrength = GeneralUtils.limitvalue(updatedConnStrength,
|
updatedConnStrength = GeneralUtils.limitvalue(updatedConnStrength,
|
||||||
@@ -321,7 +322,7 @@ end
|
|||||||
within its radius. radius must be odd number
|
within its radius. radius must be odd number
|
||||||
"""
|
"""
|
||||||
function normalizePeak!(v::Vector, radius::Integer=2)
|
function normalizePeak!(v::Vector, radius::Integer=2)
|
||||||
peak = findall(isequal.(v, maximum(abs.(v))))[1]
|
peak = findall(isequal.(abs.(v), maximum(abs.(v))))[1]
|
||||||
upindex = peak - radius
|
upindex = peak - radius
|
||||||
upindex = upindex < 1 ? 1 : upindex
|
upindex = upindex < 1 ? 1 : upindex
|
||||||
downindex = peak + radius
|
downindex = peak + radius
|
||||||
|
|||||||
@@ -512,6 +512,7 @@ Base.@kwdef mutable struct linearNeuron <: outputNeuron
|
|||||||
v_t1::Float64 = rand() # vᵗ⁺¹, postsynaptic neuron membrane potential at current timestep
|
v_t1::Float64 = rand() # vᵗ⁺¹, postsynaptic neuron membrane potential at current timestep
|
||||||
v_th::Float64 = 1.0 # vᵗʰ, neuron firing threshold
|
v_th::Float64 = 1.0 # vᵗʰ, neuron firing threshold
|
||||||
vRest::Float64 = 0.0 # resting potential after neuron fired
|
vRest::Float64 = 0.0 # resting potential after neuron fired
|
||||||
|
vError::Union{Float64,Nothing} = nothing # used to compute model error
|
||||||
z_t::Bool = false # zᵗ, neuron postsynaptic firing of previous timestep
|
z_t::Bool = false # zᵗ, neuron postsynaptic firing of previous timestep
|
||||||
# zᵗ⁺¹, neuron firing status at time = t+1. I need this because the way I calculate all
|
# zᵗ⁺¹, neuron firing status at time = t+1. I need this because the way I calculate all
|
||||||
# neurons forward function at each timestep-by-timestep is to do every neuron
|
# neurons forward function at each timestep-by-timestep is to do every neuron
|
||||||
@@ -539,7 +540,6 @@ Base.@kwdef mutable struct linearNeuron <: outputNeuron
|
|||||||
wRecChange::Union{Array{Float64},Nothing} = nothing # Δw_rec, cumulated wRec change
|
wRecChange::Union{Array{Float64},Nothing} = nothing # Δw_rec, cumulated wRec change
|
||||||
recSignal::Union{Float64,Nothing} = nothing # incoming recurrent signal
|
recSignal::Union{Float64,Nothing} = nothing # incoming recurrent signal
|
||||||
alpha_v_t::Union{Float64,Nothing} = nothing # alpha * v_t
|
alpha_v_t::Union{Float64,Nothing} = nothing # alpha * v_t
|
||||||
error::Union{Float64,Nothing} = nothing # local neuron error
|
|
||||||
|
|
||||||
firingCounter::Integer = 0 # store how many times neuron fires
|
firingCounter::Integer = 0 # store how many times neuron fires
|
||||||
end
|
end
|
||||||
@@ -627,7 +627,7 @@ function init_neuron!(id::Int64, n::lifNeuron, n_params::Dict, kfnParams::Dict)
|
|||||||
n.synapticStrength = rand(-5:0.1:-3, length(n.subscriptionList))
|
n.synapticStrength = rand(-5:0.1:-3, length(n.subscriptionList))
|
||||||
|
|
||||||
n.epsilonRec = zeros(length(n.subscriptionList))
|
n.epsilonRec = zeros(length(n.subscriptionList))
|
||||||
n.wRec = rand(length(n.subscriptionList))
|
n.wRec = rand(-0.2:0.01:0.2, length(n.subscriptionList))
|
||||||
n.wRecChange = zeros(length(n.subscriptionList))
|
n.wRecChange = zeros(length(n.subscriptionList))
|
||||||
n.alpha = calculate_α(n)
|
n.alpha = calculate_α(n)
|
||||||
end
|
end
|
||||||
@@ -646,7 +646,7 @@ function init_neuron!(id::Int64, n::alifNeuron, n_params::Dict,
|
|||||||
n.synapticStrength = rand(-5:0.1:-3, length(n.subscriptionList))
|
n.synapticStrength = rand(-5:0.1:-3, length(n.subscriptionList))
|
||||||
|
|
||||||
n.epsilonRec = zeros(length(n.subscriptionList))
|
n.epsilonRec = zeros(length(n.subscriptionList))
|
||||||
n.wRec = rand(length(n.subscriptionList))
|
n.wRec = rand(-0.2:0.01:0.2, length(n.subscriptionList))
|
||||||
n.wRecChange = zeros(length(n.subscriptionList))
|
n.wRecChange = zeros(length(n.subscriptionList))
|
||||||
|
|
||||||
# the more time has passed from the last time neuron was activated, the more
|
# the more time has passed from the last time neuron was activated, the more
|
||||||
@@ -668,7 +668,7 @@ function init_neuron!(id::Int64, n::linearNeuron, n_params::Dict, kfnParams::Dic
|
|||||||
n.synapticStrength = rand(-5:0.1:-3, length(n.subscriptionList))
|
n.synapticStrength = rand(-5:0.1:-3, length(n.subscriptionList))
|
||||||
|
|
||||||
n.epsilonRec = zeros(length(n.subscriptionList))
|
n.epsilonRec = zeros(length(n.subscriptionList))
|
||||||
n.wRec = rand(length(n.subscriptionList))
|
n.wRec = rand(-0.2:0.01:0.2, length(n.subscriptionList))
|
||||||
n.wRecChange = zeros(length(n.subscriptionList))
|
n.wRecChange = zeros(length(n.subscriptionList))
|
||||||
n.alpha = calculate_k(n)
|
n.alpha = calculate_k(n)
|
||||||
end
|
end
|
||||||
|
|||||||
Reference in New Issue
Block a user