experimenting compute neuron association
This commit is contained in:
@@ -32,10 +32,10 @@ using .learn
|
|||||||
# using .interface
|
# using .interface
|
||||||
#------------------------------------------------------------------------------------------------100
|
#------------------------------------------------------------------------------------------------100
|
||||||
|
|
||||||
"""
|
""" version 0.0.2
|
||||||
Todo:
|
Todo:
|
||||||
|
[*2] implement connection strength based on right or wrong answer
|
||||||
[2] implement connection strength based on right or wrong answer
|
[*1] how to manage how much constrength increase and decrease
|
||||||
[4] implement dormant connection
|
[4] implement dormant connection
|
||||||
[3] Δweight * connection strength
|
[3] Δweight * connection strength
|
||||||
[] using RL to control learning signal
|
[] using RL to control learning signal
|
||||||
@@ -64,6 +64,13 @@ using .learn
|
|||||||
[DONE] add multi threads
|
[DONE] add multi threads
|
||||||
[DONE] during 0 training if 1-9 output neuron fires, adjust weight only those neurons
|
[DONE] during 0 training if 1-9 output neuron fires, adjust weight only those neurons
|
||||||
[DONE] add maximum weight cap of each connection
|
[DONE] add maximum weight cap of each connection
|
||||||
|
[DONE] weaker connection should be harder to increase strength. It requires a lot of
|
||||||
|
repeat activation to get it stronger. While strong connction requires a lot of
|
||||||
|
inactivation to get it weaker. The concept is strong connection will lock
|
||||||
|
correct neural pathway through repeated use of the right connection i.e. keep training
|
||||||
|
on the correct answer -> strengthen the right neural pathway (connections) ->
|
||||||
|
this correct neural pathway resist to change.
|
||||||
|
Not used connection should dissapear (forgetting).
|
||||||
|
|
||||||
Change from version: v06_36a
|
Change from version: v06_36a
|
||||||
-
|
-
|
||||||
|
|||||||
30
src/learn.jl
30
src/learn.jl
@@ -8,7 +8,7 @@ export learn!
|
|||||||
|
|
||||||
#------------------------------------------------------------------------------------------------100
|
#------------------------------------------------------------------------------------------------100
|
||||||
|
|
||||||
function learn!(m::model, modelRespond::Vector{Bool}, correctAnswer::Union{AbstractVector, Nothing})
|
function learn!(m::model, modelRespond::AbstractVector, correctAnswer::Union{AbstractVector, Nothing})
|
||||||
if correctAnswer === nothing
|
if correctAnswer === nothing
|
||||||
correctAnswer_I = BitArray(zeros(length(modelRespond)))
|
correctAnswer_I = BitArray(zeros(length(modelRespond)))
|
||||||
else
|
else
|
||||||
@@ -21,7 +21,7 @@ end
|
|||||||
""" knowledgeFn learn()
|
""" knowledgeFn learn()
|
||||||
"""
|
"""
|
||||||
function learn!(kfn::kfn_1, correctAnswer::BitVector)
|
function learn!(kfn::kfn_1, correctAnswer::BitVector)
|
||||||
# #TESTING compute kfn error for each neuron
|
# compute kfn error for each neuron
|
||||||
# outs = [n.z_t1 for n in kfn.outputNeuronsArray]
|
# outs = [n.z_t1 for n in kfn.outputNeuronsArray]
|
||||||
# for (i, out) in enumerate(outs)
|
# for (i, out) in enumerate(outs)
|
||||||
# if out != correctAnswer[i] # need to adjust weight
|
# if out != correctAnswer[i] # need to adjust weight
|
||||||
@@ -37,25 +37,25 @@ function learn!(kfn::kfn_1, correctAnswer::BitVector)
|
|||||||
# end
|
# end
|
||||||
# end
|
# end
|
||||||
|
|
||||||
#TESTING compute kfn error for each neuron
|
# compute kfn error for each neuron
|
||||||
outs = [n.z_t1 for n in kfn.outputNeuronsArray]
|
outs = [n.z_t1 for n in kfn.outputNeuronsArray]
|
||||||
for (i, out) in enumerate(outs)
|
for (i, out) in enumerate(outs)
|
||||||
if out != correctAnswer[i] # need to adjust weight
|
if out != correctAnswer[i] # need to adjust weight
|
||||||
kfnError = ( (kfn.outputNeuronsArray[i].v_th - kfn.outputNeuronsArray[i].vError) * 100.0 /
|
kfnError = ( (kfn.outputNeuronsArray[i].v_th - kfn.outputNeuronsArray[i].vError) *
|
||||||
kfn.outputNeuronsArray[i].v_th )
|
100.0 / kfn.outputNeuronsArray[i].v_th )
|
||||||
if correctAnswer[i] == 1 # output neuron that associated with correctAnswer
|
if correctAnswer[i] == 1 # output neuron that associated with correctAnswer
|
||||||
Threads.@threads for n in kfn.neuronsArray # multithread is not atomic and causing error
|
Threads.@threads for n in kfn.neuronsArray # multithread is not atomic and causing error
|
||||||
# for n in kfn.neuronsArray
|
# for n in kfn.neuronsArray
|
||||||
compute_wRecChange!(n, kfnError)
|
compute_wRecChange!(n, kfnError)
|
||||||
learn!(n, kfn.firedNeurons, kfn.nExInType)
|
learn!(n, kfn.firedNeurons, kfn.nExInType, correctAnswer[i])
|
||||||
end
|
end
|
||||||
compute_wRecChange!(kfn.outputNeuronsArray[i], kfnError)
|
compute_wRecChange!(kfn.outputNeuronsArray[i], kfnError)
|
||||||
learn!(kfn.outputNeuronsArray[i], kfn.firedNeurons, kfn.nExInType,
|
learn!(kfn.outputNeuronsArray[i], kfn.firedNeurons, kfn.nExInType,
|
||||||
kfn.kfnParams[:totalInputPort])
|
kfn.kfnParams[:totalInputPort], correctAnswer[i])
|
||||||
else # output neuron that is NOT associated with correctAnswer
|
else # output neuron that is NOT associated with correctAnswer
|
||||||
compute_wRecChange!(kfn.outputNeuronsArray[i], kfnError)
|
compute_wRecChange!(kfn.outputNeuronsArray[i], kfnError)
|
||||||
learn!(kfn.outputNeuronsArray[i], kfn.firedNeurons, kfn.nExInType,
|
learn!(kfn.outputNeuronsArray[i], kfn.firedNeurons, kfn.nExInType,
|
||||||
kfn.kfnParams[:totalInputPort])
|
kfn.kfnParams[:totalInputPort], correctAnswer[i])
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
@@ -94,12 +94,14 @@ function compute_wRecChange!(n::linearNeuron, error::Float64)
|
|||||||
reset_epsilonRec!(n)
|
reset_epsilonRec!(n)
|
||||||
end
|
end
|
||||||
|
|
||||||
function learn!(n::T, firedNeurons, nExInType) where T<:inputNeuron
|
function learn!(n::T, firedNeurons, nExInType, correctAnswer) where T<:inputNeuron
|
||||||
# skip
|
# skip
|
||||||
end
|
end
|
||||||
|
|
||||||
function learn!(n::T, firedNeurons, nExInType) where T<:computeNeuron
|
function learn!(n::T, firedNeurons, nExInType, correctAnswer) where T<:computeNeuron
|
||||||
wSign_0 = sign.(n.wRec) # original sign
|
wSign_0 = sign.(n.wRec) # original sign
|
||||||
|
#TESTING strong connection gets less weight change, weak connection gets more weight change
|
||||||
|
n.wRecChange .*= (connStrengthAdjust.(n.synapticStrength))
|
||||||
n.wRec += n.wRecChange # merge wRecChange into wRec
|
n.wRec += n.wRecChange # merge wRecChange into wRec
|
||||||
reset_wRecChange!(n)
|
reset_wRecChange!(n)
|
||||||
wSign_1 = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign
|
wSign_1 = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign
|
||||||
@@ -110,12 +112,14 @@ function learn!(n::T, firedNeurons, nExInType) where T<:computeNeuron
|
|||||||
n.wRec .*= nonFlipedSign
|
n.wRec .*= nonFlipedSign
|
||||||
capMaxWeight!(n.wRec) # cap maximum weight
|
capMaxWeight!(n.wRec) # cap maximum weight
|
||||||
|
|
||||||
synapticConnStrength!(n)
|
synapticConnStrength!(n, correctAnswer)
|
||||||
neuroplasticity!(n, firedNeurons, nExInType)
|
neuroplasticity!(n, firedNeurons, nExInType)
|
||||||
end
|
end
|
||||||
|
|
||||||
function learn!(n::T, firedNeurons, nExInType, totalInputPort) where T<:outputNeuron
|
function learn!(n::T, firedNeurons, nExInType, totalInputPort, correctAnswer) where T<:outputNeuron
|
||||||
wSign_0 = sign.(n.wRec) # original sign
|
wSign_0 = sign.(n.wRec) # original sign
|
||||||
|
#TESTING strong connection gets less weight change, weak connection gets more weight change
|
||||||
|
n.wRecChange .*= (connStrengthAdjust.(n.synapticStrength))
|
||||||
n.wRec += n.wRecChange
|
n.wRec += n.wRecChange
|
||||||
reset_wRecChange!(n)
|
reset_wRecChange!(n)
|
||||||
wSign_1 = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign
|
wSign_1 = sign.(n.wRec) # check for fliped sign, 1 indicates non-fliped sign
|
||||||
@@ -126,7 +130,7 @@ function learn!(n::T, firedNeurons, nExInType, totalInputPort) where T<:outputNe
|
|||||||
n.wRec .*= nonFlipedSign
|
n.wRec .*= nonFlipedSign
|
||||||
capMaxWeight!(n.wRec) # cap maximum weight
|
capMaxWeight!(n.wRec) # cap maximum weight
|
||||||
|
|
||||||
synapticConnStrength!(n)
|
synapticConnStrength!(n, correctAnswer)
|
||||||
neuroplasticity!(n,firedNeurons, nExInType, totalInputPort)
|
neuroplasticity!(n,firedNeurons, nExInType, totalInputPort)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|||||||
111
src/snn_utils.jl
111
src/snn_utils.jl
@@ -6,7 +6,7 @@ export calculate_α, calculate_ρ, calculate_k, timestep_forward!, init_neuron,
|
|||||||
reset_epsilonRecA!, synapticConnStrength!, normalizePeak!, reset_wRecChange!,
|
reset_epsilonRecA!, synapticConnStrength!, normalizePeak!, reset_wRecChange!,
|
||||||
firing_rate_error!, firing_rate_regulator!, update_Bn!, cal_firing_reg!,
|
firing_rate_error!, firing_rate_regulator!, update_Bn!, cal_firing_reg!,
|
||||||
neuroplasticity!, shakeup!, reset_learning_no_wchange!, adjust_internal_learning_rate!,
|
neuroplasticity!, shakeup!, reset_learning_no_wchange!, adjust_internal_learning_rate!,
|
||||||
gradient_withloss, capMaxWeight!
|
gradient_withloss, capMaxWeight!, connStrengthAdjust
|
||||||
|
|
||||||
using Statistics, Random, LinearAlgebra, Distributions, Zygote, Flux
|
using Statistics, Random, LinearAlgebra, Distributions, Zygote, Flux
|
||||||
using GeneralUtils
|
using GeneralUtils
|
||||||
@@ -257,46 +257,111 @@ function connStrengthAdjust(currentStrength::Float64)
|
|||||||
return Δstrength::Float64
|
return Δstrength::Float64
|
||||||
end
|
end
|
||||||
|
|
||||||
""" compute synaptic connection strength. bias will shift currentStrength to fit into
|
""" Compute synaptic connection strength. bias will shift currentStrength to fit into
|
||||||
sigmoid operating range which centred at 0 and range is -37 to 37.
|
sigmoid operating range which centred at 0 and range is -37 to 37.
|
||||||
|
|
||||||
# Example
|
# Example
|
||||||
synaptic strength range is 0 to 10
|
synaptic strength range is 0 to 10
|
||||||
one may use bias = -5 to transform synaptic strength into range -5 to 5
|
one may use bias = -5 to transform synaptic strength into range -5 to 5
|
||||||
the return value is shifted back to original scale
|
the return value is shifted back to original scale.
|
||||||
|
|
||||||
|
# Concept
|
||||||
|
weaker connection should be harder to increase strength. It requires a lot of
|
||||||
|
repeat activation to get it stronger. While strong connction requires a lot of
|
||||||
|
inactivation to get it weaker. The concept is strong connection will lock
|
||||||
|
correct neural pathway through repeated use of the right connection i.e. keep training
|
||||||
|
on the correct answer -> strengthen the right neural pathway (connections) ->
|
||||||
|
this correct neural pathway resist to change.
|
||||||
|
Not used connection should dissapear (forgetting).
|
||||||
"""
|
"""
|
||||||
function synapticConnStrength(currentStrength::Float64, updown::String)
|
function synapticConnStrength(currentStrength::Float64, updown::String)
|
||||||
Δstrength = connStrengthAdjust(currentStrength)
|
Δstrength = connStrengthAdjust(currentStrength)
|
||||||
|
|
||||||
if updown == "up"
|
if updown == "up"
|
||||||
updatedStrength = currentStrength + Δstrength
|
if currentStrength > 4 # strong connection
|
||||||
|
updatedStrength = currentStrength + Δstrength
|
||||||
|
else
|
||||||
|
updatedStrength = currentStrength + (Δstrength * 0.01)
|
||||||
|
end
|
||||||
|
elseif updown == "down"
|
||||||
|
if currentStrength > 4
|
||||||
|
updatedStrength = currentStrength - (Δstrength * 0.5)
|
||||||
|
else
|
||||||
|
updatedStrength = currentStrength - Δstrength
|
||||||
|
end
|
||||||
else
|
else
|
||||||
updatedStrength = currentStrength - Δstrength
|
error("undefined condition line $(@__LINE__)")
|
||||||
end
|
end
|
||||||
return updatedStrength::Float64
|
return updatedStrength::Float64
|
||||||
end
|
end
|
||||||
|
# function synapticConnStrength(currentStrength::Float64, updown::String)
|
||||||
|
# Δstrength = connStrengthAdjust(currentStrength)
|
||||||
|
|
||||||
|
# if updown == "up"
|
||||||
|
# updatedStrength = currentStrength + Δstrength
|
||||||
|
# else
|
||||||
|
# updatedStrength = currentStrength - Δstrength
|
||||||
|
# end
|
||||||
|
# return updatedStrength::Float64
|
||||||
|
# end
|
||||||
|
|
||||||
""" Compute all synaptic connection strength of a neuron. Also mark n.wRec to 0 if wRec goes
|
""" Compute all synaptic connection strength of a neuron. Also mark n.wRec to 0 if wRec goes
|
||||||
below lowerlimit.
|
below lowerlimit.
|
||||||
"""
|
"""
|
||||||
function synapticConnStrength!(n::Union{computeNeuron, outputNeuron})
|
# function synapticConnStrength!(n::Union{computeNeuron, outputNeuron})
|
||||||
for (i, connStrength) in enumerate(n.synapticStrength)
|
# for (i, connStrength) in enumerate(n.synapticStrength)
|
||||||
# check whether connStrength increase or decrease based on usage from n.epsilonRec
|
# # check whether connStrength increase or decrease based on usage from n.epsilonRec
|
||||||
""" use n.z_i_t_commulative instead of the best choice, epsilonRec, here because ΔwRecChange
|
# """ use n.z_i_t_commulative instead of the best choice, epsilonRec, here because ΔwRecChange
|
||||||
calculation in learn!() will reset epsilonRec to zeroes vector in case where
|
# calculation in learn!() will reset epsilonRec to zeroes vector in case where
|
||||||
output neuron fires and trigger learn!() just before this synapticConnStrength
|
# output neuron fires and trigger learn!() just before this synapticConnStrength
|
||||||
calculation.
|
# calculation.
|
||||||
Since n.z_i_t_commulative indicates whether a synaptic connection were used or not, it is
|
# Since n.z_i_t_commulative indicates whether a synaptic connection were used or not, it is
|
||||||
ok to use. n.z_i_t_commulative also span across a training sample without resetting.
|
# ok to use. n.z_i_t_commulative also span across a training sample without resetting.
|
||||||
"""
|
# """
|
||||||
updown = n.z_i_t_commulative[i] == 0 ? "down" : "up" #
|
# updown = n.z_i_t_commulative[i] == 0 ? "down" : "up" #
|
||||||
updatedConnStrength = synapticConnStrength(connStrength, updown)
|
# updatedConnStrength = synapticConnStrength(connStrength, updown)
|
||||||
updatedConnStrength = GeneralUtils.limitvalue(updatedConnStrength,
|
# updatedConnStrength = GeneralUtils.limitvalue(updatedConnStrength,
|
||||||
n.synapticStrengthLimit.lowerlimit, n.synapticStrengthLimit.upperlimit)
|
# n.synapticStrengthLimit.lowerlimit, n.synapticStrengthLimit.upperlimit)
|
||||||
# at lowerlimit, mark wRec at this position to 0. for new random synaptic conn
|
# # at lowerlimit, mark wRec at this position to 0. for new random synaptic conn
|
||||||
if updatedConnStrength == n.synapticStrengthLimit.lowerlimit[1]
|
# if updatedConnStrength == n.synapticStrengthLimit.lowerlimit[1]
|
||||||
n.wRec[i] = 0.0
|
# n.wRec[i] = 0.0
|
||||||
|
# end
|
||||||
|
# n.synapticStrength[i] = updatedConnStrength
|
||||||
|
# end
|
||||||
|
# end
|
||||||
|
|
||||||
|
function synapticConnStrength!(n::Union{computeNeuron, outputNeuron}, correctAnswer::Bool)
|
||||||
|
if correctAnswer == true
|
||||||
|
for (i, connStrength) in enumerate(n.synapticStrength)
|
||||||
|
# check whether connStrength increase or decrease based on usage from n.epsilonRec
|
||||||
|
""" use n.z_i_t_commulative instead of the best choice, epsilonRec, here because ΔwRecChange
|
||||||
|
calculation in learn!() will reset epsilonRec to zeroes vector in case where
|
||||||
|
output neuron fires and trigger learn!() just before this synapticConnStrength
|
||||||
|
calculation.
|
||||||
|
Since n.z_i_t_commulative indicates whether a synaptic connection were used or not, it is
|
||||||
|
ok to use. n.z_i_t_commulative also span across a training sample without resetting.
|
||||||
|
"""
|
||||||
|
updown = n.z_i_t_commulative[i] == 0 ? "down" : "up"
|
||||||
|
updatedConnStrength = synapticConnStrength(connStrength, updown)
|
||||||
|
updatedConnStrength = GeneralUtils.limitvalue(updatedConnStrength,
|
||||||
|
n.synapticStrengthLimit.lowerlimit, n.synapticStrengthLimit.upperlimit)
|
||||||
|
# at lowerlimit, mark wRec at this position to 0. for new random synaptic conn
|
||||||
|
if updatedConnStrength == n.synapticStrengthLimit.lowerlimit[1]
|
||||||
|
n.wRec[i] = 0.0
|
||||||
|
end
|
||||||
|
n.synapticStrength[i] = updatedConnStrength
|
||||||
|
end
|
||||||
|
else
|
||||||
|
for (i, connStrength) in enumerate(n.synapticStrength)
|
||||||
|
updatedConnStrength = synapticConnStrength(connStrength, "down")
|
||||||
|
updatedConnStrength = GeneralUtils.limitvalue(updatedConnStrength,
|
||||||
|
n.synapticStrengthLimit.lowerlimit, n.synapticStrengthLimit.upperlimit)
|
||||||
|
# at lowerlimit, mark wRec at this position to 0. for new random synaptic conn
|
||||||
|
if updatedConnStrength == n.synapticStrengthLimit.lowerlimit[1]
|
||||||
|
n.wRec[i] = 0.0
|
||||||
|
end
|
||||||
|
n.synapticStrength[i] = updatedConnStrength
|
||||||
end
|
end
|
||||||
n.synapticStrength[i] = updatedConnStrength
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user