diff --git a/src/Ironpen.jl b/src/Ironpen.jl index 2936920..8073fd9 100644 --- a/src/Ironpen.jl +++ b/src/Ironpen.jl @@ -32,10 +32,10 @@ using .learn # using .interface #------------------------------------------------------------------------------------------------100 -""" +""" version 0.0.2 Todo: - - [2] implement connection strength based on right or wrong answer + [*2] implement connection strength based on right or wrong answer + [*1] how to manage how much constrength increase and decrease [4] implement dormant connection [3] Δweight * connection strength [] using RL to control learning signal @@ -64,6 +64,13 @@ using .learn [DONE] add multi threads [DONE] during 0 training if 1-9 output neuron fires, adjust weight only those neurons [DONE] add maximum weight cap of each connection + [DONE] weaker connection should be harder to increase strength. It requires a lot of + repeat activation to get it stronger. While strong connction requires a lot of + inactivation to get it weaker. The concept is strong connection will lock + correct neural pathway through repeated use of the right connection i.e. keep training + on the correct answer -> strengthen the right neural pathway (connections) -> + this correct neural pathway resist to change. + Not used connection should dissapear (forgetting). Change from version: v06_36a - diff --git a/src/learn.jl b/src/learn.jl index 08d6cbf..b643309 100644 --- a/src/learn.jl +++ b/src/learn.jl @@ -8,7 +8,7 @@ export learn! #------------------------------------------------------------------------------------------------100 -function learn!(m::model, modelRespond::Vector{Bool}, correctAnswer::Union{AbstractVector, Nothing}) +function learn!(m::model, modelRespond::AbstractVector, correctAnswer::Union{AbstractVector, Nothing}) if correctAnswer === nothing correctAnswer_I = BitArray(zeros(length(modelRespond))) else @@ -21,7 +21,7 @@ end """ knowledgeFn learn() """ function learn!(kfn::kfn_1, correctAnswer::BitVector) - # #TESTING compute kfn error for each neuron + # compute kfn error for each neuron # outs = [n.z_t1 for n in kfn.outputNeuronsArray] # for (i, out) in enumerate(outs) # if out != correctAnswer[i] # need to adjust weight @@ -37,25 +37,25 @@ function learn!(kfn::kfn_1, correctAnswer::BitVector) # end # end - #TESTING compute kfn error for each neuron + # compute kfn error for each neuron outs = [n.z_t1 for n in kfn.outputNeuronsArray] for (i, out) in enumerate(outs) if out != correctAnswer[i] # need to adjust weight - kfnError = ( (kfn.outputNeuronsArray[i].v_th - kfn.outputNeuronsArray[i].vError) * 100.0 / - kfn.outputNeuronsArray[i].v_th ) + kfnError = ( (kfn.outputNeuronsArray[i].v_th - kfn.outputNeuronsArray[i].vError) * + 100.0 / kfn.outputNeuronsArray[i].v_th ) if correctAnswer[i] == 1 # output neuron that associated with correctAnswer Threads.@threads for n in kfn.neuronsArray # multithread is not atomic and causing error # for n in kfn.neuronsArray compute_wRecChange!(n, kfnError) - learn!(n, kfn.firedNeurons, kfn.nExInType) + learn!(n, kfn.firedNeurons, kfn.nExInType, correctAnswer[i]) end compute_wRecChange!(kfn.outputNeuronsArray[i], kfnError) learn!(kfn.outputNeuronsArray[i], kfn.firedNeurons, kfn.nExInType, - kfn.kfnParams[:totalInputPort]) + kfn.kfnParams[:totalInputPort], correctAnswer[i]) else # output neuron that is NOT associated with correctAnswer compute_wRecChange!(kfn.outputNeuronsArray[i], kfnError) learn!(kfn.outputNeuronsArray[i], kfn.firedNeurons, kfn.nExInType, - kfn.kfnParams[:totalInputPort]) + kfn.kfnParams[:totalInputPort], correctAnswer[i]) end end end @@ -94,12 +94,14 @@ function compute_wRecChange!(n::linearNeuron, error::Float64) reset_epsilonRec!(n) end -function learn!(n::T, firedNeurons, nExInType) where T<:inputNeuron +function learn!(n::T, firedNeurons, nExInType, correctAnswer) where T<:inputNeuron # skip end -function learn!(n::T, firedNeurons, nExInType) where T<:computeNeuron +function learn!(n::T, firedNeurons, nExInType, correctAnswer) where T<:computeNeuron wSign_0 = sign.(n.wRec) # original sign + #TESTING strong connection gets less weight change, weak connection gets more weight change + n.wRecChange .*= (connStrengthAdjust.(n.synapticStrength)) n.wRec += n.wRecChange # merge wRecChange into wRec reset_wRecChange!(n) wSign_1 = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign @@ -110,12 +112,14 @@ function learn!(n::T, firedNeurons, nExInType) where T<:computeNeuron n.wRec .*= nonFlipedSign capMaxWeight!(n.wRec) # cap maximum weight - synapticConnStrength!(n) + synapticConnStrength!(n, correctAnswer) neuroplasticity!(n, firedNeurons, nExInType) end -function learn!(n::T, firedNeurons, nExInType, totalInputPort) where T<:outputNeuron +function learn!(n::T, firedNeurons, nExInType, totalInputPort, correctAnswer) where T<:outputNeuron wSign_0 = sign.(n.wRec) # original sign + #TESTING strong connection gets less weight change, weak connection gets more weight change + n.wRecChange .*= (connStrengthAdjust.(n.synapticStrength)) n.wRec += n.wRecChange reset_wRecChange!(n) wSign_1 = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign @@ -126,7 +130,7 @@ function learn!(n::T, firedNeurons, nExInType, totalInputPort) where T<:outputNe n.wRec .*= nonFlipedSign capMaxWeight!(n.wRec) # cap maximum weight - synapticConnStrength!(n) + synapticConnStrength!(n, correctAnswer) neuroplasticity!(n,firedNeurons, nExInType, totalInputPort) end diff --git a/src/snn_utils.jl b/src/snn_utils.jl index 8a1b549..eb1e76e 100644 --- a/src/snn_utils.jl +++ b/src/snn_utils.jl @@ -6,7 +6,7 @@ export calculate_α, calculate_ρ, calculate_k, timestep_forward!, init_neuron, reset_epsilonRecA!, synapticConnStrength!, normalizePeak!, reset_wRecChange!, firing_rate_error!, firing_rate_regulator!, update_Bn!, cal_firing_reg!, neuroplasticity!, shakeup!, reset_learning_no_wchange!, adjust_internal_learning_rate!, - gradient_withloss, capMaxWeight! + gradient_withloss, capMaxWeight!, connStrengthAdjust using Statistics, Random, LinearAlgebra, Distributions, Zygote, Flux using GeneralUtils @@ -257,46 +257,111 @@ function connStrengthAdjust(currentStrength::Float64) return Δstrength::Float64 end -""" compute synaptic connection strength. bias will shift currentStrength to fit into +""" Compute synaptic connection strength. bias will shift currentStrength to fit into sigmoid operating range which centred at 0 and range is -37 to 37. + # Example synaptic strength range is 0 to 10 one may use bias = -5 to transform synaptic strength into range -5 to 5 - the return value is shifted back to original scale + the return value is shifted back to original scale. + + # Concept + weaker connection should be harder to increase strength. It requires a lot of + repeat activation to get it stronger. While strong connction requires a lot of + inactivation to get it weaker. The concept is strong connection will lock + correct neural pathway through repeated use of the right connection i.e. keep training + on the correct answer -> strengthen the right neural pathway (connections) -> + this correct neural pathway resist to change. + Not used connection should dissapear (forgetting). """ function synapticConnStrength(currentStrength::Float64, updown::String) Δstrength = connStrengthAdjust(currentStrength) if updown == "up" - updatedStrength = currentStrength + Δstrength + if currentStrength > 4 # strong connection + updatedStrength = currentStrength + Δstrength + else + updatedStrength = currentStrength + (Δstrength * 0.01) + end + elseif updown == "down" + if currentStrength > 4 + updatedStrength = currentStrength - (Δstrength * 0.5) + else + updatedStrength = currentStrength - Δstrength + end else - updatedStrength = currentStrength - Δstrength + error("undefined condition line $(@__LINE__)") end return updatedStrength::Float64 end +# function synapticConnStrength(currentStrength::Float64, updown::String) +# Δstrength = connStrengthAdjust(currentStrength) + +# if updown == "up" +# updatedStrength = currentStrength + Δstrength +# else +# updatedStrength = currentStrength - Δstrength +# end +# return updatedStrength::Float64 +# end """ Compute all synaptic connection strength of a neuron. Also mark n.wRec to 0 if wRec goes below lowerlimit. """ -function synapticConnStrength!(n::Union{computeNeuron, outputNeuron}) - for (i, connStrength) in enumerate(n.synapticStrength) - # check whether connStrength increase or decrease based on usage from n.epsilonRec - """ use n.z_i_t_commulative instead of the best choice, epsilonRec, here because ΔwRecChange - calculation in learn!() will reset epsilonRec to zeroes vector in case where - output neuron fires and trigger learn!() just before this synapticConnStrength - calculation. - Since n.z_i_t_commulative indicates whether a synaptic connection were used or not, it is - ok to use. n.z_i_t_commulative also span across a training sample without resetting. - """ - updown = n.z_i_t_commulative[i] == 0 ? "down" : "up" # - updatedConnStrength = synapticConnStrength(connStrength, updown) - updatedConnStrength = GeneralUtils.limitvalue(updatedConnStrength, - n.synapticStrengthLimit.lowerlimit, n.synapticStrengthLimit.upperlimit) - # at lowerlimit, mark wRec at this position to 0. for new random synaptic conn - if updatedConnStrength == n.synapticStrengthLimit.lowerlimit[1] - n.wRec[i] = 0.0 +# function synapticConnStrength!(n::Union{computeNeuron, outputNeuron}) +# for (i, connStrength) in enumerate(n.synapticStrength) +# # check whether connStrength increase or decrease based on usage from n.epsilonRec +# """ use n.z_i_t_commulative instead of the best choice, epsilonRec, here because ΔwRecChange +# calculation in learn!() will reset epsilonRec to zeroes vector in case where +# output neuron fires and trigger learn!() just before this synapticConnStrength +# calculation. +# Since n.z_i_t_commulative indicates whether a synaptic connection were used or not, it is +# ok to use. n.z_i_t_commulative also span across a training sample without resetting. +# """ +# updown = n.z_i_t_commulative[i] == 0 ? "down" : "up" # +# updatedConnStrength = synapticConnStrength(connStrength, updown) +# updatedConnStrength = GeneralUtils.limitvalue(updatedConnStrength, +# n.synapticStrengthLimit.lowerlimit, n.synapticStrengthLimit.upperlimit) +# # at lowerlimit, mark wRec at this position to 0. for new random synaptic conn +# if updatedConnStrength == n.synapticStrengthLimit.lowerlimit[1] +# n.wRec[i] = 0.0 +# end +# n.synapticStrength[i] = updatedConnStrength +# end +# end + +function synapticConnStrength!(n::Union{computeNeuron, outputNeuron}, correctAnswer::Bool) + if correctAnswer == true + for (i, connStrength) in enumerate(n.synapticStrength) + # check whether connStrength increase or decrease based on usage from n.epsilonRec + """ use n.z_i_t_commulative instead of the best choice, epsilonRec, here because ΔwRecChange + calculation in learn!() will reset epsilonRec to zeroes vector in case where + output neuron fires and trigger learn!() just before this synapticConnStrength + calculation. + Since n.z_i_t_commulative indicates whether a synaptic connection were used or not, it is + ok to use. n.z_i_t_commulative also span across a training sample without resetting. + """ + updown = n.z_i_t_commulative[i] == 0 ? "down" : "up" + updatedConnStrength = synapticConnStrength(connStrength, updown) + updatedConnStrength = GeneralUtils.limitvalue(updatedConnStrength, + n.synapticStrengthLimit.lowerlimit, n.synapticStrengthLimit.upperlimit) + # at lowerlimit, mark wRec at this position to 0. for new random synaptic conn + if updatedConnStrength == n.synapticStrengthLimit.lowerlimit[1] + n.wRec[i] = 0.0 + end + n.synapticStrength[i] = updatedConnStrength + end + else + for (i, connStrength) in enumerate(n.synapticStrength) + updatedConnStrength = synapticConnStrength(connStrength, "down") + updatedConnStrength = GeneralUtils.limitvalue(updatedConnStrength, + n.synapticStrengthLimit.lowerlimit, n.synapticStrengthLimit.upperlimit) + # at lowerlimit, mark wRec at this position to 0. for new random synaptic conn + if updatedConnStrength == n.synapticStrengthLimit.lowerlimit[1] + n.wRec[i] = 0.0 + end + n.synapticStrength[i] = updatedConnStrength end - n.synapticStrength[i] = updatedConnStrength end end