This commit is contained in:
narawat lamaiin
2024-12-09 20:28:02 +07:00
parent 3338085567
commit cae94e5690
15 changed files with 2942 additions and 2942 deletions

32
.vscode/launch.json vendored
View File

@@ -1,17 +1,17 @@
{ {
// Use IntelliSense to learn about possible attributes. // Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes. // Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0", "version": "0.2.0",
"configurations": [ "configurations": [
{ {
"type": "julia", "type": "julia",
"request": "launch", "request": "launch",
"name": "Run active Julia file", "name": "Run active Julia file",
"program": "${file}", "program": "${file}",
"stopOnEntry": false, "stopOnEntry": false,
"cwd": "${workspaceFolder}", "cwd": "${workspaceFolder}",
"juliaEnv": "${command:activeJuliaEnvironment}" "juliaEnv": "${command:activeJuliaEnvironment}"
} }
] ]
} }

File diff suppressed because it is too large Load Diff

View File

@@ -1,8 +1,8 @@
name = "LLMMCTS" name = "LLMMCTS"
uuid = "d76c5a4d-449e-4835-8cc4-dd86ec44f241" uuid = "d76c5a4d-449e-4835-8cc4-dd86ec44f241"
authors = ["narawat lamaiin <narawat@outlook.com>"] authors = ["narawat lamaiin <narawat@outlook.com>"]
version = "0.1.0" version = "0.1.0"
[deps] [deps]
GeneralUtils = "c6c72f09-b708-4ac8-ac7c-2084d70108fe" GeneralUtils = "c6c72f09-b708-4ac8-ac7c-2084d70108fe"
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"

View File

@@ -1,476 +1,476 @@
# This file is machine-generated - editing it directly is not advised # This file is machine-generated - editing it directly is not advised
julia_version = "1.10.3" julia_version = "1.10.3"
manifest_format = "2.0" manifest_format = "2.0"
project_hash = "b7e1f171d36dc4812d6c1445da530f513320e6cd" project_hash = "b7e1f171d36dc4812d6c1445da530f513320e6cd"
[[deps.AliasTables]] [[deps.AliasTables]]
deps = ["PtrArrays", "Random"] deps = ["PtrArrays", "Random"]
git-tree-sha1 = "9876e1e164b144ca45e9e3198d0b689cadfed9ff" git-tree-sha1 = "9876e1e164b144ca45e9e3198d0b689cadfed9ff"
uuid = "66dad0bd-aa9a-41b7-9441-69ab47430ed8" uuid = "66dad0bd-aa9a-41b7-9441-69ab47430ed8"
version = "1.1.3" version = "1.1.3"
[[deps.ArgTools]] [[deps.ArgTools]]
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"
version = "1.1.1" version = "1.1.1"
[[deps.Artifacts]] [[deps.Artifacts]]
uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
[[deps.Base64]] [[deps.Base64]]
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
[[deps.Calculus]] [[deps.Calculus]]
deps = ["LinearAlgebra"] deps = ["LinearAlgebra"]
git-tree-sha1 = "f641eb0a4f00c343bbc32346e1217b86f3ce9dad" git-tree-sha1 = "f641eb0a4f00c343bbc32346e1217b86f3ce9dad"
uuid = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9" uuid = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9"
version = "0.5.1" version = "0.5.1"
[[deps.CodeTracking]] [[deps.CodeTracking]]
deps = ["InteractiveUtils", "UUIDs"] deps = ["InteractiveUtils", "UUIDs"]
git-tree-sha1 = "c0216e792f518b39b22212127d4a84dc31e4e386" git-tree-sha1 = "c0216e792f518b39b22212127d4a84dc31e4e386"
uuid = "da1fd8a2-8d9e-5ec2-8556-3022fb5608a2" uuid = "da1fd8a2-8d9e-5ec2-8556-3022fb5608a2"
version = "1.3.5" version = "1.3.5"
[[deps.Compat]] [[deps.Compat]]
deps = ["TOML", "UUIDs"] deps = ["TOML", "UUIDs"]
git-tree-sha1 = "b1c55339b7c6c350ee89f2c1604299660525b248" git-tree-sha1 = "b1c55339b7c6c350ee89f2c1604299660525b248"
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
version = "4.15.0" version = "4.15.0"
weakdeps = ["Dates", "LinearAlgebra"] weakdeps = ["Dates", "LinearAlgebra"]
[deps.Compat.extensions] [deps.Compat.extensions]
CompatLinearAlgebraExt = "LinearAlgebra" CompatLinearAlgebraExt = "LinearAlgebra"
[[deps.CompilerSupportLibraries_jll]] [[deps.CompilerSupportLibraries_jll]]
deps = ["Artifacts", "Libdl"] deps = ["Artifacts", "Libdl"]
uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
version = "1.1.1+0" version = "1.1.1+0"
[[deps.DataAPI]] [[deps.DataAPI]]
git-tree-sha1 = "abe83f3a2f1b857aac70ef8b269080af17764bbe" git-tree-sha1 = "abe83f3a2f1b857aac70ef8b269080af17764bbe"
uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
version = "1.16.0" version = "1.16.0"
[[deps.DataStructures]] [[deps.DataStructures]]
deps = ["Compat", "InteractiveUtils", "OrderedCollections"] deps = ["Compat", "InteractiveUtils", "OrderedCollections"]
git-tree-sha1 = "1d0a14036acb104d9e89698bd408f63ab58cdc82" git-tree-sha1 = "1d0a14036acb104d9e89698bd408f63ab58cdc82"
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
version = "0.18.20" version = "0.18.20"
[[deps.Dates]] [[deps.Dates]]
deps = ["Printf"] deps = ["Printf"]
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
[[deps.Distributed]] [[deps.Distributed]]
deps = ["Random", "Serialization", "Sockets"] deps = ["Random", "Serialization", "Sockets"]
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
[[deps.Distributions]] [[deps.Distributions]]
deps = ["AliasTables", "FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SpecialFunctions", "Statistics", "StatsAPI", "StatsBase", "StatsFuns"] deps = ["AliasTables", "FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SpecialFunctions", "Statistics", "StatsAPI", "StatsBase", "StatsFuns"]
git-tree-sha1 = "9c405847cc7ecda2dc921ccf18b47ca150d7317e" git-tree-sha1 = "9c405847cc7ecda2dc921ccf18b47ca150d7317e"
uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" uuid = "31c24e10-a181-5473-b8eb-7969acd0382f"
version = "0.25.109" version = "0.25.109"
[deps.Distributions.extensions] [deps.Distributions.extensions]
DistributionsChainRulesCoreExt = "ChainRulesCore" DistributionsChainRulesCoreExt = "ChainRulesCore"
DistributionsDensityInterfaceExt = "DensityInterface" DistributionsDensityInterfaceExt = "DensityInterface"
DistributionsTestExt = "Test" DistributionsTestExt = "Test"
[deps.Distributions.weakdeps] [deps.Distributions.weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d" DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
[[deps.DocStringExtensions]] [[deps.DocStringExtensions]]
deps = ["LibGit2"] deps = ["LibGit2"]
git-tree-sha1 = "2fb1e02f2b635d0845df5d7c167fec4dd739b00d" git-tree-sha1 = "2fb1e02f2b635d0845df5d7c167fec4dd739b00d"
uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
version = "0.9.3" version = "0.9.3"
[[deps.Downloads]] [[deps.Downloads]]
deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"]
uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
version = "1.6.0" version = "1.6.0"
[[deps.DualNumbers]] [[deps.DualNumbers]]
deps = ["Calculus", "NaNMath", "SpecialFunctions"] deps = ["Calculus", "NaNMath", "SpecialFunctions"]
git-tree-sha1 = "5837a837389fccf076445fce071c8ddaea35a566" git-tree-sha1 = "5837a837389fccf076445fce071c8ddaea35a566"
uuid = "fa6b7ba4-c1ee-5f82-b5fc-ecf0adba8f74" uuid = "fa6b7ba4-c1ee-5f82-b5fc-ecf0adba8f74"
version = "0.6.8" version = "0.6.8"
[[deps.FileWatching]] [[deps.FileWatching]]
uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee"
[[deps.FillArrays]] [[deps.FillArrays]]
deps = ["LinearAlgebra"] deps = ["LinearAlgebra"]
git-tree-sha1 = "0653c0a2396a6da5bc4766c43041ef5fd3efbe57" git-tree-sha1 = "0653c0a2396a6da5bc4766c43041ef5fd3efbe57"
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
version = "1.11.0" version = "1.11.0"
weakdeps = ["PDMats", "SparseArrays", "Statistics"] weakdeps = ["PDMats", "SparseArrays", "Statistics"]
[deps.FillArrays.extensions] [deps.FillArrays.extensions]
FillArraysPDMatsExt = "PDMats" FillArraysPDMatsExt = "PDMats"
FillArraysSparseArraysExt = "SparseArrays" FillArraysSparseArraysExt = "SparseArrays"
FillArraysStatisticsExt = "Statistics" FillArraysStatisticsExt = "Statistics"
[[deps.GeneralUtils]] [[deps.GeneralUtils]]
deps = ["DataStructures", "Dates", "Distributions", "JSON3", "MQTTClient", "Random", "Revise", "UUIDs"] deps = ["DataStructures", "Dates", "Distributions", "JSON3", "MQTTClient", "Random", "Revise", "UUIDs"]
path = "/appfolder/app/privatejuliapkg/GeneralUtils" path = "/appfolder/app/privatejuliapkg/GeneralUtils"
uuid = "c6c72f09-b708-4ac8-ac7c-2084d70108fe" uuid = "c6c72f09-b708-4ac8-ac7c-2084d70108fe"
version = "0.1.0" version = "0.1.0"
[[deps.HypergeometricFunctions]] [[deps.HypergeometricFunctions]]
deps = ["DualNumbers", "LinearAlgebra", "OpenLibm_jll", "SpecialFunctions"] deps = ["DualNumbers", "LinearAlgebra", "OpenLibm_jll", "SpecialFunctions"]
git-tree-sha1 = "f218fe3736ddf977e0e772bc9a586b2383da2685" git-tree-sha1 = "f218fe3736ddf977e0e772bc9a586b2383da2685"
uuid = "34004b35-14d8-5ef3-9330-4cdb6864b03a" uuid = "34004b35-14d8-5ef3-9330-4cdb6864b03a"
version = "0.3.23" version = "0.3.23"
[[deps.InteractiveUtils]] [[deps.InteractiveUtils]]
deps = ["Markdown"] deps = ["Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
[[deps.IrrationalConstants]] [[deps.IrrationalConstants]]
git-tree-sha1 = "630b497eafcc20001bba38a4651b327dcfc491d2" git-tree-sha1 = "630b497eafcc20001bba38a4651b327dcfc491d2"
uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" uuid = "92d709cd-6900-40b7-9082-c6be49f344b6"
version = "0.2.2" version = "0.2.2"
[[deps.JLLWrappers]] [[deps.JLLWrappers]]
deps = ["Artifacts", "Preferences"] deps = ["Artifacts", "Preferences"]
git-tree-sha1 = "7e5d6779a1e09a36db2a7b6cff50942a0a7d0fca" git-tree-sha1 = "7e5d6779a1e09a36db2a7b6cff50942a0a7d0fca"
uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210"
version = "1.5.0" version = "1.5.0"
[[deps.JSON3]] [[deps.JSON3]]
deps = ["Dates", "Mmap", "Parsers", "PrecompileTools", "StructTypes", "UUIDs"] deps = ["Dates", "Mmap", "Parsers", "PrecompileTools", "StructTypes", "UUIDs"]
git-tree-sha1 = "eb3edce0ed4fa32f75a0a11217433c31d56bd48b" git-tree-sha1 = "eb3edce0ed4fa32f75a0a11217433c31d56bd48b"
uuid = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" uuid = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
version = "1.14.0" version = "1.14.0"
[deps.JSON3.extensions] [deps.JSON3.extensions]
JSON3ArrowExt = ["ArrowTypes"] JSON3ArrowExt = ["ArrowTypes"]
[deps.JSON3.weakdeps] [deps.JSON3.weakdeps]
ArrowTypes = "31f734f8-188a-4ce0-8406-c8a06bd891cd" ArrowTypes = "31f734f8-188a-4ce0-8406-c8a06bd891cd"
[[deps.JuliaInterpreter]] [[deps.JuliaInterpreter]]
deps = ["CodeTracking", "InteractiveUtils", "Random", "UUIDs"] deps = ["CodeTracking", "InteractiveUtils", "Random", "UUIDs"]
git-tree-sha1 = "e9648d90370e2d0317f9518c9c6e0841db54a90b" git-tree-sha1 = "e9648d90370e2d0317f9518c9c6e0841db54a90b"
uuid = "aa1ae85d-cabe-5617-a682-6adf51b2e16a" uuid = "aa1ae85d-cabe-5617-a682-6adf51b2e16a"
version = "0.9.31" version = "0.9.31"
[[deps.LibCURL]] [[deps.LibCURL]]
deps = ["LibCURL_jll", "MozillaCACerts_jll"] deps = ["LibCURL_jll", "MozillaCACerts_jll"]
uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21"
version = "0.6.4" version = "0.6.4"
[[deps.LibCURL_jll]] [[deps.LibCURL_jll]]
deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"]
uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0"
version = "8.4.0+0" version = "8.4.0+0"
[[deps.LibGit2]] [[deps.LibGit2]]
deps = ["Base64", "LibGit2_jll", "NetworkOptions", "Printf", "SHA"] deps = ["Base64", "LibGit2_jll", "NetworkOptions", "Printf", "SHA"]
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
[[deps.LibGit2_jll]] [[deps.LibGit2_jll]]
deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll"] deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll"]
uuid = "e37daf67-58a4-590a-8e99-b0245dd2ffc5" uuid = "e37daf67-58a4-590a-8e99-b0245dd2ffc5"
version = "1.6.4+0" version = "1.6.4+0"
[[deps.LibSSH2_jll]] [[deps.LibSSH2_jll]]
deps = ["Artifacts", "Libdl", "MbedTLS_jll"] deps = ["Artifacts", "Libdl", "MbedTLS_jll"]
uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8"
version = "1.11.0+1" version = "1.11.0+1"
[[deps.Libdl]] [[deps.Libdl]]
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
[[deps.LinearAlgebra]] [[deps.LinearAlgebra]]
deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"] deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"]
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
[[deps.LogExpFunctions]] [[deps.LogExpFunctions]]
deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"] deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"]
git-tree-sha1 = "18144f3e9cbe9b15b070288eef858f71b291ce37" git-tree-sha1 = "18144f3e9cbe9b15b070288eef858f71b291ce37"
uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
version = "0.3.27" version = "0.3.27"
[deps.LogExpFunctions.extensions] [deps.LogExpFunctions.extensions]
LogExpFunctionsChainRulesCoreExt = "ChainRulesCore" LogExpFunctionsChainRulesCoreExt = "ChainRulesCore"
LogExpFunctionsChangesOfVariablesExt = "ChangesOfVariables" LogExpFunctionsChangesOfVariablesExt = "ChangesOfVariables"
LogExpFunctionsInverseFunctionsExt = "InverseFunctions" LogExpFunctionsInverseFunctionsExt = "InverseFunctions"
[deps.LogExpFunctions.weakdeps] [deps.LogExpFunctions.weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
[[deps.Logging]] [[deps.Logging]]
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
[[deps.LoweredCodeUtils]] [[deps.LoweredCodeUtils]]
deps = ["JuliaInterpreter"] deps = ["JuliaInterpreter"]
git-tree-sha1 = "c6a36b22d2cca0e1a903f00f600991f97bf5f426" git-tree-sha1 = "c6a36b22d2cca0e1a903f00f600991f97bf5f426"
uuid = "6f1432cf-f94c-5a45-995e-cdbf5db27b0b" uuid = "6f1432cf-f94c-5a45-995e-cdbf5db27b0b"
version = "2.4.6" version = "2.4.6"
[[deps.MQTTClient]] [[deps.MQTTClient]]
deps = ["Distributed", "Random", "Sockets"] deps = ["Distributed", "Random", "Sockets"]
git-tree-sha1 = "f2597b290d4bf17b577346153cd2ddf9accb5c26" git-tree-sha1 = "f2597b290d4bf17b577346153cd2ddf9accb5c26"
uuid = "985f35cc-2c3d-4943-b8c1-f0931d5f0959" uuid = "985f35cc-2c3d-4943-b8c1-f0931d5f0959"
version = "0.3.1" version = "0.3.1"
weakdeps = ["PrecompileTools"] weakdeps = ["PrecompileTools"]
[deps.MQTTClient.extensions] [deps.MQTTClient.extensions]
PrecompileMQTT = "PrecompileTools" PrecompileMQTT = "PrecompileTools"
[[deps.Markdown]] [[deps.Markdown]]
deps = ["Base64"] deps = ["Base64"]
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
[[deps.MbedTLS_jll]] [[deps.MbedTLS_jll]]
deps = ["Artifacts", "Libdl"] deps = ["Artifacts", "Libdl"]
uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1"
version = "2.28.2+1" version = "2.28.2+1"
[[deps.Missings]] [[deps.Missings]]
deps = ["DataAPI"] deps = ["DataAPI"]
git-tree-sha1 = "ec4f7fbeab05d7747bdf98eb74d130a2a2ed298d" git-tree-sha1 = "ec4f7fbeab05d7747bdf98eb74d130a2a2ed298d"
uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
version = "1.2.0" version = "1.2.0"
[[deps.Mmap]] [[deps.Mmap]]
uuid = "a63ad114-7e13-5084-954f-fe012c677804" uuid = "a63ad114-7e13-5084-954f-fe012c677804"
[[deps.MozillaCACerts_jll]] [[deps.MozillaCACerts_jll]]
uuid = "14a3606d-f60d-562e-9121-12d972cd8159" uuid = "14a3606d-f60d-562e-9121-12d972cd8159"
version = "2023.1.10" version = "2023.1.10"
[[deps.NaNMath]] [[deps.NaNMath]]
deps = ["OpenLibm_jll"] deps = ["OpenLibm_jll"]
git-tree-sha1 = "0877504529a3e5c3343c6f8b4c0381e57e4387e4" git-tree-sha1 = "0877504529a3e5c3343c6f8b4c0381e57e4387e4"
uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
version = "1.0.2" version = "1.0.2"
[[deps.NetworkOptions]] [[deps.NetworkOptions]]
uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"
version = "1.2.0" version = "1.2.0"
[[deps.OpenBLAS_jll]] [[deps.OpenBLAS_jll]]
deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"]
uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" uuid = "4536629a-c528-5b80-bd46-f80d51c5b363"
version = "0.3.23+4" version = "0.3.23+4"
[[deps.OpenLibm_jll]] [[deps.OpenLibm_jll]]
deps = ["Artifacts", "Libdl"] deps = ["Artifacts", "Libdl"]
uuid = "05823500-19ac-5b8b-9628-191a04bc5112" uuid = "05823500-19ac-5b8b-9628-191a04bc5112"
version = "0.8.1+2" version = "0.8.1+2"
[[deps.OpenSpecFun_jll]] [[deps.OpenSpecFun_jll]]
deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"]
git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1" git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1"
uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e"
version = "0.5.5+0" version = "0.5.5+0"
[[deps.OrderedCollections]] [[deps.OrderedCollections]]
git-tree-sha1 = "dfdf5519f235516220579f949664f1bf44e741c5" git-tree-sha1 = "dfdf5519f235516220579f949664f1bf44e741c5"
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
version = "1.6.3" version = "1.6.3"
[[deps.PDMats]] [[deps.PDMats]]
deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"] deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"]
git-tree-sha1 = "949347156c25054de2db3b166c52ac4728cbad65" git-tree-sha1 = "949347156c25054de2db3b166c52ac4728cbad65"
uuid = "90014a1f-27ba-587c-ab20-58faa44d9150" uuid = "90014a1f-27ba-587c-ab20-58faa44d9150"
version = "0.11.31" version = "0.11.31"
[[deps.Parsers]] [[deps.Parsers]]
deps = ["Dates", "PrecompileTools", "UUIDs"] deps = ["Dates", "PrecompileTools", "UUIDs"]
git-tree-sha1 = "8489905bcdbcfac64d1daa51ca07c0d8f0283821" git-tree-sha1 = "8489905bcdbcfac64d1daa51ca07c0d8f0283821"
uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0"
version = "2.8.1" version = "2.8.1"
[[deps.Pkg]] [[deps.Pkg]]
deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"]
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
version = "1.10.0" version = "1.10.0"
[[deps.PrecompileTools]] [[deps.PrecompileTools]]
deps = ["Preferences"] deps = ["Preferences"]
git-tree-sha1 = "5aa36f7049a63a1528fe8f7c3f2113413ffd4e1f" git-tree-sha1 = "5aa36f7049a63a1528fe8f7c3f2113413ffd4e1f"
uuid = "aea7be01-6a6a-4083-8856-8a6e6704d82a" uuid = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
version = "1.2.1" version = "1.2.1"
[[deps.Preferences]] [[deps.Preferences]]
deps = ["TOML"] deps = ["TOML"]
git-tree-sha1 = "9306f6085165d270f7e3db02af26a400d580f5c6" git-tree-sha1 = "9306f6085165d270f7e3db02af26a400d580f5c6"
uuid = "21216c6a-2e73-6563-6e65-726566657250" uuid = "21216c6a-2e73-6563-6e65-726566657250"
version = "1.4.3" version = "1.4.3"
[[deps.Printf]] [[deps.Printf]]
deps = ["Unicode"] deps = ["Unicode"]
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
[[deps.PtrArrays]] [[deps.PtrArrays]]
git-tree-sha1 = "f011fbb92c4d401059b2212c05c0601b70f8b759" git-tree-sha1 = "f011fbb92c4d401059b2212c05c0601b70f8b759"
uuid = "43287f4e-b6f4-7ad1-bb20-aadabca52c3d" uuid = "43287f4e-b6f4-7ad1-bb20-aadabca52c3d"
version = "1.2.0" version = "1.2.0"
[[deps.QuadGK]] [[deps.QuadGK]]
deps = ["DataStructures", "LinearAlgebra"] deps = ["DataStructures", "LinearAlgebra"]
git-tree-sha1 = "9b23c31e76e333e6fb4c1595ae6afa74966a729e" git-tree-sha1 = "9b23c31e76e333e6fb4c1595ae6afa74966a729e"
uuid = "1fd47b50-473d-5c70-9696-f719f8f3bcdc" uuid = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
version = "2.9.4" version = "2.9.4"
[[deps.REPL]] [[deps.REPL]]
deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"]
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
[[deps.Random]] [[deps.Random]]
deps = ["SHA"] deps = ["SHA"]
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
[[deps.Reexport]] [[deps.Reexport]]
git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b" git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b"
uuid = "189a3867-3050-52da-a836-e630ba90ab69" uuid = "189a3867-3050-52da-a836-e630ba90ab69"
version = "1.2.2" version = "1.2.2"
[[deps.Requires]] [[deps.Requires]]
deps = ["UUIDs"] deps = ["UUIDs"]
git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7" git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7"
uuid = "ae029012-a4dd-5104-9daa-d747884805df" uuid = "ae029012-a4dd-5104-9daa-d747884805df"
version = "1.3.0" version = "1.3.0"
[[deps.Revise]] [[deps.Revise]]
deps = ["CodeTracking", "Distributed", "FileWatching", "JuliaInterpreter", "LibGit2", "LoweredCodeUtils", "OrderedCollections", "Pkg", "REPL", "Requires", "UUIDs", "Unicode"] deps = ["CodeTracking", "Distributed", "FileWatching", "JuliaInterpreter", "LibGit2", "LoweredCodeUtils", "OrderedCollections", "Pkg", "REPL", "Requires", "UUIDs", "Unicode"]
git-tree-sha1 = "12aa2d7593df490c407a3bbd8b86b8b515017f3e" git-tree-sha1 = "12aa2d7593df490c407a3bbd8b86b8b515017f3e"
uuid = "295af30f-e4ad-537b-8983-00126c2a3abe" uuid = "295af30f-e4ad-537b-8983-00126c2a3abe"
version = "3.5.14" version = "3.5.14"
[[deps.Rmath]] [[deps.Rmath]]
deps = ["Random", "Rmath_jll"] deps = ["Random", "Rmath_jll"]
git-tree-sha1 = "f65dcb5fa46aee0cf9ed6274ccbd597adc49aa7b" git-tree-sha1 = "f65dcb5fa46aee0cf9ed6274ccbd597adc49aa7b"
uuid = "79098fc4-a85e-5d69-aa6a-4863f24498fa" uuid = "79098fc4-a85e-5d69-aa6a-4863f24498fa"
version = "0.7.1" version = "0.7.1"
[[deps.Rmath_jll]] [[deps.Rmath_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl"] deps = ["Artifacts", "JLLWrappers", "Libdl"]
git-tree-sha1 = "d483cd324ce5cf5d61b77930f0bbd6cb61927d21" git-tree-sha1 = "d483cd324ce5cf5d61b77930f0bbd6cb61927d21"
uuid = "f50d1b31-88e8-58de-be2c-1cc44531875f" uuid = "f50d1b31-88e8-58de-be2c-1cc44531875f"
version = "0.4.2+0" version = "0.4.2+0"
[[deps.SHA]] [[deps.SHA]]
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
version = "0.7.0" version = "0.7.0"
[[deps.Serialization]] [[deps.Serialization]]
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
[[deps.Sockets]] [[deps.Sockets]]
uuid = "6462fe0b-24de-5631-8697-dd941f90decc" uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
[[deps.SortingAlgorithms]] [[deps.SortingAlgorithms]]
deps = ["DataStructures"] deps = ["DataStructures"]
git-tree-sha1 = "66e0a8e672a0bdfca2c3f5937efb8538b9ddc085" git-tree-sha1 = "66e0a8e672a0bdfca2c3f5937efb8538b9ddc085"
uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c"
version = "1.2.1" version = "1.2.1"
[[deps.SparseArrays]] [[deps.SparseArrays]]
deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"] deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"]
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
version = "1.10.0" version = "1.10.0"
[[deps.SpecialFunctions]] [[deps.SpecialFunctions]]
deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"]
git-tree-sha1 = "2f5d4697f21388cbe1ff299430dd169ef97d7e14" git-tree-sha1 = "2f5d4697f21388cbe1ff299430dd169ef97d7e14"
uuid = "276daf66-3868-5448-9aa4-cd146d93841b" uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
version = "2.4.0" version = "2.4.0"
[deps.SpecialFunctions.extensions] [deps.SpecialFunctions.extensions]
SpecialFunctionsChainRulesCoreExt = "ChainRulesCore" SpecialFunctionsChainRulesCoreExt = "ChainRulesCore"
[deps.SpecialFunctions.weakdeps] [deps.SpecialFunctions.weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
[[deps.Statistics]] [[deps.Statistics]]
deps = ["LinearAlgebra", "SparseArrays"] deps = ["LinearAlgebra", "SparseArrays"]
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
version = "1.10.0" version = "1.10.0"
[[deps.StatsAPI]] [[deps.StatsAPI]]
deps = ["LinearAlgebra"] deps = ["LinearAlgebra"]
git-tree-sha1 = "1ff449ad350c9c4cbc756624d6f8a8c3ef56d3ed" git-tree-sha1 = "1ff449ad350c9c4cbc756624d6f8a8c3ef56d3ed"
uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0" uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0"
version = "1.7.0" version = "1.7.0"
[[deps.StatsBase]] [[deps.StatsBase]]
deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"]
git-tree-sha1 = "5cf7606d6cef84b543b483848d4ae08ad9832b21" git-tree-sha1 = "5cf7606d6cef84b543b483848d4ae08ad9832b21"
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
version = "0.34.3" version = "0.34.3"
[[deps.StatsFuns]] [[deps.StatsFuns]]
deps = ["HypergeometricFunctions", "IrrationalConstants", "LogExpFunctions", "Reexport", "Rmath", "SpecialFunctions"] deps = ["HypergeometricFunctions", "IrrationalConstants", "LogExpFunctions", "Reexport", "Rmath", "SpecialFunctions"]
git-tree-sha1 = "cef0472124fab0695b58ca35a77c6fb942fdab8a" git-tree-sha1 = "cef0472124fab0695b58ca35a77c6fb942fdab8a"
uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c" uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
version = "1.3.1" version = "1.3.1"
[deps.StatsFuns.extensions] [deps.StatsFuns.extensions]
StatsFunsChainRulesCoreExt = "ChainRulesCore" StatsFunsChainRulesCoreExt = "ChainRulesCore"
StatsFunsInverseFunctionsExt = "InverseFunctions" StatsFunsInverseFunctionsExt = "InverseFunctions"
[deps.StatsFuns.weakdeps] [deps.StatsFuns.weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
[[deps.StructTypes]] [[deps.StructTypes]]
deps = ["Dates", "UUIDs"] deps = ["Dates", "UUIDs"]
git-tree-sha1 = "ca4bccb03acf9faaf4137a9abc1881ed1841aa70" git-tree-sha1 = "ca4bccb03acf9faaf4137a9abc1881ed1841aa70"
uuid = "856f2bd8-1eba-4b0a-8007-ebc267875bd4" uuid = "856f2bd8-1eba-4b0a-8007-ebc267875bd4"
version = "1.10.0" version = "1.10.0"
[[deps.SuiteSparse]] [[deps.SuiteSparse]]
deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"]
uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"
[[deps.SuiteSparse_jll]] [[deps.SuiteSparse_jll]]
deps = ["Artifacts", "Libdl", "libblastrampoline_jll"] deps = ["Artifacts", "Libdl", "libblastrampoline_jll"]
uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c"
version = "7.2.1+1" version = "7.2.1+1"
[[deps.TOML]] [[deps.TOML]]
deps = ["Dates"] deps = ["Dates"]
uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
version = "1.0.3" version = "1.0.3"
[[deps.Tar]] [[deps.Tar]]
deps = ["ArgTools", "SHA"] deps = ["ArgTools", "SHA"]
uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e"
version = "1.10.0" version = "1.10.0"
[[deps.UUIDs]] [[deps.UUIDs]]
deps = ["Random", "SHA"] deps = ["Random", "SHA"]
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
[[deps.Unicode]] [[deps.Unicode]]
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
[[deps.Zlib_jll]] [[deps.Zlib_jll]]
deps = ["Libdl"] deps = ["Libdl"]
uuid = "83775a58-1f1d-513f-b197-d71354ab007a" uuid = "83775a58-1f1d-513f-b197-d71354ab007a"
version = "1.2.13+1" version = "1.2.13+1"
[[deps.libblastrampoline_jll]] [[deps.libblastrampoline_jll]]
deps = ["Artifacts", "Libdl"] deps = ["Artifacts", "Libdl"]
uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" uuid = "8e850b90-86db-534c-a0d3-1478176c7d93"
version = "5.8.0+1" version = "5.8.0+1"
[[deps.nghttp2_jll]] [[deps.nghttp2_jll]]
deps = ["Artifacts", "Libdl"] deps = ["Artifacts", "Libdl"]
uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d"
version = "1.52.0+1" version = "1.52.0+1"
[[deps.p7zip_jll]] [[deps.p7zip_jll]]
deps = ["Artifacts", "Libdl"] deps = ["Artifacts", "Libdl"]
uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0"
version = "17.4.0+2" version = "17.4.0+2"

View File

@@ -1,8 +1,8 @@
name = "LLMMCTS" name = "LLMMCTS"
uuid = "d76c5a4d-449e-4835-8cc4-dd86ec44f241" uuid = "d76c5a4d-449e-4835-8cc4-dd86ec44f241"
authors = ["narawat lamaiin <narawat@outlook.com>"] authors = ["narawat lamaiin <narawat@outlook.com>"]
version = "0.1.0" version = "0.1.0"
[deps] [deps]
GeneralUtils = "c6c72f09-b708-4ac8-ac7c-2084d70108fe" GeneralUtils = "c6c72f09-b708-4ac8-ac7c-2084d70108fe"
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"

View File

@@ -1,28 +1,28 @@
module LLMMCTS module LLMMCTS
# export agent # export agent
""" Order by dependencies of each file. The 1st included file must not depend on any other """ Order by dependencies of each file. The 1st included file must not depend on any other
files and each file can only depend on the file included before it. files and each file can only depend on the file included before it.
""" """
include("type.jl") include("type.jl")
using .type using .type
include("util.jl") include("util.jl")
using .util using .util
include("mcts.jl") include("mcts.jl")
using .mcts using .mcts
include("interface.jl") include("interface.jl")
using .interface using .interface
# ---------------------------------------------- 100 --------------------------------------------- # # ---------------------------------------------- 100 --------------------------------------------- #
end # module LLMMCTS end # module LLMMCTS

View File

@@ -1,180 +1,180 @@
module interface module interface
export runMCTS export runMCTS
using ..type, ..mcts using ..type, ..mcts
# ---------------------------------------------- 100 --------------------------------------------- # # ---------------------------------------------- 100 --------------------------------------------- #
""" Search the best action to take for a given state and task """ Search the best action to take for a given state and task
# Arguments # Arguments
- `a::agent` - `a::agent`
one of Yiem's agents one of Yiem's agents
- `initial state` - `initial state`
initial state initial state
- `decisionMaker::Function` - `decisionMaker::Function`
decide what action to take decide what action to take
- `evaluator::Function` - `evaluator::Function`
assess the value of the state assess the value of the state
- `reflector::Function` - `reflector::Function`
generate lesson from trajectory and reward generate lesson from trajectory and reward
- `isterminal::Function` - `isterminal::Function`
determine whether a given state is a terminal state determine whether a given state is a terminal state
- `n::Integer` - `n::Integer`
how many times action will be sampled from decisionMaker how many times action will be sampled from decisionMaker
- `w::Float64` - `w::Float64`
exploration weight. Value is usually between 1 to 2. exploration weight. Value is usually between 1 to 2.
Value 1.0 makes MCTS balance between exploration and exploitation like 50%-50% Value 1.0 makes MCTS balance between exploration and exploitation like 50%-50%
Value 2.0 makes MCTS aggressively search the tree Value 2.0 makes MCTS aggressively search the tree
# Return # Return
- `plan::Vector{Dict}` - `plan::Vector{Dict}`
best plan best plan
# Example # Example
```jldoctest ```jldoctest
julia> julia>
``` ```
# TODO # TODO
[] update docstring [] update docstring
[] return best action [] return best action
# Signature # Signature
""" """
function runMCTS( function runMCTS(
config::T1, config::T1,
initialState, initialState,
decisionMaker::Function, decisionMaker::Function,
evaluator::Function, evaluator::Function,
reflector::Function, reflector::Function,
transition::Function, transition::Function,
; ;
totalsample::Integer=3, totalsample::Integer=3,
maxDepth::Integer=3, maxDepth::Integer=3,
maxiterations::Integer=10, maxiterations::Integer=10,
explorationweight::Number=1.0, explorationweight::Number=1.0,
) where {T1<:AbstractDict} ) where {T1<:AbstractDict}
root = MCTSNode("root", initialState, 0, 0, 0, 0, false, nothing, Dict{String, MCTSNode}()) root = MCTSNode("root", initialState, 0, 0, 0, 0, false, nothing, Dict{String, MCTSNode}())
for nth in 1:maxiterations for nth in 1:maxiterations
node = root node = root
node.visits += 1 node.visits += 1
while !isleaf(node) while !isleaf(node)
node = UCTselect(node, explorationweight) node = UCTselect(node, explorationweight)
end end
if node.isterminal if node.isterminal
# MCTS arrive at the leaf node that is also a terminal state, # MCTS arrive at the leaf node that is also a terminal state,
# do nothing then go directly to backpropagation # do nothing then go directly to backpropagation
backpropagate(leafNode, node.reward) backpropagate(leafNode, node.reward)
else else
expand(config, node, decisionMaker, evaluator, reflector, transition; expand(config, node, decisionMaker, evaluator, reflector, transition;
totalsample=totalsample) totalsample=totalsample)
leafNode = selectChildNode(node) leafNode = selectChildNode(node)
simTrajectoryReward, terminalstate = simulate(config, leafNode, decisionMaker, evaluator, simTrajectoryReward, terminalstate = simulate(config, leafNode, decisionMaker, evaluator,
reflector, transition; maxDepth=maxDepth, totalsample=totalsample) reflector, transition; maxDepth=maxDepth, totalsample=totalsample)
if terminalstate !== nothing #XXX not sure why I need this if terminalstate !== nothing #XXX not sure why I need this
terminalstate[:totalTrajectoryReward] = simTrajectoryReward terminalstate[:totalTrajectoryReward] = simTrajectoryReward
end end
#[] write best state to file if it has higher simTrajectoryReward. Use to improve evaluation #[] write best state to file if it has higher simTrajectoryReward. Use to improve evaluation
# open("trajectory.json", "w") do io # open("trajectory.json", "w") do io
# JSON3.pretty(io, terminalstate) # JSON3.pretty(io, terminalstate)
# end # end
backpropagate(leafNode, simTrajectoryReward) backpropagate(leafNode, simTrajectoryReward)
end end
end end
bestNextState = selectBestNextState(root) bestNextState = selectBestNextState(root)
besttrajectory = selectBestTrajectory(root) besttrajectory = selectBestTrajectory(root)
return (bestNextState.state, besttrajectory.state) return (bestNextState.state, besttrajectory.state)
end end
end # module interface end # module interface

View File

@@ -1,438 +1,438 @@
module mcts module mcts
export selectBestNextState, selectBestTrajectory, backpropagate, isleaf, isroot, selectChildNode, export selectBestNextState, selectBestTrajectory, backpropagate, isleaf, isroot, selectChildNode,
expand, simulate, makeNewState expand, simulate, makeNewState
using GeneralUtils using GeneralUtils
using ..type using ..type
# ---------------------------------------------- 100 --------------------------------------------- # # ---------------------------------------------- 100 --------------------------------------------- #
""" """
# Arguments # Arguments
- `node::MCTSNode` - `node::MCTSNode`
node of a search tree node of a search tree
# Return # Return
- `childNode::MCTSNode` - `childNode::MCTSNode`
the highest value child node the highest value child node
# Example # Example
```jldoctest ```jldoctest
julia> julia>
``` ```
# TODO # TODO
- [] update docs - [] update docs
- [x] implement the function - [x] implement the function
# Signature # Signature
""" """
function selectBestNextState(node::MCTSNode)::MCTSNode function selectBestNextState(node::MCTSNode)::MCTSNode
highestProgressValue = 0 highestProgressValue = 0
nodekey = nothing nodekey = nothing
# if all childnode has statevalue == 0, use progressvalue + reward to select the best node # if all childnode has statevalue == 0, use progressvalue + reward to select the best node
stateValueSum = sum([v.statevalue for (k, v) in node.children]) stateValueSum = sum([v.statevalue for (k, v) in node.children])
if stateValueSum != 0 if stateValueSum != 0
for (k, childnode) in node.children for (k, childnode) in node.children
potential = childnode.statevalue / childnode.visits potential = childnode.statevalue / childnode.visits
if potential > highestProgressValue if potential > highestProgressValue
highestProgressValue = potential highestProgressValue = potential
nodekey = childnode.nodekey nodekey = childnode.nodekey
end end
end end
else else
for (k, childnode) in node.children for (k, childnode) in node.children
potential = childnode.progressvalue + childnode.reward potential = childnode.progressvalue + childnode.reward
if potential > highestProgressValue if potential > highestProgressValue
highestProgressValue = potential highestProgressValue = potential
nodekey = childnode.nodekey nodekey = childnode.nodekey
end end
end end
end end
return node.children[nodekey] return node.children[nodekey]
end end
""" """
# Arguments # Arguments
- `node::MCTSNode` - `node::MCTSNode`
node of a search tree node of a search tree
# Return # Return
- `childNode::MCTSNode` - `childNode::MCTSNode`
the highest value child node the highest value child node
# Example # Example
```jldoctest ```jldoctest
julia> julia>
``` ```
# TODO # TODO
- [] update docs - [] update docs
- [x] implement the function - [x] implement the function
# Signature # Signature
""" """
function selectBestTrajectory(node::MCTSNode)::MCTSNode function selectBestTrajectory(node::MCTSNode)::MCTSNode
while !isleaf(node) while !isleaf(node)
node = selectBestNextState(node) node = selectBestNextState(node)
end end
return node return node
end end
""" Backpropagate reward along the simulation chain """ Backpropagate reward along the simulation chain
# Arguments # Arguments
- `node::MCTSNode` - `node::MCTSNode`
leaf node of a search tree leaf node of a search tree
- `simTrajectoryReward::T` - `simTrajectoryReward::T`
total reward from trajectory simulation total reward from trajectory simulation
# Return # Return
- `No return` - `No return`
# Example # Example
```jldoctest ```jldoctest
julia> julia>
``` ```
# Signature # Signature
""" """
function backpropagate(node::MCTSNode, simTrajectoryReward::T; function backpropagate(node::MCTSNode, simTrajectoryReward::T;
discountRewardCoeff::AbstractFloat=0.9) where {T<:Number} discountRewardCoeff::AbstractFloat=0.9) where {T<:Number}
while !isroot(node) while !isroot(node)
# Update the statistics of the current node based on the result of the playout # Update the statistics of the current node based on the result of the playout
node.visits += 1 node.visits += 1
node.statevalue += ((node.statevalue * (node.visits-1)) + simTrajectoryReward) / node.visits node.statevalue += ((node.statevalue * (node.visits-1)) + simTrajectoryReward) / node.visits
simTrajectoryReward *= discountRewardCoeff # discount because future reward is uncertain simTrajectoryReward *= discountRewardCoeff # discount because future reward is uncertain
node = node.parent node = node.parent
end end
end end
""" Determine whether a node is a leaf node of a search tree. """ Determine whether a node is a leaf node of a search tree.
# Arguments # Arguments
- `node::MCTSNode` - `node::MCTSNode`
a search tree node a search tree node
# Return # Return
- `result::Bool` - `result::Bool`
true if it is a leaf node, false otherwise. true if it is a leaf node, false otherwise.
# Example # Example
```jldoctest ```jldoctest
julia> using Revise julia> using Revise
julia> using YiemAgent, DataStructures julia> using YiemAgent, DataStructures
julia> initialState = Dict{Symbol, Any}( julia> initialState = Dict{Symbol, Any}(
:customerinfo=> Dict{Symbol, Any}(), :customerinfo=> Dict{Symbol, Any}(),
:storeinfo=> Dict{Symbol, Any}(), :storeinfo=> Dict{Symbol, Any}(),
:thoughtHistory=> OrderedDict{Symbol, Any}( :thoughtHistory=> OrderedDict{Symbol, Any}(
:question=> "How are you?", :question=> "How are you?",
) )
) )
julia> statetype = typeof(initialState) julia> statetype = typeof(initialState)
julia> root = YiemAgent.MCTSNode(initialState, 0, 0.0, Dict{statetype, YiemAgent.MCTSNode}()) julia> root = YiemAgent.MCTSNode(initialState, 0, 0.0, Dict{statetype, YiemAgent.MCTSNode}())
julia> YiemAgent.isleaf(root) julia> YiemAgent.isleaf(root)
true true
``` ```
# TODO # TODO
[] update docs [] update docs
# Signature # Signature
""" """
isleaf(node::MCTSNode)::Bool = isempty(node.children) isleaf(node::MCTSNode)::Bool = isempty(node.children)
""" Determine wheter a given node is a root node """ Determine wheter a given node is a root node
# Arguments # Arguments
- `node::MCTSNode` - `node::MCTSNode`
node of a search tree node of a search tree
# Return # Return
- `isrootnode::Bool` - `isrootnode::Bool`
true if the given node is root node, false otherwise true if the given node is root node, false otherwise
# Example # Example
```jldoctest ```jldoctest
julia> julia>
``` ```
# Signature # Signature
""" """
isroot(node::MCTSNode)::Bool = node.nodekey == "root" ? true : false isroot(node::MCTSNode)::Bool = node.nodekey == "root" ? true : false
""" Select child node based on the highest statevalue """ Select child node based on the highest statevalue
# Arguments # Arguments
- `node::MCTSNode` - `node::MCTSNode`
node of a search tree node of a search tree
# Return # Return
- `childNode::MCTSNode` - `childNode::MCTSNode`
the highest value child node the highest value child node
# Example # Example
```jldoctest ```jldoctest
julia> julia>
``` ```
# Signature # Signature
""" """
function selectChildNode(node::MCTSNode)::MCTSNode function selectChildNode(node::MCTSNode)::MCTSNode
highestProgressValue = 0 highestProgressValue = 0
nodekey = nothing nodekey = nothing
# loop thought node children dictionary to find the highest progress value # loop thought node children dictionary to find the highest progress value
for (k, childNode) in node.children for (k, childNode) in node.children
potential = childNode.progressvalue + childNode.reward potential = childNode.progressvalue + childNode.reward
if childNode.reward > 0 #XXX for testing. remove when done. if childNode.reward > 0 #XXX for testing. remove when done.
println("") println("")
end end
if potential > highestProgressValue if potential > highestProgressValue
highestProgressValue = potential highestProgressValue = potential
nodekey = childNode.nodekey nodekey = childNode.nodekey
end end
end end
return node.children[nodekey] return node.children[nodekey]
end end
""" Expand selected node """ Expand selected node
# Arguments # Arguments
- `a::T1` - `a::T1`
One of YiemAgent's agent One of YiemAgent's agent
- `node::MCTSNode` - `node::MCTSNode`
MCTS node MCTS node
- `state::T2` - `state::T2`
a state of a game. Can be a Dict or something else. a state of a game. Can be a Dict or something else.
- `decisionMaker::Function` - `decisionMaker::Function`
a function that output Thought and Action a function that output Thought and Action
- `evaluator::Function` - `evaluator::Function`
a function that output trajectory progress score a function that output trajectory progress score
# Return # Return
# Example # Example
```jldoctest ```jldoctest
julia> julia>
``` ```
# TODO # TODO
[] update docstring [] update docstring
[] try loop should limit to 3 times. if not succeed, skip [] try loop should limit to 3 times. if not succeed, skip
[] newNodeKey ∉ keys(node.children). New state may have semantic vector close enought to one of existing child state. Which can be assume that they are the same state semantically-wise. [] newNodeKey ∉ keys(node.children). New state may have semantic vector close enought to one of existing child state. Which can be assume that they are the same state semantically-wise.
[x] store feedback -> state -> agent. [x] store feedback -> state -> agent.
# Signature # Signature
""" """
function expand(config::T1, node::MCTSNode, decisionMaker::Function, evaluator::Function, function expand(config::T1, node::MCTSNode, decisionMaker::Function, evaluator::Function,
reflector::Function, transition::Function; totalsample::Integer=3 reflector::Function, transition::Function; totalsample::Integer=3
) where {T1<:AbstractDict} ) where {T1<:AbstractDict}
nthSample = 0 nthSample = 0
while true while true
nthSample += 1 nthSample += 1
if nthSample <= totalsample if nthSample <= totalsample
newNodeKey, newstate, progressvalue = transition(config, node.state, decisionMaker, newNodeKey, newstate, progressvalue = transition(config, node.state, decisionMaker,
evaluator, reflector) evaluator, reflector)
if newNodeKey keys(node.children) if newNodeKey keys(node.children)
node.children[newNodeKey] = node.children[newNodeKey] =
MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward], MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward],
newstate[:isterminal], node, Dict{String, MCTSNode}()) newstate[:isterminal], node, Dict{String, MCTSNode}())
end end
else else
break break
end end
end end
end end
""" Simulate interactions between agent and environment """ Simulate interactions between agent and environment
# Arguments # Arguments
- `a::T` - `a::T`
one of YiemAgent's agent one of YiemAgent's agent
- `node::MCTSNode` - `node::MCTSNode`
node that will be a simulation starting point. node that will be a simulation starting point.
- `decisionMaker::Function` - `decisionMaker::Function`
function that receive state return Thought and Action function that receive state return Thought and Action
# Return # Return
- `simTrajectoryReward::Number` - `simTrajectoryReward::Number`
# Example # Example
```jldoctest ```jldoctest
julia> julia>
``` ```
# TODO # TODO
- [] update docs - [] update docs
# Signature # Signature
""" """
function simulate(config::T, node::MCTSNode, decisionMaker::Function, evaluator::Function, function simulate(config::T, node::MCTSNode, decisionMaker::Function, evaluator::Function,
reflector::Function, transition::Function; maxDepth::Integer=3, totalsample::Integer=3 reflector::Function, transition::Function; maxDepth::Integer=3, totalsample::Integer=3
)::Union{Tuple{Number, Dict{Symbol, <:Any}}, Tuple{Number, Nothing}} where {T<:AbstractDict} )::Union{Tuple{Number, Dict{Symbol, <:Any}}, Tuple{Number, Nothing}} where {T<:AbstractDict}
simTrajectoryReward = 0.0 simTrajectoryReward = 0.0
terminalstate = nothing terminalstate = nothing
for depth in 1:maxDepth for depth in 1:maxDepth
simTrajectoryReward += node.reward simTrajectoryReward += node.reward
if node.isterminal if node.isterminal
terminalstate = node.state terminalstate = node.state
break break
else else
expand(config, node, decisionMaker, evaluator, reflector, transition; expand(config, node, decisionMaker, evaluator, reflector, transition;
totalsample=totalsample) totalsample=totalsample)
node = selectChildNode(node) node = selectChildNode(node)
end end
end end
return (simTrajectoryReward, terminalstate) return (simTrajectoryReward, terminalstate)
end end
""" """
# Arguments # Arguments
# Return # Return
# Example # Example
```jldoctest ```jldoctest
julia> julia>
``` ```
# TODO # TODO
- [] update docstring - [] update docstring
- [x] implement the function - [x] implement the function
# Signature # Signature
""" """
function makeNewState(currentstate::T1, thoughtDict::T4, response::T2, select::Union{T3, Nothing}, function makeNewState(currentstate::T1, thoughtDict::T4, response::T2, select::Union{T3, Nothing},
reward::T3, isterminal::Bool reward::T3, isterminal::Bool
)::Tuple{String, Dict{Symbol, <:Any}} where {T1<:AbstractDict, T2<:AbstractString, T3<:Number, T4<:AbstractDict} )::Tuple{String, Dict{Symbol, <:Any}} where {T1<:AbstractDict, T2<:AbstractString, T3<:Number, T4<:AbstractDict}
currentstate_latestThoughtKey, currentstate_latestThoughtIndice = currentstate_latestThoughtKey, currentstate_latestThoughtIndice =
GeneralUtils.findHighestIndexKey(currentstate[:thoughtHistory], "thought") GeneralUtils.findHighestIndexKey(currentstate[:thoughtHistory], "thought")
currentstate_nextIndice = currentstate_nextIndice =
currentstate_latestThoughtKey == :NA ? 1 : currentstate_latestThoughtIndice + 1 currentstate_latestThoughtKey == :NA ? 1 : currentstate_latestThoughtIndice + 1
currentstate_latestThoughtKey = Symbol("thought_$currentstate_nextIndice") currentstate_latestThoughtKey = Symbol("thought_$currentstate_nextIndice")
latestActionKey = Symbol("action_$currentstate_nextIndice") latestActionKey = Symbol("action_$currentstate_nextIndice")
_, thoughtDict_latestThoughtIndice = _, thoughtDict_latestThoughtIndice =
GeneralUtils.findHighestIndexKey(thoughtDict, "thought") GeneralUtils.findHighestIndexKey(thoughtDict, "thought")
thoughtDict_latestThoughtKey, thoughtDict_latestActionKey = thoughtDict_latestThoughtKey, thoughtDict_latestActionKey =
if thoughtDict_latestThoughtIndice == -1 if thoughtDict_latestThoughtIndice == -1
(:thought, :action) (:thought, :action)
else else
( (
Symbol("thought_$thoughtDict_latestThoughtIndice"), Symbol("thought_$thoughtDict_latestThoughtIndice"),
Symbol("action_$thoughtDict_latestThoughtIndice"), Symbol("action_$thoughtDict_latestThoughtIndice"),
) )
end end
# add Thought, action, observation to thoughtHistory # add Thought, action, observation to thoughtHistory
newstate = deepcopy(currentstate) newstate = deepcopy(currentstate)
newstate[:thoughtHistory][currentstate_latestThoughtKey] = newstate[:thoughtHistory][currentstate_latestThoughtKey] =
thoughtDict[thoughtDict_latestThoughtKey] thoughtDict[thoughtDict_latestThoughtKey]
newstate[:thoughtHistory][latestActionKey] = thoughtDict[thoughtDict_latestActionKey] newstate[:thoughtHistory][latestActionKey] = thoughtDict[thoughtDict_latestActionKey]
newObservationKey = Symbol("observation_$(currentstate_nextIndice)") newObservationKey = Symbol("observation_$(currentstate_nextIndice)")
newstate[:thoughtHistory][newObservationKey] = response newstate[:thoughtHistory][newObservationKey] = response
newstate[:reward] = reward newstate[:reward] = reward
newstate[:select] = select newstate[:select] = select
newstate[:isterminal] = isterminal newstate[:isterminal] = isterminal
newNodeKey = GeneralUtils.uuid4snakecase() newNodeKey = GeneralUtils.uuid4snakecase()
return (newNodeKey, newstate) return (newNodeKey, newstate)
end end
end # module mcts end # module mcts

View File

@@ -1,116 +1,116 @@
module type module type
export MCTSNode export MCTSNode
# ---------------------------------------------- 100 --------------------------------------------- # # ---------------------------------------------- 100 --------------------------------------------- #
""" a node for MCTS search tree """ a node for MCTS search tree
# Arguments # Arguments
- `state::T` - `state::T`
a state of a game. Can be a Dict or something else. a state of a game. Can be a Dict or something else.
- `visits::Integer ` - `visits::Integer `
number of time the game visits this state number of time the game visits this state
- `stateValue::Float64` - `stateValue::Float64`
state value state value
- `children::Dict{T, MCTSNode}` - `children::Dict{T, MCTSNode}`
children node children node
# Return # Return
- `nothing` - `nothing`
# Example # Example
```jldoctest ```jldoctest
julia> state = Dict( julia> state = Dict(
:info=> Dict(), # keyword info :info=> Dict(), # keyword info
:thoughtHistory=> Dict( :thoughtHistory=> Dict(
:question=> _, :question=> _,
:thought_1=> _, :thought_1=> _,
:action_1=> _, :action_1=> _,
:observation_1=> _, :observation_1=> _,
:thought_2=> _, :thought_2=> _,
... ...
) )
) )
``` ```
# TODO # TODO
[] update docstring [] update docstring
# Signature # Signature
""" """
mutable struct MCTSNode{T1<:AbstractDict, T2<:AbstractString} mutable struct MCTSNode{T1<:AbstractDict, T2<:AbstractString}
nodekey::T2 nodekey::T2
state::T1 state::T1
visits::Integer visits::Integer
progressvalue::Number # estimate value by LLM's reasoning progressvalue::Number # estimate value by LLM's reasoning
statevalue::Number # store discounted commulative reward (gather from its child node) statevalue::Number # store discounted commulative reward (gather from its child node)
reward::Number # this node's own reward reward::Number # this node's own reward
isterminal::Bool isterminal::Bool
parent::Union{MCTSNode, Nothing} parent::Union{MCTSNode, Nothing}
children::Dict{String, MCTSNode} children::Dict{String, MCTSNode}
end end
end # module type end # module type

View File

@@ -1,139 +1,139 @@
module util module util
export UCTselect export UCTselect
using ..type using ..type
# ---------------------------------------------- 100 --------------------------------------------- # # ---------------------------------------------- 100 --------------------------------------------- #
""" Select a node based on UCT score """ Select a node based on UCT score
# Arguments # Arguments
- `node::MCTSNode` - `node::MCTSNode`
mcts node mcts node
- `w::T` - `w::T`
exploration weight. Value is usually between 1 to 2. exploration weight. Value is usually between 1 to 2.
Value 1.0 makes MCTS balance between exploration and exploitation like 50%-50%. Value 1.0 makes MCTS balance between exploration and exploitation like 50%-50%.
Value 2.0 makes MCTS aggressively search the tree. Value 2.0 makes MCTS aggressively search the tree.
# Return # Return
- `selectedNode::MCTSNode` - `selectedNode::MCTSNode`
# Example # Example
```jldoctest ```jldoctest
julia> julia>
``` ```
# Signature # Signature
""" """
function UCTselect(node::MCTSNode, w::T)::MCTSNode where {T<:AbstractFloat} function UCTselect(node::MCTSNode, w::T)::MCTSNode where {T<:AbstractFloat}
maxUCT = -Inf maxUCT = -Inf
selectedNode = nothing selectedNode = nothing
for (childState, childNode) in node.children for (childState, childNode) in node.children
UCTvalue = UCTvalue =
if childNode.visits != 0 if childNode.visits != 0
weightedterm = w * sqrt(log(node.visits) / childNode.visits) # explore term weightedterm = w * sqrt(log(node.visits) / childNode.visits) # explore term
childNode.statevalue + weightedterm childNode.statevalue + weightedterm
else # node.visits == 0 makes sqrt() in explore term error else # node.visits == 0 makes sqrt() in explore term error
childNode.progressvalue # exploit term childNode.progressvalue # exploit term
end end
if UCTvalue > maxUCT if UCTvalue > maxUCT
maxUCT = UCTvalue maxUCT = UCTvalue
selectedNode = childNode selectedNode = childNode
end end
end end
return selectedNode return selectedNode
end end
end # module util end # module util

View File

@@ -1,38 +1,38 @@
module LLMMCTS module LLMMCTS
# export agent # export agent
""" Order by dependencies of each file. The 1st included file must not depend on any other """ Order by dependencies of each file. The 1st included file must not depend on any other
files and each file can only depend on the file included before it. files and each file can only depend on the file included before it.
""" """
include("type.jl") include("type.jl")
using .type using .type
include("util.jl") include("util.jl")
using .util using .util
include("mcts.jl") include("mcts.jl")
using .mcts using .mcts
include("interface.jl") include("interface.jl")
using .interface using .interface
# ---------------------------------------------- 100 --------------------------------------------- # # ---------------------------------------------- 100 --------------------------------------------- #
""" version 0.0.2 """ version 0.0.2
Todo: Todo:
- [] - []
Change from version: 0.0.1 Change from version: 0.0.1
- -
All features All features
""" """
end # module LLMMCTS end # module LLMMCTS

View File

@@ -1,228 +1,228 @@
module interface module interface
export runMCTS export runMCTS
using ..type, ..mcts, ..util using ..type, ..mcts, ..util
# ---------------------------------------------- 100 --------------------------------------------- # # ---------------------------------------------- 100 --------------------------------------------- #
""" Search the best action to take for a given state and task """ Search the best action to take for a given state and task
# Arguments # Arguments
- `initialstate::T` - `initialstate::T`
initial state initial state
- `transition::Function` - `transition::Function`
a function that define how the state transitions a function that define how the state transitions
- `transitionargs::NamedTuple` - `transitionargs::NamedTuple`
arguments for transition function arguments for transition function
# Keyword Arguments # Keyword Arguments
- `totalsample::Integer` - `totalsample::Integer`
a number of child state MCTS sample at each node during expansion phase a number of child state MCTS sample at each node during expansion phase
- `maxdepth::Integer` - `maxdepth::Integer`
a number of levels MCTS goes during simulation phase a number of levels MCTS goes during simulation phase
- `maxiterations::Integer` - `maxiterations::Integer`
a number of iteration MCTS goes thru expansion -> simulation -> backpropagation cycle a number of iteration MCTS goes thru expansion -> simulation -> backpropagation cycle
- `explorationweight::Number` - `explorationweight::Number`
exploration weight controls how much MCTS should explore new state instead of exploit exploration weight controls how much MCTS should explore new state instead of exploit
a known state. 1.0 balance between exploration and exploitation like 50%-50%. 2.0 makes MCTS a known state. 1.0 balance between exploration and exploitation like 50%-50%. 2.0 makes MCTS
aggressively explore new state. aggressively explore new state.
# Return # Return
- `NamedTuple{(:bestNextState, :bestFinalState), Tuple{T, T}}` - `NamedTuple{(:bestNextState, :bestFinalState), Tuple{T, T}}`
the best next state and the best final state the best next state and the best final state
# Example # Example
Refers to SQLLLM package Refers to SQLLLM package
# Signature # Signature
""" """
function runMCTS( function runMCTS(
initialstate::T, initialstate::T,
transition::Function, transition::Function,
transitionargs::NamedTuple, transitionargs::NamedTuple,
; ;
totalsample::Integer=3, totalsample::Integer=3,
maxdepth::Integer=3, maxdepth::Integer=3,
maxiterations::Integer=10, maxiterations::Integer=10,
explorationweight::Number=1.0, explorationweight::Number=1.0,
earlystop::Union{Function,Nothing}=nothing earlystop::Union{Function,Nothing}=nothing
)::NamedTuple{(:bestNextState, :bestFinalState),Tuple{T,T}} where {T<:Any} )::NamedTuple{(:bestNextState, :bestFinalState),Tuple{T,T}} where {T<:Any}
root = MCTSNode("root", initialstate, 0, 0, 0, 0, false, nothing, Dict{String,MCTSNode}()) root = MCTSNode("root", initialstate, 0, 0, 0, 0, false, nothing, Dict{String,MCTSNode}())
for nth in 1:maxiterations for nth in 1:maxiterations
node = root node = root
node.visits += 1 node.visits += 1
while !isleaf(node) while !isleaf(node)
node = UCTselect(node, explorationweight) node = UCTselect(node, explorationweight)
end end
if node.isterminal if node.isterminal
# MCTS arrive at the leaf node that is also a terminal state, # MCTS arrive at the leaf node that is also a terminal state,
# do nothing then go directly to backpropagation. It means the end of this iteration # do nothing then go directly to backpropagation. It means the end of this iteration
backpropagate(node, node.reward) backpropagate(node, node.reward)
else else
expand(node, transition, transitionargs; expand(node, transition, transitionargs;
totalsample=totalsample) totalsample=totalsample)
leafNode = selectChildNode(node) leafNode = selectChildNode(node)
simTrajectoryReward, terminalstate = simulate(leafNode, transition, transitionargs; simTrajectoryReward, terminalstate = simulate(leafNode, transition, transitionargs;
maxdepth=maxdepth, totalsample=totalsample) maxdepth=maxdepth, totalsample=totalsample)
# if terminalstate !== nothing #XXX not sure why I need this # if terminalstate !== nothing #XXX not sure why I need this
# terminalstate[:totalTrajectoryReward] = simTrajectoryReward # terminalstate[:totalTrajectoryReward] = simTrajectoryReward
# end # end
#[] write best state to file if it has higher simTrajectoryReward. Use to improve evaluation #[] write best state to file if it has higher simTrajectoryReward. Use to improve evaluation
# open("trajectory.json", "w") do io # open("trajectory.json", "w") do io
# JSON3.pretty(io, terminalstate) # JSON3.pretty(io, terminalstate)
# end # end
backpropagate(leafNode, simTrajectoryReward) backpropagate(leafNode, simTrajectoryReward)
end end
# stop if the early stop condition is met # stop if the early stop condition is met
if typeof(earlystop) <: Function && earlystop(node.state) if typeof(earlystop) <: Function && earlystop(node.state)
break break
end end
end end
bestNextState = selectBestNextNode(root) bestNextState = selectBestNextNode(root)
besttrajectory = selectBestTrajectoryNode(root) besttrajectory = selectBestTrajectoryNode(root)
return (bestNextState=bestNextState.state, bestFinalState=besttrajectory.state) return (bestNextState=bestNextState.state, bestFinalState=besttrajectory.state)
end end
# function runMCTS( # function runMCTS(
# initialstate::T, # initialstate::T,
# transition::Function, # transition::Function,
# transitionargs::NamedTuple, # transitionargs::NamedTuple,
# ; # ;
# totalsample::Integer=3, # totalsample::Integer=3,
# maxdepth::Integer=3, # maxdepth::Integer=3,
# maxiterations::Integer=10, # maxiterations::Integer=10,
# explorationweight::Number=1.0, # explorationweight::Number=1.0,
# )::NamedTuple{(:bestNextState, :bestFinalState),Tuple{T,T}} where {T<:Any} # )::NamedTuple{(:bestNextState, :bestFinalState),Tuple{T,T}} where {T<:Any}
# root = MCTSNode("root", initialstate, 0, 0, 0, 0, false, nothing, Dict{String,MCTSNode}()) # root = MCTSNode("root", initialstate, 0, 0, 0, 0, false, nothing, Dict{String,MCTSNode}())
# for nth in 1:maxiterations # for nth in 1:maxiterations
# node = root # node = root
# node.visits += 1 # node.visits += 1
# while !isleaf(node) # while !isleaf(node)
# node = UCTselect(node, explorationweight) # node = UCTselect(node, explorationweight)
# end # end
# if node.isterminal # if node.isterminal
# # MCTS arrive at the leaf node that is also a terminal state, # # MCTS arrive at the leaf node that is also a terminal state,
# # do nothing then go directly to backpropagation. It means the end of this iteration # # do nothing then go directly to backpropagation. It means the end of this iteration
# backpropagate(leafNode, node.reward) # backpropagate(leafNode, node.reward)
# else # else
# expand(node, transition, transitionargs; # expand(node, transition, transitionargs;
# totalsample=totalsample) # totalsample=totalsample)
# leafNode = selectChildNode(node) # leafNode = selectChildNode(node)
# simTrajectoryReward, terminalstate = simulate(leafNode, transition, transitionargs; # simTrajectoryReward, terminalstate = simulate(leafNode, transition, transitionargs;
# maxdepth=maxdepth, totalsample=totalsample) # maxdepth=maxdepth, totalsample=totalsample)
# # if terminalstate !== nothing #XXX not sure why I need this # # if terminalstate !== nothing #XXX not sure why I need this
# # terminalstate[:totalTrajectoryReward] = simTrajectoryReward # # terminalstate[:totalTrajectoryReward] = simTrajectoryReward
# # end # # end
# #[] write best state to file if it has higher simTrajectoryReward. Use to improve evaluation # #[] write best state to file if it has higher simTrajectoryReward. Use to improve evaluation
# # open("trajectory.json", "w") do io # # open("trajectory.json", "w") do io
# # JSON3.pretty(io, terminalstate) # # JSON3.pretty(io, terminalstate)
# # end # # end
# backpropagate(leafNode, simTrajectoryReward) # backpropagate(leafNode, simTrajectoryReward)
# end # end
# end # end
# bestNextState = selectBestNextNode(root) # bestNextState = selectBestNextNode(root)
# besttrajectory = selectBestTrajectoryNode(root) # besttrajectory = selectBestTrajectoryNode(root)
# return (bestNextState=bestNextState.state, bestFinalState=besttrajectory.state) # return (bestNextState=bestNextState.state, bestFinalState=besttrajectory.state)
# end # end
end # module interface end # module interface

View File

@@ -1,429 +1,429 @@
module mcts module mcts
export selectBestNextNode, selectBestTrajectoryNode, backpropagate, isleaf, isroot, selectChildNode, export selectBestNextNode, selectBestTrajectoryNode, backpropagate, isleaf, isroot, selectChildNode,
expand, simulate, makeNewState expand, simulate, makeNewState
using Base.Threads using Base.Threads
using GeneralUtils using GeneralUtils
using ..type using ..type
# ---------------------------------------------- 100 --------------------------------------------- # # ---------------------------------------------- 100 --------------------------------------------- #
""" """
# Arguments # Arguments
- `node::MCTSNode` - `node::MCTSNode`
node of a search tree node of a search tree
# Return # Return
- `childNode::MCTSNode` - `childNode::MCTSNode`
the highest value child node the highest value child node
# Signature # Signature
""" """
function selectBestNextNode(node::MCTSNode)::MCTSNode function selectBestNextNode(node::MCTSNode)::MCTSNode
highestProgressValue = -1 highestProgressValue = -1
nodekey = nothing nodekey = nothing
# if all childnode has statevalue == 0, use progressvalue + reward to select the best node # if all childnode has statevalue == 0, use progressvalue + reward to select the best node
stateValueSum = sum([v.statevalue for (k, v) in node.children]) stateValueSum = sum([v.statevalue for (k, v) in node.children])
if stateValueSum != 0 if stateValueSum != 0
for (k, childnode) in node.children for (k, childnode) in node.children
potential = childnode.statevalue / childnode.visits potential = childnode.statevalue / childnode.visits
if potential > highestProgressValue if potential > highestProgressValue
highestProgressValue = potential highestProgressValue = potential
nodekey = childnode.nodekey nodekey = childnode.nodekey
end end
end end
else else
for (k, childnode) in node.children for (k, childnode) in node.children
potential = childnode.progressvalue + childnode.reward potential = childnode.progressvalue + childnode.reward
if potential > highestProgressValue if potential > highestProgressValue
highestProgressValue = potential highestProgressValue = potential
nodekey = childnode.nodekey nodekey = childnode.nodekey
end end
end end
end end
return node.children[nodekey] return node.children[nodekey]
end end
""" """
# Arguments # Arguments
- `node::MCTSNode` - `node::MCTSNode`
node of a search tree node of a search tree
# Return # Return
- `childNode::MCTSNode` - `childNode::MCTSNode`
the highest value child node the highest value child node
# Signature # Signature
""" """
function selectBestTrajectoryNode(node::MCTSNode)::MCTSNode function selectBestTrajectoryNode(node::MCTSNode)::MCTSNode
while !isleaf(node) while !isleaf(node)
node = selectBestNextNode(node) node = selectBestNextNode(node)
end end
return node return node
end end
""" Backpropagate reward along the simulation chain """ Backpropagate reward along the simulation chain
# Arguments # Arguments
- `node::MCTSNode` - `node::MCTSNode`
leaf node of a search tree leaf node of a search tree
- `simTrajectoryReward::T` - `simTrajectoryReward::T`
total reward from trajectory simulation total reward from trajectory simulation
- `discountRewardCoeff::AbstractFloat` - `discountRewardCoeff::AbstractFloat`
A discount reward coefficient to reduce future reward. The futher in the future the lower A discount reward coefficient to reduce future reward. The futher in the future the lower
reward it is now. reward it is now.
# Return # Return
- `None` - `None`
# Signature # Signature
""" """
function backpropagate(node::MCTSNode, simTrajectoryReward::T; function backpropagate(node::MCTSNode, simTrajectoryReward::T;
discountRewardCoeff::AbstractFloat=0.9) where {T<:Number} discountRewardCoeff::AbstractFloat=0.9) where {T<:Number}
while !isroot(node) while !isroot(node)
# Update the statistics of the current node based on the result of the playout # Update the statistics of the current node based on the result of the playout
node.visits += 1 node.visits += 1
node.statevalue += ((node.statevalue * (node.visits-1)) + simTrajectoryReward) / node.visits node.statevalue += ((node.statevalue * (node.visits-1)) + simTrajectoryReward) / node.visits
simTrajectoryReward *= discountRewardCoeff # discount because future reward is uncertain simTrajectoryReward *= discountRewardCoeff # discount because future reward is uncertain
node = node.parent node = node.parent
end end
end end
""" Determine whether a node is a leaf node of a search tree. """ Determine whether a node is a leaf node of a search tree.
# Arguments # Arguments
- `node::MCTSNode` - `node::MCTSNode`
a search tree node a search tree node
# Return # Return
- `result::Bool` - `result::Bool`
true if it is a leaf node, false otherwise. true if it is a leaf node, false otherwise.
# Example # Example
```jldoctest ```jldoctest
julia> using Revise julia> using Revise
julia> using YiemAgent, DataStructures julia> using YiemAgent, DataStructures
julia> initialState = Dict{Symbol, Any}( julia> initialState = Dict{Symbol, Any}(
:customerinfo=> Dict{Symbol, Any}(), :customerinfo=> Dict{Symbol, Any}(),
:storeinfo=> Dict{Symbol, Any}(), :storeinfo=> Dict{Symbol, Any}(),
:thoughtHistory=> OrderedDict{Symbol, Any}( :thoughtHistory=> OrderedDict{Symbol, Any}(
:question=> "How are you?", :question=> "How are you?",
) )
) )
julia> statetype = typeof(initialState) julia> statetype = typeof(initialState)
julia> root = YiemAgent.MCTSNode(initialState, 0, 0.0, Dict{statetype, YiemAgent.MCTSNode}()) julia> root = YiemAgent.MCTSNode(initialState, 0, 0.0, Dict{statetype, YiemAgent.MCTSNode}())
julia> YiemAgent.isleaf(root) julia> YiemAgent.isleaf(root)
true true
``` ```
# TODO # TODO
[] update docs [] update docs
# Signature # Signature
""" """
isleaf(node::MCTSNode)::Bool = isempty(node.children) isleaf(node::MCTSNode)::Bool = isempty(node.children)
""" Determine wheter a given node is a root node """ Determine wheter a given node is a root node
# Arguments # Arguments
- `node::MCTSNode` - `node::MCTSNode`
node of a search tree node of a search tree
# Return # Return
- `isrootnode::Bool` - `isrootnode::Bool`
true if the given node is root node, false otherwise true if the given node is root node, false otherwise
# Signature # Signature
""" """
isroot(node::MCTSNode)::Bool = node.nodekey == "root" ? true : false isroot(node::MCTSNode)::Bool = node.nodekey == "root" ? true : false
""" Select child node based on the highest statevalue """ Select child node based on the highest statevalue
# Arguments # Arguments
- `node::MCTSNode` - `node::MCTSNode`
node of a search tree node of a search tree
# Return # Return
- `childNode::MCTSNode` - `childNode::MCTSNode`
the highest value child node the highest value child node
# Signature # Signature
""" """
function selectChildNode(node::MCTSNode)::MCTSNode function selectChildNode(node::MCTSNode)::MCTSNode
highestProgressValue = -1 highestProgressValue = -1
nodekey = nothing nodekey = nothing
# loop thought node children dictionary to find the highest progress value # loop thought node children dictionary to find the highest progress value
for (k, childNode) in node.children for (k, childNode) in node.children
potential = childNode.progressvalue + childNode.reward potential = childNode.progressvalue + childNode.reward
if potential > highestProgressValue if potential > highestProgressValue
highestProgressValue = potential highestProgressValue = potential
nodekey = childNode.nodekey nodekey = childNode.nodekey
end end
end end
return node.children[nodekey] return node.children[nodekey]
end end
""" Expand selected node. """ Expand selected node.
# Arguments # Arguments
- `node::MCTSNode` - `node::MCTSNode`
MCTS node MCTS node
- `transition::Function` - `transition::Function`
A function that handles state transition. A function that handles state transition.
- `transitionargs::NamedTuple` - `transitionargs::NamedTuple`
Arguments for transition() Arguments for transition()
- `totalsample::Integer` - `totalsample::Integer`
Total number to sample from the current node (i.e. expand new node horizontally) Total number to sample from the current node (i.e. expand new node horizontally)
# Return # Return
- None - None
# Signature # Signature
""" """
# function expand(node::MCTSNode, transition::Function, transitionargs::NamedTuple; # function expand(node::MCTSNode, transition::Function, transitionargs::NamedTuple;
# totalsample::Integer=3) # totalsample::Integer=3)
# # not use Any[] because I want to preserve result order # # not use Any[] because I want to preserve result order
# results = Vector{Any}(undef, totalsample) # results = Vector{Any}(undef, totalsample)
# @sync for i in 1:totalsample # @sync for i in 1:totalsample
# @spawn begin # @spawn begin
# result = transition(deepcopy(node.state), deepcopy(transitionargs)) # result = transition(deepcopy(node.state), deepcopy(transitionargs))
# results[i] = result # results[i] = result
# end # end
# end # end
# for result in results # for result in results
# newNodeKey::AbstractString = result[:newNodeKey] # newNodeKey::AbstractString = result[:newNodeKey]
# newstate::AbstractDict = result[:newstate] # newstate::AbstractDict = result[:newstate]
# progressvalue::Integer = result[:progressvalue] # progressvalue::Integer = result[:progressvalue]
# """ # """
# [] newNodeKey ∉ keys(node.children). # [] newNodeKey ∉ keys(node.children).
# New state may have semantic vector close enought to # New state may have semantic vector close enought to
# one of existing child state. Which can be assume that they are the same state # one of existing child state. Which can be assume that they are the same state
# semantically-wise i.e. De javu. This could be used to recall lessons for this # semantically-wise i.e. De javu. This could be used to recall lessons for this
# similar situation to improve decisionMaker and evaluator. # similar situation to improve decisionMaker and evaluator.
# """ # """
# if newNodeKey ∉ keys(node.children) # if newNodeKey ∉ keys(node.children)
# node.children[newNodeKey] = # node.children[newNodeKey] =
# MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward], # MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward],
# newstate[:isterminal], node, Dict{String, MCTSNode}()) # newstate[:isterminal], node, Dict{String, MCTSNode}())
# end # end
# end # end
# end # end
function expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple; function expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple;
totalsample::Integer=3) totalsample::Integer=3)
nthSample = 0 nthSample = 0
while true while true
nthSample += 1 nthSample += 1
if nthSample <= totalsample if nthSample <= totalsample
result = transition(node.state, transitionargs) result = transition(node.state, transitionargs)
newNodeKey::AbstractString = result[:newNodeKey] newNodeKey::AbstractString = result[:newNodeKey]
newstate::AbstractDict = result[:newstate] newstate::AbstractDict = result[:newstate]
progressvalue::Integer = result[:progressvalue] progressvalue::Integer = result[:progressvalue]
""" """
[] newNodeKey ∉ keys(node.children). [] newNodeKey ∉ keys(node.children).
New state may have semantic vector close enought to New state may have semantic vector close enought to
one of existing child state. Which can be assume that they are the same state one of existing child state. Which can be assume that they are the same state
semantically-wise i.e. De javu. This could be used to recall lessons for this semantically-wise i.e. De javu. This could be used to recall lessons for this
similar situation to improve decisionMaker and evaluator. similar situation to improve decisionMaker and evaluator.
""" """
if newNodeKey keys(node.children) if newNodeKey keys(node.children)
node.children[newNodeKey] = node.children[newNodeKey] =
MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward], MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward],
newstate[:isterminal], node, Dict{String, MCTSNode}()) newstate[:isterminal], node, Dict{String, MCTSNode}())
end end
else else
break break
end end
end end
end end
""" Simulate interactions between agent and environment """ Simulate interactions between agent and environment
# Arguments # Arguments
- `node::MCTSNode` - `node::MCTSNode`
node that will be a simulation starting point. node that will be a simulation starting point.
- `transition::Function` - `transition::Function`
A user function that handles how state transition. A user function that handles how state transition.
- `transitionargs::NamedTuple` - `transitionargs::NamedTuple`
Arguments for everything the user will use within transition(). Arguments for everything the user will use within transition().
- `maxdepth::Integer` - `maxdepth::Integer`
maximum depth level MCTS goes vertically. maximum depth level MCTS goes vertically.
- totalsample::Integer - totalsample::Integer
Total number to sample from the current node (i.e. expand new node horizontally) Total number to sample from the current node (i.e. expand new node horizontally)
# Return # Return
- `::NamedTuple{(:simTrajectoryReward, :terminalstate), Tuple{Number, Union{Dict{Symbol, Any}, Nothing}}}` - `::NamedTuple{(:simTrajectoryReward, :terminalstate), Tuple{Number, Union{Dict{Symbol, Any}, Nothing}}}`
# Signature # Signature
""" """
function simulate(node::MCTSNode, transition::Function, transitionargs::NamedTuple; function simulate(node::MCTSNode, transition::Function, transitionargs::NamedTuple;
maxdepth::Integer=3, totalsample::Integer=3 maxdepth::Integer=3, totalsample::Integer=3
)::NamedTuple{(:simTrajectoryReward, :terminalstate), Tuple{Number, Union{Dict{Symbol, Any}, Nothing}}} )::NamedTuple{(:simTrajectoryReward, :terminalstate), Tuple{Number, Union{Dict{Symbol, Any}, Nothing}}}
simTrajectoryReward = 0.0 simTrajectoryReward = 0.0
terminalstate = nothing terminalstate = nothing
for depth in 1:maxdepth for depth in 1:maxdepth
simTrajectoryReward += node.reward simTrajectoryReward += node.reward
if node.isterminal if node.isterminal
terminalstate = node.state terminalstate = node.state
break break
else else
expand(node, transition, transitionargs; expand(node, transition, transitionargs;
totalsample=totalsample) totalsample=totalsample)
node = selectChildNode(node) node = selectChildNode(node)
end end
end end
return (simTrajectoryReward=simTrajectoryReward, terminalstate=terminalstate) return (simTrajectoryReward=simTrajectoryReward, terminalstate=terminalstate)
end end
""" """
# Arguments # Arguments
# Return # Return
# Example # Example
```jldoctest ```jldoctest
julia> julia>
``` ```
# TODO # TODO
- [] update docstring - [] update docstring
- [x] implement the function - [x] implement the function
# Signature # Signature
""" """
function makeNewState(currentstate::T1, thoughtDict::T4, response::T2, select::Union{T3, Nothing}, function makeNewState(currentstate::T1, thoughtDict::T4, response::T2, select::Union{T3, Nothing},
reward::T3, isterminal::Bool reward::T3, isterminal::Bool
)::Tuple{String, Dict{Symbol, <:Any}} where {T1<:AbstractDict, T2<:AbstractString, T3<:Number, T4<:AbstractDict} )::Tuple{String, Dict{Symbol, <:Any}} where {T1<:AbstractDict, T2<:AbstractString, T3<:Number, T4<:AbstractDict}
currentstate_latestThoughtKey, currentstate_latestThoughtIndice = currentstate_latestThoughtKey, currentstate_latestThoughtIndice =
GeneralUtils.findHighestIndexKey(currentstate[:thoughtHistory], "thought") GeneralUtils.findHighestIndexKey(currentstate[:thoughtHistory], "thought")
currentstate_nextIndice = currentstate_nextIndice =
currentstate_latestThoughtKey == :NA ? 1 : currentstate_latestThoughtIndice + 1 currentstate_latestThoughtKey == :NA ? 1 : currentstate_latestThoughtIndice + 1
currentstate_latestThoughtKey = Symbol("thought_$currentstate_nextIndice") currentstate_latestThoughtKey = Symbol("thought_$currentstate_nextIndice")
latestActionKey = Symbol("action_$currentstate_nextIndice") latestActionKey = Symbol("action_$currentstate_nextIndice")
_, thoughtDict_latestThoughtIndice = _, thoughtDict_latestThoughtIndice =
GeneralUtils.findHighestIndexKey(thoughtDict, "thought") GeneralUtils.findHighestIndexKey(thoughtDict, "thought")
thoughtDict_latestThoughtKey, thoughtDict_latestActionKey = thoughtDict_latestThoughtKey, thoughtDict_latestActionKey =
if thoughtDict_latestThoughtIndice == -1 if thoughtDict_latestThoughtIndice == -1
(:thought, :action) (:thought, :action)
else else
( (
Symbol("thought_$thoughtDict_latestThoughtIndice"), Symbol("thought_$thoughtDict_latestThoughtIndice"),
Symbol("action_$thoughtDict_latestThoughtIndice"), Symbol("action_$thoughtDict_latestThoughtIndice"),
) )
end end
# add Thought, action, observation to thoughtHistory # add Thought, action, observation to thoughtHistory
newstate = deepcopy(currentstate) newstate = deepcopy(currentstate)
newstate[:thoughtHistory][currentstate_latestThoughtKey] = newstate[:thoughtHistory][currentstate_latestThoughtKey] =
thoughtDict[thoughtDict_latestThoughtKey] thoughtDict[thoughtDict_latestThoughtKey]
newstate[:thoughtHistory][latestActionKey] = thoughtDict[thoughtDict_latestActionKey] newstate[:thoughtHistory][latestActionKey] = thoughtDict[thoughtDict_latestActionKey]
newObservationKey = Symbol("observation_$(currentstate_nextIndice)") newObservationKey = Symbol("observation_$(currentstate_nextIndice)")
newstate[:thoughtHistory][newObservationKey] = response newstate[:thoughtHistory][newObservationKey] = response
newstate[:reward] = reward newstate[:reward] = reward
newstate[:select] = select newstate[:select] = select
newstate[:isterminal] = isterminal newstate[:isterminal] = isterminal
newNodeKey = GeneralUtils.uuid4snakecase() newNodeKey = GeneralUtils.uuid4snakecase()
return (newNodeKey, newstate) return (newNodeKey, newstate)
end end
end # module mcts end # module mcts

View File

@@ -1,116 +1,116 @@
module type module type
export MCTSNode export MCTSNode
# ---------------------------------------------- 100 --------------------------------------------- # # ---------------------------------------------- 100 --------------------------------------------- #
""" a node for MCTS search tree """ a node for MCTS search tree
# Arguments # Arguments
- `state::T` - `state::T`
a state of a game. Can be a Dict or something else. a state of a game. Can be a Dict or something else.
- `visits::Integer ` - `visits::Integer `
number of time the game visits this state number of time the game visits this state
- `stateValue::Float64` - `stateValue::Float64`
state value state value
- `children::Dict{T, MCTSNode}` - `children::Dict{T, MCTSNode}`
children node children node
# Return # Return
- `nothing` - `nothing`
# Example # Example
```jldoctest ```jldoctest
julia> state = Dict( julia> state = Dict(
:info=> Dict(), # keyword info :info=> Dict(), # keyword info
:thoughtHistory=> Dict( :thoughtHistory=> Dict(
:question=> _, :question=> _,
:thought_1=> _, :thought_1=> _,
:action_1=> _, :action_1=> _,
:observation_1=> _, :observation_1=> _,
:thought_2=> _, :thought_2=> _,
... ...
) )
) )
``` ```
# TODO # TODO
[] update docstring [] update docstring
# Signature # Signature
""" """
mutable struct MCTSNode{T1<:AbstractDict, T2<:AbstractString} mutable struct MCTSNode{T1<:AbstractDict, T2<:AbstractString}
nodekey::T2 nodekey::T2
state::T1 state::T1
visits::Integer visits::Integer
progressvalue::Number # estimate value by LLM's reasoning progressvalue::Number # estimate value by LLM's reasoning
statevalue::Number # current state value. store the node's immediate reward and all future discounted rewards (gather from its child node) statevalue::Number # current state value. store the node's immediate reward and all future discounted rewards (gather from its child node)
reward::Number # this node's immediate reward reward::Number # this node's immediate reward
isterminal::Bool isterminal::Bool
parent::Union{MCTSNode, Nothing} parent::Union{MCTSNode, Nothing}
children::Dict{String, MCTSNode} children::Dict{String, MCTSNode}
end end
end # module type end # module type

View File

@@ -1,139 +1,139 @@
module util module util
export UCTselect export UCTselect
using ..type using ..type
# ---------------------------------------------- 100 --------------------------------------------- # # ---------------------------------------------- 100 --------------------------------------------- #
""" Select a node based on UCT score """ Select a node based on UCT score
# Arguments # Arguments
- `node::MCTSNode` - `node::MCTSNode`
mcts node mcts node
- `w::T` - `w::T`
exploration weight. Value is usually between 1 to 2. exploration weight. Value is usually between 1 to 2.
Value 1.0 makes MCTS balance between exploration and exploitation like 50%-50%. Value 1.0 makes MCTS balance between exploration and exploitation like 50%-50%.
Value 2.0 makes MCTS aggressively search the tree. Value 2.0 makes MCTS aggressively search the tree.
# Return # Return
- `selectedNode::MCTSNode` - `selectedNode::MCTSNode`
# Example # Example
```jldoctest ```jldoctest
julia> julia>
``` ```
# Signature # Signature
""" """
function UCTselect(node::MCTSNode, w::T)::MCTSNode where {T<:AbstractFloat} function UCTselect(node::MCTSNode, w::T)::MCTSNode where {T<:AbstractFloat}
maxUCT = -Inf maxUCT = -Inf
selectedNode = nothing selectedNode = nothing
for (childState, childNode) in node.children for (childState, childNode) in node.children
UCTvalue = UCTvalue =
if childNode.visits != 0 if childNode.visits != 0
weightedterm = w * sqrt(log(node.visits) / childNode.visits) # explore term weightedterm = w * sqrt(log(node.visits) / childNode.visits) # explore term
childNode.statevalue + weightedterm childNode.statevalue + weightedterm
else # node.visits == 0 makes sqrt() in explore term error else # node.visits == 0 makes sqrt() in explore term error
childNode.progressvalue # exploit term childNode.progressvalue # exploit term
end end
if UCTvalue > maxUCT if UCTvalue > maxUCT
maxUCT = UCTvalue maxUCT = UCTvalue
selectedNode = childNode selectedNode = childNode
end end
end end
return selectedNode return selectedNode
end end
end # module util end # module util