refractoring
This commit is contained in:
304
src/learn.jl
Normal file
304
src/learn.jl
Normal file
@@ -0,0 +1,304 @@
|
||||
module learn
|
||||
|
||||
using Flux.Optimise: apply!
|
||||
|
||||
using Statistics, Flux, Random, LinearAlgebra
|
||||
using GeneralUtils
|
||||
using ..types
|
||||
|
||||
export learn!
|
||||
|
||||
#------------------------------------------------------------------------------------------------100
|
||||
|
||||
function learn!(m::model, model_respond, correct_answer)
|
||||
if m.learning_stage == "learning"
|
||||
model_error = Flux.logitcrossentropy(model_respond, correct_answer)
|
||||
output_elements_error = model_respond - correct_answer
|
||||
|
||||
learn!(m.knowledgeFn[:I], model_error, output_elements_error)
|
||||
|
||||
#WORKING compute error
|
||||
# if m.time_stamp < m.m
|
||||
model_error = model_respond .- correct_answer
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
else
|
||||
model_error = nothing
|
||||
end
|
||||
|
||||
return model_error
|
||||
end
|
||||
|
||||
|
||||
function learn!(m::model, raw_model_respond, correct_answer=nothing)
|
||||
if m.learning_stage != "doing_inference"
|
||||
model_error = Flux.logitcrossentropy(raw_model_respond, correct_answer)
|
||||
output_elements_error = raw_model_respond - correct_answer
|
||||
|
||||
learn!(m.knowledgeFn[:I], model_error, output_elements_error)
|
||||
else
|
||||
model_error = nothing
|
||||
end
|
||||
|
||||
return model_error
|
||||
end
|
||||
|
||||
|
||||
|
||||
|
||||
""" knowledgeFn learn()
|
||||
"""
|
||||
function learn!(kfn::knowledgeFn, error::Union{Float64,Nothing}=nothing,
|
||||
output_error::Union{Vector,Nothing}=nothing)
|
||||
kfn.error = error
|
||||
kfn.output_error = output_error
|
||||
|
||||
# Threads.@threads for n in kfn.neurons_array
|
||||
for n in kfn.neurons_array
|
||||
learn!(n, kfn) # Neurons are always learning, besides error from model output
|
||||
end
|
||||
|
||||
if kfn.output_error !== nothing
|
||||
# Threads.@threads for n in kfn.output_neurons_array
|
||||
for n in kfn.output_neurons_array # not use multithreading because 1st output neuron
|
||||
# will set learning rate that will be used by
|
||||
# other output neurons
|
||||
learn!(n, kfn)
|
||||
end
|
||||
#TODO: put other KFN to learn here
|
||||
|
||||
# for main loop user's display and training's exit condition
|
||||
avg_neurons_firing_rate = 0.0
|
||||
for n in kfn.neurons_array
|
||||
if typeof(n) <: compute_neuron
|
||||
avg_neurons_firing_rate += n.firing_rate
|
||||
end
|
||||
end
|
||||
kfn.avg_neurons_firing_rate = avg_neurons_firing_rate /
|
||||
kfn.kfn_params[:compute_neuron_number]
|
||||
avg_neurons_v_t1 = 0.0
|
||||
for n in kfn.neurons_array
|
||||
if typeof(n) <: compute_neuron
|
||||
avg_neurons_v_t1 += n.v_t1
|
||||
end
|
||||
end
|
||||
kfn.avg_neurons_v_t1 = avg_neurons_v_t1 / kfn.kfn_params[:compute_neuron_number]
|
||||
end
|
||||
end
|
||||
|
||||
""" passthrough_neuron learn()
|
||||
"""
|
||||
function learn!(n::passthrough_neuron, kfn::knowledgeFn)
|
||||
# skip
|
||||
end
|
||||
|
||||
""" lif learn()
|
||||
"""
|
||||
function learn!(n::lif_neuron, kfn::knowledgeFn)
|
||||
if n.learnable_flag == true
|
||||
|
||||
n.decayed_epsilon_rec = n.alpha * n.epsilon_rec
|
||||
n.epsilon_rec = n.decayed_epsilon_rec + n.z_i_t
|
||||
n.e_rec = n.phi * n.epsilon_rec
|
||||
end
|
||||
|
||||
# a piece of knowledgeFn error that belongs to this neuron
|
||||
n.error = isnothing(kfn.error) ? nothing : kfn.error * n.Bn
|
||||
n.learning_stage = kfn.learning_stage
|
||||
|
||||
# accumulate voltage regularization terms
|
||||
Snn_utils.cal_v_reg!(n)
|
||||
|
||||
if n.learning_stage == "doing_inference"
|
||||
# no learning
|
||||
elseif n.learning_stage == "start_learning" ||
|
||||
n.learning_stage == "start_learning_no_wchange_reset"
|
||||
|
||||
# if error signal available then accumulates Δw
|
||||
if n.error !== nothing
|
||||
Snn_utils.firing_rate!(n)
|
||||
Snn_utils.firing_diff!(n)
|
||||
n.w_rec_change = n.w_rec_change +
|
||||
-apply!(n.optimiser, n.w_rec,
|
||||
(n.error + Snn_utils.voltage_error!(n) + n.firing_rate_error) * n.e_rec) +
|
||||
-Snn_utils.firing_rate_regulator!(n) +
|
||||
-Snn_utils.voltage_regulator!(n)
|
||||
end
|
||||
elseif n.learning_stage == "during_learning"
|
||||
# if error signal available then accumulates Δw
|
||||
if n.error !== nothing
|
||||
Snn_utils.firing_rate!(n)
|
||||
Snn_utils.firing_diff!(n)
|
||||
n.w_rec_change = n.w_rec_change +
|
||||
-apply!(n.optimiser, n.w_rec,
|
||||
(n.error + Snn_utils.voltage_error!(n) + n.firing_rate_error) * n.e_rec) +
|
||||
-Snn_utils.firing_rate_regulator!(n) +
|
||||
-Snn_utils.voltage_regulator!(n)
|
||||
end
|
||||
elseif n.learning_stage == "end_learning"
|
||||
# if error signal available then accumulates Δw
|
||||
if n.error !== nothing
|
||||
Snn_utils.firing_rate!(n)
|
||||
Snn_utils.firing_diff!(n)
|
||||
n.w_rec_change = n.w_rec_change +
|
||||
-apply!(n.optimiser, n.w_rec,
|
||||
(n.error + Snn_utils.voltage_error!(n) + n.firing_rate_error) * n.e_rec) +
|
||||
-Snn_utils.firing_rate_regulator!(n) +
|
||||
-Snn_utils.voltage_regulator!(n)
|
||||
end
|
||||
|
||||
not_zero = (!iszero).(n.w_rec)
|
||||
# set 0 in w_rec_change update according to 0 in w_rec for hard constrain connection
|
||||
n.w_rec = n.w_rec + (not_zero .* n.w_rec_change)
|
||||
replace!(x -> x < 0 ? 0 : x, n.w_rec) # no negative weight
|
||||
|
||||
Snn_utils.neuroplasticity!(n, kfn.firing_neurons_list)
|
||||
end
|
||||
end
|
||||
|
||||
""" alif_neuron learn()
|
||||
"""
|
||||
function learn!(n::alif_neuron, kfn::knowledgeFn)
|
||||
n.decayed_epsilon_rec = n.alpha * n.epsilon_rec
|
||||
n.epsilon_rec = n.decayed_epsilon_rec + n.z_i_t
|
||||
n.epsilon_rec_a = (n.phi * n.epsilon_rec) +
|
||||
((n.rho - (n.phi * n.beta)) * n.epsilon_rec_a)
|
||||
n.e_rec_v = n.phi * n.epsilon_rec
|
||||
n.e_rec_a = -n.phi * n.beta * n.epsilon_rec_a
|
||||
n.e_rec = n.e_rec_v + n.e_rec_a
|
||||
|
||||
# a piece of knowledgeFn error that belongs to this neuron
|
||||
n.error = isnothing(kfn.error) ? nothing : kfn.error * n.Bn
|
||||
n.learning_stage = kfn.learning_stage
|
||||
|
||||
|
||||
|
||||
if n.learning_stage == "doing_inference"
|
||||
# no learning
|
||||
elseif n.learning_stage == "start_learning" ||
|
||||
n.learning_stage == "start_learning_no_wchange_reset"
|
||||
|
||||
# if error signal available then accumulates Δw
|
||||
if n.error !== nothing
|
||||
Snn_utils.firing_rate!(n)
|
||||
Snn_utils.firing_diff!(n)
|
||||
n.w_rec_change = n.w_rec_change +
|
||||
-apply!(n.optimiser, n.w_rec,
|
||||
(n.error + Snn_utils.voltage_error!(n) + n.firing_rate_error) * n.e_rec) +
|
||||
-Snn_utils.firing_rate_regulator!(n) +
|
||||
-Snn_utils.voltage_regulator!(n)
|
||||
end
|
||||
elseif n.learning_stage == "during_learning"
|
||||
# if error signal available then accumulates Δw
|
||||
if n.error !== nothing
|
||||
Snn_utils.firing_rate!(n)
|
||||
Snn_utils.firing_diff!(n)
|
||||
n.w_rec_change = n.w_rec_change +
|
||||
-apply!(n.optimiser, n.w_rec,
|
||||
(n.error + Snn_utils.voltage_error!(n) + n.firing_rate_error) * n.e_rec) +
|
||||
-Snn_utils.firing_rate_regulator!(n) +
|
||||
-Snn_utils.voltage_regulator!(n)
|
||||
end
|
||||
elseif n.learning_stage == "end_learning"
|
||||
|
||||
# if error signal available then accumulates Δw
|
||||
if n.error !== nothing
|
||||
Snn_utils.firing_rate!(n)
|
||||
Snn_utils.firing_diff!(n)
|
||||
n.w_rec_change = n.w_rec_change +
|
||||
-apply!(n.optimiser, n.w_rec,
|
||||
(n.error + Snn_utils.voltage_error!(n) + n.firing_rate_error) * n.e_rec) +
|
||||
-Snn_utils.firing_rate_regulator!(n) +
|
||||
-Snn_utils.voltage_regulator!(n)
|
||||
end
|
||||
|
||||
not_zero = (!iszero).(n.w_rec)
|
||||
# set 0 in w_rec_change update according to 0 in w_rec for hard constrain connection
|
||||
n.w_rec = n.w_rec + (not_zero .* n.w_rec_change)
|
||||
replace!(x -> x < 0 ? 0 : x, n.w_rec) # no negative weight
|
||||
|
||||
Snn_utils.neuroplasticity!(n, kfn.firing_neurons_list)
|
||||
end
|
||||
end
|
||||
|
||||
""" linear_neuron learn()
|
||||
"""
|
||||
function learn!(n::linear_neuron, kfn::knowledgeFn)
|
||||
n.error = kfn.output_error[n.id]
|
||||
n.learning_stage = kfn.learning_stage
|
||||
|
||||
if n.learning_stage == "doing_inference"
|
||||
# no learning
|
||||
elseif n.learning_stage == "start_learning"
|
||||
# if error signal available then accumulates Δw
|
||||
if n.error !== nothing && n.id == 1 # NOT working w/ multithreading training
|
||||
Δw = -apply!(n.optimiser, n.w_out, (n.error * n.epsilon_j))
|
||||
n.w_out_change = n.w_out_change + Δw
|
||||
n.eta = n.optimiser.eta
|
||||
Δb = -n.eta * n.error
|
||||
n.b_change = n.b_change + Δb
|
||||
elseif n.error !== nothing && n.id !== 1
|
||||
n.eta = kfn.output_neurons_array[1].eta
|
||||
Δw = -n.eta * n.error * n.epsilon_j
|
||||
n.w_out_change = n.w_out_change + Δw
|
||||
Δb = -n.eta * n.error
|
||||
n.b_change = n.b_change + Δb
|
||||
end
|
||||
elseif n.learning_stage == "during_learning"
|
||||
# if error signal available then accumulates Δw
|
||||
if n.error !== nothing && n.id == 1 # NOT working w/ multithreading training
|
||||
Δw = -apply!(n.optimiser, n.w_out, (n.error * n.epsilon_j))
|
||||
n.w_out_change = n.w_out_change + Δw
|
||||
n.eta = n.optimiser.eta
|
||||
Δb = -n.eta * n.error
|
||||
n.b_change = n.b_change + Δb
|
||||
elseif n.error !== nothing && n.id !== 1
|
||||
n.eta = kfn.output_neurons_array[1].eta
|
||||
Δw = -n.eta * n.error * n.epsilon_j
|
||||
n.w_out_change = n.w_out_change + Δw
|
||||
Δb = -n.eta * n.error
|
||||
n.b_change = n.b_change + Δb
|
||||
end
|
||||
elseif n.learning_stage == "end_learning"
|
||||
# if error signal available then accumulates Δw
|
||||
if n.error !== nothing && n.id == 1 # NOT working w/ multithreading training
|
||||
Δw = -apply!(n.optimiser, n.w_out, (n.error * n.epsilon_j))
|
||||
n.w_out_change = n.w_out_change + Δw
|
||||
n.eta = n.optimiser.eta
|
||||
Δb = -n.eta * n.error
|
||||
n.b_change = n.b_change + Δb
|
||||
elseif n.error !== nothing && n.id !== 1
|
||||
n.eta = kfn.output_neurons_array[1].eta
|
||||
Δw = -n.eta * n.error * n.epsilon_j
|
||||
n.w_out_change = n.w_out_change + Δw
|
||||
Δb = -n.eta * n.error
|
||||
n.b_change = n.b_change + Δb
|
||||
end
|
||||
|
||||
n.w_out = n.w_out + n.w_out_change
|
||||
n.b = n.b + n.b_change
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
end # module end
|
||||
Reference in New Issue
Block a user