diff --git a/Manifest.toml b/Manifest.toml index cf2277b..8d88c52 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -2,12 +2,145 @@ julia_version = "1.10.3" manifest_format = "2.0" -project_hash = "ae4c47fe855fc91b54d755564b9f3a9249b10b0a" +project_hash = "b7e1f171d36dc4812d6c1445da530f513320e6cd" + +[[deps.AliasTables]] +deps = ["PtrArrays", "Random"] +git-tree-sha1 = "9876e1e164b144ca45e9e3198d0b689cadfed9ff" +uuid = "66dad0bd-aa9a-41b7-9441-69ab47430ed8" +version = "1.1.3" + +[[deps.ArgTools]] +uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" +version = "1.1.1" + +[[deps.Artifacts]] +uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" + +[[deps.Base64]] +uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" + +[[deps.Calculus]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "f641eb0a4f00c343bbc32346e1217b86f3ce9dad" +uuid = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9" +version = "0.5.1" + +[[deps.CodeTracking]] +deps = ["InteractiveUtils", "UUIDs"] +git-tree-sha1 = "c0216e792f518b39b22212127d4a84dc31e4e386" +uuid = "da1fd8a2-8d9e-5ec2-8556-3022fb5608a2" +version = "1.3.5" + +[[deps.Compat]] +deps = ["TOML", "UUIDs"] +git-tree-sha1 = "b1c55339b7c6c350ee89f2c1604299660525b248" +uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" +version = "4.15.0" +weakdeps = ["Dates", "LinearAlgebra"] + + [deps.Compat.extensions] + CompatLinearAlgebraExt = "LinearAlgebra" + +[[deps.CompilerSupportLibraries_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" +version = "1.1.1+0" + +[[deps.DataAPI]] +git-tree-sha1 = "abe83f3a2f1b857aac70ef8b269080af17764bbe" +uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" +version = "1.16.0" + +[[deps.DataStructures]] +deps = ["Compat", "InteractiveUtils", "OrderedCollections"] +git-tree-sha1 = "1d0a14036acb104d9e89698bd408f63ab58cdc82" +uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" +version = "0.18.20" [[deps.Dates]] deps = ["Printf"] uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" +[[deps.Distributed]] +deps = ["Random", "Serialization", "Sockets"] +uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" + +[[deps.Distributions]] +deps = ["AliasTables", "FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SpecialFunctions", "Statistics", "StatsAPI", "StatsBase", "StatsFuns"] +git-tree-sha1 = "9c405847cc7ecda2dc921ccf18b47ca150d7317e" +uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" +version = "0.25.109" + + [deps.Distributions.extensions] + DistributionsChainRulesCoreExt = "ChainRulesCore" + DistributionsDensityInterfaceExt = "DensityInterface" + DistributionsTestExt = "Test" + + [deps.Distributions.weakdeps] + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d" + Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[[deps.DocStringExtensions]] +deps = ["LibGit2"] +git-tree-sha1 = "2fb1e02f2b635d0845df5d7c167fec4dd739b00d" +uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +version = "0.9.3" + +[[deps.Downloads]] +deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] +uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" +version = "1.6.0" + +[[deps.DualNumbers]] +deps = ["Calculus", "NaNMath", "SpecialFunctions"] +git-tree-sha1 = "5837a837389fccf076445fce071c8ddaea35a566" +uuid = "fa6b7ba4-c1ee-5f82-b5fc-ecf0adba8f74" +version = "0.6.8" + +[[deps.FileWatching]] +uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" + +[[deps.FillArrays]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "0653c0a2396a6da5bc4766c43041ef5fd3efbe57" +uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" +version = "1.11.0" +weakdeps = ["PDMats", "SparseArrays", "Statistics"] + + [deps.FillArrays.extensions] + FillArraysPDMatsExt = "PDMats" + FillArraysSparseArraysExt = "SparseArrays" + FillArraysStatisticsExt = "Statistics" + +[[deps.GeneralUtils]] +deps = ["DataStructures", "Dates", "Distributions", "JSON3", "MQTTClient", "Random", "Revise", "UUIDs"] +path = "/appfolder/app/privatejuliapkg/GeneralUtils" +uuid = "c6c72f09-b708-4ac8-ac7c-2084d70108fe" +version = "0.1.0" + +[[deps.HypergeometricFunctions]] +deps = ["DualNumbers", "LinearAlgebra", "OpenLibm_jll", "SpecialFunctions"] +git-tree-sha1 = "f218fe3736ddf977e0e772bc9a586b2383da2685" +uuid = "34004b35-14d8-5ef3-9330-4cdb6864b03a" +version = "0.3.23" + +[[deps.InteractiveUtils]] +deps = ["Markdown"] +uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" + +[[deps.IrrationalConstants]] +git-tree-sha1 = "630b497eafcc20001bba38a4651b327dcfc491d2" +uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" +version = "0.2.2" + +[[deps.JLLWrappers]] +deps = ["Artifacts", "Preferences"] +git-tree-sha1 = "7e5d6779a1e09a36db2a7b6cff50942a0a7d0fca" +uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" +version = "1.5.0" + [[deps.JSON3]] deps = ["Dates", "Mmap", "Parsers", "PrecompileTools", "StructTypes", "UUIDs"] git-tree-sha1 = "eb3edce0ed4fa32f75a0a11217433c31d56bd48b" @@ -20,15 +153,148 @@ version = "1.14.0" [deps.JSON3.weakdeps] ArrowTypes = "31f734f8-188a-4ce0-8406-c8a06bd891cd" +[[deps.JuliaInterpreter]] +deps = ["CodeTracking", "InteractiveUtils", "Random", "UUIDs"] +git-tree-sha1 = "e9648d90370e2d0317f9518c9c6e0841db54a90b" +uuid = "aa1ae85d-cabe-5617-a682-6adf51b2e16a" +version = "0.9.31" + +[[deps.LibCURL]] +deps = ["LibCURL_jll", "MozillaCACerts_jll"] +uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" +version = "0.6.4" + +[[deps.LibCURL_jll]] +deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] +uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" +version = "8.4.0+0" + +[[deps.LibGit2]] +deps = ["Base64", "LibGit2_jll", "NetworkOptions", "Printf", "SHA"] +uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" + +[[deps.LibGit2_jll]] +deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll"] +uuid = "e37daf67-58a4-590a-8e99-b0245dd2ffc5" +version = "1.6.4+0" + +[[deps.LibSSH2_jll]] +deps = ["Artifacts", "Libdl", "MbedTLS_jll"] +uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" +version = "1.11.0+1" + +[[deps.Libdl]] +uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" + +[[deps.LinearAlgebra]] +deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"] +uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" + +[[deps.LogExpFunctions]] +deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"] +git-tree-sha1 = "18144f3e9cbe9b15b070288eef858f71b291ce37" +uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" +version = "0.3.27" + + [deps.LogExpFunctions.extensions] + LogExpFunctionsChainRulesCoreExt = "ChainRulesCore" + LogExpFunctionsChangesOfVariablesExt = "ChangesOfVariables" + LogExpFunctionsInverseFunctionsExt = "InverseFunctions" + + [deps.LogExpFunctions.weakdeps] + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" + InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" + +[[deps.Logging]] +uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" + +[[deps.LoweredCodeUtils]] +deps = ["JuliaInterpreter"] +git-tree-sha1 = "c6a36b22d2cca0e1a903f00f600991f97bf5f426" +uuid = "6f1432cf-f94c-5a45-995e-cdbf5db27b0b" +version = "2.4.6" + +[[deps.MQTTClient]] +deps = ["Distributed", "Random", "Sockets"] +git-tree-sha1 = "f2597b290d4bf17b577346153cd2ddf9accb5c26" +uuid = "985f35cc-2c3d-4943-b8c1-f0931d5f0959" +version = "0.3.1" +weakdeps = ["PrecompileTools"] + + [deps.MQTTClient.extensions] + PrecompileMQTT = "PrecompileTools" + +[[deps.Markdown]] +deps = ["Base64"] +uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" + +[[deps.MbedTLS_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" +version = "2.28.2+1" + +[[deps.Missings]] +deps = ["DataAPI"] +git-tree-sha1 = "ec4f7fbeab05d7747bdf98eb74d130a2a2ed298d" +uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" +version = "1.2.0" + [[deps.Mmap]] uuid = "a63ad114-7e13-5084-954f-fe012c677804" +[[deps.MozillaCACerts_jll]] +uuid = "14a3606d-f60d-562e-9121-12d972cd8159" +version = "2023.1.10" + +[[deps.NaNMath]] +deps = ["OpenLibm_jll"] +git-tree-sha1 = "0877504529a3e5c3343c6f8b4c0381e57e4387e4" +uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" +version = "1.0.2" + +[[deps.NetworkOptions]] +uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" +version = "1.2.0" + +[[deps.OpenBLAS_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] +uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" +version = "0.3.23+4" + +[[deps.OpenLibm_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "05823500-19ac-5b8b-9628-191a04bc5112" +version = "0.8.1+2" + +[[deps.OpenSpecFun_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1" +uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" +version = "0.5.5+0" + +[[deps.OrderedCollections]] +git-tree-sha1 = "dfdf5519f235516220579f949664f1bf44e741c5" +uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" +version = "1.6.3" + +[[deps.PDMats]] +deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"] +git-tree-sha1 = "949347156c25054de2db3b166c52ac4728cbad65" +uuid = "90014a1f-27ba-587c-ab20-58faa44d9150" +version = "0.11.31" + [[deps.Parsers]] deps = ["Dates", "PrecompileTools", "UUIDs"] git-tree-sha1 = "8489905bcdbcfac64d1daa51ca07c0d8f0283821" uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" version = "2.8.1" +[[deps.Pkg]] +deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] +uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +version = "1.10.0" + [[deps.PrecompileTools]] deps = ["Preferences"] git-tree-sha1 = "5aa36f7049a63a1528fe8f7c3f2113413ffd4e1f" @@ -45,28 +311,166 @@ version = "1.4.3" deps = ["Unicode"] uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" +[[deps.PtrArrays]] +git-tree-sha1 = "f011fbb92c4d401059b2212c05c0601b70f8b759" +uuid = "43287f4e-b6f4-7ad1-bb20-aadabca52c3d" +version = "1.2.0" + +[[deps.QuadGK]] +deps = ["DataStructures", "LinearAlgebra"] +git-tree-sha1 = "9b23c31e76e333e6fb4c1595ae6afa74966a729e" +uuid = "1fd47b50-473d-5c70-9696-f719f8f3bcdc" +version = "2.9.4" + +[[deps.REPL]] +deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] +uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" + [[deps.Random]] deps = ["SHA"] uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +[[deps.Reexport]] +git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b" +uuid = "189a3867-3050-52da-a836-e630ba90ab69" +version = "1.2.2" + +[[deps.Requires]] +deps = ["UUIDs"] +git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7" +uuid = "ae029012-a4dd-5104-9daa-d747884805df" +version = "1.3.0" + +[[deps.Revise]] +deps = ["CodeTracking", "Distributed", "FileWatching", "JuliaInterpreter", "LibGit2", "LoweredCodeUtils", "OrderedCollections", "Pkg", "REPL", "Requires", "UUIDs", "Unicode"] +git-tree-sha1 = "12aa2d7593df490c407a3bbd8b86b8b515017f3e" +uuid = "295af30f-e4ad-537b-8983-00126c2a3abe" +version = "3.5.14" + +[[deps.Rmath]] +deps = ["Random", "Rmath_jll"] +git-tree-sha1 = "f65dcb5fa46aee0cf9ed6274ccbd597adc49aa7b" +uuid = "79098fc4-a85e-5d69-aa6a-4863f24498fa" +version = "0.7.1" + +[[deps.Rmath_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "d483cd324ce5cf5d61b77930f0bbd6cb61927d21" +uuid = "f50d1b31-88e8-58de-be2c-1cc44531875f" +version = "0.4.2+0" + [[deps.SHA]] uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" version = "0.7.0" +[[deps.Serialization]] +uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" + +[[deps.Sockets]] +uuid = "6462fe0b-24de-5631-8697-dd941f90decc" + +[[deps.SortingAlgorithms]] +deps = ["DataStructures"] +git-tree-sha1 = "66e0a8e672a0bdfca2c3f5937efb8538b9ddc085" +uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" +version = "1.2.1" + +[[deps.SparseArrays]] +deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"] +uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +version = "1.10.0" + +[[deps.SpecialFunctions]] +deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] +git-tree-sha1 = "2f5d4697f21388cbe1ff299430dd169ef97d7e14" +uuid = "276daf66-3868-5448-9aa4-cd146d93841b" +version = "2.4.0" + + [deps.SpecialFunctions.extensions] + SpecialFunctionsChainRulesCoreExt = "ChainRulesCore" + + [deps.SpecialFunctions.weakdeps] + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + +[[deps.Statistics]] +deps = ["LinearAlgebra", "SparseArrays"] +uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +version = "1.10.0" + +[[deps.StatsAPI]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "1ff449ad350c9c4cbc756624d6f8a8c3ef56d3ed" +uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0" +version = "1.7.0" + +[[deps.StatsBase]] +deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] +git-tree-sha1 = "5cf7606d6cef84b543b483848d4ae08ad9832b21" +uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +version = "0.34.3" + +[[deps.StatsFuns]] +deps = ["HypergeometricFunctions", "IrrationalConstants", "LogExpFunctions", "Reexport", "Rmath", "SpecialFunctions"] +git-tree-sha1 = "cef0472124fab0695b58ca35a77c6fb942fdab8a" +uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c" +version = "1.3.1" + + [deps.StatsFuns.extensions] + StatsFunsChainRulesCoreExt = "ChainRulesCore" + StatsFunsInverseFunctionsExt = "InverseFunctions" + + [deps.StatsFuns.weakdeps] + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" + [[deps.StructTypes]] deps = ["Dates", "UUIDs"] git-tree-sha1 = "ca4bccb03acf9faaf4137a9abc1881ed1841aa70" uuid = "856f2bd8-1eba-4b0a-8007-ebc267875bd4" version = "1.10.0" +[[deps.SuiteSparse]] +deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] +uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" + +[[deps.SuiteSparse_jll]] +deps = ["Artifacts", "Libdl", "libblastrampoline_jll"] +uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" +version = "7.2.1+1" + [[deps.TOML]] deps = ["Dates"] uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" version = "1.0.3" +[[deps.Tar]] +deps = ["ArgTools", "SHA"] +uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" +version = "1.10.0" + [[deps.UUIDs]] deps = ["Random", "SHA"] uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" [[deps.Unicode]] uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" + +[[deps.Zlib_jll]] +deps = ["Libdl"] +uuid = "83775a58-1f1d-513f-b197-d71354ab007a" +version = "1.2.13+1" + +[[deps.libblastrampoline_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" +version = "5.8.0+1" + +[[deps.nghttp2_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" +version = "1.52.0+1" + +[[deps.p7zip_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" +version = "17.4.0+2" diff --git a/Project.toml b/Project.toml index be4f8d8..74a4bb8 100644 --- a/Project.toml +++ b/Project.toml @@ -4,4 +4,5 @@ authors = ["narawat lamaiin "] version = "0.1.0" [deps] +GeneralUtils = "c6c72f09-b708-4ac8-ac7c-2084d70108fe" JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" diff --git a/src/interface.jl b/src/interface.jl index 5e0ce04..3e02955 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -47,7 +47,7 @@ julia> # Signature """ function runMCTS( - workDict::T1, + config::T1, initialState, decisionMaker::Function, evaluator::Function, @@ -74,11 +74,11 @@ function runMCTS( # do nothing then go directly to backpropagation backpropagate(leafNode, node.reward) else - expand(workDict, node, decisionMaker, evaluator, reflector, transition; + expand(config, node, decisionMaker, evaluator, reflector, transition; totalsample=totalsample) leafNode = selectChildNode(node) - simTrajectoryReward, terminalstate = simulate(workDict, leafNode, decisionMaker, evaluator, - reflector; maxDepth=maxDepth, totalsample=totalsample) + simTrajectoryReward, terminalstate = simulate(config, leafNode, decisionMaker, evaluator, + reflector, transition; maxDepth=maxDepth, totalsample=totalsample) if terminalstate !== nothing #XXX not sure why I need this terminalstate[:totalTrajectoryReward] = simTrajectoryReward end diff --git a/src/mcts.jl b/src/mcts.jl index 63c708a..ac7bb18 100644 --- a/src/mcts.jl +++ b/src/mcts.jl @@ -1,7 +1,9 @@ module mcts export selectBestNextState, selectBestTrajectory, backpropagate, isleaf, isroot, selectChildNode, - expand, simulate, mctstransition + expand, simulate, makeNewState + +using GeneralUtils using ..type @@ -242,7 +244,7 @@ julia> # Signature """ -function expand(workDict::T1, node::MCTSNode, decisionMaker::Function, evaluator::Function, +function expand(config::T1, node::MCTSNode, decisionMaker::Function, evaluator::Function, reflector::Function, transition::Function; totalsample::Integer=3 ) where {T1<:AbstractDict} @@ -250,32 +252,8 @@ function expand(workDict::T1, node::MCTSNode, decisionMaker::Function, evaluator while true nthSample += 1 if nthSample <= totalsample - thoughtDict = decisionMaker(workDict, node.state) - println("---> expand() sample $nthSample") - pprintln(node.state[:thoughtHistory]) - pprintln(thoughtDict) - newNodeKey, newstate = mctstransition(workDict, transition, node.state, thoughtDict) - - stateevaluation, progressvalue = evaluator(workDict, newstate) - - if newstate[:reward] < 0 - pprint(newstate[:thoughtHistory]) - newstate[:evaluation] = stateevaluation - newstate[:lesson] = reflector(workDict, newstate) - - # store new lesson for later use - lessonDict = copy(JSON3.read("lesson.json")) - latestLessonKey, latestLessonIndice = - GeneralUtils.findHighestIndexKey(lessonDict, "lesson") - nextIndice = latestLessonKey == :NA ? 1 : latestLessonIndice + 1 - newLessonKey = Symbol("lesson_$(nextIndice)") - lessonDict[newLessonKey] = newstate - open("lesson.json", "w") do io - JSON3.pretty(io, lessonDict) - end - print("---> reflector()") - end - + newNodeKey, newstate, progressvalue = transition(config, node.state, decisionMaker, + evaluator, reflector) if newNodeKey ∉ keys(node.children) node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward], @@ -311,8 +289,8 @@ julia> # Signature """ -function simulate(workDict::T, node::MCTSNode, decisionMaker::Function, evaluator::Function, - reflector::Function; maxDepth::Integer=3, totalsample::Integer=3 +function simulate(config::T, node::MCTSNode, decisionMaker::Function, evaluator::Function, + reflector::Function, transition::Function; maxDepth::Integer=3, totalsample::Integer=3 )::Union{Tuple{Number, Dict{Symbol, <:Any}}, Tuple{Number, Nothing}} where {T<:AbstractDict} simTrajectoryReward = 0.0 @@ -324,7 +302,8 @@ function simulate(workDict::T, node::MCTSNode, decisionMaker::Function, evaluato terminalstate = node.state break else - expand(workDict, node, decisionMaker, evaluator, reflector; totalsample=totalsample) + expand(config, node, decisionMaker, evaluator, reflector, transition; + totalsample=totalsample) node = selectChildNode(node) end end @@ -333,71 +312,58 @@ function simulate(workDict::T, node::MCTSNode, decisionMaker::Function, evaluato end - -""" Get a new state +""" # Arguments - - `a::T1` - one of YiemAgent's agent - - `state::T2` - current game state - - `thoughtDict::T3` - contain Thought, Action, Observation - - `isterminal::Function` - a function to determine terminal state - + # Return - - `(newNodeKey, newstate, isterminalstate, reward)::Tuple{String, Dict{Symbol, <:Any}, Bool, <:Number}` # Example ```jldoctest -julia> state = Dict{Symbol, Dict{Symbol, Any}}( - :thoughtHistory => Dict(:question => "Hello, I want to buy a bottle of wine."), - :storeinfo => Dict(), - :customerinfo => Dict() - ) -julia> thoughtDict = Dict( - :question=> "I want to buy a bottle of wine.", - :thought_1=> "The customer wants to buy a bottle of wine.", - :action_1=> Dict{Symbol, Any}( - :name=>"Chatbox", - :input=>"What occasion are you buying the wine for?", - ), - :observation_1 => "" - ) +julia> ``` # TODO - - [] add other actions - - [WORKING] add embedding of newstate and store in newstate[:embedding] + - [] update docstring + - [x] implement the function # Signature """ -function mctstransition(workDict::T1, transition::Function, state::T2, thoughtDict::T2 - )::Tuple{String, Dict{Symbol, <:Any}} where {T1<:AbstractDict, T2<:AbstractDict} - error("--> mctstransition") - # actionname = thoughtDict[:action][:name] - # actioninput = thoughtDict[:action][:input] +function makeNewState(currentstate::T1, thoughtDict::T4, response::T2, select::Union{T3, Nothing}, + reward::T3, isterminal::Bool + )::Tuple{String, Dict{Symbol, <:Any}} where {T1<:AbstractDict, T2<:AbstractString, T3<:Number, T4<:AbstractDict} - # # map action and input() to llm function - # response, select, reward, isterminal = - # if actionname == "chatbox" - # # deepcopy(state[:virtualCustomerChatHistory]) because I want to keep it clean - # # so that other simulation start from this same node is not contaminated with actioninput - # virtualWineUserChatbox(workDict, actioninput, deepcopy(state[:virtualCustomerChatHistory])) # virtual customer - # elseif actionname == "winestock" - # winestock(a, actioninput) - # elseif actionname == "recommendbox" - # virtualWineUserRecommendbox(workDict, actioninput) - # else - # error("undefined LLM function. Requesting $actionname") - # end + currentstate_latestThoughtKey, currentstate_latestThoughtIndice = + GeneralUtils.findHighestIndexKey(currentstate[:thoughtHistory], "thought") + currentstate_nextIndice = currentstate_latestThoughtKey == :NA ? 1 : currentstate_latestThoughtIndice + 1 + currentstate_latestThoughtKey = Symbol("thought_$currentstate_nextIndice") + latestActionKey = Symbol("action_$currentstate_nextIndice") - # newNodeKey, newstate = makeNewState(state, thoughtDict, response, select, reward, isterminal) - # if actionname == "chatbox" - # push!(newstate[:virtualCustomerChatHistory], Dict(:name=>"assistant", :text=> actioninput) ) - # push!(newstate[:virtualCustomerChatHistory], Dict(:name=>"user", :text=> response)) - # end + _, thoughtDict_latestThoughtIndice = + GeneralUtils.findHighestIndexKey(thoughtDict, "thought") + + thoughtDict_latestThoughtKey, thoughtDict_latestActionKey = + if thoughtDict_latestThoughtIndice == -1 + (:thought, :action) + else + ( + Symbol("thought_$thoughtDict_latestThoughtIndice"), + Symbol("action_$thoughtDict_latestThoughtIndice"), + ) + end + + # add Thought, action, observation to thoughtHistory + newstate = deepcopy(currentstate) + newstate[:thoughtHistory][currentstate_latestThoughtKey] = + thoughtDict[thoughtDict_latestThoughtKey] + newstate[:thoughtHistory][latestActionKey] = thoughtDict[thoughtDict_latestActionKey] + newObservationKey = Symbol("observation_$(currentstate_nextIndice)") + newstate[:thoughtHistory][newObservationKey] = response + newstate[:reward] = reward + newstate[:select] = select + newstate[:isterminal] = isterminal + + newNodeKey = GeneralUtils.uuid4snakecase() return (newNodeKey, newstate) end @@ -460,13 +426,6 @@ end - - - - - - -