experimenting compute neuron association

This commit is contained in:
2023-05-27 20:20:37 +07:00
parent b0cede75c1
commit 91a835cd64
3 changed files with 115 additions and 39 deletions

View File

@@ -32,10 +32,10 @@ using .learn
# using .interface # using .interface
#------------------------------------------------------------------------------------------------100 #------------------------------------------------------------------------------------------------100
""" """ version 0.0.2
Todo: Todo:
[*2] implement connection strength based on right or wrong answer
[2] implement connection strength based on right or wrong answer [*1] how to manage how much constrength increase and decrease
[4] implement dormant connection [4] implement dormant connection
[3] Δweight * connection strength [3] Δweight * connection strength
[] using RL to control learning signal [] using RL to control learning signal
@@ -64,6 +64,13 @@ using .learn
[DONE] add multi threads [DONE] add multi threads
[DONE] during 0 training if 1-9 output neuron fires, adjust weight only those neurons [DONE] during 0 training if 1-9 output neuron fires, adjust weight only those neurons
[DONE] add maximum weight cap of each connection [DONE] add maximum weight cap of each connection
[DONE] weaker connection should be harder to increase strength. It requires a lot of
repeat activation to get it stronger. While strong connction requires a lot of
inactivation to get it weaker. The concept is strong connection will lock
correct neural pathway through repeated use of the right connection i.e. keep training
on the correct answer -> strengthen the right neural pathway (connections) ->
this correct neural pathway resist to change.
Not used connection should dissapear (forgetting).
Change from version: v06_36a Change from version: v06_36a
- -

View File

@@ -8,7 +8,7 @@ export learn!
#------------------------------------------------------------------------------------------------100 #------------------------------------------------------------------------------------------------100
function learn!(m::model, modelRespond::Vector{Bool}, correctAnswer::Union{AbstractVector, Nothing}) function learn!(m::model, modelRespond::AbstractVector, correctAnswer::Union{AbstractVector, Nothing})
if correctAnswer === nothing if correctAnswer === nothing
correctAnswer_I = BitArray(zeros(length(modelRespond))) correctAnswer_I = BitArray(zeros(length(modelRespond)))
else else
@@ -21,7 +21,7 @@ end
""" knowledgeFn learn() """ knowledgeFn learn()
""" """
function learn!(kfn::kfn_1, correctAnswer::BitVector) function learn!(kfn::kfn_1, correctAnswer::BitVector)
# #TESTING compute kfn error for each neuron # compute kfn error for each neuron
# outs = [n.z_t1 for n in kfn.outputNeuronsArray] # outs = [n.z_t1 for n in kfn.outputNeuronsArray]
# for (i, out) in enumerate(outs) # for (i, out) in enumerate(outs)
# if out != correctAnswer[i] # need to adjust weight # if out != correctAnswer[i] # need to adjust weight
@@ -37,25 +37,25 @@ function learn!(kfn::kfn_1, correctAnswer::BitVector)
# end # end
# end # end
#TESTING compute kfn error for each neuron # compute kfn error for each neuron
outs = [n.z_t1 for n in kfn.outputNeuronsArray] outs = [n.z_t1 for n in kfn.outputNeuronsArray]
for (i, out) in enumerate(outs) for (i, out) in enumerate(outs)
if out != correctAnswer[i] # need to adjust weight if out != correctAnswer[i] # need to adjust weight
kfnError = ( (kfn.outputNeuronsArray[i].v_th - kfn.outputNeuronsArray[i].vError) * 100.0 / kfnError = ( (kfn.outputNeuronsArray[i].v_th - kfn.outputNeuronsArray[i].vError) *
kfn.outputNeuronsArray[i].v_th ) 100.0 / kfn.outputNeuronsArray[i].v_th )
if correctAnswer[i] == 1 # output neuron that associated with correctAnswer if correctAnswer[i] == 1 # output neuron that associated with correctAnswer
Threads.@threads for n in kfn.neuronsArray # multithread is not atomic and causing error Threads.@threads for n in kfn.neuronsArray # multithread is not atomic and causing error
# for n in kfn.neuronsArray # for n in kfn.neuronsArray
compute_wRecChange!(n, kfnError) compute_wRecChange!(n, kfnError)
learn!(n, kfn.firedNeurons, kfn.nExInType) learn!(n, kfn.firedNeurons, kfn.nExInType, correctAnswer[i])
end end
compute_wRecChange!(kfn.outputNeuronsArray[i], kfnError) compute_wRecChange!(kfn.outputNeuronsArray[i], kfnError)
learn!(kfn.outputNeuronsArray[i], kfn.firedNeurons, kfn.nExInType, learn!(kfn.outputNeuronsArray[i], kfn.firedNeurons, kfn.nExInType,
kfn.kfnParams[:totalInputPort]) kfn.kfnParams[:totalInputPort], correctAnswer[i])
else # output neuron that is NOT associated with correctAnswer else # output neuron that is NOT associated with correctAnswer
compute_wRecChange!(kfn.outputNeuronsArray[i], kfnError) compute_wRecChange!(kfn.outputNeuronsArray[i], kfnError)
learn!(kfn.outputNeuronsArray[i], kfn.firedNeurons, kfn.nExInType, learn!(kfn.outputNeuronsArray[i], kfn.firedNeurons, kfn.nExInType,
kfn.kfnParams[:totalInputPort]) kfn.kfnParams[:totalInputPort], correctAnswer[i])
end end
end end
end end
@@ -94,12 +94,14 @@ function compute_wRecChange!(n::linearNeuron, error::Float64)
reset_epsilonRec!(n) reset_epsilonRec!(n)
end end
function learn!(n::T, firedNeurons, nExInType) where T<:inputNeuron function learn!(n::T, firedNeurons, nExInType, correctAnswer) where T<:inputNeuron
# skip # skip
end end
function learn!(n::T, firedNeurons, nExInType) where T<:computeNeuron function learn!(n::T, firedNeurons, nExInType, correctAnswer) where T<:computeNeuron
wSign_0 = sign.(n.wRec) # original sign wSign_0 = sign.(n.wRec) # original sign
#TESTING strong connection gets less weight change, weak connection gets more weight change
n.wRecChange .*= (connStrengthAdjust.(n.synapticStrength))
n.wRec += n.wRecChange # merge wRecChange into wRec n.wRec += n.wRecChange # merge wRecChange into wRec
reset_wRecChange!(n) reset_wRecChange!(n)
wSign_1 = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign wSign_1 = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign
@@ -110,12 +112,14 @@ function learn!(n::T, firedNeurons, nExInType) where T<:computeNeuron
n.wRec .*= nonFlipedSign n.wRec .*= nonFlipedSign
capMaxWeight!(n.wRec) # cap maximum weight capMaxWeight!(n.wRec) # cap maximum weight
synapticConnStrength!(n) synapticConnStrength!(n, correctAnswer)
neuroplasticity!(n, firedNeurons, nExInType) neuroplasticity!(n, firedNeurons, nExInType)
end end
function learn!(n::T, firedNeurons, nExInType, totalInputPort) where T<:outputNeuron function learn!(n::T, firedNeurons, nExInType, totalInputPort, correctAnswer) where T<:outputNeuron
wSign_0 = sign.(n.wRec) # original sign wSign_0 = sign.(n.wRec) # original sign
#TESTING strong connection gets less weight change, weak connection gets more weight change
n.wRecChange .*= (connStrengthAdjust.(n.synapticStrength))
n.wRec += n.wRecChange n.wRec += n.wRecChange
reset_wRecChange!(n) reset_wRecChange!(n)
wSign_1 = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign wSign_1 = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign
@@ -126,7 +130,7 @@ function learn!(n::T, firedNeurons, nExInType, totalInputPort) where T<:outputNe
n.wRec .*= nonFlipedSign n.wRec .*= nonFlipedSign
capMaxWeight!(n.wRec) # cap maximum weight capMaxWeight!(n.wRec) # cap maximum weight
synapticConnStrength!(n) synapticConnStrength!(n, correctAnswer)
neuroplasticity!(n,firedNeurons, nExInType, totalInputPort) neuroplasticity!(n,firedNeurons, nExInType, totalInputPort)
end end

View File

@@ -6,7 +6,7 @@ export calculate_α, calculate_ρ, calculate_k, timestep_forward!, init_neuron,
reset_epsilonRecA!, synapticConnStrength!, normalizePeak!, reset_wRecChange!, reset_epsilonRecA!, synapticConnStrength!, normalizePeak!, reset_wRecChange!,
firing_rate_error!, firing_rate_regulator!, update_Bn!, cal_firing_reg!, firing_rate_error!, firing_rate_regulator!, update_Bn!, cal_firing_reg!,
neuroplasticity!, shakeup!, reset_learning_no_wchange!, adjust_internal_learning_rate!, neuroplasticity!, shakeup!, reset_learning_no_wchange!, adjust_internal_learning_rate!,
gradient_withloss, capMaxWeight! gradient_withloss, capMaxWeight!, connStrengthAdjust
using Statistics, Random, LinearAlgebra, Distributions, Zygote, Flux using Statistics, Random, LinearAlgebra, Distributions, Zygote, Flux
using GeneralUtils using GeneralUtils
@@ -257,28 +257,81 @@ function connStrengthAdjust(currentStrength::Float64)
return Δstrength::Float64 return Δstrength::Float64
end end
""" compute synaptic connection strength. bias will shift currentStrength to fit into """ Compute synaptic connection strength. bias will shift currentStrength to fit into
sigmoid operating range which centred at 0 and range is -37 to 37. sigmoid operating range which centred at 0 and range is -37 to 37.
# Example # Example
synaptic strength range is 0 to 10 synaptic strength range is 0 to 10
one may use bias = -5 to transform synaptic strength into range -5 to 5 one may use bias = -5 to transform synaptic strength into range -5 to 5
the return value is shifted back to original scale the return value is shifted back to original scale.
# Concept
weaker connection should be harder to increase strength. It requires a lot of
repeat activation to get it stronger. While strong connction requires a lot of
inactivation to get it weaker. The concept is strong connection will lock
correct neural pathway through repeated use of the right connection i.e. keep training
on the correct answer -> strengthen the right neural pathway (connections) ->
this correct neural pathway resist to change.
Not used connection should dissapear (forgetting).
""" """
function synapticConnStrength(currentStrength::Float64, updown::String) function synapticConnStrength(currentStrength::Float64, updown::String)
Δstrength = connStrengthAdjust(currentStrength) Δstrength = connStrengthAdjust(currentStrength)
if updown == "up" if updown == "up"
if currentStrength > 4 # strong connection
updatedStrength = currentStrength + Δstrength updatedStrength = currentStrength + Δstrength
else
updatedStrength = currentStrength + (Δstrength * 0.01)
end
elseif updown == "down"
if currentStrength > 4
updatedStrength = currentStrength - (Δstrength * 0.5)
else else
updatedStrength = currentStrength - Δstrength updatedStrength = currentStrength - Δstrength
end end
else
error("undefined condition line $(@__LINE__)")
end
return updatedStrength::Float64 return updatedStrength::Float64
end end
# function synapticConnStrength(currentStrength::Float64, updown::String)
# Δstrength = connStrengthAdjust(currentStrength)
# if updown == "up"
# updatedStrength = currentStrength + Δstrength
# else
# updatedStrength = currentStrength - Δstrength
# end
# return updatedStrength::Float64
# end
""" Compute all synaptic connection strength of a neuron. Also mark n.wRec to 0 if wRec goes """ Compute all synaptic connection strength of a neuron. Also mark n.wRec to 0 if wRec goes
below lowerlimit. below lowerlimit.
""" """
function synapticConnStrength!(n::Union{computeNeuron, outputNeuron}) # function synapticConnStrength!(n::Union{computeNeuron, outputNeuron})
# for (i, connStrength) in enumerate(n.synapticStrength)
# # check whether connStrength increase or decrease based on usage from n.epsilonRec
# """ use n.z_i_t_commulative instead of the best choice, epsilonRec, here because ΔwRecChange
# calculation in learn!() will reset epsilonRec to zeroes vector in case where
# output neuron fires and trigger learn!() just before this synapticConnStrength
# calculation.
# Since n.z_i_t_commulative indicates whether a synaptic connection were used or not, it is
# ok to use. n.z_i_t_commulative also span across a training sample without resetting.
# """
# updown = n.z_i_t_commulative[i] == 0 ? "down" : "up" #
# updatedConnStrength = synapticConnStrength(connStrength, updown)
# updatedConnStrength = GeneralUtils.limitvalue(updatedConnStrength,
# n.synapticStrengthLimit.lowerlimit, n.synapticStrengthLimit.upperlimit)
# # at lowerlimit, mark wRec at this position to 0. for new random synaptic conn
# if updatedConnStrength == n.synapticStrengthLimit.lowerlimit[1]
# n.wRec[i] = 0.0
# end
# n.synapticStrength[i] = updatedConnStrength
# end
# end
function synapticConnStrength!(n::Union{computeNeuron, outputNeuron}, correctAnswer::Bool)
if correctAnswer == true
for (i, connStrength) in enumerate(n.synapticStrength) for (i, connStrength) in enumerate(n.synapticStrength)
# check whether connStrength increase or decrease based on usage from n.epsilonRec # check whether connStrength increase or decrease based on usage from n.epsilonRec
""" use n.z_i_t_commulative instead of the best choice, epsilonRec, here because ΔwRecChange """ use n.z_i_t_commulative instead of the best choice, epsilonRec, here because ΔwRecChange
@@ -288,7 +341,7 @@ function synapticConnStrength!(n::Union{computeNeuron, outputNeuron})
Since n.z_i_t_commulative indicates whether a synaptic connection were used or not, it is Since n.z_i_t_commulative indicates whether a synaptic connection were used or not, it is
ok to use. n.z_i_t_commulative also span across a training sample without resetting. ok to use. n.z_i_t_commulative also span across a training sample without resetting.
""" """
updown = n.z_i_t_commulative[i] == 0 ? "down" : "up" # updown = n.z_i_t_commulative[i] == 0 ? "down" : "up"
updatedConnStrength = synapticConnStrength(connStrength, updown) updatedConnStrength = synapticConnStrength(connStrength, updown)
updatedConnStrength = GeneralUtils.limitvalue(updatedConnStrength, updatedConnStrength = GeneralUtils.limitvalue(updatedConnStrength,
n.synapticStrengthLimit.lowerlimit, n.synapticStrengthLimit.upperlimit) n.synapticStrengthLimit.lowerlimit, n.synapticStrengthLimit.upperlimit)
@@ -298,6 +351,18 @@ function synapticConnStrength!(n::Union{computeNeuron, outputNeuron})
end end
n.synapticStrength[i] = updatedConnStrength n.synapticStrength[i] = updatedConnStrength
end end
else
for (i, connStrength) in enumerate(n.synapticStrength)
updatedConnStrength = synapticConnStrength(connStrength, "down")
updatedConnStrength = GeneralUtils.limitvalue(updatedConnStrength,
n.synapticStrengthLimit.lowerlimit, n.synapticStrengthLimit.upperlimit)
# at lowerlimit, mark wRec at this position to 0. for new random synaptic conn
if updatedConnStrength == n.synapticStrengthLimit.lowerlimit[1]
n.wRec[i] = 0.0
end
n.synapticStrength[i] = updatedConnStrength
end
end
end end
function synapticConnStrength!(n::inputNeuron) end function synapticConnStrength!(n::inputNeuron) end