experimenting compute neuron association
This commit is contained in:
@@ -32,10 +32,10 @@ using .learn
|
||||
# using .interface
|
||||
#------------------------------------------------------------------------------------------------100
|
||||
|
||||
"""
|
||||
""" version 0.0.2
|
||||
Todo:
|
||||
|
||||
[2] implement connection strength based on right or wrong answer
|
||||
[*2] implement connection strength based on right or wrong answer
|
||||
[*1] how to manage how much constrength increase and decrease
|
||||
[4] implement dormant connection
|
||||
[3] Δweight * connection strength
|
||||
[] using RL to control learning signal
|
||||
@@ -64,6 +64,13 @@ using .learn
|
||||
[DONE] add multi threads
|
||||
[DONE] during 0 training if 1-9 output neuron fires, adjust weight only those neurons
|
||||
[DONE] add maximum weight cap of each connection
|
||||
[DONE] weaker connection should be harder to increase strength. It requires a lot of
|
||||
repeat activation to get it stronger. While strong connction requires a lot of
|
||||
inactivation to get it weaker. The concept is strong connection will lock
|
||||
correct neural pathway through repeated use of the right connection i.e. keep training
|
||||
on the correct answer -> strengthen the right neural pathway (connections) ->
|
||||
this correct neural pathway resist to change.
|
||||
Not used connection should dissapear (forgetting).
|
||||
|
||||
Change from version: v06_36a
|
||||
-
|
||||
|
||||
30
src/learn.jl
30
src/learn.jl
@@ -8,7 +8,7 @@ export learn!
|
||||
|
||||
#------------------------------------------------------------------------------------------------100
|
||||
|
||||
function learn!(m::model, modelRespond::Vector{Bool}, correctAnswer::Union{AbstractVector, Nothing})
|
||||
function learn!(m::model, modelRespond::AbstractVector, correctAnswer::Union{AbstractVector, Nothing})
|
||||
if correctAnswer === nothing
|
||||
correctAnswer_I = BitArray(zeros(length(modelRespond)))
|
||||
else
|
||||
@@ -21,7 +21,7 @@ end
|
||||
""" knowledgeFn learn()
|
||||
"""
|
||||
function learn!(kfn::kfn_1, correctAnswer::BitVector)
|
||||
# #TESTING compute kfn error for each neuron
|
||||
# compute kfn error for each neuron
|
||||
# outs = [n.z_t1 for n in kfn.outputNeuronsArray]
|
||||
# for (i, out) in enumerate(outs)
|
||||
# if out != correctAnswer[i] # need to adjust weight
|
||||
@@ -37,25 +37,25 @@ function learn!(kfn::kfn_1, correctAnswer::BitVector)
|
||||
# end
|
||||
# end
|
||||
|
||||
#TESTING compute kfn error for each neuron
|
||||
# compute kfn error for each neuron
|
||||
outs = [n.z_t1 for n in kfn.outputNeuronsArray]
|
||||
for (i, out) in enumerate(outs)
|
||||
if out != correctAnswer[i] # need to adjust weight
|
||||
kfnError = ( (kfn.outputNeuronsArray[i].v_th - kfn.outputNeuronsArray[i].vError) * 100.0 /
|
||||
kfn.outputNeuronsArray[i].v_th )
|
||||
kfnError = ( (kfn.outputNeuronsArray[i].v_th - kfn.outputNeuronsArray[i].vError) *
|
||||
100.0 / kfn.outputNeuronsArray[i].v_th )
|
||||
if correctAnswer[i] == 1 # output neuron that associated with correctAnswer
|
||||
Threads.@threads for n in kfn.neuronsArray # multithread is not atomic and causing error
|
||||
# for n in kfn.neuronsArray
|
||||
compute_wRecChange!(n, kfnError)
|
||||
learn!(n, kfn.firedNeurons, kfn.nExInType)
|
||||
learn!(n, kfn.firedNeurons, kfn.nExInType, correctAnswer[i])
|
||||
end
|
||||
compute_wRecChange!(kfn.outputNeuronsArray[i], kfnError)
|
||||
learn!(kfn.outputNeuronsArray[i], kfn.firedNeurons, kfn.nExInType,
|
||||
kfn.kfnParams[:totalInputPort])
|
||||
kfn.kfnParams[:totalInputPort], correctAnswer[i])
|
||||
else # output neuron that is NOT associated with correctAnswer
|
||||
compute_wRecChange!(kfn.outputNeuronsArray[i], kfnError)
|
||||
learn!(kfn.outputNeuronsArray[i], kfn.firedNeurons, kfn.nExInType,
|
||||
kfn.kfnParams[:totalInputPort])
|
||||
kfn.kfnParams[:totalInputPort], correctAnswer[i])
|
||||
end
|
||||
end
|
||||
end
|
||||
@@ -94,12 +94,14 @@ function compute_wRecChange!(n::linearNeuron, error::Float64)
|
||||
reset_epsilonRec!(n)
|
||||
end
|
||||
|
||||
function learn!(n::T, firedNeurons, nExInType) where T<:inputNeuron
|
||||
function learn!(n::T, firedNeurons, nExInType, correctAnswer) where T<:inputNeuron
|
||||
# skip
|
||||
end
|
||||
|
||||
function learn!(n::T, firedNeurons, nExInType) where T<:computeNeuron
|
||||
function learn!(n::T, firedNeurons, nExInType, correctAnswer) where T<:computeNeuron
|
||||
wSign_0 = sign.(n.wRec) # original sign
|
||||
#TESTING strong connection gets less weight change, weak connection gets more weight change
|
||||
n.wRecChange .*= (connStrengthAdjust.(n.synapticStrength))
|
||||
n.wRec += n.wRecChange # merge wRecChange into wRec
|
||||
reset_wRecChange!(n)
|
||||
wSign_1 = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign
|
||||
@@ -110,12 +112,14 @@ function learn!(n::T, firedNeurons, nExInType) where T<:computeNeuron
|
||||
n.wRec .*= nonFlipedSign
|
||||
capMaxWeight!(n.wRec) # cap maximum weight
|
||||
|
||||
synapticConnStrength!(n)
|
||||
synapticConnStrength!(n, correctAnswer)
|
||||
neuroplasticity!(n, firedNeurons, nExInType)
|
||||
end
|
||||
|
||||
function learn!(n::T, firedNeurons, nExInType, totalInputPort) where T<:outputNeuron
|
||||
function learn!(n::T, firedNeurons, nExInType, totalInputPort, correctAnswer) where T<:outputNeuron
|
||||
wSign_0 = sign.(n.wRec) # original sign
|
||||
#TESTING strong connection gets less weight change, weak connection gets more weight change
|
||||
n.wRecChange .*= (connStrengthAdjust.(n.synapticStrength))
|
||||
n.wRec += n.wRecChange
|
||||
reset_wRecChange!(n)
|
||||
wSign_1 = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign
|
||||
@@ -126,7 +130,7 @@ function learn!(n::T, firedNeurons, nExInType, totalInputPort) where T<:outputNe
|
||||
n.wRec .*= nonFlipedSign
|
||||
capMaxWeight!(n.wRec) # cap maximum weight
|
||||
|
||||
synapticConnStrength!(n)
|
||||
synapticConnStrength!(n, correctAnswer)
|
||||
neuroplasticity!(n,firedNeurons, nExInType, totalInputPort)
|
||||
end
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ export calculate_α, calculate_ρ, calculate_k, timestep_forward!, init_neuron,
|
||||
reset_epsilonRecA!, synapticConnStrength!, normalizePeak!, reset_wRecChange!,
|
||||
firing_rate_error!, firing_rate_regulator!, update_Bn!, cal_firing_reg!,
|
||||
neuroplasticity!, shakeup!, reset_learning_no_wchange!, adjust_internal_learning_rate!,
|
||||
gradient_withloss, capMaxWeight!
|
||||
gradient_withloss, capMaxWeight!, connStrengthAdjust
|
||||
|
||||
using Statistics, Random, LinearAlgebra, Distributions, Zygote, Flux
|
||||
using GeneralUtils
|
||||
@@ -257,28 +257,81 @@ function connStrengthAdjust(currentStrength::Float64)
|
||||
return Δstrength::Float64
|
||||
end
|
||||
|
||||
""" compute synaptic connection strength. bias will shift currentStrength to fit into
|
||||
""" Compute synaptic connection strength. bias will shift currentStrength to fit into
|
||||
sigmoid operating range which centred at 0 and range is -37 to 37.
|
||||
|
||||
# Example
|
||||
synaptic strength range is 0 to 10
|
||||
one may use bias = -5 to transform synaptic strength into range -5 to 5
|
||||
the return value is shifted back to original scale
|
||||
the return value is shifted back to original scale.
|
||||
|
||||
# Concept
|
||||
weaker connection should be harder to increase strength. It requires a lot of
|
||||
repeat activation to get it stronger. While strong connction requires a lot of
|
||||
inactivation to get it weaker. The concept is strong connection will lock
|
||||
correct neural pathway through repeated use of the right connection i.e. keep training
|
||||
on the correct answer -> strengthen the right neural pathway (connections) ->
|
||||
this correct neural pathway resist to change.
|
||||
Not used connection should dissapear (forgetting).
|
||||
"""
|
||||
function synapticConnStrength(currentStrength::Float64, updown::String)
|
||||
Δstrength = connStrengthAdjust(currentStrength)
|
||||
|
||||
if updown == "up"
|
||||
if currentStrength > 4 # strong connection
|
||||
updatedStrength = currentStrength + Δstrength
|
||||
else
|
||||
updatedStrength = currentStrength + (Δstrength * 0.01)
|
||||
end
|
||||
elseif updown == "down"
|
||||
if currentStrength > 4
|
||||
updatedStrength = currentStrength - (Δstrength * 0.5)
|
||||
else
|
||||
updatedStrength = currentStrength - Δstrength
|
||||
end
|
||||
else
|
||||
error("undefined condition line $(@__LINE__)")
|
||||
end
|
||||
return updatedStrength::Float64
|
||||
end
|
||||
# function synapticConnStrength(currentStrength::Float64, updown::String)
|
||||
# Δstrength = connStrengthAdjust(currentStrength)
|
||||
|
||||
# if updown == "up"
|
||||
# updatedStrength = currentStrength + Δstrength
|
||||
# else
|
||||
# updatedStrength = currentStrength - Δstrength
|
||||
# end
|
||||
# return updatedStrength::Float64
|
||||
# end
|
||||
|
||||
""" Compute all synaptic connection strength of a neuron. Also mark n.wRec to 0 if wRec goes
|
||||
below lowerlimit.
|
||||
"""
|
||||
function synapticConnStrength!(n::Union{computeNeuron, outputNeuron})
|
||||
# function synapticConnStrength!(n::Union{computeNeuron, outputNeuron})
|
||||
# for (i, connStrength) in enumerate(n.synapticStrength)
|
||||
# # check whether connStrength increase or decrease based on usage from n.epsilonRec
|
||||
# """ use n.z_i_t_commulative instead of the best choice, epsilonRec, here because ΔwRecChange
|
||||
# calculation in learn!() will reset epsilonRec to zeroes vector in case where
|
||||
# output neuron fires and trigger learn!() just before this synapticConnStrength
|
||||
# calculation.
|
||||
# Since n.z_i_t_commulative indicates whether a synaptic connection were used or not, it is
|
||||
# ok to use. n.z_i_t_commulative also span across a training sample without resetting.
|
||||
# """
|
||||
# updown = n.z_i_t_commulative[i] == 0 ? "down" : "up" #
|
||||
# updatedConnStrength = synapticConnStrength(connStrength, updown)
|
||||
# updatedConnStrength = GeneralUtils.limitvalue(updatedConnStrength,
|
||||
# n.synapticStrengthLimit.lowerlimit, n.synapticStrengthLimit.upperlimit)
|
||||
# # at lowerlimit, mark wRec at this position to 0. for new random synaptic conn
|
||||
# if updatedConnStrength == n.synapticStrengthLimit.lowerlimit[1]
|
||||
# n.wRec[i] = 0.0
|
||||
# end
|
||||
# n.synapticStrength[i] = updatedConnStrength
|
||||
# end
|
||||
# end
|
||||
|
||||
function synapticConnStrength!(n::Union{computeNeuron, outputNeuron}, correctAnswer::Bool)
|
||||
if correctAnswer == true
|
||||
for (i, connStrength) in enumerate(n.synapticStrength)
|
||||
# check whether connStrength increase or decrease based on usage from n.epsilonRec
|
||||
""" use n.z_i_t_commulative instead of the best choice, epsilonRec, here because ΔwRecChange
|
||||
@@ -288,7 +341,7 @@ function synapticConnStrength!(n::Union{computeNeuron, outputNeuron})
|
||||
Since n.z_i_t_commulative indicates whether a synaptic connection were used or not, it is
|
||||
ok to use. n.z_i_t_commulative also span across a training sample without resetting.
|
||||
"""
|
||||
updown = n.z_i_t_commulative[i] == 0 ? "down" : "up" #
|
||||
updown = n.z_i_t_commulative[i] == 0 ? "down" : "up"
|
||||
updatedConnStrength = synapticConnStrength(connStrength, updown)
|
||||
updatedConnStrength = GeneralUtils.limitvalue(updatedConnStrength,
|
||||
n.synapticStrengthLimit.lowerlimit, n.synapticStrengthLimit.upperlimit)
|
||||
@@ -298,6 +351,18 @@ function synapticConnStrength!(n::Union{computeNeuron, outputNeuron})
|
||||
end
|
||||
n.synapticStrength[i] = updatedConnStrength
|
||||
end
|
||||
else
|
||||
for (i, connStrength) in enumerate(n.synapticStrength)
|
||||
updatedConnStrength = synapticConnStrength(connStrength, "down")
|
||||
updatedConnStrength = GeneralUtils.limitvalue(updatedConnStrength,
|
||||
n.synapticStrengthLimit.lowerlimit, n.synapticStrengthLimit.upperlimit)
|
||||
# at lowerlimit, mark wRec at this position to 0. for new random synaptic conn
|
||||
if updatedConnStrength == n.synapticStrengthLimit.lowerlimit[1]
|
||||
n.wRec[i] = 0.0
|
||||
end
|
||||
n.synapticStrength[i] = updatedConnStrength
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
function synapticConnStrength!(n::inputNeuron) end
|
||||
|
||||
Reference in New Issue
Block a user