add multithreading
This commit is contained in:
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
julia_version = "1.9.0"
|
julia_version = "1.9.0"
|
||||||
manifest_format = "2.0"
|
manifest_format = "2.0"
|
||||||
project_hash = "6da2bd801ebd94457c5a5cb36ae71250437066e8"
|
project_hash = "b9e7ae4b78dc59a5adb629a04e856c4fedc6fb60"
|
||||||
|
|
||||||
[[deps.AbstractFFTs]]
|
[[deps.AbstractFFTs]]
|
||||||
deps = ["LinearAlgebra"]
|
deps = ["LinearAlgebra"]
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ version = "0.1.0"
|
|||||||
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
|
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
|
||||||
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
|
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
|
||||||
GeneralUtils = "c6c72f09-b708-4ac8-ac7c-2084d70108fe"
|
GeneralUtils = "c6c72f09-b708-4ac8-ac7c-2084d70108fe"
|
||||||
|
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
|
||||||
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
|
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
|
||||||
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
|
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
|
||||||
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
|
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ module forward
|
|||||||
|
|
||||||
using Flux.Optimise: apply!
|
using Flux.Optimise: apply!
|
||||||
|
|
||||||
using Statistics, Flux, Random, LinearAlgebra
|
using Statistics, Flux, Random, LinearAlgebra, JSON3
|
||||||
using GeneralUtils
|
using GeneralUtils
|
||||||
using ..types, ..snn_utils
|
using ..types, ..snn_utils
|
||||||
|
|
||||||
@@ -77,8 +77,8 @@ function (kfn::kfn_1)(m::model, input_data::AbstractVector)
|
|||||||
|
|
||||||
kfn.firedNeurons_t0 = [n.z_t for n in kfn.neuronsArray] #TODO check if it is used?
|
kfn.firedNeurons_t0 = [n.z_t for n in kfn.neuronsArray] #TODO check if it is used?
|
||||||
|
|
||||||
#CHANGE Threads.@threads for n in kfn.neuronsArray
|
Threads.@threads for n in kfn.neuronsArray
|
||||||
for n in kfn.neuronsArray
|
# for n in kfn.neuronsArray
|
||||||
n(kfn)
|
n(kfn)
|
||||||
end
|
end
|
||||||
|
|
||||||
@@ -88,8 +88,8 @@ function (kfn::kfn_1)(m::model, input_data::AbstractVector)
|
|||||||
kfn.firedNeurons |> unique! # use for random new neuron connection
|
kfn.firedNeurons |> unique! # use for random new neuron connection
|
||||||
end
|
end
|
||||||
|
|
||||||
# Threads.@threads for n in kfn.outputNeuronsArray
|
Threads.@threads for n in kfn.outputNeuronsArray
|
||||||
for n in kfn.outputNeuronsArray
|
# for n in kfn.outputNeuronsArray
|
||||||
n(kfn)
|
n(kfn)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|||||||
10
src/learn.jl
10
src/learn.jl
@@ -2,7 +2,7 @@ module learn
|
|||||||
|
|
||||||
using Flux.Optimise: apply!
|
using Flux.Optimise: apply!
|
||||||
|
|
||||||
using Statistics, Flux, Random, LinearAlgebra
|
using Statistics, Flux, Random, LinearAlgebra, JSON3
|
||||||
using GeneralUtils
|
using GeneralUtils
|
||||||
using ..types, ..snn_utils
|
using ..types, ..snn_utils
|
||||||
|
|
||||||
@@ -30,8 +30,8 @@ function learn!(kfn::kfn_1, correctAnswer::AbstractVector)
|
|||||||
kfnError = (kfn.outputNeuronsArray[i].v_th - kfn.outputNeuronsArray[i].vError) *
|
kfnError = (kfn.outputNeuronsArray[i].v_th - kfn.outputNeuronsArray[i].vError) *
|
||||||
100 / kfn.outputNeuronsArray[i].v_th
|
100 / kfn.outputNeuronsArray[i].v_th
|
||||||
|
|
||||||
# Threads.@threads for n in kfn.neuronsArray
|
Threads.@threads for n in kfn.neuronsArray
|
||||||
for n in kfn.neuronsArray
|
# for n in kfn.neuronsArray
|
||||||
learn!(n, kfnError)
|
learn!(n, kfnError)
|
||||||
end
|
end
|
||||||
|
|
||||||
@@ -41,8 +41,8 @@ function learn!(kfn::kfn_1, correctAnswer::AbstractVector)
|
|||||||
|
|
||||||
# wrap up learning session
|
# wrap up learning session
|
||||||
if kfn.learningStage == "end_learning"
|
if kfn.learningStage == "end_learning"
|
||||||
# Threads.@threads for n in kfn.neuronsArray
|
Threads.@threads for n in kfn.neuronsArray
|
||||||
for n in kfn.neuronsArray
|
# for n in kfn.neuronsArray
|
||||||
if typeof(n) <: computeNeuron
|
if typeof(n) <: computeNeuron
|
||||||
wSign_0 = sign.(n.wRec) # original sign
|
wSign_0 = sign.(n.wRec) # original sign
|
||||||
n.wRec += n.wRecChange # merge wRecChange into wRec
|
n.wRec += n.wRecChange # merge wRecChange into wRec
|
||||||
|
|||||||
Reference in New Issue
Block a user