refractoring

This commit is contained in:
2023-05-10 20:38:23 +07:00
commit 7c4a0dfa6f
15 changed files with 3195 additions and 0 deletions

304
src/learn.jl Normal file
View File

@@ -0,0 +1,304 @@
module learn
using Flux.Optimise: apply!
using Statistics, Flux, Random, LinearAlgebra
using GeneralUtils
using ..types
export learn!
#------------------------------------------------------------------------------------------------100
function learn!(m::model, model_respond, correct_answer)
if m.learning_stage == "learning"
model_error = Flux.logitcrossentropy(model_respond, correct_answer)
output_elements_error = model_respond - correct_answer
learn!(m.knowledgeFn[:I], model_error, output_elements_error)
#WORKING compute error
# if m.time_stamp < m.m
model_error = model_respond .- correct_answer
else
model_error = nothing
end
return model_error
end
function learn!(m::model, raw_model_respond, correct_answer=nothing)
if m.learning_stage != "doing_inference"
model_error = Flux.logitcrossentropy(raw_model_respond, correct_answer)
output_elements_error = raw_model_respond - correct_answer
learn!(m.knowledgeFn[:I], model_error, output_elements_error)
else
model_error = nothing
end
return model_error
end
""" knowledgeFn learn()
"""
function learn!(kfn::knowledgeFn, error::Union{Float64,Nothing}=nothing,
output_error::Union{Vector,Nothing}=nothing)
kfn.error = error
kfn.output_error = output_error
# Threads.@threads for n in kfn.neurons_array
for n in kfn.neurons_array
learn!(n, kfn) # Neurons are always learning, besides error from model output
end
if kfn.output_error !== nothing
# Threads.@threads for n in kfn.output_neurons_array
for n in kfn.output_neurons_array # not use multithreading because 1st output neuron
# will set learning rate that will be used by
# other output neurons
learn!(n, kfn)
end
#TODO: put other KFN to learn here
# for main loop user's display and training's exit condition
avg_neurons_firing_rate = 0.0
for n in kfn.neurons_array
if typeof(n) <: compute_neuron
avg_neurons_firing_rate += n.firing_rate
end
end
kfn.avg_neurons_firing_rate = avg_neurons_firing_rate /
kfn.kfn_params[:compute_neuron_number]
avg_neurons_v_t1 = 0.0
for n in kfn.neurons_array
if typeof(n) <: compute_neuron
avg_neurons_v_t1 += n.v_t1
end
end
kfn.avg_neurons_v_t1 = avg_neurons_v_t1 / kfn.kfn_params[:compute_neuron_number]
end
end
""" passthrough_neuron learn()
"""
function learn!(n::passthrough_neuron, kfn::knowledgeFn)
# skip
end
""" lif learn()
"""
function learn!(n::lif_neuron, kfn::knowledgeFn)
if n.learnable_flag == true
n.decayed_epsilon_rec = n.alpha * n.epsilon_rec
n.epsilon_rec = n.decayed_epsilon_rec + n.z_i_t
n.e_rec = n.phi * n.epsilon_rec
end
# a piece of knowledgeFn error that belongs to this neuron
n.error = isnothing(kfn.error) ? nothing : kfn.error * n.Bn
n.learning_stage = kfn.learning_stage
# accumulate voltage regularization terms
Snn_utils.cal_v_reg!(n)
if n.learning_stage == "doing_inference"
# no learning
elseif n.learning_stage == "start_learning" ||
n.learning_stage == "start_learning_no_wchange_reset"
# if error signal available then accumulates Δw
if n.error !== nothing
Snn_utils.firing_rate!(n)
Snn_utils.firing_diff!(n)
n.w_rec_change = n.w_rec_change +
-apply!(n.optimiser, n.w_rec,
(n.error + Snn_utils.voltage_error!(n) + n.firing_rate_error) * n.e_rec) +
-Snn_utils.firing_rate_regulator!(n) +
-Snn_utils.voltage_regulator!(n)
end
elseif n.learning_stage == "during_learning"
# if error signal available then accumulates Δw
if n.error !== nothing
Snn_utils.firing_rate!(n)
Snn_utils.firing_diff!(n)
n.w_rec_change = n.w_rec_change +
-apply!(n.optimiser, n.w_rec,
(n.error + Snn_utils.voltage_error!(n) + n.firing_rate_error) * n.e_rec) +
-Snn_utils.firing_rate_regulator!(n) +
-Snn_utils.voltage_regulator!(n)
end
elseif n.learning_stage == "end_learning"
# if error signal available then accumulates Δw
if n.error !== nothing
Snn_utils.firing_rate!(n)
Snn_utils.firing_diff!(n)
n.w_rec_change = n.w_rec_change +
-apply!(n.optimiser, n.w_rec,
(n.error + Snn_utils.voltage_error!(n) + n.firing_rate_error) * n.e_rec) +
-Snn_utils.firing_rate_regulator!(n) +
-Snn_utils.voltage_regulator!(n)
end
not_zero = (!iszero).(n.w_rec)
# set 0 in w_rec_change update according to 0 in w_rec for hard constrain connection
n.w_rec = n.w_rec + (not_zero .* n.w_rec_change)
replace!(x -> x < 0 ? 0 : x, n.w_rec) # no negative weight
Snn_utils.neuroplasticity!(n, kfn.firing_neurons_list)
end
end
""" alif_neuron learn()
"""
function learn!(n::alif_neuron, kfn::knowledgeFn)
n.decayed_epsilon_rec = n.alpha * n.epsilon_rec
n.epsilon_rec = n.decayed_epsilon_rec + n.z_i_t
n.epsilon_rec_a = (n.phi * n.epsilon_rec) +
((n.rho - (n.phi * n.beta)) * n.epsilon_rec_a)
n.e_rec_v = n.phi * n.epsilon_rec
n.e_rec_a = -n.phi * n.beta * n.epsilon_rec_a
n.e_rec = n.e_rec_v + n.e_rec_a
# a piece of knowledgeFn error that belongs to this neuron
n.error = isnothing(kfn.error) ? nothing : kfn.error * n.Bn
n.learning_stage = kfn.learning_stage
if n.learning_stage == "doing_inference"
# no learning
elseif n.learning_stage == "start_learning" ||
n.learning_stage == "start_learning_no_wchange_reset"
# if error signal available then accumulates Δw
if n.error !== nothing
Snn_utils.firing_rate!(n)
Snn_utils.firing_diff!(n)
n.w_rec_change = n.w_rec_change +
-apply!(n.optimiser, n.w_rec,
(n.error + Snn_utils.voltage_error!(n) + n.firing_rate_error) * n.e_rec) +
-Snn_utils.firing_rate_regulator!(n) +
-Snn_utils.voltage_regulator!(n)
end
elseif n.learning_stage == "during_learning"
# if error signal available then accumulates Δw
if n.error !== nothing
Snn_utils.firing_rate!(n)
Snn_utils.firing_diff!(n)
n.w_rec_change = n.w_rec_change +
-apply!(n.optimiser, n.w_rec,
(n.error + Snn_utils.voltage_error!(n) + n.firing_rate_error) * n.e_rec) +
-Snn_utils.firing_rate_regulator!(n) +
-Snn_utils.voltage_regulator!(n)
end
elseif n.learning_stage == "end_learning"
# if error signal available then accumulates Δw
if n.error !== nothing
Snn_utils.firing_rate!(n)
Snn_utils.firing_diff!(n)
n.w_rec_change = n.w_rec_change +
-apply!(n.optimiser, n.w_rec,
(n.error + Snn_utils.voltage_error!(n) + n.firing_rate_error) * n.e_rec) +
-Snn_utils.firing_rate_regulator!(n) +
-Snn_utils.voltage_regulator!(n)
end
not_zero = (!iszero).(n.w_rec)
# set 0 in w_rec_change update according to 0 in w_rec for hard constrain connection
n.w_rec = n.w_rec + (not_zero .* n.w_rec_change)
replace!(x -> x < 0 ? 0 : x, n.w_rec) # no negative weight
Snn_utils.neuroplasticity!(n, kfn.firing_neurons_list)
end
end
""" linear_neuron learn()
"""
function learn!(n::linear_neuron, kfn::knowledgeFn)
n.error = kfn.output_error[n.id]
n.learning_stage = kfn.learning_stage
if n.learning_stage == "doing_inference"
# no learning
elseif n.learning_stage == "start_learning"
# if error signal available then accumulates Δw
if n.error !== nothing && n.id == 1 # NOT working w/ multithreading training
Δw = -apply!(n.optimiser, n.w_out, (n.error * n.epsilon_j))
n.w_out_change = n.w_out_change + Δw
n.eta = n.optimiser.eta
Δb = -n.eta * n.error
n.b_change = n.b_change + Δb
elseif n.error !== nothing && n.id !== 1
n.eta = kfn.output_neurons_array[1].eta
Δw = -n.eta * n.error * n.epsilon_j
n.w_out_change = n.w_out_change + Δw
Δb = -n.eta * n.error
n.b_change = n.b_change + Δb
end
elseif n.learning_stage == "during_learning"
# if error signal available then accumulates Δw
if n.error !== nothing && n.id == 1 # NOT working w/ multithreading training
Δw = -apply!(n.optimiser, n.w_out, (n.error * n.epsilon_j))
n.w_out_change = n.w_out_change + Δw
n.eta = n.optimiser.eta
Δb = -n.eta * n.error
n.b_change = n.b_change + Δb
elseif n.error !== nothing && n.id !== 1
n.eta = kfn.output_neurons_array[1].eta
Δw = -n.eta * n.error * n.epsilon_j
n.w_out_change = n.w_out_change + Δw
Δb = -n.eta * n.error
n.b_change = n.b_change + Δb
end
elseif n.learning_stage == "end_learning"
# if error signal available then accumulates Δw
if n.error !== nothing && n.id == 1 # NOT working w/ multithreading training
Δw = -apply!(n.optimiser, n.w_out, (n.error * n.epsilon_j))
n.w_out_change = n.w_out_change + Δw
n.eta = n.optimiser.eta
Δb = -n.eta * n.error
n.b_change = n.b_change + Δb
elseif n.error !== nothing && n.id !== 1
n.eta = kfn.output_neurons_array[1].eta
Δw = -n.eta * n.error * n.epsilon_j
n.w_out_change = n.w_out_change + Δw
Δb = -n.eta * n.error
n.b_change = n.b_change + Δb
end
n.w_out = n.w_out + n.w_out_change
n.b = n.b + n.b_change
end
end
end # module end