implement start learning
This commit is contained in:
1
src/.vscode/settings.json
vendored
Normal file
1
src/.vscode/settings.json
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
{}
|
||||||
@@ -34,13 +34,15 @@ using .interface
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
Todo:
|
Todo:
|
||||||
[*3] implement "start learning", reset learning and "during_learning", "end_learning and
|
[7] time-based learning method based on new error formula
|
||||||
"inference"
|
if output neuron not activate when it should, use output neuron's
|
||||||
[4] output neuron connect to random multiple compute neurons
|
(vth - vt)*100/vth as error
|
||||||
[7] add time-based learning method.
|
if output neuron activates when it should NOT, use output neuron's
|
||||||
[] implement "thinking period"
|
(vt*100)/vth as error
|
||||||
|
[*4] output neuron connect to random multiple compute neurons and have the same structure
|
||||||
|
as lif
|
||||||
[8] verify that model can complete learning cycle with no error
|
[8] verify that model can complete learning cycle with no error
|
||||||
[5] synaptic connection strength concept
|
[5] synaptic connection strength concept. use sigmoid
|
||||||
[6] neuroplasticity() i.e. change connection
|
[6] neuroplasticity() i.e. change connection
|
||||||
[] using RL to control learning signal
|
[] using RL to control learning signal
|
||||||
[] consider using Dates.now() instead of timestamp because time_stamp may overflow
|
[] consider using Dates.now() instead of timestamp because time_stamp may overflow
|
||||||
@@ -50,6 +52,8 @@ using .interface
|
|||||||
[DONE] each knowledgeFn should have its own noise generater
|
[DONE] each knowledgeFn should have its own noise generater
|
||||||
[DONE] where to put pseudo derivative (n.phi)
|
[DONE] where to put pseudo derivative (n.phi)
|
||||||
[DONE] add excitatory, inhabitory to neuron
|
[DONE] add excitatory, inhabitory to neuron
|
||||||
|
[DONE] implement "start learning", reset learning and "learning", "end_learning and
|
||||||
|
"inference"
|
||||||
|
|
||||||
Change from version: v06_36a
|
Change from version: v06_36a
|
||||||
-
|
-
|
||||||
|
|||||||
81
src/learn.jl
81
src/learn.jl
@@ -10,70 +10,37 @@ export learn!
|
|||||||
|
|
||||||
#------------------------------------------------------------------------------------------------100
|
#------------------------------------------------------------------------------------------------100
|
||||||
|
|
||||||
function learn!(m::model, modelRespond, correctAnswer=nothing, correctTiming=nothing)
|
function learn!(m::model, modelRespond, correctAnswer=nothing)
|
||||||
|
m.knowledgeFn[:I].learningStage = m.learningStage
|
||||||
|
# ΔWeight Conn. Strength
|
||||||
|
# case 1 no no during input signal, no correct answer available, no answer
|
||||||
|
# case 2 no - during input signal, no correct answer available, wrong answer
|
||||||
|
# case 3 + - during input signal, correct answer available, no answer
|
||||||
|
# case 4 no - during input signal, correct answer available, wrong answer
|
||||||
|
# case 5 no ++ during input signal, correct answer
|
||||||
|
# case 6 no ++ after input signal, at correct timing, correct answer
|
||||||
|
# case 6 + - after input signal, at correct timing, no answer
|
||||||
|
# case 9 no -- after input signal, at correct timing, wrong answer
|
||||||
|
# case 7 adjust + after input signal, after correct timing (late), correct answer
|
||||||
|
# case 8 after input signal, after correct timing (late), no answer
|
||||||
|
# case 8 no - after input signal, after correct timing (late), wrong answer
|
||||||
|
|
||||||
# set all KFN
|
# success
|
||||||
if m.learningStage == "start_learning"
|
|
||||||
m.knowledgeFn[:I].learningStage = "start_learning"
|
|
||||||
elseif m.learningStage == "end_learning"
|
|
||||||
m.knowledgeFn[:I].learningStage = "end_learning"
|
|
||||||
else
|
|
||||||
end
|
|
||||||
|
|
||||||
#WORKING compute error
|
# how many matched respond and correct answer
|
||||||
# timingError =
|
matched = sum(isequal(modelRespond, correctAnswer))
|
||||||
|
|
||||||
|
correctAnswer_I = correctAnswer # correct answer for kfn I
|
||||||
|
learn!(m.knowledgeFn[:I], correctAnswer_I)
|
||||||
|
|
||||||
too_early = m.modelParams[:perfect_timing] - m.timeStep
|
# return model_error
|
||||||
model_error = (model_respond .- correct_answer) * too_early
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
model_error = Flux.logitcrossentropy(model_respond, correct_answer)
|
|
||||||
output_elements_error = model_respond - correct_answer
|
|
||||||
|
|
||||||
learn!(m.knowledgeFn[:I], model_error, output_elements_error)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
return model_error
|
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
# function learn!(m::model, raw_model_respond, correct_answer=nothing)
|
|
||||||
# if m.learningStage != "doing_inference"
|
|
||||||
# model_error = Flux.logitcrossentropy(raw_model_respond, correct_answer)
|
|
||||||
# output_elements_error = raw_model_respond - correct_answer
|
|
||||||
|
|
||||||
# learn!(m.knowledgeFn[:I], model_error, output_elements_error)
|
|
||||||
# else
|
|
||||||
# model_error = nothing
|
|
||||||
# end
|
|
||||||
|
|
||||||
# return model_error
|
|
||||||
# end
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
""" knowledgeFn learn()
|
""" knowledgeFn learn()
|
||||||
"""
|
"""
|
||||||
function learn!(kfn::knowledgeFn, error::Union{Float64,Nothing}=nothing,
|
function learn!(kfn::kfn_1, correctAnswer=nothing)
|
||||||
outputError::Union{Vector,Nothing}=nothing)
|
if kfn.learningStage == "start_learning"
|
||||||
kfn.error = error
|
|
||||||
kfn.outputError = outputError
|
|
||||||
|
|
||||||
kfn.learningStage = m.learningStage
|
|
||||||
if m.learningStage == "start_learning"
|
|
||||||
# reset params here instead of at the end_learning so that neuron's parameter data
|
# reset params here instead of at the end_learning so that neuron's parameter data
|
||||||
# don't gets wiped and can be logged for visualization later
|
# don't gets wiped and can be logged for visualization later
|
||||||
for n in kfn.neuronsArray
|
for n in kfn.neuronsArray
|
||||||
@@ -85,6 +52,10 @@ function learn!(kfn::knowledgeFn, error::Union{Float64,Nothing}=nothing,
|
|||||||
# clear variables
|
# clear variables
|
||||||
kfn.firedNeurons = Vector{Int64}()
|
kfn.firedNeurons = Vector{Int64}()
|
||||||
kfn.outputs = nothing
|
kfn.outputs = nothing
|
||||||
|
|
||||||
|
kfn.learningStage = "learning"
|
||||||
|
elseif kfn.learningStage = "end_learning"
|
||||||
|
kfn.learningStage = "inference"
|
||||||
end
|
end
|
||||||
|
|
||||||
# Threads.@threads for n in kfn.neuronsArray
|
# Threads.@threads for n in kfn.neuronsArray
|
||||||
|
|||||||
42
src/types.jl
42
src/types.jl
@@ -106,7 +106,6 @@ Base.@kwdef mutable struct kfn_1 <: knowledgeFn
|
|||||||
learningStage::String = "inference"
|
learningStage::String = "inference"
|
||||||
|
|
||||||
error::Union{Float64,Nothing} = nothing
|
error::Union{Float64,Nothing} = nothing
|
||||||
outputError::Union{Array,Nothing} = Vector{AbstractFloat}()
|
|
||||||
softreset::Bool = false
|
softreset::Bool = false
|
||||||
|
|
||||||
firedNeurons::Array{Int64} = Vector{Int64}() # store unique id of firing neurons to be used when random neuron connection
|
firedNeurons::Array{Int64} = Vector{Int64}() # store unique id of firing neurons to be used when random neuron connection
|
||||||
@@ -331,7 +330,7 @@ Base.@kwdef mutable struct lif_neuron <: compute_neuron
|
|||||||
decayedEpsilonRec::Union{Array{Float64},Nothing} = nothing # α * epsilonRec
|
decayedEpsilonRec::Union{Array{Float64},Nothing} = nothing # α * epsilonRec
|
||||||
eRec::Union{Array{Float64},Nothing} = nothing # eligibility trace for neuron spike
|
eRec::Union{Array{Float64},Nothing} = nothing # eligibility trace for neuron spike
|
||||||
delta::Union{Float64,Nothing} = 1.0 # δ, discreate timestep size in millisecond
|
delta::Union{Float64,Nothing} = 1.0 # δ, discreate timestep size in millisecond
|
||||||
lastFiringTime::Union{Float64,Nothing} = 0.0 # the last time neuron fires
|
lastFiringTime::Union{Float64,Nothing} = 0.0 # the last time neuron fires, use to calculate exponantial decay of v_t1
|
||||||
refractoryDuration::Union{Float64,Nothing} = 3 # neuron's refratory period in millisecond
|
refractoryDuration::Union{Float64,Nothing} = 3 # neuron's refratory period in millisecond
|
||||||
# refractory_state_active::Union{Bool,Nothing} = false # if true, neuron is in refractory state and cannot process new information
|
# refractory_state_active::Union{Bool,Nothing} = false # if true, neuron is in refractory state and cannot process new information
|
||||||
refractoryCounter::Integer = 0
|
refractoryCounter::Integer = 0
|
||||||
@@ -340,7 +339,6 @@ Base.@kwdef mutable struct lif_neuron <: compute_neuron
|
|||||||
wRecChange::Union{Array{Float64},Nothing} = nothing # Δw_rec, cumulated w_rec change
|
wRecChange::Union{Array{Float64},Nothing} = nothing # Δw_rec, cumulated w_rec change
|
||||||
recSignal::Union{Float64,Nothing} = nothing # incoming recurrent signal
|
recSignal::Union{Float64,Nothing} = nothing # incoming recurrent signal
|
||||||
alpha_v_t::Union{Float64,Nothing} = nothing # alpha * v_t
|
alpha_v_t::Union{Float64,Nothing} = nothing # alpha * v_t
|
||||||
voltageDropPercentage::Union{Float64,Nothing} = 1.0 # voltage drop as a percentage of v_th
|
|
||||||
error::Union{Float64,Nothing} = nothing # local neuron error
|
error::Union{Float64,Nothing} = nothing # local neuron error
|
||||||
optimiser::Union{Any,Nothing} = load_optimiser("AdaBelief") # Flux optimizer
|
optimiser::Union{Any,Nothing} = load_optimiser("AdaBelief") # Flux optimizer
|
||||||
|
|
||||||
@@ -428,7 +426,7 @@ Base.@kwdef mutable struct alif_neuron <: compute_neuron
|
|||||||
eRec::Union{Array{Float64},Nothing} = nothing # neuron's eligibility trace
|
eRec::Union{Array{Float64},Nothing} = nothing # neuron's eligibility trace
|
||||||
eta::Union{Float64,Nothing} = 0.01 # eta, learning rate
|
eta::Union{Float64,Nothing} = 0.01 # eta, learning rate
|
||||||
gammaPd::Union{Float64,Nothing} = 0.3 # γ_pd, discount factor, value from paper
|
gammaPd::Union{Float64,Nothing} = 0.3 # γ_pd, discount factor, value from paper
|
||||||
lastFiringTime::Union{Float64,Nothing} = 0.0 # the last time neuron fires
|
lastFiringTime::Union{Float64,Nothing} = 0.0 # the last time neuron fires, use to calculate exponantial decay of v_t1
|
||||||
phi::Union{Float64,Nothing} = nothing # ϕ, psuedo derivative
|
phi::Union{Float64,Nothing} = nothing # ϕ, psuedo derivative
|
||||||
refractoryDuration::Union{Float64,Nothing} = 3 # neuron's refractory period in millisecond
|
refractoryDuration::Union{Float64,Nothing} = 3 # neuron's refractory period in millisecond
|
||||||
# refractory_state_active::Union{Bool,Nothing} = false # if true, neuron is in refractory state and cannot process new information
|
# refractory_state_active::Union{Bool,Nothing} = false # if true, neuron is in refractory state and cannot process new information
|
||||||
@@ -437,7 +435,6 @@ Base.@kwdef mutable struct alif_neuron <: compute_neuron
|
|||||||
wRecChange::Union{Array{Float64},Nothing} = nothing # Δw_rec, cumulated w_rec change
|
wRecChange::Union{Array{Float64},Nothing} = nothing # Δw_rec, cumulated w_rec change
|
||||||
recSignal::Union{Float64,Nothing} = nothing # incoming recurrent signal
|
recSignal::Union{Float64,Nothing} = nothing # incoming recurrent signal
|
||||||
alpha_v_t::Union{Float64,Nothing} = nothing # alpha * v_t
|
alpha_v_t::Union{Float64,Nothing} = nothing # alpha * v_t
|
||||||
voltageDropPercentage::Union{Float64,Nothing} = 1.0 # voltage drop as a percentage of v_th
|
|
||||||
error::Union{Float64,Nothing} = nothing # local neuron error
|
error::Union{Float64,Nothing} = nothing # local neuron error
|
||||||
optimiser::Union{Any,Nothing} = load_optimiser("AdaBelief") # Flux optimizer
|
optimiser::Union{Any,Nothing} = load_optimiser("AdaBelief") # Flux optimizer
|
||||||
|
|
||||||
@@ -510,9 +507,42 @@ Base.@kwdef mutable struct linear_neuron <: output_neuron
|
|||||||
knowledgeFnName::Union{String,Nothing} = nothing # knowledgeFn that this neuron belongs to
|
knowledgeFnName::Union{String,Nothing} = nothing # knowledgeFn that this neuron belongs to
|
||||||
subscriptionList::Union{Array{Int64},Nothing} = nothing # list of other neuron that this neuron synapse subscribed to
|
subscriptionList::Union{Array{Int64},Nothing} = nothing # list of other neuron that this neuron synapse subscribed to
|
||||||
timeStep::Union{Number,Nothing} = nothing # current time
|
timeStep::Union{Number,Nothing} = nothing # current time
|
||||||
delta::Union{Float64,Nothing} = 1.0 # δ, discreate timestep size in millisecond
|
|
||||||
out_t::Bool = false # output of linear neuron BEFORE forward()
|
out_t::Bool = false # output of linear neuron BEFORE forward()
|
||||||
out_t1::Bool = false # output of linear neuron AFTER forward()
|
out_t1::Bool = false # output of linear neuron AFTER forward()
|
||||||
|
#WORKING
|
||||||
|
subExInType::Array{Int64} = Vector{Int64}() # store ExIn type of subscribed neurons
|
||||||
|
w_rec::Union{Array{Float64},Nothing} = nothing # synaptic weight (for receiving signal from other neuron)
|
||||||
|
v_t::Float64 = 0.0 # vᵗ, postsynaptic neuron membrane potential of previous timestep
|
||||||
|
v_t1::Float64 = 0.0 # vᵗ⁺¹, postsynaptic neuron membrane potential at current timestep
|
||||||
|
v_t_default::Union{Float64,Nothing} = 0.0 # default membrane potential voltage
|
||||||
|
v_th::Float64 = 1.0 # vᵗʰ, neuron firing threshold
|
||||||
|
vRest::Float64 = 0.0 # resting potential after neuron fired
|
||||||
|
# zᵗ⁺¹, neuron firing status at time = t+1. I need this because the way I calculate all
|
||||||
|
# neurons forward function at each timestep-by-timestep is to do every neuron
|
||||||
|
# forward calculation. Each neuron requires access to other neuron's firing status
|
||||||
|
# during v_t1 calculation hence I need a variable to hold z_t1 so that I'm not replacing z_t
|
||||||
|
z_t1::Bool = false # neuron postsynaptic firing at current timestep (after neuron's calculation)
|
||||||
|
|
||||||
|
# neuron presynaptic firing at current timestep (which is other neuron postsynaptic firing of
|
||||||
|
# previous timestep)
|
||||||
|
z_i_t::Union{Array{Bool},Nothing} = nothing
|
||||||
|
|
||||||
|
gammaPd::Union{Float64,Nothing} = 0.3 # γ_pd, discount factor, value from paper
|
||||||
|
alpha::Union{Float64,Nothing} = nothing # α, neuron membrane potential decay factor
|
||||||
|
phi::Union{Float64,Nothing} = nothing # ϕ, psuedo derivative
|
||||||
|
epsilonRec::Union{Array{Float64},Nothing} = nothing # ϵ_rec, eligibility vector for neuron spike
|
||||||
|
decayedEpsilonRec::Union{Array{Float64},Nothing} = nothing # α * epsilonRec
|
||||||
|
eRec::Union{Array{Float64},Nothing} = nothing # eligibility trace for neuron spike
|
||||||
|
delta::Union{Float64,Nothing} = 1.0 # δ, discreate timestep size in millisecond
|
||||||
|
lastFiringTime::Union{Float64,Nothing} = 0.0 # the last time neuron fires, use to calculate exponantial decay of v_t1
|
||||||
|
refractoryDuration::Union{Float64,Nothing} = 3 # neuron's refratory period in millisecond
|
||||||
|
refractoryCounter::Integer = 0
|
||||||
|
tau_m::Union{Float64,Nothing} = nothing # τ_m, membrane time constant in millisecond
|
||||||
|
eta::Union{Float64,Nothing} = 0.01 # η, learning rate
|
||||||
|
wRecChange::Union{Array{Float64},Nothing} = nothing # Δw_rec, cumulated w_rec change
|
||||||
|
recSignal::Union{Float64,Nothing} = nothing # incoming recurrent signal
|
||||||
|
alpha_v_t::Union{Float64,Nothing} = nothing # alpha * v_t
|
||||||
|
error::Union{Float64,Nothing} = nothing # local neuron error
|
||||||
end
|
end
|
||||||
|
|
||||||
""" linear neuron outer constructor
|
""" linear neuron outer constructor
|
||||||
|
|||||||
Reference in New Issue
Block a user