From 351bccc0592b4be583b1fcabbc72289f5d8ccfdd Mon Sep 17 00:00:00 2001 From: narawat lamaiin Date: Mon, 14 Oct 2024 09:13:46 +0700 Subject: [PATCH] update --- Manifest.toml | 159 +++++++++++++++++++++++++---------------------- src/LLMMCTS.jl | 10 +++ src/interface.jl | 89 +++++++++++++++++++++----- src/type.jl | 4 +- 4 files changed, 169 insertions(+), 93 deletions(-) diff --git a/Manifest.toml b/Manifest.toml index 9582465..028b55f 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -1,6 +1,6 @@ # This file is machine-generated - editing it directly is not advised -julia_version = "1.10.4" +julia_version = "1.11.0" manifest_format = "2.0" project_hash = "b7e1f171d36dc4812d6c1445da530f513320e6cd" @@ -12,13 +12,15 @@ version = "1.1.3" [[deps.ArgTools]] uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" -version = "1.1.1" +version = "1.1.2" [[deps.Artifacts]] uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" +version = "1.11.0" [[deps.Base64]] uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" +version = "1.11.0" [[deps.CSV]] deps = ["CodecZlib", "Dates", "FilePathsBase", "InlineStrings", "Mmap", "Parsers", "PooledArrays", "PrecompileTools", "SentinelArrays", "Tables", "Unicode", "WeakRefStrings", "WorkerUtilities"] @@ -26,12 +28,6 @@ git-tree-sha1 = "6c834533dc1fabd820c1db03c839bf97e45a3fab" uuid = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" version = "0.10.14" -[[deps.CodeTracking]] -deps = ["InteractiveUtils", "UUIDs"] -git-tree-sha1 = "7eee164f122511d3e4e1ebadb7956939ea7e1c77" -uuid = "da1fd8a2-8d9e-5ec2-8556-3022fb5608a2" -version = "1.3.6" - [[deps.CodecZlib]] deps = ["TranscodingStreams", "Zlib_jll"] git-tree-sha1 = "bce6804e5e6044c6daab27bb533d1295e4a2e759" @@ -64,10 +60,10 @@ uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" version = "1.16.0" [[deps.DataFrames]] -deps = ["Compat", "DataAPI", "DataStructures", "Future", "InlineStrings", "InvertedIndices", "IteratorInterfaceExtensions", "LinearAlgebra", "Markdown", "Missings", "PooledArrays", "PrecompileTools", "PrettyTables", "Printf", "REPL", "Random", "Reexport", "SentinelArrays", "SortingAlgorithms", "Statistics", "TableTraits", "Tables", "Unicode"] -git-tree-sha1 = "04c738083f29f86e62c8afc341f0967d8717bdb8" +deps = ["Compat", "DataAPI", "DataStructures", "Future", "InlineStrings", "InvertedIndices", "IteratorInterfaceExtensions", "LinearAlgebra", "Markdown", "Missings", "PooledArrays", "PrecompileTools", "PrettyTables", "Printf", "Random", "Reexport", "SentinelArrays", "SortingAlgorithms", "Statistics", "TableTraits", "Tables", "Unicode"] +git-tree-sha1 = "fb61b4812c49343d7ef0b533ba982c46021938a6" uuid = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" -version = "1.6.1" +version = "1.7.0" [[deps.DataStructures]] deps = ["Compat", "InteractiveUtils", "OrderedCollections"] @@ -83,16 +79,18 @@ version = "1.0.0" [[deps.Dates]] deps = ["Printf"] uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" +version = "1.11.0" [[deps.Distributed]] deps = ["Random", "Serialization", "Sockets"] uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" +version = "1.11.0" [[deps.Distributions]] deps = ["AliasTables", "FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SpecialFunctions", "Statistics", "StatsAPI", "StatsBase", "StatsFuns"] -git-tree-sha1 = "e6c693a0e4394f8fda0e51a5bdf5aef26f8235e9" +git-tree-sha1 = "d7477ecdafb813ddee2ae727afa94e9dcb5f3fb0" uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" -version = "0.25.111" +version = "0.25.112" [deps.Distributions.extensions] DistributionsChainRulesCoreExt = "ChainRulesCore" @@ -131,12 +129,13 @@ version = "0.9.22" [[deps.FileWatching]] uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" +version = "1.11.0" [[deps.FillArrays]] deps = ["LinearAlgebra"] -git-tree-sha1 = "fd0002c0b5362d7eb952450ad5eb742443340d6e" +git-tree-sha1 = "6a70198746448456524cb442b8af316927ff3e1a" uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" -version = "1.12.0" +version = "1.13.0" weakdeps = ["PDMats", "SparseArrays", "Statistics"] [deps.FillArrays.extensions] @@ -147,9 +146,10 @@ weakdeps = ["PDMats", "SparseArrays", "Statistics"] [[deps.Future]] deps = ["Random"] uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" +version = "1.11.0" [[deps.GeneralUtils]] -deps = ["CSV", "DataFrames", "DataStructures", "Dates", "Distributions", "JSON3", "MQTTClient", "Random", "Revise", "UUIDs"] +deps = ["CSV", "DataFrames", "DataStructures", "Dates", "Distributions", "JSON3", "MQTTClient", "PrettyPrinting", "Random", "SHA", "UUIDs"] path = "/appfolder/app/privatejuliapkg/GeneralUtils" uuid = "c6c72f09-b708-4ac8-ac7c-2084d70108fe" version = "0.1.0" @@ -176,6 +176,7 @@ version = "1.4.2" [[deps.InteractiveUtils]] deps = ["Markdown"] uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" +version = "1.11.0" [[deps.InvertedIndices]] git-tree-sha1 = "0dc7b50b8d436461be01300fd8cd45aa0274b038" @@ -194,9 +195,9 @@ version = "1.0.0" [[deps.JLLWrappers]] deps = ["Artifacts", "Preferences"] -git-tree-sha1 = "7e5d6779a1e09a36db2a7b6cff50942a0a7d0fca" +git-tree-sha1 = "be3dc50a92e5a386872a493a10050136d4703f9b" uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" -version = "1.5.0" +version = "1.6.1" [[deps.JSON3]] deps = ["Dates", "Mmap", "Parsers", "PrecompileTools", "StructTypes", "UUIDs"] @@ -210,12 +211,6 @@ version = "1.14.0" [deps.JSON3.weakdeps] ArrowTypes = "31f734f8-188a-4ce0-8406-c8a06bd891cd" -[[deps.JuliaInterpreter]] -deps = ["CodeTracking", "InteractiveUtils", "Random", "UUIDs"] -git-tree-sha1 = "4b415b6cccb9ab61fec78a621572c82ac7fa5776" -uuid = "aa1ae85d-cabe-5617-a682-6adf51b2e16a" -version = "0.9.35" - [[deps.LaTeXStrings]] git-tree-sha1 = "50901ebc375ed41dbf8058da26f9de442febbbec" uuid = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f" @@ -229,16 +224,17 @@ version = "0.6.4" [[deps.LibCURL_jll]] deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" -version = "8.4.0+0" +version = "8.6.0+0" [[deps.LibGit2]] deps = ["Base64", "LibGit2_jll", "NetworkOptions", "Printf", "SHA"] uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" +version = "1.11.0" [[deps.LibGit2_jll]] deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll"] uuid = "e37daf67-58a4-590a-8e99-b0245dd2ffc5" -version = "1.6.4+0" +version = "1.7.2+0" [[deps.LibSSH2_jll]] deps = ["Artifacts", "Libdl", "MbedTLS_jll"] @@ -247,10 +243,12 @@ version = "1.11.0+1" [[deps.Libdl]] uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" +version = "1.11.0" [[deps.LinearAlgebra]] deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"] uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +version = "1.11.0" [[deps.LogExpFunctions]] deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"] @@ -270,12 +268,7 @@ version = "0.3.28" [[deps.Logging]] uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" - -[[deps.LoweredCodeUtils]] -deps = ["JuliaInterpreter"] -git-tree-sha1 = "1ce1834f9644a8f7c011eb0592b7fd6c42c90653" -uuid = "6f1432cf-f94c-5a45-995e-cdbf5db27b0b" -version = "3.0.1" +version = "1.11.0" [[deps.MQTTClient]] deps = ["Distributed", "Random", "Sockets"] @@ -290,11 +283,12 @@ weakdeps = ["PrecompileTools"] [[deps.Markdown]] deps = ["Base64"] uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" +version = "1.11.0" [[deps.MbedTLS_jll]] deps = ["Artifacts", "Libdl"] uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" -version = "2.28.2+1" +version = "2.28.6+0" [[deps.Missings]] deps = ["DataAPI"] @@ -304,10 +298,11 @@ version = "1.2.0" [[deps.Mmap]] uuid = "a63ad114-7e13-5084-954f-fe012c677804" +version = "1.11.0" [[deps.MozillaCACerts_jll]] uuid = "14a3606d-f60d-562e-9121-12d972cd8159" -version = "2023.1.10" +version = "2023.12.12" [[deps.NetworkOptions]] uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" @@ -316,7 +311,7 @@ version = "1.2.0" [[deps.OpenBLAS_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" -version = "0.3.23+4" +version = "0.3.27+1" [[deps.OpenLibm_jll]] deps = ["Artifacts", "Libdl"] @@ -347,9 +342,15 @@ uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" version = "2.8.1" [[deps.Pkg]] -deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] +deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "Random", "SHA", "TOML", "Tar", "UUIDs", "p7zip_jll"] uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" -version = "1.10.0" +version = "1.11.0" + + [deps.Pkg.extensions] + REPLExt = "REPL" + + [deps.Pkg.weakdeps] + REPL = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" [[deps.PooledArrays]] deps = ["DataAPI", "Future"] @@ -369,63 +370,60 @@ git-tree-sha1 = "9306f6085165d270f7e3db02af26a400d580f5c6" uuid = "21216c6a-2e73-6563-6e65-726566657250" version = "1.4.3" +[[deps.PrettyPrinting]] +git-tree-sha1 = "142ee93724a9c5d04d78df7006670a93ed1b244e" +uuid = "54e16d92-306c-5ea0-a30b-337be88ac337" +version = "0.4.2" + [[deps.PrettyTables]] deps = ["Crayons", "LaTeXStrings", "Markdown", "PrecompileTools", "Printf", "Reexport", "StringManipulation", "Tables"] -git-tree-sha1 = "66b20dd35966a748321d3b2537c4584cf40387c7" +git-tree-sha1 = "1101cd475833706e4d0e7b122218257178f48f34" uuid = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" -version = "2.3.2" +version = "2.4.0" [[deps.Printf]] deps = ["Unicode"] uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" +version = "1.11.0" [[deps.PtrArrays]] -git-tree-sha1 = "f011fbb92c4d401059b2212c05c0601b70f8b759" +git-tree-sha1 = "77a42d78b6a92df47ab37e177b2deac405e1c88f" uuid = "43287f4e-b6f4-7ad1-bb20-aadabca52c3d" -version = "1.2.0" +version = "1.2.1" [[deps.QuadGK]] deps = ["DataStructures", "LinearAlgebra"] -git-tree-sha1 = "e237232771fdafbae3db5c31275303e056afaa9f" +git-tree-sha1 = "cda3b045cf9ef07a08ad46731f5a3165e56cf3da" uuid = "1fd47b50-473d-5c70-9696-f719f8f3bcdc" -version = "2.10.1" +version = "2.11.1" -[[deps.REPL]] -deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] -uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" + [deps.QuadGK.extensions] + QuadGKEnzymeExt = "Enzyme" + + [deps.QuadGK.weakdeps] + Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" [[deps.Random]] deps = ["SHA"] uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +version = "1.11.0" [[deps.Reexport]] git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b" uuid = "189a3867-3050-52da-a836-e630ba90ab69" version = "1.2.2" -[[deps.Requires]] -deps = ["UUIDs"] -git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7" -uuid = "ae029012-a4dd-5104-9daa-d747884805df" -version = "1.3.0" - -[[deps.Revise]] -deps = ["CodeTracking", "Distributed", "FileWatching", "JuliaInterpreter", "LibGit2", "LoweredCodeUtils", "OrderedCollections", "REPL", "Requires", "UUIDs", "Unicode"] -git-tree-sha1 = "7b7850bb94f75762d567834d7e9802fc22d62f9c" -uuid = "295af30f-e4ad-537b-8983-00126c2a3abe" -version = "3.5.18" - [[deps.Rmath]] deps = ["Random", "Rmath_jll"] -git-tree-sha1 = "f65dcb5fa46aee0cf9ed6274ccbd597adc49aa7b" +git-tree-sha1 = "852bd0f55565a9e973fcfee83a84413270224dc4" uuid = "79098fc4-a85e-5d69-aa6a-4863f24498fa" -version = "0.7.1" +version = "0.8.0" [[deps.Rmath_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "e60724fd3beea548353984dc61c943ecddb0e29a" +git-tree-sha1 = "58cdd8fb2201a6267e1db87ff148dd6c1dbd8ad8" uuid = "f50d1b31-88e8-58de-be2c-1cc44531875f" -version = "0.4.3+0" +version = "0.5.1+0" [[deps.SHA]] uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" @@ -439,9 +437,11 @@ version = "1.4.5" [[deps.Serialization]] uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" +version = "1.11.0" [[deps.Sockets]] uuid = "6462fe0b-24de-5631-8697-dd941f90decc" +version = "1.11.0" [[deps.SortingAlgorithms]] deps = ["DataStructures"] @@ -452,7 +452,7 @@ version = "1.2.1" [[deps.SparseArrays]] deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"] uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" -version = "1.10.0" +version = "1.11.0" [[deps.SpecialFunctions]] deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] @@ -467,9 +467,14 @@ version = "2.4.0" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" [[deps.Statistics]] -deps = ["LinearAlgebra", "SparseArrays"] +deps = ["LinearAlgebra"] +git-tree-sha1 = "ae3bb1eb3bba077cd276bc5cfc337cc65c3075c0" uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -version = "1.10.0" +version = "1.11.1" +weakdeps = ["SparseArrays"] + + [deps.Statistics.extensions] + SparseArraysExt = ["SparseArrays"] [[deps.StatsAPI]] deps = ["LinearAlgebra"] @@ -485,9 +490,9 @@ version = "0.34.3" [[deps.StatsFuns]] deps = ["HypergeometricFunctions", "IrrationalConstants", "LogExpFunctions", "Reexport", "Rmath", "SpecialFunctions"] -git-tree-sha1 = "cef0472124fab0695b58ca35a77c6fb942fdab8a" +git-tree-sha1 = "b423576adc27097764a90e163157bcfc9acf0f46" uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c" -version = "1.3.1" +version = "1.3.2" [deps.StatsFuns.extensions] StatsFunsChainRulesCoreExt = "ChainRulesCore" @@ -499,15 +504,15 @@ version = "1.3.1" [[deps.StringManipulation]] deps = ["PrecompileTools"] -git-tree-sha1 = "a04cabe79c5f01f4d723cc6704070ada0b9d46d5" +git-tree-sha1 = "a6b1675a536c5ad1a60e5a5153e1fee12eb146e3" uuid = "892a3eda-7b42-436c-8928-eab12a02cf0e" -version = "0.3.4" +version = "0.4.0" [[deps.StructTypes]] deps = ["Dates", "UUIDs"] -git-tree-sha1 = "ca4bccb03acf9faaf4137a9abc1881ed1841aa70" +git-tree-sha1 = "159331b30e94d7b11379037feeb9b690950cace8" uuid = "856f2bd8-1eba-4b0a-8007-ebc267875bd4" -version = "1.10.0" +version = "1.11.0" [[deps.SuiteSparse]] deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] @@ -516,7 +521,7 @@ uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" [[deps.SuiteSparse_jll]] deps = ["Artifacts", "Libdl", "libblastrampoline_jll"] uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" -version = "7.2.1+1" +version = "7.7.0+0" [[deps.TOML]] deps = ["Dates"] @@ -541,16 +546,18 @@ uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" version = "1.10.0" [[deps.TranscodingStreams]] -git-tree-sha1 = "e84b3a11b9bece70d14cce63406bbc79ed3464d2" +git-tree-sha1 = "0c45878dcfdcfa8480052b6ab162cdd138781742" uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" -version = "0.11.2" +version = "0.11.3" [[deps.UUIDs]] deps = ["Random", "SHA"] uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" +version = "1.11.0" [[deps.Unicode]] uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" +version = "1.11.0" [[deps.WeakRefStrings]] deps = ["DataAPI", "InlineStrings", "Parsers"] @@ -571,12 +578,12 @@ version = "1.2.13+1" [[deps.libblastrampoline_jll]] deps = ["Artifacts", "Libdl"] uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" -version = "5.8.0+1" +version = "5.11.0+0" [[deps.nghttp2_jll]] deps = ["Artifacts", "Libdl"] uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" -version = "1.52.0+1" +version = "1.59.0+0" [[deps.p7zip_jll]] deps = ["Artifacts", "Libdl"] diff --git a/src/LLMMCTS.jl b/src/LLMMCTS.jl index f5f6e76..d61a2e8 100644 --- a/src/LLMMCTS.jl +++ b/src/LLMMCTS.jl @@ -22,6 +22,16 @@ module LLMMCTS # ---------------------------------------------- 100 --------------------------------------------- # +""" version 0.0.2 + Todo: + - [] + + Change from version: 0.0.1 + - + + All features + +""" diff --git a/src/interface.jl b/src/interface.jl index f48efa8..84b9ae1 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -45,13 +45,14 @@ function runMCTS( transition::Function, transitionargs::NamedTuple, ; - totalsample::Integer=3, - maxdepth::Integer=3, - maxiterations::Integer=10, + totalsample::Integer=3, + maxdepth::Integer=3, + maxiterations::Integer=10, explorationweight::Number=1.0, - )::NamedTuple{(:bestNextState, :bestFinalState), Tuple{T, T}} where {T<:Any} - - root = MCTSNode("root", initialstate, 0, 0, 0, 0, false, nothing, Dict{String, MCTSNode}()) + earlystop::Union{Function,Nothing}=nothing +)::NamedTuple{(:bestNextState, :bestFinalState),Tuple{T,T}} where {T<:Any} + + root = MCTSNode("root", initialstate, 0, 0, 0, 0, false, nothing, Dict{String,MCTSNode}()) for nth in 1:maxiterations node = root @@ -60,19 +61,20 @@ function runMCTS( while !isleaf(node) node = UCTselect(node, explorationweight) end + if node.isterminal # MCTS arrive at the leaf node that is also a terminal state, - # do nothing then go directly to backpropagation - backpropagate(leafNode, node.reward) + # do nothing then go directly to backpropagation. It means the end of this iteration + backpropagate(node, node.reward) else - expand(node, transition, transitionargs; - totalsample=totalsample) + expand(node, transition, transitionargs; + totalsample=totalsample) leafNode = selectChildNode(node) - simTrajectoryReward, terminalstate = simulate(leafNode, transition, transitionargs; - maxdepth=maxdepth, totalsample=totalsample) - if terminalstate !== nothing #XXX not sure why I need this - terminalstate[:totalTrajectoryReward] = simTrajectoryReward - end + simTrajectoryReward, terminalstate = simulate(leafNode, transition, transitionargs; + maxdepth=maxdepth, totalsample=totalsample) + # if terminalstate !== nothing #XXX not sure why I need this + # terminalstate[:totalTrajectoryReward] = simTrajectoryReward + # end #[] write best state to file if it has higher simTrajectoryReward. Use to improve evaluation # open("trajectory.json", "w") do io @@ -81,6 +83,13 @@ function runMCTS( backpropagate(leafNode, simTrajectoryReward) end + + # stop if the tree early stop condition is met + if typeof(earlystop) <: Function + result = earlystop(node.state) + + break + end end bestNextState = selectBestNextNode(root) @@ -90,6 +99,56 @@ function runMCTS( end +# function runMCTS( +# initialstate::T, +# transition::Function, +# transitionargs::NamedTuple, +# ; +# totalsample::Integer=3, +# maxdepth::Integer=3, +# maxiterations::Integer=10, +# explorationweight::Number=1.0, +# )::NamedTuple{(:bestNextState, :bestFinalState),Tuple{T,T}} where {T<:Any} + +# root = MCTSNode("root", initialstate, 0, 0, 0, 0, false, nothing, Dict{String,MCTSNode}()) + +# for nth in 1:maxiterations +# node = root +# node.visits += 1 + +# while !isleaf(node) +# node = UCTselect(node, explorationweight) +# end +# if node.isterminal +# # MCTS arrive at the leaf node that is also a terminal state, +# # do nothing then go directly to backpropagation. It means the end of this iteration +# backpropagate(leafNode, node.reward) +# else +# expand(node, transition, transitionargs; +# totalsample=totalsample) +# leafNode = selectChildNode(node) +# simTrajectoryReward, terminalstate = simulate(leafNode, transition, transitionargs; +# maxdepth=maxdepth, totalsample=totalsample) +# # if terminalstate !== nothing #XXX not sure why I need this +# # terminalstate[:totalTrajectoryReward] = simTrajectoryReward +# # end + +# #[] write best state to file if it has higher simTrajectoryReward. Use to improve evaluation +# # open("trajectory.json", "w") do io +# # JSON3.pretty(io, terminalstate) +# # end + +# backpropagate(leafNode, simTrajectoryReward) +# end +# end + +# bestNextState = selectBestNextNode(root) +# besttrajectory = selectBestTrajectoryNode(root) + +# return (bestNextState=bestNextState.state, bestFinalState=besttrajectory.state) +# end + + diff --git a/src/type.jl b/src/type.jl index bb6cdce..a3503d7 100644 --- a/src/type.jl +++ b/src/type.jl @@ -44,8 +44,8 @@ mutable struct MCTSNode{T1<:AbstractDict, T2<:AbstractString} state::T1 visits::Integer progressvalue::Number # estimate value by LLM's reasoning - statevalue::Number # store discounted commulative reward (gather from its child node) - reward::Number # this node's own reward + statevalue::Number # current state value. store the node's immediate reward and all future discounted rewards (gather from its child node) + reward::Number # this node's immediate reward isterminal::Bool parent::Union{MCTSNode, Nothing} children::Dict{String, MCTSNode}