refractoring
This commit is contained in:
@@ -35,7 +35,7 @@ using .interface
|
||||
"""
|
||||
Todo:
|
||||
[3] verify that model can complete learning cycle with no error
|
||||
[2] neuroplasticity() i.e. change connection
|
||||
[*2] neuroplasticity() i.e. change connection
|
||||
[] using RL to control learning signal
|
||||
[] consider using Dates.now() instead of timestamp because time_stamp may overflow
|
||||
[] training should include adjusting α, neuron membrane potential decay factor
|
||||
|
||||
@@ -31,6 +31,29 @@ function (kfn::kfn_1)(m::model, input_data::AbstractVector)
|
||||
|
||||
kfn.learningStage = m.learningStage
|
||||
|
||||
if kfn.learningStage == "start_learning"
|
||||
# reset params here instead of at the end_learning so that neuron's parameter data
|
||||
# don't gets wiped and can be logged for visualization later
|
||||
for n in kfn.neuronsArray
|
||||
# epsilonRec need to be reset because it counting how many each synaptic fires and
|
||||
# use this info to calculate how much synaptic weight should be adjust
|
||||
resetLearningParams!(n)
|
||||
end
|
||||
|
||||
for n in kfn.output_neuron
|
||||
# epsilonRec need to be reset because it counting how many each synaptic fires and
|
||||
# use this info to calculate how much synaptic weight should be adjust
|
||||
resetLearningParams!(n)
|
||||
end
|
||||
|
||||
# clear variables
|
||||
kfn.firedNeurons = Vector{Int64}()
|
||||
kfn.firedNeurons_t0 = Vector{Bool}()
|
||||
kfn.firedNeurons_t1 = Vector{Bool}()
|
||||
|
||||
kfn.learningStage = "learning"
|
||||
end
|
||||
|
||||
# generate noise
|
||||
noise = [GeneralUtils.randomChoiceWithProb([true, false],[0.5,0.5])
|
||||
for i in 1:length(input_data)]
|
||||
@@ -41,6 +64,9 @@ function (kfn::kfn_1)(m::model, input_data::AbstractVector)
|
||||
for n in kfn.neuronsArray
|
||||
timestep_forward!(n)
|
||||
end
|
||||
for n in kfn.outputNeuronsArray
|
||||
timestep_forward!(n)
|
||||
end
|
||||
|
||||
# pass input_data into input neuron.
|
||||
# number of data point equals to number of input neuron starting from id 1
|
||||
@@ -89,7 +115,6 @@ function (n::lif_neuron)(kfn::knowledgeFn)
|
||||
|
||||
# pulling other neuron's firing status at time t
|
||||
n.z_i_t = getindex(kfn.firedNeurons_t0, n.subscriptionList)
|
||||
n.z_i_t .*= n.subExInType
|
||||
|
||||
if n.refractoryCounter != 0
|
||||
n.refractoryCounter -= 1
|
||||
@@ -102,7 +127,7 @@ function (n::lif_neuron)(kfn::knowledgeFn)
|
||||
# decay of v_t1
|
||||
n.v_t1 = n.alpha * n.v_t
|
||||
else
|
||||
n.recSignal = sum(n.wRec .* n.z_i_t) # signal from other neuron that this neuron subscribed
|
||||
n.recSignal = sum(n.wRec .* n.z_i_t .* n.subExInType) # signal from other neuron that this neuron subscribed
|
||||
|
||||
n.alpha_v_t = n.alpha * n.v_t
|
||||
n.v_t1 = n.alpha_v_t + n.recSignal
|
||||
@@ -132,7 +157,6 @@ function (n::alif_neuron)(kfn::knowledgeFn)
|
||||
n.timeStep = kfn.timeStep
|
||||
|
||||
n.z_i_t = getindex(kfn.firedNeurons_t0, n.subscriptionList)
|
||||
n.z_i_t .*= n.subExInType
|
||||
|
||||
if n.refractoryCounter != 0
|
||||
n.refractoryCounter -= 1
|
||||
@@ -149,7 +173,7 @@ function (n::alif_neuron)(kfn::knowledgeFn)
|
||||
n.z_t = isnothing(n.z_t) ? false : n.z_t
|
||||
n.a = (n.rho * n.a) + ((1 - n.rho) * n.z_t)
|
||||
n.av_th = n.v_th + (n.beta * n.a)
|
||||
n.recSignal = sum(n.wRec .* n.z_i_t) # signal from other neuron that this neuron subscribed
|
||||
n.recSignal = sum(n.wRec .* n.z_i_t .* n.subExInType) # signal from other neuron that this neuron subscribed
|
||||
n.alpha_v_t = n.alpha * n.v_t
|
||||
n.v_t1 = n.alpha_v_t + n.recSignal
|
||||
n.v_t1 = no_negative!.(n.v_t1)
|
||||
@@ -181,7 +205,6 @@ function (n::linear_neuron)(kfn::T) where T<:knowledgeFn
|
||||
|
||||
# pulling other neuron's firing status at time t
|
||||
n.z_i_t = getindex(kfn.firedNeurons_t0, n.subscriptionList)
|
||||
n.z_i_t .*= n.subExInType
|
||||
|
||||
if n.refractoryCounter != 0
|
||||
n.refractoryCounter -= 1
|
||||
@@ -194,7 +217,7 @@ function (n::linear_neuron)(kfn::T) where T<:knowledgeFn
|
||||
# decay of v_t1
|
||||
n.v_t1 = n.alpha * n.v_t
|
||||
else
|
||||
n.recSignal = sum(n.wRec .* n.z_i_t) # signal from other neuron that this neuron subscribed
|
||||
n.recSignal = sum(n.wRec .* n.z_i_t .* n.subExInType) # signal from other neuron that this neuron subscribed
|
||||
|
||||
n.alpha_v_t = n.alpha * n.v_t
|
||||
n.v_t1 = n.alpha_v_t + n.recSignal
|
||||
|
||||
37
src/learn.jl
37
src/learn.jl
@@ -12,10 +12,6 @@ export learn!
|
||||
|
||||
function learn!(m::model, modelRespond, correctAnswer=nothing)
|
||||
m.knowledgeFn[:I].learningStage = m.learningStage
|
||||
|
||||
# # how many matched respond and correct answer
|
||||
# matched = sum(isequal(modelRespond, correctAnswer))
|
||||
|
||||
if correctAnswer === nothing
|
||||
correctAnswer_I = zeros(length(modelRespond))
|
||||
else
|
||||
@@ -28,39 +24,6 @@ end
|
||||
""" knowledgeFn learn()
|
||||
"""
|
||||
function learn!(kfn::kfn_1, correctAnswer::AbstractVector)
|
||||
|
||||
# ΔWeight Conn. Strength
|
||||
# case 1 no no during input signal, no correct answer available, no answer
|
||||
# case 2 no - during input signal, no correct answer available, wrong answer
|
||||
# case 3 + - during input signal, correct answer available, no answer
|
||||
# case 4 no - during input signal, correct answer available, wrong answer
|
||||
# case 5 no ++ during input signal, correct answer
|
||||
# case 6 no ++ after input signal, at correct timing, correct answer
|
||||
# case 6 + - after input signal, at correct timing, no answer
|
||||
# case 9 no -- after input signal, at correct timing, wrong answer
|
||||
# case 7 adjust + after input signal, after correct timing (late), correct answer
|
||||
# case 8 after input signal, after correct timing (late), no answer
|
||||
# case 8 no - after input signal, after correct timing (late), wrong answer
|
||||
|
||||
# success
|
||||
|
||||
if kfn.learningStage == "start_learning"
|
||||
# reset params here instead of at the end_learning so that neuron's parameter data
|
||||
# don't gets wiped and can be logged for visualization later
|
||||
for n in kfn.neuronsArray
|
||||
# epsilonRec need to be reset because it counting how many each synaptic fires and
|
||||
# use this info to calculate how much synaptic weight should be adjust
|
||||
resetLearningParams!(n)
|
||||
end
|
||||
|
||||
# clear variables
|
||||
kfn.firedNeurons = Vector{Int64}()
|
||||
kfn.firedNeurons_t0 = Vector{Bool}()
|
||||
kfn.firedNeurons_t1 = Vector{Bool}()
|
||||
|
||||
kfn.learningStage = "learning"
|
||||
end
|
||||
|
||||
# compute kfn error
|
||||
outs = [n.z_t1 for n in kfn.outputNeuronsArray]
|
||||
for (i, out) in enumerate(outs)
|
||||
|
||||
@@ -19,7 +19,7 @@ function timestep_forward!(x::passthrough_neuron)
|
||||
x.z_t = x.z_t1
|
||||
end
|
||||
|
||||
function timestep_forward!(x::compute_neuron)
|
||||
function timestep_forward!(x::Union{compute_neuron, output_neuron})
|
||||
x.z_t = x.z_t1
|
||||
x.v_t = x.v_t1
|
||||
end
|
||||
|
||||
13
src/types.jl
13
src/types.jl
@@ -310,9 +310,8 @@ Base.@kwdef mutable struct lif_neuron <: compute_neuron
|
||||
subExInType::Array{Int64} = Vector{Int64}() # store ExIn type of subscribed neurons
|
||||
timeStep::Number = 0.0 # current time
|
||||
wRec::Union{Array{Float64},Nothing} = nothing # synaptic weight (for receiving signal from other neuron)
|
||||
v_t::Float64 = rand() # vᵗ, postsynaptic neuron membrane potential of previous timestep
|
||||
v_t::Float64 = 0.0 # vᵗ, postsynaptic neuron membrane potential of previous timestep
|
||||
v_t1::Float64 = rand() # vᵗ⁺¹, postsynaptic neuron membrane potential at current timestep
|
||||
v_t_default::Union{Float64,Nothing} = 0.0 # default membrane potential voltage
|
||||
v_th::Float64 = 1.0 # vᵗʰ, neuron firing threshold
|
||||
vRest::Float64 = 0.0 # resting potential after neuron fired
|
||||
z_t::Bool = false # zᵗ, neuron postsynaptic firing of previous timestep
|
||||
@@ -403,9 +402,8 @@ Base.@kwdef mutable struct alif_neuron <: compute_neuron
|
||||
subExInType::Array{Int64} = Vector{Int64}() # store ExIn type of subscribed neurons
|
||||
timeStep::Union{Number,Nothing} = nothing # current time
|
||||
wRec::Union{Array{Float64},Nothing} = nothing # synaptic weight (for receiving signal from other neuron)
|
||||
v_t::Float64 = rand() # vᵗ, postsynaptic neuron membrane potential of previous timestep
|
||||
v_t::Float64 = 0.0 # vᵗ, postsynaptic neuron membrane potential of previous timestep
|
||||
v_t1::Float64 = rand() # vᵗ⁺¹, postsynaptic neuron membrane potential at current timestep
|
||||
v_t_default::Union{Float64,Nothing} = 0.0
|
||||
v_th::Float64 = 1.0 # vᵗʰ, neuron firing threshold
|
||||
vRest::Float64 = 0.0 # resting potential after neuron fired
|
||||
z_t::Bool = false # zᵗ, neuron postsynaptic firing of previous timestep
|
||||
@@ -508,14 +506,13 @@ Base.@kwdef mutable struct linear_neuron <: output_neuron
|
||||
knowledgeFnName::Union{String,Nothing} = nothing # knowledgeFn that this neuron belongs to
|
||||
subscriptionList::Union{Array{Int64},Nothing} = nothing # list of other neuron that this neuron synapse subscribed to
|
||||
timeStep::Union{Number,Nothing} = nothing # current time
|
||||
|
||||
subExInType::Array{Int64} = Vector{Int64}() # store ExIn type of subscribed neurons
|
||||
subExInType::Array{Int64} = Vector{Int64}() # store ExIn type of subscribed neurons
|
||||
wRec::Union{Array{Float64},Nothing} = nothing # synaptic weight (for receiving signal from other neuron)
|
||||
v_t::Float64 = 0.0 # vᵗ, postsynaptic neuron membrane potential of previous timestep
|
||||
v_t1::Float64 = rand() # vᵗ⁺¹, postsynaptic neuron membrane potential at current timestep
|
||||
v_t_default::Union{Float64,Nothing} = 0.0 # default membrane potential voltage
|
||||
v_t1::Float64 = rand() # vᵗ⁺¹, postsynaptic neuron membrane potential at current timestep
|
||||
v_th::Float64 = 1.0 # vᵗʰ, neuron firing threshold
|
||||
vRest::Float64 = 0.0 # resting potential after neuron fired
|
||||
z_t::Bool = false # zᵗ, neuron postsynaptic firing of previous timestep
|
||||
# zᵗ⁺¹, neuron firing status at time = t+1. I need this because the way I calculate all
|
||||
# neurons forward function at each timestep-by-timestep is to do every neuron
|
||||
# forward calculation. Each neuron requires access to other neuron's firing status
|
||||
|
||||
Reference in New Issue
Block a user