refractoring

This commit is contained in:
2023-05-17 14:28:43 +07:00
parent df26a01929
commit 214466d9e9
5 changed files with 36 additions and 53 deletions

View File

@@ -12,10 +12,6 @@ export learn!
function learn!(m::model, modelRespond, correctAnswer=nothing)
m.knowledgeFn[:I].learningStage = m.learningStage
# # how many matched respond and correct answer
# matched = sum(isequal(modelRespond, correctAnswer))
if correctAnswer === nothing
correctAnswer_I = zeros(length(modelRespond))
else
@@ -28,39 +24,6 @@ end
""" knowledgeFn learn()
"""
function learn!(kfn::kfn_1, correctAnswer::AbstractVector)
# ΔWeight Conn. Strength
# case 1 no no during input signal, no correct answer available, no answer
# case 2 no - during input signal, no correct answer available, wrong answer
# case 3 + - during input signal, correct answer available, no answer
# case 4 no - during input signal, correct answer available, wrong answer
# case 5 no ++ during input signal, correct answer
# case 6 no ++ after input signal, at correct timing, correct answer
# case 6 + - after input signal, at correct timing, no answer
# case 9 no -- after input signal, at correct timing, wrong answer
# case 7 adjust + after input signal, after correct timing (late), correct answer
# case 8 after input signal, after correct timing (late), no answer
# case 8 no - after input signal, after correct timing (late), wrong answer
# success
if kfn.learningStage == "start_learning"
# reset params here instead of at the end_learning so that neuron's parameter data
# don't gets wiped and can be logged for visualization later
for n in kfn.neuronsArray
# epsilonRec need to be reset because it counting how many each synaptic fires and
# use this info to calculate how much synaptic weight should be adjust
resetLearningParams!(n)
end
# clear variables
kfn.firedNeurons = Vector{Int64}()
kfn.firedNeurons_t0 = Vector{Bool}()
kfn.firedNeurons_t1 = Vector{Bool}()
kfn.learningStage = "learning"
end
# compute kfn error
outs = [n.z_t1 for n in kfn.outputNeuronsArray]
for (i, out) in enumerate(outs)