0.0.2
This commit is contained in:
36
src/learn.jl
36
src/learn.jl
@@ -40,23 +40,27 @@ function learn!(kfn::kfn_1, correctAnswer::BitVector)
|
||||
# compute kfn error for each neuron
|
||||
outs = [n.z_t1 for n in kfn.outputNeuronsArray]
|
||||
for (i, out) in enumerate(outs)
|
||||
if out != correctAnswer[i] # need to adjust weight
|
||||
kfnError = ( (kfn.outputNeuronsArray[i].v_th - kfn.outputNeuronsArray[i].vError) *
|
||||
100.0 / kfn.outputNeuronsArray[i].v_th )
|
||||
if correctAnswer[i] == 1 # output neuron that associated with correctAnswer
|
||||
Threads.@threads for n in kfn.neuronsArray # multithread is not atomic and causing error
|
||||
# for n in kfn.neuronsArray
|
||||
compute_wRecChange!(n, kfnError)
|
||||
learn!(n, kfn.firedNeurons, kfn.nExInType, correctAnswer[i])
|
||||
end
|
||||
compute_wRecChange!(kfn.outputNeuronsArray[i], kfnError)
|
||||
learn!(kfn.outputNeuronsArray[i], kfn.firedNeurons, kfn.nExInType,
|
||||
kfn.kfnParams[:totalInputPort], correctAnswer[i])
|
||||
else # output neuron that is NOT associated with correctAnswer
|
||||
compute_wRecChange!(kfn.outputNeuronsArray[i], kfnError)
|
||||
learn!(kfn.outputNeuronsArray[i], kfn.firedNeurons, kfn.nExInType,
|
||||
kfn.kfnParams[:totalInputPort], correctAnswer[i])
|
||||
if out == correctAnswer # output correct
|
||||
kfnError = 0.0
|
||||
Threads.@threads for n in kfn.neuronsArray # multithread is not atomic and causing error
|
||||
# for n in kfn.neuronsArray
|
||||
compute_wRecChange!(n, kfnError)
|
||||
learn!(n, kfn.firedNeurons, kfn.nExInType, true)
|
||||
end
|
||||
compute_wRecChange!(kfn.outputNeuronsArray[i], kfnError)
|
||||
learn!(kfn.outputNeuronsArray[i], kfn.firedNeurons, kfn.nExInType,
|
||||
kfn.kfnParams[:totalInputPort], true)
|
||||
else
|
||||
kfnError = ( (kfn.outputNeuronsArray[i].v_th - kfn.outputNeuronsArray[i].vError) *
|
||||
100.0 / kfn.outputNeuronsArray[i].v_th )^2
|
||||
Threads.@threads for n in kfn.neuronsArray # multithread is not atomic and causing error
|
||||
# for n in kfn.neuronsArray
|
||||
compute_wRecChange!(n, kfnError)
|
||||
learn!(n, kfn.firedNeurons, kfn.nExInType, false)
|
||||
end
|
||||
compute_wRecChange!(kfn.outputNeuronsArray[i], kfnError)
|
||||
learn!(kfn.outputNeuronsArray[i], kfn.firedNeurons, kfn.nExInType,
|
||||
kfn.kfnParams[:totalInputPort], false)
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
@@ -279,15 +279,15 @@ function synapticConnStrength(currentStrength::Float64, updown::String)
|
||||
|
||||
if updown == "up"
|
||||
if currentStrength > 4 # strong connection
|
||||
updatedStrength = currentStrength + Δstrength
|
||||
updatedStrength = currentStrength + (Δstrength * 0.2)
|
||||
else
|
||||
updatedStrength = currentStrength + (Δstrength * 0.01)
|
||||
updatedStrength = currentStrength + (Δstrength * 0.1)
|
||||
end
|
||||
elseif updown == "down"
|
||||
if currentStrength > 4
|
||||
updatedStrength = currentStrength - (Δstrength * 0.5)
|
||||
updatedStrength = currentStrength - (Δstrength * 0.1)
|
||||
else
|
||||
updatedStrength = currentStrength - Δstrength
|
||||
updatedStrength = currentStrength - (Δstrength * 1.0)
|
||||
end
|
||||
else
|
||||
error("undefined condition line $(@__LINE__)")
|
||||
|
||||
Reference in New Issue
Block a user