implement start learning
This commit is contained in:
83
src/learn.jl
83
src/learn.jl
@@ -10,70 +10,37 @@ export learn!
|
||||
|
||||
#------------------------------------------------------------------------------------------------100
|
||||
|
||||
function learn!(m::model, modelRespond, correctAnswer=nothing, correctTiming=nothing)
|
||||
|
||||
# set all KFN
|
||||
if m.learningStage == "start_learning"
|
||||
m.knowledgeFn[:I].learningStage = "start_learning"
|
||||
elseif m.learningStage == "end_learning"
|
||||
m.knowledgeFn[:I].learningStage = "end_learning"
|
||||
else
|
||||
end
|
||||
|
||||
#WORKING compute error
|
||||
# timingError =
|
||||
|
||||
|
||||
too_early = m.modelParams[:perfect_timing] - m.timeStep
|
||||
model_error = (model_respond .- correct_answer) * too_early
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
model_error = Flux.logitcrossentropy(model_respond, correct_answer)
|
||||
output_elements_error = model_respond - correct_answer
|
||||
|
||||
learn!(m.knowledgeFn[:I], model_error, output_elements_error)
|
||||
|
||||
|
||||
|
||||
function learn!(m::model, modelRespond, correctAnswer=nothing)
|
||||
m.knowledgeFn[:I].learningStage = m.learningStage
|
||||
# ΔWeight Conn. Strength
|
||||
# case 1 no no during input signal, no correct answer available, no answer
|
||||
# case 2 no - during input signal, no correct answer available, wrong answer
|
||||
# case 3 + - during input signal, correct answer available, no answer
|
||||
# case 4 no - during input signal, correct answer available, wrong answer
|
||||
# case 5 no ++ during input signal, correct answer
|
||||
# case 6 no ++ after input signal, at correct timing, correct answer
|
||||
# case 6 + - after input signal, at correct timing, no answer
|
||||
# case 9 no -- after input signal, at correct timing, wrong answer
|
||||
# case 7 adjust + after input signal, after correct timing (late), correct answer
|
||||
# case 8 after input signal, after correct timing (late), no answer
|
||||
# case 8 no - after input signal, after correct timing (late), wrong answer
|
||||
|
||||
# success
|
||||
|
||||
# how many matched respond and correct answer
|
||||
matched = sum(isequal(modelRespond, correctAnswer))
|
||||
|
||||
return model_error
|
||||
correctAnswer_I = correctAnswer # correct answer for kfn I
|
||||
learn!(m.knowledgeFn[:I], correctAnswer_I)
|
||||
|
||||
# return model_error
|
||||
end
|
||||
|
||||
|
||||
# function learn!(m::model, raw_model_respond, correct_answer=nothing)
|
||||
# if m.learningStage != "doing_inference"
|
||||
# model_error = Flux.logitcrossentropy(raw_model_respond, correct_answer)
|
||||
# output_elements_error = raw_model_respond - correct_answer
|
||||
|
||||
# learn!(m.knowledgeFn[:I], model_error, output_elements_error)
|
||||
# else
|
||||
# model_error = nothing
|
||||
# end
|
||||
|
||||
# return model_error
|
||||
# end
|
||||
|
||||
|
||||
|
||||
|
||||
""" knowledgeFn learn()
|
||||
"""
|
||||
function learn!(kfn::knowledgeFn, error::Union{Float64,Nothing}=nothing,
|
||||
outputError::Union{Vector,Nothing}=nothing)
|
||||
kfn.error = error
|
||||
kfn.outputError = outputError
|
||||
|
||||
kfn.learningStage = m.learningStage
|
||||
if m.learningStage == "start_learning"
|
||||
function learn!(kfn::kfn_1, correctAnswer=nothing)
|
||||
if kfn.learningStage == "start_learning"
|
||||
# reset params here instead of at the end_learning so that neuron's parameter data
|
||||
# don't gets wiped and can be logged for visualization later
|
||||
for n in kfn.neuronsArray
|
||||
@@ -85,6 +52,10 @@ function learn!(kfn::knowledgeFn, error::Union{Float64,Nothing}=nothing,
|
||||
# clear variables
|
||||
kfn.firedNeurons = Vector{Int64}()
|
||||
kfn.outputs = nothing
|
||||
|
||||
kfn.learningStage = "learning"
|
||||
elseif kfn.learningStage = "end_learning"
|
||||
kfn.learningStage = "inference"
|
||||
end
|
||||
|
||||
# Threads.@threads for n in kfn.neuronsArray
|
||||
|
||||
Reference in New Issue
Block a user