diff --git a/src/Ironpen.jl b/src/Ironpen.jl index 0f4bf33..0ddd4aa 100644 --- a/src/Ironpen.jl +++ b/src/Ironpen.jl @@ -34,17 +34,20 @@ using .learn """ version 0.0.3 Todo: - [*2] implement connection strength based on right or wrong answer - [*1] how to manage how much constrength increase and decrease [4] implement dormant connection - [3] Δweight * connection strength [] using RL to control learning signal [] consider using Dates.now() instead of timestamp because time_stamp may overflow [5] training should include adjusting α, neuron membrane potential decay factor which defined by neuron.tau_m formula in type.jl Change from version: 0.0.2 - - + [DONE] new learning method + - use Flux.logitcrossentropy for overall error + - remove ΔwRecChange that apply immediately during online learning + - collect ΔwRecChange during online learning (0-784th) and merge with wRec at + the end learning (1000th). + - compute model error at the end learning. Model error times with 5 constant for + higher learning impact than the error during online All features - multidispatch + for loop as main compute method @@ -86,6 +89,9 @@ using .learn on the correct answer -> strengthen the right neural pathway (connections) -> this correct neural pathway resist to change. Not used connection should dissapear (forgetting). + + Removed features + - Δweight * connection strength """ diff --git a/src/learn.jl b/src/learn.jl index 8010c73..dc08433 100644 --- a/src/learn.jl +++ b/src/learn.jl @@ -1,67 +1,27 @@ module learn -using Statistics, Random, LinearAlgebra, JSON3 +using Statistics, Random, LinearAlgebra, JSON3, Flux using GeneralUtils using ..types, ..snn_utils -export learn! +export learn!, compute_wRecChange!, computeModelError #------------------------------------------------------------------------------------------------100 -function learn!(m::model, modelRespond::AbstractVector, correctAnswer::Union{AbstractVector, Nothing}) - if correctAnswer === nothing - correctAnswer_I = BitArray(zeros(length(modelRespond))) - else - correctAnswer_I = Bool.(correctAnswer) # correct answer for kfn I - end - - learn!(m.knowledgeFn[:I], correctAnswer_I) +function learn!(m::model) + learn!(m.knowledgeFn[:I]) end """ knowledgeFn learn() """ -function learn!(kfn::kfn_1, correctAnswer::BitVector) - # compute kfn error for each neuron - # outs = [n.z_t1 for n in kfn.outputNeuronsArray] - # for (i, out) in enumerate(outs) - # if out != correctAnswer[i] # need to adjust weight - # kfnError = ( (kfn.outputNeuronsArray[i].v_th - kfn.outputNeuronsArray[i].vError) * - # 100 / kfn.outputNeuronsArray[i].v_th ) - - # Threads.@threads for n in kfn.neuronsArray - # # for n in kfn.neuronsArray - # learn!(n, kfnError) - # end - - # learn!(kfn.outputNeuronsArray[i], kfnError) - # end - # end - +function learn!(kfn::kfn_1) # compute kfn error for each neuron - outs = [n.z_t1 for n in kfn.outputNeuronsArray] - for (i, out) in enumerate(outs) - if out == correctAnswer # output correct - kfnError = 0.0 - Threads.@threads for n in kfn.neuronsArray # multithread is not atomic and causing error - # for n in kfn.neuronsArray - compute_wRecChange!(n, kfnError) - learn!(n, kfn.firedNeurons, kfn.nExInType, true) - end - compute_wRecChange!(kfn.outputNeuronsArray[i], kfnError) - learn!(kfn.outputNeuronsArray[i], kfn.firedNeurons, kfn.nExInType, - kfn.kfnParams[:totalInputPort], true) - else - kfnError = ( (kfn.outputNeuronsArray[i].v_th - kfn.outputNeuronsArray[i].vError) * - 100.0 / kfn.outputNeuronsArray[i].v_th )^2 - Threads.@threads for n in kfn.neuronsArray # multithread is not atomic and causing error - # for n in kfn.neuronsArray - compute_wRecChange!(n, kfnError) - learn!(n, kfn.firedNeurons, kfn.nExInType, false) - end - compute_wRecChange!(kfn.outputNeuronsArray[i], kfnError) - learn!(kfn.outputNeuronsArray[i], kfn.firedNeurons, kfn.nExInType, - kfn.kfnParams[:totalInputPort], false) + Threads.@threads for n in kfn.neuronsArray # multithread is not atomic and causing error + # for n in kfn.neuronsArray + learn!(n, kfn.firedNeurons, kfn.nExInType) end + for n in kfn.outputNeuronsArray + learn!(n, kfn.firedNeurons, kfn.nExInType, kfn.kfnParams[:totalInputPort]) end # wrap up learning session @@ -70,6 +30,30 @@ function learn!(kfn::kfn_1, correctAnswer::BitVector) end end +function computeModelError(modelRespond, correctAnswer; magnitude::Float64=1.0) + if correctAnswer === nothing + correctAnswer = BitArray(zeros(length(modelRespond))) + else + correctAnswer = Bool.(correctAnswer) # correct answer for kfn I + end + return Flux.logitcrossentropy(modelRespond, correctAnswer) .* magnitude +end + +function compute_wRecChange!(m::model, error::Float64) + compute_wRecChange!(m.knowledgeFn[:I], error) +end + +function compute_wRecChange!(kfn::kfn_1, error::Float64) + # compute kfn error for each neuron + Threads.@threads for n in kfn.neuronsArray + # for n in kfn.neuronsArray + compute_wRecChange!(n, error) + end + for n in kfn.outputNeuronsArray + compute_wRecChange!(n, error) + end +end + function compute_wRecChange!(n::passthroughNeuron, error::Float64) # skip end @@ -98,14 +82,12 @@ function compute_wRecChange!(n::linearNeuron, error::Float64) reset_epsilonRec!(n) end -function learn!(n::T, firedNeurons, nExInType, correctAnswer) where T<:inputNeuron +function learn!(n::T, firedNeurons, nExInType) where T<:inputNeuron # skip end -function learn!(n::T, firedNeurons, nExInType, correctAnswer) where T<:computeNeuron +function learn!(n::T, firedNeurons, nExInType) where T<:computeNeuron wSign_0 = sign.(n.wRec) # original sign - #TESTING strong connection gets less weight change, weak connection gets more weight change - n.wRecChange .*= (connStrengthAdjust.(n.synapticStrength)) n.wRec += n.wRecChange # merge wRecChange into wRec reset_wRecChange!(n) wSign_1 = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign @@ -116,14 +98,12 @@ function learn!(n::T, firedNeurons, nExInType, correctAnswer) where T<:computeNe n.wRec .*= nonFlipedSign capMaxWeight!(n.wRec) # cap maximum weight - synapticConnStrength!(n, correctAnswer) + synapticConnStrength!(n) neuroplasticity!(n, firedNeurons, nExInType) end -function learn!(n::T, firedNeurons, nExInType, totalInputPort, correctAnswer) where T<:outputNeuron +function learn!(n::T, firedNeurons, nExInType, totalInputPort) where T<:outputNeuron wSign_0 = sign.(n.wRec) # original sign - #TESTING strong connection gets less weight change, weak connection gets more weight change - n.wRecChange .*= (connStrengthAdjust.(n.synapticStrength)) n.wRec += n.wRecChange reset_wRecChange!(n) wSign_1 = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign @@ -134,7 +114,7 @@ function learn!(n::T, firedNeurons, nExInType, totalInputPort, correctAnswer) wh n.wRec .*= nonFlipedSign capMaxWeight!(n.wRec) # cap maximum weight - synapticConnStrength!(n, correctAnswer) + synapticConnStrength!(n) neuroplasticity!(n,firedNeurons, nExInType, totalInputPort) end diff --git a/src/snn_utils.jl b/src/snn_utils.jl index ff7f892..f5d77e4 100644 --- a/src/snn_utils.jl +++ b/src/snn_utils.jl @@ -279,13 +279,13 @@ function synapticConnStrength(currentStrength::Float64, updown::String) if updown == "up" if currentStrength > 4 # strong connection - updatedStrength = currentStrength + (Δstrength * 0.2) + updatedStrength = currentStrength + (Δstrength * 1.0) else - updatedStrength = currentStrength + (Δstrength * 0.1) + updatedStrength = currentStrength + (Δstrength * 1.0) end elseif updown == "down" if currentStrength > 4 - updatedStrength = currentStrength - (Δstrength * 0.1) + updatedStrength = currentStrength - (Δstrength * 1.0) else updatedStrength = currentStrength - (Δstrength * 1.0) end @@ -294,74 +294,29 @@ function synapticConnStrength(currentStrength::Float64, updown::String) end return updatedStrength::Float64 end -# function synapticConnStrength(currentStrength::Float64, updown::String) -# Δstrength = connStrengthAdjust(currentStrength) - -# if updown == "up" -# updatedStrength = currentStrength + Δstrength -# else -# updatedStrength = currentStrength - Δstrength -# end -# return updatedStrength::Float64 -# end """ Compute all synaptic connection strength of a neuron. Also mark n.wRec to 0 if wRec goes below lowerlimit. """ -# function synapticConnStrength!(n::Union{computeNeuron, outputNeuron}) -# for (i, connStrength) in enumerate(n.synapticStrength) -# # check whether connStrength increase or decrease based on usage from n.epsilonRec -# """ use n.z_i_t_commulative instead of the best choice, epsilonRec, here because ΔwRecChange -# calculation in learn!() will reset epsilonRec to zeroes vector in case where -# output neuron fires and trigger learn!() just before this synapticConnStrength -# calculation. -# Since n.z_i_t_commulative indicates whether a synaptic connection were used or not, it is -# ok to use. n.z_i_t_commulative also span across a training sample without resetting. -# """ -# updown = n.z_i_t_commulative[i] == 0 ? "down" : "up" # -# updatedConnStrength = synapticConnStrength(connStrength, updown) -# updatedConnStrength = GeneralUtils.limitvalue(updatedConnStrength, -# n.synapticStrengthLimit.lowerlimit, n.synapticStrengthLimit.upperlimit) -# # at lowerlimit, mark wRec at this position to 0. for new random synaptic conn -# if updatedConnStrength == n.synapticStrengthLimit.lowerlimit[1] -# n.wRec[i] = 0.0 -# end -# n.synapticStrength[i] = updatedConnStrength -# end -# end - -function synapticConnStrength!(n::Union{computeNeuron, outputNeuron}, correctAnswer::Bool) - if correctAnswer == true - for (i, connStrength) in enumerate(n.synapticStrength) - # check whether connStrength increase or decrease based on usage from n.epsilonRec - """ use n.z_i_t_commulative instead of the best choice, epsilonRec, here because ΔwRecChange - calculation in learn!() will reset epsilonRec to zeroes vector in case where - output neuron fires and trigger learn!() just before this synapticConnStrength - calculation. - Since n.z_i_t_commulative indicates whether a synaptic connection were used or not, it is - ok to use. n.z_i_t_commulative also span across a training sample without resetting. - """ - updown = n.z_i_t_commulative[i] == 0 ? "down" : "up" - updatedConnStrength = synapticConnStrength(connStrength, updown) - updatedConnStrength = GeneralUtils.limitvalue(updatedConnStrength, - n.synapticStrengthLimit.lowerlimit, n.synapticStrengthLimit.upperlimit) - # at lowerlimit, mark wRec at this position to 0. for new random synaptic conn - if updatedConnStrength == n.synapticStrengthLimit.lowerlimit[1] - n.wRec[i] = 0.0 - end - n.synapticStrength[i] = updatedConnStrength - end - else - for (i, connStrength) in enumerate(n.synapticStrength) - updatedConnStrength = synapticConnStrength(connStrength, "down") - updatedConnStrength = GeneralUtils.limitvalue(updatedConnStrength, - n.synapticStrengthLimit.lowerlimit, n.synapticStrengthLimit.upperlimit) - # at lowerlimit, mark wRec at this position to 0. for new random synaptic conn - if updatedConnStrength == n.synapticStrengthLimit.lowerlimit[1] - n.wRec[i] = 0.0 - end - n.synapticStrength[i] = updatedConnStrength +function synapticConnStrength!(n::Union{computeNeuron, outputNeuron}) + for (i, connStrength) in enumerate(n.synapticStrength) + # check whether connStrength increase or decrease based on usage from n.epsilonRec + """ use n.z_i_t_commulative instead of the best choice, epsilonRec, here because ΔwRecChange + calculation in learn!() will reset epsilonRec to zeroes vector in case where + output neuron fires and trigger learn!() just before this synapticConnStrength + calculation. + Since n.z_i_t_commulative indicates whether a synaptic connection were used or not, it is + ok to use. n.z_i_t_commulative also span across a training sample without resetting. + """ + updown = n.z_i_t_commulative[i] == 0 ? "down" : "up" + updatedConnStrength = synapticConnStrength(connStrength, updown) + updatedConnStrength = GeneralUtils.limitvalue(updatedConnStrength, + n.synapticStrengthLimit.lowerlimit, n.synapticStrengthLimit.upperlimit) + # at lowerlimit, mark wRec at this position to 0. for new random synaptic conn + if updatedConnStrength == n.synapticStrengthLimit.lowerlimit[1] + n.wRec[i] = 0.0 end + n.synapticStrength[i] = updatedConnStrength end end