This commit is contained in:
narawat lamaiin
2025-03-07 13:33:38 +07:00
parent 6920be2334
commit 9add88b145
10 changed files with 28 additions and 1401 deletions

View File

@@ -1,476 +0,0 @@
# This file is machine-generated - editing it directly is not advised
julia_version = "1.10.3"
manifest_format = "2.0"
project_hash = "b7e1f171d36dc4812d6c1445da530f513320e6cd"
[[deps.AliasTables]]
deps = ["PtrArrays", "Random"]
git-tree-sha1 = "9876e1e164b144ca45e9e3198d0b689cadfed9ff"
uuid = "66dad0bd-aa9a-41b7-9441-69ab47430ed8"
version = "1.1.3"
[[deps.ArgTools]]
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"
version = "1.1.1"
[[deps.Artifacts]]
uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
[[deps.Base64]]
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
[[deps.Calculus]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "f641eb0a4f00c343bbc32346e1217b86f3ce9dad"
uuid = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9"
version = "0.5.1"
[[deps.CodeTracking]]
deps = ["InteractiveUtils", "UUIDs"]
git-tree-sha1 = "c0216e792f518b39b22212127d4a84dc31e4e386"
uuid = "da1fd8a2-8d9e-5ec2-8556-3022fb5608a2"
version = "1.3.5"
[[deps.Compat]]
deps = ["TOML", "UUIDs"]
git-tree-sha1 = "b1c55339b7c6c350ee89f2c1604299660525b248"
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
version = "4.15.0"
weakdeps = ["Dates", "LinearAlgebra"]
[deps.Compat.extensions]
CompatLinearAlgebraExt = "LinearAlgebra"
[[deps.CompilerSupportLibraries_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
version = "1.1.1+0"
[[deps.DataAPI]]
git-tree-sha1 = "abe83f3a2f1b857aac70ef8b269080af17764bbe"
uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
version = "1.16.0"
[[deps.DataStructures]]
deps = ["Compat", "InteractiveUtils", "OrderedCollections"]
git-tree-sha1 = "1d0a14036acb104d9e89698bd408f63ab58cdc82"
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
version = "0.18.20"
[[deps.Dates]]
deps = ["Printf"]
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
[[deps.Distributed]]
deps = ["Random", "Serialization", "Sockets"]
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
[[deps.Distributions]]
deps = ["AliasTables", "FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SpecialFunctions", "Statistics", "StatsAPI", "StatsBase", "StatsFuns"]
git-tree-sha1 = "9c405847cc7ecda2dc921ccf18b47ca150d7317e"
uuid = "31c24e10-a181-5473-b8eb-7969acd0382f"
version = "0.25.109"
[deps.Distributions.extensions]
DistributionsChainRulesCoreExt = "ChainRulesCore"
DistributionsDensityInterfaceExt = "DensityInterface"
DistributionsTestExt = "Test"
[deps.Distributions.weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
[[deps.DocStringExtensions]]
deps = ["LibGit2"]
git-tree-sha1 = "2fb1e02f2b635d0845df5d7c167fec4dd739b00d"
uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
version = "0.9.3"
[[deps.Downloads]]
deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"]
uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
version = "1.6.0"
[[deps.DualNumbers]]
deps = ["Calculus", "NaNMath", "SpecialFunctions"]
git-tree-sha1 = "5837a837389fccf076445fce071c8ddaea35a566"
uuid = "fa6b7ba4-c1ee-5f82-b5fc-ecf0adba8f74"
version = "0.6.8"
[[deps.FileWatching]]
uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee"
[[deps.FillArrays]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "0653c0a2396a6da5bc4766c43041ef5fd3efbe57"
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
version = "1.11.0"
weakdeps = ["PDMats", "SparseArrays", "Statistics"]
[deps.FillArrays.extensions]
FillArraysPDMatsExt = "PDMats"
FillArraysSparseArraysExt = "SparseArrays"
FillArraysStatisticsExt = "Statistics"
[[deps.GeneralUtils]]
deps = ["DataStructures", "Dates", "Distributions", "JSON3", "MQTTClient", "Random", "Revise", "UUIDs"]
path = "/appfolder/app/privatejuliapkg/GeneralUtils"
uuid = "c6c72f09-b708-4ac8-ac7c-2084d70108fe"
version = "0.1.0"
[[deps.HypergeometricFunctions]]
deps = ["DualNumbers", "LinearAlgebra", "OpenLibm_jll", "SpecialFunctions"]
git-tree-sha1 = "f218fe3736ddf977e0e772bc9a586b2383da2685"
uuid = "34004b35-14d8-5ef3-9330-4cdb6864b03a"
version = "0.3.23"
[[deps.InteractiveUtils]]
deps = ["Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
[[deps.IrrationalConstants]]
git-tree-sha1 = "630b497eafcc20001bba38a4651b327dcfc491d2"
uuid = "92d709cd-6900-40b7-9082-c6be49f344b6"
version = "0.2.2"
[[deps.JLLWrappers]]
deps = ["Artifacts", "Preferences"]
git-tree-sha1 = "7e5d6779a1e09a36db2a7b6cff50942a0a7d0fca"
uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210"
version = "1.5.0"
[[deps.JSON3]]
deps = ["Dates", "Mmap", "Parsers", "PrecompileTools", "StructTypes", "UUIDs"]
git-tree-sha1 = "eb3edce0ed4fa32f75a0a11217433c31d56bd48b"
uuid = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
version = "1.14.0"
[deps.JSON3.extensions]
JSON3ArrowExt = ["ArrowTypes"]
[deps.JSON3.weakdeps]
ArrowTypes = "31f734f8-188a-4ce0-8406-c8a06bd891cd"
[[deps.JuliaInterpreter]]
deps = ["CodeTracking", "InteractiveUtils", "Random", "UUIDs"]
git-tree-sha1 = "e9648d90370e2d0317f9518c9c6e0841db54a90b"
uuid = "aa1ae85d-cabe-5617-a682-6adf51b2e16a"
version = "0.9.31"
[[deps.LibCURL]]
deps = ["LibCURL_jll", "MozillaCACerts_jll"]
uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21"
version = "0.6.4"
[[deps.LibCURL_jll]]
deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"]
uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0"
version = "8.4.0+0"
[[deps.LibGit2]]
deps = ["Base64", "LibGit2_jll", "NetworkOptions", "Printf", "SHA"]
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
[[deps.LibGit2_jll]]
deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll"]
uuid = "e37daf67-58a4-590a-8e99-b0245dd2ffc5"
version = "1.6.4+0"
[[deps.LibSSH2_jll]]
deps = ["Artifacts", "Libdl", "MbedTLS_jll"]
uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8"
version = "1.11.0+1"
[[deps.Libdl]]
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
[[deps.LinearAlgebra]]
deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"]
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
[[deps.LogExpFunctions]]
deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"]
git-tree-sha1 = "18144f3e9cbe9b15b070288eef858f71b291ce37"
uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
version = "0.3.27"
[deps.LogExpFunctions.extensions]
LogExpFunctionsChainRulesCoreExt = "ChainRulesCore"
LogExpFunctionsChangesOfVariablesExt = "ChangesOfVariables"
LogExpFunctionsInverseFunctionsExt = "InverseFunctions"
[deps.LogExpFunctions.weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
[[deps.Logging]]
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
[[deps.LoweredCodeUtils]]
deps = ["JuliaInterpreter"]
git-tree-sha1 = "c6a36b22d2cca0e1a903f00f600991f97bf5f426"
uuid = "6f1432cf-f94c-5a45-995e-cdbf5db27b0b"
version = "2.4.6"
[[deps.MQTTClient]]
deps = ["Distributed", "Random", "Sockets"]
git-tree-sha1 = "f2597b290d4bf17b577346153cd2ddf9accb5c26"
uuid = "985f35cc-2c3d-4943-b8c1-f0931d5f0959"
version = "0.3.1"
weakdeps = ["PrecompileTools"]
[deps.MQTTClient.extensions]
PrecompileMQTT = "PrecompileTools"
[[deps.Markdown]]
deps = ["Base64"]
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
[[deps.MbedTLS_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1"
version = "2.28.2+1"
[[deps.Missings]]
deps = ["DataAPI"]
git-tree-sha1 = "ec4f7fbeab05d7747bdf98eb74d130a2a2ed298d"
uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
version = "1.2.0"
[[deps.Mmap]]
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
[[deps.MozillaCACerts_jll]]
uuid = "14a3606d-f60d-562e-9121-12d972cd8159"
version = "2023.1.10"
[[deps.NaNMath]]
deps = ["OpenLibm_jll"]
git-tree-sha1 = "0877504529a3e5c3343c6f8b4c0381e57e4387e4"
uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
version = "1.0.2"
[[deps.NetworkOptions]]
uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"
version = "1.2.0"
[[deps.OpenBLAS_jll]]
deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"]
uuid = "4536629a-c528-5b80-bd46-f80d51c5b363"
version = "0.3.23+4"
[[deps.OpenLibm_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "05823500-19ac-5b8b-9628-191a04bc5112"
version = "0.8.1+2"
[[deps.OpenSpecFun_jll]]
deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"]
git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1"
uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e"
version = "0.5.5+0"
[[deps.OrderedCollections]]
git-tree-sha1 = "dfdf5519f235516220579f949664f1bf44e741c5"
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
version = "1.6.3"
[[deps.PDMats]]
deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"]
git-tree-sha1 = "949347156c25054de2db3b166c52ac4728cbad65"
uuid = "90014a1f-27ba-587c-ab20-58faa44d9150"
version = "0.11.31"
[[deps.Parsers]]
deps = ["Dates", "PrecompileTools", "UUIDs"]
git-tree-sha1 = "8489905bcdbcfac64d1daa51ca07c0d8f0283821"
uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0"
version = "2.8.1"
[[deps.Pkg]]
deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"]
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
version = "1.10.0"
[[deps.PrecompileTools]]
deps = ["Preferences"]
git-tree-sha1 = "5aa36f7049a63a1528fe8f7c3f2113413ffd4e1f"
uuid = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
version = "1.2.1"
[[deps.Preferences]]
deps = ["TOML"]
git-tree-sha1 = "9306f6085165d270f7e3db02af26a400d580f5c6"
uuid = "21216c6a-2e73-6563-6e65-726566657250"
version = "1.4.3"
[[deps.Printf]]
deps = ["Unicode"]
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
[[deps.PtrArrays]]
git-tree-sha1 = "f011fbb92c4d401059b2212c05c0601b70f8b759"
uuid = "43287f4e-b6f4-7ad1-bb20-aadabca52c3d"
version = "1.2.0"
[[deps.QuadGK]]
deps = ["DataStructures", "LinearAlgebra"]
git-tree-sha1 = "9b23c31e76e333e6fb4c1595ae6afa74966a729e"
uuid = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
version = "2.9.4"
[[deps.REPL]]
deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"]
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
[[deps.Random]]
deps = ["SHA"]
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
[[deps.Reexport]]
git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b"
uuid = "189a3867-3050-52da-a836-e630ba90ab69"
version = "1.2.2"
[[deps.Requires]]
deps = ["UUIDs"]
git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7"
uuid = "ae029012-a4dd-5104-9daa-d747884805df"
version = "1.3.0"
[[deps.Revise]]
deps = ["CodeTracking", "Distributed", "FileWatching", "JuliaInterpreter", "LibGit2", "LoweredCodeUtils", "OrderedCollections", "Pkg", "REPL", "Requires", "UUIDs", "Unicode"]
git-tree-sha1 = "12aa2d7593df490c407a3bbd8b86b8b515017f3e"
uuid = "295af30f-e4ad-537b-8983-00126c2a3abe"
version = "3.5.14"
[[deps.Rmath]]
deps = ["Random", "Rmath_jll"]
git-tree-sha1 = "f65dcb5fa46aee0cf9ed6274ccbd597adc49aa7b"
uuid = "79098fc4-a85e-5d69-aa6a-4863f24498fa"
version = "0.7.1"
[[deps.Rmath_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl"]
git-tree-sha1 = "d483cd324ce5cf5d61b77930f0bbd6cb61927d21"
uuid = "f50d1b31-88e8-58de-be2c-1cc44531875f"
version = "0.4.2+0"
[[deps.SHA]]
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
version = "0.7.0"
[[deps.Serialization]]
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
[[deps.Sockets]]
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
[[deps.SortingAlgorithms]]
deps = ["DataStructures"]
git-tree-sha1 = "66e0a8e672a0bdfca2c3f5937efb8538b9ddc085"
uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c"
version = "1.2.1"
[[deps.SparseArrays]]
deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"]
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
version = "1.10.0"
[[deps.SpecialFunctions]]
deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"]
git-tree-sha1 = "2f5d4697f21388cbe1ff299430dd169ef97d7e14"
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
version = "2.4.0"
[deps.SpecialFunctions.extensions]
SpecialFunctionsChainRulesCoreExt = "ChainRulesCore"
[deps.SpecialFunctions.weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
[[deps.Statistics]]
deps = ["LinearAlgebra", "SparseArrays"]
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
version = "1.10.0"
[[deps.StatsAPI]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "1ff449ad350c9c4cbc756624d6f8a8c3ef56d3ed"
uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0"
version = "1.7.0"
[[deps.StatsBase]]
deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"]
git-tree-sha1 = "5cf7606d6cef84b543b483848d4ae08ad9832b21"
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
version = "0.34.3"
[[deps.StatsFuns]]
deps = ["HypergeometricFunctions", "IrrationalConstants", "LogExpFunctions", "Reexport", "Rmath", "SpecialFunctions"]
git-tree-sha1 = "cef0472124fab0695b58ca35a77c6fb942fdab8a"
uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
version = "1.3.1"
[deps.StatsFuns.extensions]
StatsFunsChainRulesCoreExt = "ChainRulesCore"
StatsFunsInverseFunctionsExt = "InverseFunctions"
[deps.StatsFuns.weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
[[deps.StructTypes]]
deps = ["Dates", "UUIDs"]
git-tree-sha1 = "ca4bccb03acf9faaf4137a9abc1881ed1841aa70"
uuid = "856f2bd8-1eba-4b0a-8007-ebc267875bd4"
version = "1.10.0"
[[deps.SuiteSparse]]
deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"]
uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"
[[deps.SuiteSparse_jll]]
deps = ["Artifacts", "Libdl", "libblastrampoline_jll"]
uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c"
version = "7.2.1+1"
[[deps.TOML]]
deps = ["Dates"]
uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
version = "1.0.3"
[[deps.Tar]]
deps = ["ArgTools", "SHA"]
uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e"
version = "1.10.0"
[[deps.UUIDs]]
deps = ["Random", "SHA"]
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
[[deps.Unicode]]
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
[[deps.Zlib_jll]]
deps = ["Libdl"]
uuid = "83775a58-1f1d-513f-b197-d71354ab007a"
version = "1.2.13+1"
[[deps.libblastrampoline_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "8e850b90-86db-534c-a0d3-1478176c7d93"
version = "5.8.0+1"
[[deps.nghttp2_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d"
version = "1.52.0+1"
[[deps.p7zip_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0"
version = "17.4.0+2"

View File

@@ -1,8 +0,0 @@
name = "LLMMCTS"
uuid = "d76c5a4d-449e-4835-8cc4-dd86ec44f241"
authors = ["narawat lamaiin <narawat@outlook.com>"]
version = "0.1.0"
[deps]
GeneralUtils = "c6c72f09-b708-4ac8-ac7c-2084d70108fe"
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"

View File

@@ -1,28 +0,0 @@
module LLMMCTS
# export agent
""" Order by dependencies of each file. The 1st included file must not depend on any other
files and each file can only depend on the file included before it.
"""
include("type.jl")
using .type
include("util.jl")
using .util
include("mcts.jl")
using .mcts
include("interface.jl")
using .interface
# ---------------------------------------------- 100 --------------------------------------------- #
end # module LLMMCTS

View File

@@ -1,180 +0,0 @@
module interface
export runMCTS
using ..type, ..mcts
# ---------------------------------------------- 100 --------------------------------------------- #
""" Search the best action to take for a given state and task
# Arguments
- `a::agent`
one of Yiem's agents
- `initial state`
initial state
- `decisionMaker::Function`
decide what action to take
- `evaluator::Function`
assess the value of the state
- `reflector::Function`
generate lesson from trajectory and reward
- `isterminal::Function`
determine whether a given state is a terminal state
- `n::Integer`
how many times action will be sampled from decisionMaker
- `w::Float64`
exploration weight. Value is usually between 1 to 2.
Value 1.0 makes MCTS balance between exploration and exploitation like 50%-50%
Value 2.0 makes MCTS aggressively search the tree
# Return
- `plan::Vector{Dict}`
best plan
# Example
```jldoctest
julia>
```
# TODO
[] update docstring
[] return best action
# Signature
"""
function runMCTS(
config::T1,
initialState,
decisionMaker::Function,
evaluator::Function,
reflector::Function,
transition::Function,
;
totalsample::Integer=3,
maxDepth::Integer=3,
maxiterations::Integer=10,
explorationweight::Number=1.0,
) where {T1<:AbstractDict}
root = MCTSNode("root", initialState, 0, 0, 0, 0, false, nothing, Dict{String, MCTSNode}())
for nth in 1:maxiterations
node = root
node.visits += 1
while !isleaf(node)
node = UCTselect(node, explorationweight)
end
if node.isterminal
# MCTS arrive at the leaf node that is also a terminal state,
# do nothing then go directly to backpropagation
backpropagate(leafNode, node.reward)
else
expand(config, node, decisionMaker, evaluator, reflector, transition;
totalsample=totalsample)
leafNode = selectChildNode(node)
simTrajectoryReward, terminalstate = simulate(config, leafNode, decisionMaker, evaluator,
reflector, transition; maxDepth=maxDepth, totalsample=totalsample)
if terminalstate !== nothing #XXX not sure why I need this
terminalstate[:totalTrajectoryReward] = simTrajectoryReward
end
#[] write best state to file if it has higher simTrajectoryReward. Use to improve evaluation
# open("trajectory.json", "w") do io
# JSON3.pretty(io, terminalstate)
# end
backpropagate(leafNode, simTrajectoryReward)
end
end
bestNextState = selectBestNextState(root)
besttrajectory = selectBestTrajectory(root)
return (bestNextState.state, besttrajectory.state)
end
end # module interface

View File

@@ -1,438 +0,0 @@
module mcts
export selectBestNextState, selectBestTrajectory, backpropagate, isleaf, isroot, selectChildNode,
expand, simulate, makeNewState
using GeneralUtils
using ..type
# ---------------------------------------------- 100 --------------------------------------------- #
"""
# Arguments
- `node::MCTSNode`
node of a search tree
# Return
- `childNode::MCTSNode`
the highest value child node
# Example
```jldoctest
julia>
```
# TODO
- [] update docs
- [x] implement the function
# Signature
"""
function selectBestNextState(node::MCTSNode)::MCTSNode
highestProgressValue = 0
nodekey = nothing
# if all childnode has statevalue == 0, use progressvalue + reward to select the best node
stateValueSum = sum([v.statevalue for (k, v) in node.children])
if stateValueSum != 0
for (k, childnode) in node.children
potential = childnode.statevalue / childnode.visits
if potential > highestProgressValue
highestProgressValue = potential
nodekey = childnode.nodekey
end
end
else
for (k, childnode) in node.children
potential = childnode.progressvalue + childnode.reward
if potential > highestProgressValue
highestProgressValue = potential
nodekey = childnode.nodekey
end
end
end
return node.children[nodekey]
end
"""
# Arguments
- `node::MCTSNode`
node of a search tree
# Return
- `childNode::MCTSNode`
the highest value child node
# Example
```jldoctest
julia>
```
# TODO
- [] update docs
- [x] implement the function
# Signature
"""
function selectBestTrajectory(node::MCTSNode)::MCTSNode
while !isleaf(node)
node = selectBestNextState(node)
end
return node
end
""" Backpropagate reward along the simulation chain
# Arguments
- `node::MCTSNode`
leaf node of a search tree
- `simTrajectoryReward::T`
total reward from trajectory simulation
# Return
- `No return`
# Example
```jldoctest
julia>
```
# Signature
"""
function backpropagate(node::MCTSNode, simTrajectoryReward::T;
discountRewardCoeff::AbstractFloat=0.9) where {T<:Number}
while !isroot(node)
# Update the statistics of the current node based on the result of the playout
node.visits += 1
node.statevalue += ((node.statevalue * (node.visits-1)) + simTrajectoryReward) / node.visits
simTrajectoryReward *= discountRewardCoeff # discount because future reward is uncertain
node = node.parent
end
end
""" Determine whether a node is a leaf node of a search tree.
# Arguments
- `node::MCTSNode`
a search tree node
# Return
- `result::Bool`
true if it is a leaf node, false otherwise.
# Example
```jldoctest
julia> using Revise
julia> using YiemAgent, DataStructures
julia> initialState = Dict{Symbol, Any}(
:customerinfo=> Dict{Symbol, Any}(),
:storeinfo=> Dict{Symbol, Any}(),
:thoughtHistory=> OrderedDict{Symbol, Any}(
:question=> "How are you?",
)
)
julia> statetype = typeof(initialState)
julia> root = YiemAgent.MCTSNode(initialState, 0, 0.0, Dict{statetype, YiemAgent.MCTSNode}())
julia> YiemAgent.isleaf(root)
true
```
# TODO
[] update docs
# Signature
"""
isleaf(node::MCTSNode)::Bool = isempty(node.children)
""" Determine wheter a given node is a root node
# Arguments
- `node::MCTSNode`
node of a search tree
# Return
- `isrootnode::Bool`
true if the given node is root node, false otherwise
# Example
```jldoctest
julia>
```
# Signature
"""
isroot(node::MCTSNode)::Bool = node.nodekey == "root" ? true : false
""" Select child node based on the highest statevalue
# Arguments
- `node::MCTSNode`
node of a search tree
# Return
- `childNode::MCTSNode`
the highest value child node
# Example
```jldoctest
julia>
```
# Signature
"""
function selectChildNode(node::MCTSNode)::MCTSNode
highestProgressValue = 0
nodekey = nothing
# loop thought node children dictionary to find the highest progress value
for (k, childNode) in node.children
potential = childNode.progressvalue + childNode.reward
if childNode.reward > 0 #XXX for testing. remove when done.
println("")
end
if potential > highestProgressValue
highestProgressValue = potential
nodekey = childNode.nodekey
end
end
return node.children[nodekey]
end
""" Expand selected node
# Arguments
- `a::T1`
One of YiemAgent's agent
- `node::MCTSNode`
MCTS node
- `state::T2`
a state of a game. Can be a Dict or something else.
- `decisionMaker::Function`
a function that output Thought and Action
- `evaluator::Function`
a function that output trajectory progress score
# Return
# Example
```jldoctest
julia>
```
# TODO
[] update docstring
[] try loop should limit to 3 times. if not succeed, skip
[] newNodeKey ∉ keys(node.children). New state may have semantic vector close enought to one of existing child state. Which can be assume that they are the same state semantically-wise.
[x] store feedback -> state -> agent.
# Signature
"""
function expand(config::T1, node::MCTSNode, decisionMaker::Function, evaluator::Function,
reflector::Function, transition::Function; totalsample::Integer=3
) where {T1<:AbstractDict}
nthSample = 0
while true
nthSample += 1
if nthSample <= totalsample
newNodeKey, newstate, progressvalue = transition(config, node.state, decisionMaker,
evaluator, reflector)
if newNodeKey keys(node.children)
node.children[newNodeKey] =
MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward],
newstate[:isterminal], node, Dict{String, MCTSNode}())
end
else
break
end
end
end
""" Simulate interactions between agent and environment
# Arguments
- `a::T`
one of YiemAgent's agent
- `node::MCTSNode`
node that will be a simulation starting point.
- `decisionMaker::Function`
function that receive state return Thought and Action
# Return
- `simTrajectoryReward::Number`
# Example
```jldoctest
julia>
```
# TODO
- [] update docs
# Signature
"""
function simulate(config::T, node::MCTSNode, decisionMaker::Function, evaluator::Function,
reflector::Function, transition::Function; maxDepth::Integer=3, totalsample::Integer=3
)::Union{Tuple{Number, Dict{Symbol, <:Any}}, Tuple{Number, Nothing}} where {T<:AbstractDict}
simTrajectoryReward = 0.0
terminalstate = nothing
for depth in 1:maxDepth
simTrajectoryReward += node.reward
if node.isterminal
terminalstate = node.state
break
else
expand(config, node, decisionMaker, evaluator, reflector, transition;
totalsample=totalsample)
node = selectChildNode(node)
end
end
return (simTrajectoryReward, terminalstate)
end
"""
# Arguments
# Return
# Example
```jldoctest
julia>
```
# TODO
- [] update docstring
- [x] implement the function
# Signature
"""
function makeNewState(currentstate::T1, thoughtDict::T4, response::T2, select::Union{T3, Nothing},
reward::T3, isterminal::Bool
)::Tuple{String, Dict{Symbol, <:Any}} where {T1<:AbstractDict, T2<:AbstractString, T3<:Number, T4<:AbstractDict}
currentstate_latestThoughtKey, currentstate_latestThoughtIndice =
GeneralUtils.findHighestIndexKey(currentstate[:thoughtHistory], "thought")
currentstate_nextIndice =
currentstate_latestThoughtKey == :NA ? 1 : currentstate_latestThoughtIndice + 1
currentstate_latestThoughtKey = Symbol("thought_$currentstate_nextIndice")
latestActionKey = Symbol("action_$currentstate_nextIndice")
_, thoughtDict_latestThoughtIndice =
GeneralUtils.findHighestIndexKey(thoughtDict, "thought")
thoughtDict_latestThoughtKey, thoughtDict_latestActionKey =
if thoughtDict_latestThoughtIndice == -1
(:thought, :action)
else
(
Symbol("thought_$thoughtDict_latestThoughtIndice"),
Symbol("action_$thoughtDict_latestThoughtIndice"),
)
end
# add Thought, action, observation to thoughtHistory
newstate = deepcopy(currentstate)
newstate[:thoughtHistory][currentstate_latestThoughtKey] =
thoughtDict[thoughtDict_latestThoughtKey]
newstate[:thoughtHistory][latestActionKey] = thoughtDict[thoughtDict_latestActionKey]
newObservationKey = Symbol("observation_$(currentstate_nextIndice)")
newstate[:thoughtHistory][newObservationKey] = response
newstate[:reward] = reward
newstate[:select] = select
newstate[:isterminal] = isterminal
newNodeKey = GeneralUtils.uuid4snakecase()
return (newNodeKey, newstate)
end
end # module mcts

View File

@@ -1,116 +0,0 @@
module type
export MCTSNode
# ---------------------------------------------- 100 --------------------------------------------- #
""" a node for MCTS search tree
# Arguments
- `state::T`
a state of a game. Can be a Dict or something else.
- `visits::Integer `
number of time the game visits this state
- `stateValue::Float64`
state value
- `children::Dict{T, MCTSNode}`
children node
# Return
- `nothing`
# Example
```jldoctest
julia> state = Dict(
:info=> Dict(), # keyword info
:thoughtHistory=> Dict(
:question=> _,
:thought_1=> _,
:action_1=> _,
:observation_1=> _,
:thought_2=> _,
...
)
)
```
# TODO
[] update docstring
# Signature
"""
mutable struct MCTSNode{T1<:AbstractDict, T2<:AbstractString}
nodekey::T2
state::T1
visits::Integer
progressvalue::Number # estimate value by LLM's reasoning
statevalue::Number # store discounted commulative reward (gather from its child node)
reward::Number # this node's own reward
isterminal::Bool
parent::Union{MCTSNode, Nothing}
children::Dict{String, MCTSNode}
end
end # module type

View File

@@ -1,139 +0,0 @@
module util
export UCTselect
using ..type
# ---------------------------------------------- 100 --------------------------------------------- #
""" Select a node based on UCT score
# Arguments
- `node::MCTSNode`
mcts node
- `w::T`
exploration weight. Value is usually between 1 to 2.
Value 1.0 makes MCTS balance between exploration and exploitation like 50%-50%.
Value 2.0 makes MCTS aggressively search the tree.
# Return
- `selectedNode::MCTSNode`
# Example
```jldoctest
julia>
```
# Signature
"""
function UCTselect(node::MCTSNode, w::T)::MCTSNode where {T<:AbstractFloat}
maxUCT = -Inf
selectedNode = nothing
for (childState, childNode) in node.children
UCTvalue =
if childNode.visits != 0
weightedterm = w * sqrt(log(node.visits) / childNode.visits) # explore term
childNode.statevalue + weightedterm
else # node.visits == 0 makes sqrt() in explore term error
childNode.progressvalue # exploit term
end
if UCTvalue > maxUCT
maxUCT = UCTvalue
selectedNode = childNode
end
end
return selectedNode
end
end # module util

View File

@@ -2,6 +2,7 @@ module interface
export runMCTS export runMCTS
using Base.Threads
using ..type, ..mcts, ..util using ..type, ..mcts, ..util
@@ -45,7 +46,8 @@ function runMCTS(
transition::Function, transition::Function,
transitionargs::NamedTuple, transitionargs::NamedTuple,
; ;
totalsample::Integer=3, horizontalSampleExpansionPhase::Integer=3,
horizontalSampleSimulationPhase::Integer=3,
maxdepth::Integer=3, maxdepth::Integer=3,
maxiterations::Integer=10, maxiterations::Integer=10,
explorationweight::Number=1.0, explorationweight::Number=1.0,
@@ -69,10 +71,20 @@ function runMCTS(
backpropagate(node, node.reward) backpropagate(node, node.reward)
else else
_ = expand(node, transition, transitionargs; _ = expand(node, transition, transitionargs;
totalsample=totalsample) horizontalSample=horizontalSampleExpansionPhase)
#[WORKING] make simulation parallel, leafNodes must be newly expanded nodes
leafNode = selectChildNode(node) leafNode = selectChildNode(node)
simTrajectoryReward, terminalstate = simulate(leafNode, transition, transitionargs;
maxdepth=maxdepth, totalsample=totalsample)
# outputch = Channel(8)
#[WORKING] test whether multiple spawn retain result leafNode's child node
@spawn simulate(outputch, leafNode, transition, transitionargs;
maxdepth=maxdepth, horizontalSample=horizontalSampleSimulationPhase)
# if terminalstate !== nothing #XXX not sure why I need this # if terminalstate !== nothing #XXX not sure why I need this
# terminalstate[:totalTrajectoryReward] = simTrajectoryReward # terminalstate[:totalTrajectoryReward] = simTrajectoryReward
# end # end
@@ -82,10 +94,9 @@ function runMCTS(
# JSON3.pretty(io, terminalstate) # JSON3.pretty(io, terminalstate)
# end # end
backpropagate(leafNode, simTrajectoryReward) # result = take!(outputch)
# delete all child node, no need for child node that was created during simulation backpropagate(leafNode, simTrajectoryReward)
leafNode.children = Dict{String,MCTSNode}()
end end
# stop if the early stop condition is met # stop if the early stop condition is met

View File

@@ -231,13 +231,13 @@ end
# end # end
# end # end
function expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple; function expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple;
totalsample::Integer=3) horizontalSample::Integer=3)
nthSample = 0 nthSample = 0
listOfNewNodeId = [] listOfNewNodeId = []
while true while true
nthSample += 1 nthSample += 1
if nthSample <= totalsample if nthSample <= horizontalSample
result = transition(node.state, transitionargs) result = transition(node.state, transitionargs)
newNodeKey::AbstractString = result[:newNodeKey] newNodeKey::AbstractString = result[:newNodeKey]
newstate::AbstractDict = result[:newstate] newstate::AbstractDict = result[:newstate]
@@ -252,9 +252,9 @@ function expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple;
""" """
if newNodeKey keys(node.children) if newNodeKey keys(node.children)
push!(listOfNewNodeId, newNodeKey) push!(listOfNewNodeId, newNodeKey)
node.children[newNodeKey] = newNode = MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward],
MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward],
newstate[:isterminal], node, Dict{String, MCTSNode}(), Dict{Symbol, Any}()) newstate[:isterminal], node, Dict{String, MCTSNode}(), Dict{Symbol, Any}())
node.children[newNodeKey] = newNode
end end
else else
return listOfNewNodeId return listOfNewNodeId
@@ -274,7 +274,7 @@ end
Arguments for everything the user will use within transition(). Arguments for everything the user will use within transition().
- `maxdepth::Integer` - `maxdepth::Integer`
maximum depth level MCTS goes vertically. maximum depth level MCTS goes vertically.
- totalsample::Integer - horizontalSample::Integer
Total number to sample from the current node (i.e. expand new node horizontally) Total number to sample from the current node (i.e. expand new node horizontally)
# Return # Return
@@ -282,8 +282,8 @@ end
# Signature # Signature
""" """
function simulate(node::MCTSNode, transition::Function, transitionargs::NamedTuple; function simulate(outputchannel::Channel, node::MCTSNode, transition::Function, transitionargs::NamedTuple;
maxdepth::Integer=3, totalsample::Integer=3 maxdepth::Integer=3, horizontalSample::Integer=3
)::NamedTuple{(:simTrajectoryReward, :terminalstate), Tuple{Number, Union{Dict{Symbol, Any}, Nothing}}} )::NamedTuple{(:simTrajectoryReward, :terminalstate), Tuple{Number, Union{Dict{Symbol, Any}, Nothing}}}
simTrajectoryReward = 0.0 simTrajectoryReward = 0.0
@@ -297,12 +297,13 @@ function simulate(node::MCTSNode, transition::Function, transitionargs::NamedTup
break break
else else
_ = expand(node, transition, transitionargs; _ = expand(node, transition, transitionargs;
totalsample=totalsample) horizontalSample=horizontalSample)
node = selectChildNode(node) node = selectChildNode(node)
end end
end end
return (simTrajectoryReward=simTrajectoryReward, terminalstate=terminalstate) put!(outputchannel, (simTrajectoryReward=simTrajectoryReward, terminalstate=terminalstate))
# return (simTrajectoryReward=simTrajectoryReward, terminalstate=terminalstate)
end end