building forward()

This commit is contained in:
ton
2023-07-10 21:02:12 +07:00
parent d427875679
commit 3482e87892
4 changed files with 194 additions and 17 deletions

View File

@@ -2,7 +2,7 @@
julia_version = "1.9.2" julia_version = "1.9.2"
manifest_format = "2.0" manifest_format = "2.0"
project_hash = "bf29b9b6c22b1a96a3e10687abec46c3de1b7715" project_hash = "1a1cddac46fdd2108611b4e2f350497572f0c8d4"
[[deps.AbstractFFTs]] [[deps.AbstractFFTs]]
deps = ["LinearAlgebra"] deps = ["LinearAlgebra"]
@@ -111,6 +111,12 @@ git-tree-sha1 = "2918fbffb50e3b7a0b9127617587afa76d4276e8"
uuid = "62b44479-cb7b-5706-934f-f13b2eb2e645" uuid = "62b44479-cb7b-5706-934f-f13b2eb2e645"
version = "8.8.1+0" version = "8.8.1+0"
[[deps.Calculus]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "f641eb0a4f00c343bbc32346e1217b86f3ce9dad"
uuid = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9"
version = "0.5.1"
[[deps.ChainRules]] [[deps.ChainRules]]
deps = ["Adapt", "ChainRulesCore", "Compat", "Distributed", "GPUArraysCore", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "Statistics", "StructArrays"] deps = ["Adapt", "ChainRulesCore", "Compat", "Distributed", "GPUArraysCore", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "Statistics", "StructArrays"]
git-tree-sha1 = "1cdf290d4feec68824bfb84f4bfc9f3aba185647" git-tree-sha1 = "1cdf290d4feec68824bfb84f4bfc9f3aba185647"
@@ -222,6 +228,20 @@ version = "1.15.1"
deps = ["Random", "Serialization", "Sockets"] deps = ["Random", "Serialization", "Sockets"]
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
[[deps.Distributions]]
deps = ["FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SparseArrays", "SpecialFunctions", "Statistics", "StatsAPI", "StatsBase", "StatsFuns", "Test"]
git-tree-sha1 = "e76a3281de2719d7c81ed62c6ea7057380c87b1d"
uuid = "31c24e10-a181-5473-b8eb-7969acd0382f"
version = "0.25.98"
[deps.Distributions.extensions]
DistributionsChainRulesCoreExt = "ChainRulesCore"
DistributionsDensityInterfaceExt = "DensityInterface"
[deps.Distributions.weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d"
[[deps.DocStringExtensions]] [[deps.DocStringExtensions]]
deps = ["LibGit2"] deps = ["LibGit2"]
git-tree-sha1 = "2fb1e02f2b635d0845df5d7c167fec4dd739b00d" git-tree-sha1 = "2fb1e02f2b635d0845df5d7c167fec4dd739b00d"
@@ -233,6 +253,12 @@ deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"]
uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
version = "1.6.0" version = "1.6.0"
[[deps.DualNumbers]]
deps = ["Calculus", "NaNMath", "SpecialFunctions"]
git-tree-sha1 = "5837a837389fccf076445fce071c8ddaea35a566"
uuid = "fa6b7ba4-c1ee-5f82-b5fc-ecf0adba8f74"
version = "0.6.8"
[[deps.ExprTools]] [[deps.ExprTools]]
git-tree-sha1 = "c1d06d129da9f55715c6c212866f5b1bddc5fa00" git-tree-sha1 = "c1d06d129da9f55715c6c212866f5b1bddc5fa00"
uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04" uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
@@ -311,6 +337,18 @@ git-tree-sha1 = "cb090aea21c6ca78d59672a7e7d13bd56d09de64"
uuid = "61eb1bfa-7361-4325-ad38-22787b887f55" uuid = "61eb1bfa-7361-4325-ad38-22787b887f55"
version = "0.20.3" version = "0.20.3"
[[deps.GeneralUtils]]
deps = ["CUDA", "DataStructures", "Distributions", "Flux", "JSON3"]
path = "C:\\Users\\naraw\\.julia\\dev\\GeneralUtils"
uuid = "c6c72f09-b708-4ac8-ac7c-2084d70108fe"
version = "0.1.0"
[[deps.HypergeometricFunctions]]
deps = ["DualNumbers", "LinearAlgebra", "OpenLibm_jll", "SpecialFunctions"]
git-tree-sha1 = "ce7ea9cc5db29563b1fe20196b6d23ab3b111384"
uuid = "34004b35-14d8-5ef3-9330-4cdb6864b03a"
version = "0.3.18"
[[deps.IRTools]] [[deps.IRTools]]
deps = ["InteractiveUtils", "MacroTools", "Test"] deps = ["InteractiveUtils", "MacroTools", "Test"]
git-tree-sha1 = "eac00994ce3229a464c2847e956d77a2c64ad3a5" git-tree-sha1 = "eac00994ce3229a464c2847e956d77a2c64ad3a5"
@@ -342,6 +380,12 @@ git-tree-sha1 = "abc9885a7ca2052a736a600f7fa66209f96506e1"
uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210"
version = "1.4.1" version = "1.4.1"
[[deps.JSON3]]
deps = ["Dates", "Mmap", "Parsers", "PrecompileTools", "StructTypes", "UUIDs"]
git-tree-sha1 = "5b62d93f2582b09e469b3099d839c2d2ebf5066d"
uuid = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
version = "1.13.1"
[[deps.JuliaVariables]] [[deps.JuliaVariables]]
deps = ["MLStyle", "NameResolution"] deps = ["MLStyle", "NameResolution"]
git-tree-sha1 = "49fb3cb53362ddadb4415e9b73926d6b40709e70" git-tree-sha1 = "49fb3cb53362ddadb4415e9b73926d6b40709e70"
@@ -527,6 +571,18 @@ git-tree-sha1 = "d321bf2de576bf25ec4d3e4360faca399afca282"
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
version = "1.6.0" version = "1.6.0"
[[deps.PDMats]]
deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"]
git-tree-sha1 = "67eae2738d63117a196f497d7db789821bce61d1"
uuid = "90014a1f-27ba-587c-ab20-58faa44d9150"
version = "0.11.17"
[[deps.Parsers]]
deps = ["Dates", "PrecompileTools", "UUIDs"]
git-tree-sha1 = "4b2e829ee66d4218e0cef22c0a64ee37cf258c29"
uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0"
version = "2.7.1"
[[deps.Pkg]] [[deps.Pkg]]
deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"]
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
@@ -559,6 +615,12 @@ git-tree-sha1 = "80d919dee55b9c50e8d9e2da5eeafff3fe58b539"
uuid = "33c8b6b6-d38a-422a-b730-caa89a2f386c" uuid = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
version = "0.1.4" version = "0.1.4"
[[deps.QuadGK]]
deps = ["DataStructures", "LinearAlgebra"]
git-tree-sha1 = "6ec7ac8412e83d57e313393220879ede1740f9ee"
uuid = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
version = "2.8.2"
[[deps.REPL]] [[deps.REPL]]
deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"]
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
@@ -596,6 +658,18 @@ git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7"
uuid = "ae029012-a4dd-5104-9daa-d747884805df" uuid = "ae029012-a4dd-5104-9daa-d747884805df"
version = "1.3.0" version = "1.3.0"
[[deps.Rmath]]
deps = ["Random", "Rmath_jll"]
git-tree-sha1 = "f65dcb5fa46aee0cf9ed6274ccbd597adc49aa7b"
uuid = "79098fc4-a85e-5d69-aa6a-4863f24498fa"
version = "0.7.1"
[[deps.Rmath_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
git-tree-sha1 = "6ed52fdd3382cf21947b15e8870ac0ddbff736da"
uuid = "f50d1b31-88e8-58de-be2c-1cc44531875f"
version = "0.4.0+0"
[[deps.SHA]] [[deps.SHA]]
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
version = "0.7.0" version = "0.7.0"
@@ -683,12 +757,36 @@ git-tree-sha1 = "75ebe04c5bed70b91614d684259b661c9e6274a4"
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
version = "0.34.0" version = "0.34.0"
[[deps.StatsFuns]]
deps = ["HypergeometricFunctions", "IrrationalConstants", "LogExpFunctions", "Reexport", "Rmath", "SpecialFunctions"]
git-tree-sha1 = "f625d686d5a88bcd2b15cd81f18f98186fdc0c9a"
uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
version = "1.3.0"
[deps.StatsFuns.extensions]
StatsFunsChainRulesCoreExt = "ChainRulesCore"
StatsFunsInverseFunctionsExt = "InverseFunctions"
[deps.StatsFuns.weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
[[deps.StructArrays]] [[deps.StructArrays]]
deps = ["Adapt", "DataAPI", "GPUArraysCore", "StaticArraysCore", "Tables"] deps = ["Adapt", "DataAPI", "GPUArraysCore", "StaticArraysCore", "Tables"]
git-tree-sha1 = "521a0e828e98bb69042fec1809c1b5a680eb7389" git-tree-sha1 = "521a0e828e98bb69042fec1809c1b5a680eb7389"
uuid = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" uuid = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
version = "0.6.15" version = "0.6.15"
[[deps.StructTypes]]
deps = ["Dates", "UUIDs"]
git-tree-sha1 = "ca4bccb03acf9faaf4137a9abc1881ed1841aa70"
uuid = "856f2bd8-1eba-4b0a-8007-ebc267875bd4"
version = "1.10.0"
[[deps.SuiteSparse]]
deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"]
uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"
[[deps.SuiteSparse_jll]] [[deps.SuiteSparse_jll]]
deps = ["Artifacts", "Libdl", "Pkg", "libblastrampoline_jll"] deps = ["Artifacts", "Libdl", "Pkg", "libblastrampoline_jll"]
uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c"

View File

@@ -6,4 +6,5 @@ version = "0.1.0"
[deps] [deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
GeneralUtils = "c6c72f09-b708-4ac8-ac7c-2084d70108fe"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

View File

@@ -23,7 +23,7 @@ using .interface
Todo: Todo:
[*1] knowledgeFn in GPU format [*1] knowledgeFn in GPU format
[] use partial error update for computeNeuron [] use partial error update for computeNeuron
[] use integrate_neuron_params synapticConnectionPercent = 20% [] use integrate_neuron_params synapticConnectionPercent LESS THAN 100%
[2] implement dormant connection and pruning machanism. the longer the training the longer [2] implement dormant connection and pruning machanism. the longer the training the longer
0 weight stay 0. 0 weight stay 0.
[] using RL to control learning signal [] using RL to control learning signal

View File

@@ -2,41 +2,71 @@ module type
export export
# struct # struct
kfn kfn_1
# function # function
using Random using Random, GeneralUtils
#------------------------------------------------------------------------------------------------100 #------------------------------------------------------------------------------------------------100
rng = MersenneTwister(1234)
abstract type Ironpen end abstract type Ironpen end
abstract type knowledgeFn <: Ironpen end abstract type knowledgeFn <: Ironpen end
rng = MersenneTwister(1234)
#------------------------------------------------------------------------------------------------100 #------------------------------------------------------------------------------------------------100
Base.@kwdef mutable struct kfn_1 <: knowledgeFn
Base.@kwdef mutable struct kfn <: knowledgeFn
params::Dict = Dict() # store params of knowledgeFn itself for later use params::Dict = Dict() # store params of knowledgeFn itself for later use
timeStep::AbstractArray = [0] timeStep::AbstractArray = [0]
refractory::Union{AbstractArray, Nothing} = nothing refractory::Union{AbstractArray, Nothing} = nothing
learningStage::AbstractArray = [0] # 0 inference, 1 start, 2 during, 3 end learning
z_i_t1::Union{AbstractArray, Nothing} = nothing # 2D activation matrix z_i_t1::Union{AbstractArray, Nothing} = nothing # 2D activation matrix
z_i_t::Union{AbstractArray, Nothing} = nothing z_i_t0::Union{AbstractArray, Nothing} = nothing
z_t::Union{AbstractArray, Nothing} = nothing
z_t1::Union{AbstractArray, Nothing} = nothing lif_w::Union{AbstractArray, Nothing} = nothing
alif_w::Union{AbstractArray, Nothing} = nothing
end end
function kfn(kfnParams::Dict) # outer constructor
kfn_1 = kfn() function kfn_1(params::Dict)
kfn_1.params = kfnParams kfn = kfn_1()
kfn.params = params
# initialize activation matrix
row, col = kfn.params[:inputPort][:noise][:numbers]
row += kfn.params[:inputPort][:signal][:numbers][1]
col += kfn.params[:inputPort][:signal][:numbers][2]
col += kfn.params[:computeNeuron][:lif][:numbers][2]
col += kfn.params[:computeNeuron][:alif][:numbers][2]
if kfn_1.params[:computeNeuronNumber] < kfn_1.params[:totalInputPort] kfn.z_i_t1 = zeros(row, col, 1)
throw(error("number of compute neuron must be greater than input neuron")) kfn.z_i_t0 = zeros(row, col, 1)
kfn.lif_w = zeros(row, col, row*col)
kfn.alif_w = zeros(row, col, row*col)
# lif subscription
row, col, z = size(kfn.lif_w) # row*col is synaptic subscribe weight for each neuron in z
synapticConnectionPercent = kfn.params[:computeNeuron][:lif][:params][:synapticConnectionPercent]
synapticConnection = Int(floor(z*synapticConnectionPercent/100))
for slice in eachslice(kfn.lif_w, dims=3)
pool = shuffle!([1:z...])[1:synapticConnection]
for i in pool
slice[i] = randn()/10
end
end end
error("debug end") # alif subscription
row, col, z = size(kfn.alif_w) # row*col is synaptic subscribe weight for each neuron in z
synapticConnectionPercent = kfn.params[:computeNeuron][:alif][:params][:synapticConnectionPercent]
synapticConnection = Int(floor(z*synapticConnectionPercent/100))
for slice in eachslice(kfn.alif_w, dims=3)
pool = shuffle!([1:z...])[1:synapticConnection]
for i in pool
slice[i] = randn()/10
end
end end
@@ -49,6 +79,54 @@ end
# error("debug end outer constructor")
return kfn
end
# kfn forward
function (kfn::kfn_1)(input::AbstractArray)
kfn.timeStep .+= 1
# row, col = size(input) # if input is a 2D matrix
println(">>> 5 ", size(input))
println(">>> 6 ", size(kfn.z_i_t1))
#WORKING multiply input with kfn.z_i_t1 may be using cartesian coordinates
println(">>> 7 ", view(kfn.z_i_t1, :, 1, :))
view(kfn.z_i_t1, :, 1) .= input
println(">>> 8 ", kfn.z_i_t1[:, 1])
# multiply kfn.z_i_t1 with kfn.lif_w
r = GeneralUtils.batchMatEleMul(kfn.z_i_t1, kfn.lif_w)
println(size(r))
error("debug end kfn forward")
end