add todo
This commit is contained in:
@@ -35,9 +35,10 @@ using .interface
|
|||||||
"""
|
"""
|
||||||
Todo:
|
Todo:
|
||||||
[*3] no "start learning" use reset learning and "inference", "learning" mode instead
|
[*3] no "start learning" use reset learning and "inference", "learning" mode instead
|
||||||
[4] add time-based learning method. Also implement "thinking period"
|
[6] add time-based learning method. Also implement "thinking period"
|
||||||
[5] verify that model can complete learning cycle with no error
|
[7] verify that model can complete learning cycle with no error
|
||||||
[6] neuroplasticity() with synaptic connection strength concept
|
[4] synaptic connection strength concept
|
||||||
|
[5] neuroplasticity() i.e. change connection
|
||||||
[] using RL to control learning signal
|
[] using RL to control learning signal
|
||||||
[] consider using Dates.now() instead of timestamp because time_stamp may overflow
|
[] consider using Dates.now() instead of timestamp because time_stamp may overflow
|
||||||
[] training should include adjusting α, neuron membrane potential decay factor
|
[] training should include adjusting α, neuron membrane potential decay factor
|
||||||
|
|||||||
@@ -126,6 +126,7 @@ function (n::lif_neuron)(kfn::knowledgeFn)
|
|||||||
|
|
||||||
n.alpha_v_t = n.alpha * n.v_t
|
n.alpha_v_t = n.alpha * n.v_t
|
||||||
n.v_t1 = n.alpha_v_t + n.recurrent_signal
|
n.v_t1 = n.alpha_v_t + n.recurrent_signal
|
||||||
|
n.v_t1 = no_negative!.(n.v_t1)
|
||||||
|
|
||||||
if n.v_t1 > n.v_th
|
if n.v_t1 > n.v_th
|
||||||
n.z_t1 = true
|
n.z_t1 = true
|
||||||
@@ -169,6 +170,7 @@ function (n::alif_neuron)(kfn::knowledgeFn)
|
|||||||
n.recurrent_signal = sum(n.w_rec .* n.z_i_t) # signal from other neuron that this neuron subscribed
|
n.recurrent_signal = sum(n.w_rec .* n.z_i_t) # signal from other neuron that this neuron subscribed
|
||||||
n.alpha_v_t = n.alpha * n.v_t
|
n.alpha_v_t = n.alpha * n.v_t
|
||||||
n.v_t1 = n.alpha_v_t + n.recurrent_signal
|
n.v_t1 = n.alpha_v_t + n.recurrent_signal
|
||||||
|
n.v_t1 = no_negative!.(n.v_t1)
|
||||||
if n.v_t1 > n.av_th
|
if n.v_t1 > n.av_th
|
||||||
n.z_t1 = true
|
n.z_t1 = true
|
||||||
n.refractory_counter = n.refractory_duration
|
n.refractory_counter = n.refractory_duration
|
||||||
|
|||||||
50
src/learn.jl
50
src/learn.jl
@@ -12,14 +12,38 @@ export learn!
|
|||||||
|
|
||||||
function learn!(m::model, model_respond, correct_answer)
|
function learn!(m::model, model_respond, correct_answer)
|
||||||
if m.learning_stage == "learning"
|
if m.learning_stage == "learning"
|
||||||
|
#WORKING compute error
|
||||||
|
if m.time_stamp < m.model_params[:perfect_timing]
|
||||||
|
too_early = m.model_params[:perfect_timing] - m.time_stamp
|
||||||
|
model_error = (model_respond .- correct_answer) * too_early
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
model_error = Flux.logitcrossentropy(model_respond, correct_answer)
|
model_error = Flux.logitcrossentropy(model_respond, correct_answer)
|
||||||
output_elements_error = model_respond - correct_answer
|
output_elements_error = model_respond - correct_answer
|
||||||
|
|
||||||
learn!(m.knowledgeFn[:I], model_error, output_elements_error)
|
learn!(m.knowledgeFn[:I], model_error, output_elements_error)
|
||||||
|
|
||||||
#WORKING compute error
|
|
||||||
# if m.time_stamp < m.m
|
|
||||||
model_error = model_respond .- correct_answer
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -34,18 +58,18 @@ function learn!(m::model, model_respond, correct_answer)
|
|||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
function learn!(m::model, raw_model_respond, correct_answer=nothing)
|
# function learn!(m::model, raw_model_respond, correct_answer=nothing)
|
||||||
if m.learning_stage != "doing_inference"
|
# if m.learning_stage != "doing_inference"
|
||||||
model_error = Flux.logitcrossentropy(raw_model_respond, correct_answer)
|
# model_error = Flux.logitcrossentropy(raw_model_respond, correct_answer)
|
||||||
output_elements_error = raw_model_respond - correct_answer
|
# output_elements_error = raw_model_respond - correct_answer
|
||||||
|
|
||||||
learn!(m.knowledgeFn[:I], model_error, output_elements_error)
|
# learn!(m.knowledgeFn[:I], model_error, output_elements_error)
|
||||||
else
|
# else
|
||||||
model_error = nothing
|
# model_error = nothing
|
||||||
end
|
# end
|
||||||
|
|
||||||
return model_error
|
# return model_error
|
||||||
end
|
# end
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
module snn_utils
|
module snn_utils
|
||||||
|
|
||||||
using Flux.Optimise: apply!
|
using Flux.Optimise: apply!
|
||||||
export calculate_α, calculate_ρ, calculate_k, timestep_forward!, init_neuron, no_negative,
|
export calculate_α, calculate_ρ, calculate_k, timestep_forward!, init_neuron, no_negative!,
|
||||||
precision, calculate_w_change!, store_knowledgefn_error!, interneurons_adjustment!,
|
precision, calculate_w_change!, store_knowledgefn_error!, interneurons_adjustment!,
|
||||||
reset_z_t!, reset_learning_params!, reset_learning_history_params!,
|
reset_z_t!, reset_learning_params!, reset_learning_history_params!,
|
||||||
cal_v_reg!, calculate_w_change_end!,
|
cal_v_reg!, calculate_w_change_end!,
|
||||||
@@ -28,7 +28,7 @@ function timestep_forward!(x::linear_neuron)
|
|||||||
x.out_t = x.out_t1
|
x.out_t = x.out_t1
|
||||||
end
|
end
|
||||||
|
|
||||||
no_negative(n) = n < 0.0 ? 0.0 : x
|
no_negative!(x) = x < 0.0 ? 0.0 : x
|
||||||
precision(x::Array{<:Array}) = ( std(mean.(x)) / mean(mean.(x)) ) * 100
|
precision(x::Array{<:Array}) = ( std(mean.(x)) / mean(mean.(x)) ) * 100
|
||||||
|
|
||||||
# reset functions for LIF/ALIF neuron
|
# reset functions for LIF/ALIF neuron
|
||||||
|
|||||||
Reference in New Issue
Block a user