add todo
This commit is contained in:
50
src/learn.jl
50
src/learn.jl
@@ -12,14 +12,38 @@ export learn!
|
||||
|
||||
function learn!(m::model, model_respond, correct_answer)
|
||||
if m.learning_stage == "learning"
|
||||
#WORKING compute error
|
||||
if m.time_stamp < m.model_params[:perfect_timing]
|
||||
too_early = m.model_params[:perfect_timing] - m.time_stamp
|
||||
model_error = (model_respond .- correct_answer) * too_early
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
model_error = Flux.logitcrossentropy(model_respond, correct_answer)
|
||||
output_elements_error = model_respond - correct_answer
|
||||
|
||||
learn!(m.knowledgeFn[:I], model_error, output_elements_error)
|
||||
|
||||
#WORKING compute error
|
||||
# if m.time_stamp < m.m
|
||||
model_error = model_respond .- correct_answer
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -34,18 +58,18 @@ function learn!(m::model, model_respond, correct_answer)
|
||||
end
|
||||
|
||||
|
||||
function learn!(m::model, raw_model_respond, correct_answer=nothing)
|
||||
if m.learning_stage != "doing_inference"
|
||||
model_error = Flux.logitcrossentropy(raw_model_respond, correct_answer)
|
||||
output_elements_error = raw_model_respond - correct_answer
|
||||
# function learn!(m::model, raw_model_respond, correct_answer=nothing)
|
||||
# if m.learning_stage != "doing_inference"
|
||||
# model_error = Flux.logitcrossentropy(raw_model_respond, correct_answer)
|
||||
# output_elements_error = raw_model_respond - correct_answer
|
||||
|
||||
learn!(m.knowledgeFn[:I], model_error, output_elements_error)
|
||||
else
|
||||
model_error = nothing
|
||||
end
|
||||
# learn!(m.knowledgeFn[:I], model_error, output_elements_error)
|
||||
# else
|
||||
# model_error = nothing
|
||||
# end
|
||||
|
||||
return model_error
|
||||
end
|
||||
# return model_error
|
||||
# end
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user