experimenting compute neuron association
This commit is contained in:
@@ -34,6 +34,7 @@ using .interface
|
||||
|
||||
"""
|
||||
Todo:
|
||||
[*1] add maximum weight cap of each connection
|
||||
[2] implement connection strength based on right or wrong answer
|
||||
[4] implement dormant connection
|
||||
[3] Δweight * connection strength
|
||||
|
||||
@@ -14,11 +14,11 @@ function (m::model)(input_data::AbstractVector)
|
||||
m.timeStep += 1
|
||||
|
||||
# process all corresponding KFN
|
||||
raw_model_respond, outputNeuron_v_t1 = m.knowledgeFn[:I](m, input_data)
|
||||
# raw_model_respond, outputNeuron_v_t1, firedNeurons_t1 = m.knowledgeFn[:I](m, input_data)
|
||||
|
||||
# the 2nd return (KFN error) should not be used as model error but I use it because there is
|
||||
# only one KFN in a model right now
|
||||
return raw_model_respond::Array{Bool}, outputNeuron_v_t1::Array{Float64}
|
||||
return m.knowledgeFn[:I](m, input_data)
|
||||
end
|
||||
|
||||
#------------------------------------------------------------------------------------------------100
|
||||
@@ -96,7 +96,8 @@ function (kfn::kfn_1)(m::model, input_data::AbstractVector)
|
||||
out = [n.z_t1 for n in kfn.outputNeuronsArray]
|
||||
outputNeuron_v_t1 = [n.v_t1 for n in kfn.outputNeuronsArray]
|
||||
|
||||
return out::Array{Bool}, outputNeuron_v_t1::Array{Float64}
|
||||
return out::Array{Bool}, outputNeuron_v_t1::Array{Float64}, sum(kfn.firedNeurons_t1),
|
||||
kfn.exSignalSum, kfn.inSignalsum
|
||||
end
|
||||
|
||||
#------------------------------------------------------------------------------------------------100
|
||||
@@ -220,7 +221,18 @@ function (n::linearNeuron)(kfn::T) where T<:knowledgeFn
|
||||
n.v_t1 = n.alpha * n.v_t
|
||||
n.vError = n.v_t1 # store voltage that will be used to calculate error later
|
||||
else
|
||||
n.recSignal = sum(n.wRec .* n.z_i_t) # signal from other neuron that this neuron subscribed
|
||||
recSignal = n.wRec .* n.z_i_t
|
||||
if n.id == 1 #FIXME debugging output neuron dead
|
||||
for i in recSignal
|
||||
if i > 0
|
||||
kfn.exSignalSum += i
|
||||
elseif i < 0
|
||||
kfn.inSignalsum += i
|
||||
else
|
||||
end
|
||||
end
|
||||
end
|
||||
n.recSignal = sum(recSignal) # signal from other neuron that this neuron subscribed
|
||||
n.alpha_v_t = n.alpha * n.v_t
|
||||
n.v_t1 = n.alpha_v_t + n.recSignal
|
||||
n.v_t1 = no_negative!(n.v_t1)
|
||||
|
||||
52
src/learn.jl
52
src/learn.jl
@@ -23,41 +23,41 @@ end
|
||||
""" knowledgeFn learn()
|
||||
"""
|
||||
function learn!(kfn::kfn_1, correctAnswer::BitVector)
|
||||
# compute kfn error for each neuron
|
||||
outs = [n.z_t1 for n in kfn.outputNeuronsArray]
|
||||
for (i, out) in enumerate(outs)
|
||||
if out != correctAnswer[i] # need to adjust weight
|
||||
kfnError = ( (kfn.outputNeuronsArray[i].v_th - kfn.outputNeuronsArray[i].vError) *
|
||||
100 / kfn.outputNeuronsArray[i].v_th )
|
||||
|
||||
Threads.@threads for n in kfn.neuronsArray
|
||||
# for n in kfn.neuronsArray
|
||||
learn!(n, kfnError)
|
||||
end
|
||||
|
||||
learn!(kfn.outputNeuronsArray[i], kfnError)
|
||||
end
|
||||
end
|
||||
|
||||
# #TESTING compute kfn error for each neuron
|
||||
# outs = [n.z_t1 for n in kfn.outputNeuronsArray]
|
||||
# for (i, out) in enumerate(outs)
|
||||
# if out != correctAnswer[i] # need to adjust weight
|
||||
# kfnError = ( (kfn.outputNeuronsArray[i].v_th - kfn.outputNeuronsArray[i].vError) *
|
||||
# 100 / kfn.outputNeuronsArray[i].v_th )
|
||||
# if correctAnswer[i] == 1 # output neuron that associated with correctAnswer
|
||||
# Threads.@threads for n in kfn.neuronsArray
|
||||
# # for n in kfn.neuronsArray
|
||||
# learn!(n, kfnError)
|
||||
# end
|
||||
|
||||
# learn!(kfn.outputNeuronsArray[i], kfnError)
|
||||
# else # output neuron that is NOT associated with correctAnswer
|
||||
# learn!(kfn.outputNeuronsArray[i], kfnError)
|
||||
# Threads.@threads for n in kfn.neuronsArray
|
||||
# # for n in kfn.neuronsArray
|
||||
# learn!(n, kfnError)
|
||||
# end
|
||||
|
||||
# learn!(kfn.outputNeuronsArray[i], kfnError)
|
||||
# end
|
||||
# end
|
||||
|
||||
#TESTING compute kfn error for each neuron
|
||||
outs = [n.z_t1 for n in kfn.outputNeuronsArray]
|
||||
for (i, out) in enumerate(outs)
|
||||
if out != correctAnswer[i] # need to adjust weight
|
||||
kfnError = ( (kfn.outputNeuronsArray[i].v_th - kfn.outputNeuronsArray[i].vError) /
|
||||
kfn.outputNeuronsArray[i].v_th )
|
||||
if correctAnswer[i] == 1 # output neuron that associated with correctAnswer
|
||||
Threads.@threads for n in kfn.neuronsArray
|
||||
# for n in kfn.neuronsArray
|
||||
learn!(n, kfnError)
|
||||
end
|
||||
|
||||
learn!(kfn.outputNeuronsArray[i], kfnError)
|
||||
else # output neuron that is NOT associated with correctAnswer
|
||||
learn!(kfn.outputNeuronsArray[i], kfnError)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
# wrap up learning session
|
||||
if kfn.learningStage == "end_learning"
|
||||
Threads.@threads for n in kfn.neuronsArray
|
||||
@@ -69,8 +69,9 @@ function learn!(kfn::kfn_1, correctAnswer::BitVector)
|
||||
nonFlipedSign = isequal.(wSign_0, wSign_1) # 1 not fliped, 0 fliped
|
||||
# normalize wRec peak to prevent input signal overwhelming neuron
|
||||
normalizePeak!(n.wRec, n.wRecChange, 2)
|
||||
# set weight that fliped sign to 0 for random new connection
|
||||
# set weight that fliped sign to 0 for random new connection
|
||||
n.wRec .*= nonFlipedSign
|
||||
capMaxWeight!(n.wRec) # cap maximum weight
|
||||
|
||||
synapticConnStrength!(n)
|
||||
neuroplasticity!(n, kfn.firedNeurons, kfn.nExInType)
|
||||
@@ -84,6 +85,7 @@ function learn!(kfn::kfn_1, correctAnswer::BitVector)
|
||||
nonFlipedSign = isequal.(wSign_0, wSign_1) # 1 not fliped, 0 fliped
|
||||
normalizePeak!(n.wRec, n.wRecChange, 2)
|
||||
n.wRec .*= nonFlipedSign # set weight that fliped sign to 0 for random new connection
|
||||
capMaxWeight!(n.wRec) # cap maximum weight
|
||||
|
||||
synapticConnStrength!(n)
|
||||
neuroplasticity!(n, kfn.firedNeurons, kfn.nExInType, kfn.kfnParams[:totalInputPort])
|
||||
|
||||
@@ -7,7 +7,7 @@ export calculate_α, calculate_ρ, calculate_k, timestep_forward!, init_neuron,
|
||||
reset_epsilonRecA!, synapticConnStrength!, normalizePeak!,
|
||||
firing_rate_error!, firing_rate_regulator!, update_Bn!, cal_firing_reg!,
|
||||
neuroplasticity!, shakeup!, reset_learning_no_wchange!, adjust_internal_learning_rate!,
|
||||
gradient_withloss
|
||||
gradient_withloss, capMaxWeight!
|
||||
|
||||
using Statistics, Random, LinearAlgebra, Distributions, Zygote, Flux
|
||||
using GeneralUtils
|
||||
@@ -283,12 +283,12 @@ end
|
||||
function synapticConnStrength!(n::Union{computeNeuron, outputNeuron})
|
||||
for (i, connStrength) in enumerate(n.synapticStrength)
|
||||
# check whether connStrength increase or decrease based on usage from n.epsilonRec
|
||||
""" use n.wRecChange instead of the best choise, epsilonRec, here because ΔwRecChange
|
||||
""" use n.z_i_t_commulative instead of the best choice, epsilonRec, here because ΔwRecChange
|
||||
calculation in learn!() will reset epsilonRec to zeroes vector in case where
|
||||
output neuron fires and trigger learn!() just before this synapticConnStrength
|
||||
calculation.
|
||||
Since n.wRecChange indicates whether a synaptic connection were used or not, it is
|
||||
ok to use. n.wRecChange also span across a training sample without resetting.
|
||||
Since n.z_i_t_commulative indicates whether a synaptic connection were used or not, it is
|
||||
ok to use. n.z_i_t_commulative also span across a training sample without resetting.
|
||||
"""
|
||||
updown = n.z_i_t_commulative[i] == 0 ? "down" : "up" #
|
||||
updatedConnStrength = synapticConnStrength(connStrength, updown)
|
||||
@@ -441,7 +441,12 @@ function neuroplasticity!(n::outputNeuron, firedNeurons::Vector,
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
""" Cap maximum weight of each neuron connection
|
||||
"""
|
||||
function capMaxWeight!(v::Vector{Float64}, max=1.0)
|
||||
originalSign = sign.(v)
|
||||
v = originalSign .* GeneralUtils.replaceMoreThan.(abs.(v), max)
|
||||
end
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -115,6 +115,9 @@ Base.@kwdef mutable struct kfn_1 <: knowledgeFn
|
||||
nInhabitory::Array{Int64} = Int64[] # list of inhabitory neuron id
|
||||
nExInType::Array{Int64} = Int64[] # list all neuron EX or IN
|
||||
excitatoryPercent::Int64 = 60 # percentage of excitatory neuron, inhabitory percent will be 100-ExcitatoryPercent
|
||||
|
||||
exSignalSum = 0
|
||||
inSignalsum = 0
|
||||
end
|
||||
|
||||
#------------------------------------------------------------------------------------------------100
|
||||
@@ -347,7 +350,7 @@ Base.@kwdef mutable struct lifNeuron <: computeNeuron
|
||||
refractoryDuration::Int64 = 3 # neuron's refratory period in millisecond
|
||||
refractoryCounter::Int64 = 0
|
||||
tau_m::Float64 = 0.0 # τ_m, membrane time constant in millisecond
|
||||
eta::Float64 = 0.01 # η, learning rate
|
||||
eta::Float64 = 0.0001 # η, learning rate
|
||||
wRecChange::Array{Float64} = Float64[] # Δw_rec, cumulated wRec change
|
||||
recSignal::Float64 = 0.0 # incoming recurrent signal
|
||||
alpha_v_t::Float64 = 0.0 # alpha * v_t
|
||||
@@ -435,7 +438,7 @@ Base.@kwdef mutable struct alifNeuron <: computeNeuron
|
||||
eRec_v::Array{Float64} = Float64[] # a component of neuron's eligibility trace resulted from v_t
|
||||
eRec_a::Array{Float64} = Float64[] # a component of neuron's eligibility trace resulted from av_th
|
||||
eRec::Array{Float64} = Float64[] # neuron's eligibility trace
|
||||
eta::Float64 = 0.01 # eta, learning rate
|
||||
eta::Float64 = 0.0001 # eta, learning rate
|
||||
gammaPd::Float64 = 0.3 # γ_pd, discount factor, value from paper
|
||||
phi::Float64 = 0.0 # ϕ, psuedo derivative
|
||||
refractoryDuration::Int64 = 3 # neuron's refractory period in millisecond
|
||||
@@ -545,7 +548,7 @@ Base.@kwdef mutable struct linearNeuron <: outputNeuron
|
||||
refractoryDuration::Int64 = 3 # neuron's refratory period in millisecond
|
||||
refractoryCounter::Int64 = 0
|
||||
tau_out::Float64 = 0.0 # τ_out, membrane time constant in millisecond
|
||||
eta::Float64 = 0.01 # η, learning rate
|
||||
eta::Float64 = 0.0001 # η, learning rate
|
||||
wRecChange::Array{Float64} = Float64[] # Δw_rec, cumulated wRec change
|
||||
recSignal::Float64 = 0.0 # incoming recurrent signal
|
||||
alpha_v_t::Float64 = 0.0 # alpha * v_t
|
||||
|
||||
Reference in New Issue
Block a user