This commit is contained in:
narawat lamaiin
2024-07-10 11:38:59 +07:00
parent 9e39d54c4b
commit 05830f3d9a
3 changed files with 167 additions and 85 deletions

View File

@@ -20,6 +20,12 @@ uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
[[deps.Base64]] [[deps.Base64]]
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
[[deps.CSV]]
deps = ["CodecZlib", "Dates", "FilePathsBase", "InlineStrings", "Mmap", "Parsers", "PooledArrays", "PrecompileTools", "SentinelArrays", "Tables", "Unicode", "WeakRefStrings", "WorkerUtilities"]
git-tree-sha1 = "6c834533dc1fabd820c1db03c839bf97e45a3fab"
uuid = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
version = "0.10.14"
[[deps.Calculus]] [[deps.Calculus]]
deps = ["LinearAlgebra"] deps = ["LinearAlgebra"]
git-tree-sha1 = "f641eb0a4f00c343bbc32346e1217b86f3ce9dad" git-tree-sha1 = "f641eb0a4f00c343bbc32346e1217b86f3ce9dad"
@@ -32,6 +38,12 @@ git-tree-sha1 = "c0216e792f518b39b22212127d4a84dc31e4e386"
uuid = "da1fd8a2-8d9e-5ec2-8556-3022fb5608a2" uuid = "da1fd8a2-8d9e-5ec2-8556-3022fb5608a2"
version = "1.3.5" version = "1.3.5"
[[deps.CodecZlib]]
deps = ["TranscodingStreams", "Zlib_jll"]
git-tree-sha1 = "b8fe8546d52ca154ac556809e10c75e6e7430ac8"
uuid = "944b1d66-785c-5afd-91f1-9de20f533193"
version = "0.7.5"
[[deps.Compat]] [[deps.Compat]]
deps = ["TOML", "UUIDs"] deps = ["TOML", "UUIDs"]
git-tree-sha1 = "b1c55339b7c6c350ee89f2c1604299660525b248" git-tree-sha1 = "b1c55339b7c6c350ee89f2c1604299660525b248"
@@ -47,17 +59,33 @@ deps = ["Artifacts", "Libdl"]
uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
version = "1.1.1+0" version = "1.1.1+0"
[[deps.Crayons]]
git-tree-sha1 = "249fe38abf76d48563e2f4556bebd215aa317e15"
uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f"
version = "4.1.1"
[[deps.DataAPI]] [[deps.DataAPI]]
git-tree-sha1 = "abe83f3a2f1b857aac70ef8b269080af17764bbe" git-tree-sha1 = "abe83f3a2f1b857aac70ef8b269080af17764bbe"
uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
version = "1.16.0" version = "1.16.0"
[[deps.DataFrames]]
deps = ["Compat", "DataAPI", "DataStructures", "Future", "InlineStrings", "InvertedIndices", "IteratorInterfaceExtensions", "LinearAlgebra", "Markdown", "Missings", "PooledArrays", "PrecompileTools", "PrettyTables", "Printf", "REPL", "Random", "Reexport", "SentinelArrays", "SortingAlgorithms", "Statistics", "TableTraits", "Tables", "Unicode"]
git-tree-sha1 = "04c738083f29f86e62c8afc341f0967d8717bdb8"
uuid = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
version = "1.6.1"
[[deps.DataStructures]] [[deps.DataStructures]]
deps = ["Compat", "InteractiveUtils", "OrderedCollections"] deps = ["Compat", "InteractiveUtils", "OrderedCollections"]
git-tree-sha1 = "1d0a14036acb104d9e89698bd408f63ab58cdc82" git-tree-sha1 = "1d0a14036acb104d9e89698bd408f63ab58cdc82"
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
version = "0.18.20" version = "0.18.20"
[[deps.DataValueInterfaces]]
git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6"
uuid = "e2d170a0-9d28-54be-80f0-106bbe20a464"
version = "1.0.0"
[[deps.Dates]] [[deps.Dates]]
deps = ["Printf"] deps = ["Printf"]
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
@@ -99,6 +127,12 @@ git-tree-sha1 = "5837a837389fccf076445fce071c8ddaea35a566"
uuid = "fa6b7ba4-c1ee-5f82-b5fc-ecf0adba8f74" uuid = "fa6b7ba4-c1ee-5f82-b5fc-ecf0adba8f74"
version = "0.6.8" version = "0.6.8"
[[deps.FilePathsBase]]
deps = ["Compat", "Dates", "Mmap", "Printf", "Test", "UUIDs"]
git-tree-sha1 = "9f00e42f8d99fdde64d40c8ea5d14269a2e2c1aa"
uuid = "48062228-2e41-5def-b9a4-89aafe57970f"
version = "0.9.21"
[[deps.FileWatching]] [[deps.FileWatching]]
uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee"
@@ -114,8 +148,12 @@ weakdeps = ["PDMats", "SparseArrays", "Statistics"]
FillArraysSparseArraysExt = "SparseArrays" FillArraysSparseArraysExt = "SparseArrays"
FillArraysStatisticsExt = "Statistics" FillArraysStatisticsExt = "Statistics"
[[deps.Future]]
deps = ["Random"]
uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820"
[[deps.GeneralUtils]] [[deps.GeneralUtils]]
deps = ["DataStructures", "Dates", "Distributions", "JSON3", "MQTTClient", "Random", "Revise", "UUIDs"] deps = ["CSV", "DataFrames", "DataStructures", "Dates", "Distributions", "JSON3", "MQTTClient", "Random", "Revise", "UUIDs"]
path = "/appfolder/app/privatejuliapkg/GeneralUtils" path = "/appfolder/app/privatejuliapkg/GeneralUtils"
uuid = "c6c72f09-b708-4ac8-ac7c-2084d70108fe" uuid = "c6c72f09-b708-4ac8-ac7c-2084d70108fe"
version = "0.1.0" version = "0.1.0"
@@ -126,15 +164,37 @@ git-tree-sha1 = "f218fe3736ddf977e0e772bc9a586b2383da2685"
uuid = "34004b35-14d8-5ef3-9330-4cdb6864b03a" uuid = "34004b35-14d8-5ef3-9330-4cdb6864b03a"
version = "0.3.23" version = "0.3.23"
[[deps.InlineStrings]]
deps = ["Parsers"]
git-tree-sha1 = "86356004f30f8e737eff143d57d41bd580e437aa"
uuid = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48"
version = "1.4.1"
[deps.InlineStrings.extensions]
ArrowTypesExt = "ArrowTypes"
[deps.InlineStrings.weakdeps]
ArrowTypes = "31f734f8-188a-4ce0-8406-c8a06bd891cd"
[[deps.InteractiveUtils]] [[deps.InteractiveUtils]]
deps = ["Markdown"] deps = ["Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
[[deps.InvertedIndices]]
git-tree-sha1 = "0dc7b50b8d436461be01300fd8cd45aa0274b038"
uuid = "41ab1584-1d38-5bbf-9106-f11c6c58b48f"
version = "1.3.0"
[[deps.IrrationalConstants]] [[deps.IrrationalConstants]]
git-tree-sha1 = "630b497eafcc20001bba38a4651b327dcfc491d2" git-tree-sha1 = "630b497eafcc20001bba38a4651b327dcfc491d2"
uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" uuid = "92d709cd-6900-40b7-9082-c6be49f344b6"
version = "0.2.2" version = "0.2.2"
[[deps.IteratorInterfaceExtensions]]
git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856"
uuid = "82899510-4779-5014-852e-03e436cf321d"
version = "1.0.0"
[[deps.JLLWrappers]] [[deps.JLLWrappers]]
deps = ["Artifacts", "Preferences"] deps = ["Artifacts", "Preferences"]
git-tree-sha1 = "7e5d6779a1e09a36db2a7b6cff50942a0a7d0fca" git-tree-sha1 = "7e5d6779a1e09a36db2a7b6cff50942a0a7d0fca"
@@ -159,6 +219,11 @@ git-tree-sha1 = "a6adc2dcfe4187c40dc7c2c9d2128e326360e90a"
uuid = "aa1ae85d-cabe-5617-a682-6adf51b2e16a" uuid = "aa1ae85d-cabe-5617-a682-6adf51b2e16a"
version = "0.9.32" version = "0.9.32"
[[deps.LaTeXStrings]]
git-tree-sha1 = "50901ebc375ed41dbf8058da26f9de442febbbec"
uuid = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f"
version = "1.3.1"
[[deps.LibCURL]] [[deps.LibCURL]]
deps = ["LibCURL_jll", "MozillaCACerts_jll"] deps = ["LibCURL_jll", "MozillaCACerts_jll"]
uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21"
@@ -295,6 +360,12 @@ deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
version = "1.10.0" version = "1.10.0"
[[deps.PooledArrays]]
deps = ["DataAPI", "Future"]
git-tree-sha1 = "36d8b4b899628fb92c2749eb488d884a926614d3"
uuid = "2dfb63ee-cc39-5dd5-95bd-886bf059d720"
version = "1.4.3"
[[deps.PrecompileTools]] [[deps.PrecompileTools]]
deps = ["Preferences"] deps = ["Preferences"]
git-tree-sha1 = "5aa36f7049a63a1528fe8f7c3f2113413ffd4e1f" git-tree-sha1 = "5aa36f7049a63a1528fe8f7c3f2113413ffd4e1f"
@@ -307,6 +378,12 @@ git-tree-sha1 = "9306f6085165d270f7e3db02af26a400d580f5c6"
uuid = "21216c6a-2e73-6563-6e65-726566657250" uuid = "21216c6a-2e73-6563-6e65-726566657250"
version = "1.4.3" version = "1.4.3"
[[deps.PrettyTables]]
deps = ["Crayons", "LaTeXStrings", "Markdown", "PrecompileTools", "Printf", "Reexport", "StringManipulation", "Tables"]
git-tree-sha1 = "66b20dd35966a748321d3b2537c4584cf40387c7"
uuid = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
version = "2.3.2"
[[deps.Printf]] [[deps.Printf]]
deps = ["Unicode"] deps = ["Unicode"]
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
@@ -363,6 +440,12 @@ version = "0.4.2+0"
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
version = "0.7.0" version = "0.7.0"
[[deps.SentinelArrays]]
deps = ["Dates", "Random"]
git-tree-sha1 = "ff11acffdb082493657550959d4feb4b6149e73a"
uuid = "91c51154-3ec4-41a3-a24f-3f23e20d615c"
version = "1.4.5"
[[deps.Serialization]] [[deps.Serialization]]
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
@@ -423,6 +506,12 @@ version = "1.3.1"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
[[deps.StringManipulation]]
deps = ["PrecompileTools"]
git-tree-sha1 = "a04cabe79c5f01f4d723cc6704070ada0b9d46d5"
uuid = "892a3eda-7b42-436c-8928-eab12a02cf0e"
version = "0.3.4"
[[deps.StructTypes]] [[deps.StructTypes]]
deps = ["Dates", "UUIDs"] deps = ["Dates", "UUIDs"]
git-tree-sha1 = "ca4bccb03acf9faaf4137a9abc1881ed1841aa70" git-tree-sha1 = "ca4bccb03acf9faaf4137a9abc1881ed1841aa70"
@@ -443,11 +532,36 @@ deps = ["Dates"]
uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
version = "1.0.3" version = "1.0.3"
[[deps.TableTraits]]
deps = ["IteratorInterfaceExtensions"]
git-tree-sha1 = "c06b2f539df1c6efa794486abfb6ed2022561a39"
uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c"
version = "1.0.1"
[[deps.Tables]]
deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "LinearAlgebra", "OrderedCollections", "TableTraits"]
git-tree-sha1 = "cb76cf677714c095e535e3501ac7954732aeea2d"
uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
version = "1.11.1"
[[deps.Tar]] [[deps.Tar]]
deps = ["ArgTools", "SHA"] deps = ["ArgTools", "SHA"]
uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e"
version = "1.10.0" version = "1.10.0"
[[deps.Test]]
deps = ["InteractiveUtils", "Logging", "Random", "Serialization"]
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
[[deps.TranscodingStreams]]
git-tree-sha1 = "60df3f8126263c0d6b357b9a1017bb94f53e3582"
uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa"
version = "0.11.0"
weakdeps = ["Random", "Test"]
[deps.TranscodingStreams.extensions]
TestExt = ["Test", "Random"]
[[deps.UUIDs]] [[deps.UUIDs]]
deps = ["Random", "SHA"] deps = ["Random", "SHA"]
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
@@ -455,6 +569,17 @@ uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
[[deps.Unicode]] [[deps.Unicode]]
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
[[deps.WeakRefStrings]]
deps = ["DataAPI", "InlineStrings", "Parsers"]
git-tree-sha1 = "b1be2855ed9ed8eac54e5caff2afcdb442d52c23"
uuid = "ea10d353-3f73-51f8-a26c-33c1cb351aa5"
version = "1.4.2"
[[deps.WorkerUtilities]]
git-tree-sha1 = "cd1659ba0d57b71a464a29e64dbc67cfe83d54e7"
uuid = "76eceee3-57b5-4d4a-8e66-0e911cebbf60"
version = "1.6.1"
[[deps.Zlib_jll]] [[deps.Zlib_jll]]
deps = ["Libdl"] deps = ["Libdl"]
uuid = "83775a58-1f1d-513f-b197-d71354ab007a" uuid = "83775a58-1f1d-513f-b197-d71354ab007a"

View File

@@ -12,7 +12,7 @@ using ..type, ..mcts, ..util
""" Search the best action to take for a given state and task """ Search the best action to take for a given state and task
# Arguments # Arguments
- `initial state` - `initialstate::T`
initial state initial state
- `transition::Function` - `transition::Function`
a function that define how the state transitions a function that define how the state transitions
@@ -32,21 +32,16 @@ using ..type, ..mcts, ..util
aggressively explore new state. aggressively explore new state.
# Return # Return
- `(bestNextState, BestFinalState)::Tuple` - `(bestNextState, BestFinalState)::@NamedTuple{bestNextState::T, bestFinalState::T}`
the best next state and the best final state the best next state and the best final state
# Example # Example
```jldoctest Refers to SQLLLM package
julia>
```
# TODO
[] update example
# Signature # Signature
""" """
function runMCTS( function runMCTS(
initialstate, initialstate::T,
transition::Function, transition::Function,
transitionargs::NamedTuple, transitionargs::NamedTuple,
; ;
@@ -54,7 +49,7 @@ function runMCTS(
maxdepth::Integer=3, maxdepth::Integer=3,
maxiterations::Integer=10, maxiterations::Integer=10,
explorationweight::Number=1.0, explorationweight::Number=1.0,
)::NamedTuple )::@NamedTuple{bestNextState::T, bestFinalState::T} where {T<:Any}
root = MCTSNode("root", initialstate, 0, 0, 0, 0, false, nothing, Dict{String, MCTSNode}()) root = MCTSNode("root", initialstate, 0, 0, 0, 0, false, nothing, Dict{String, MCTSNode}())
@@ -88,8 +83,8 @@ function runMCTS(
end end
end end
bestNextState = selectBestNextState(root) bestNextState = selectBestNextNode(root)
besttrajectory = selectBestTrajectory(root) besttrajectory = selectBestTrajectoryNode(root)
return (bestNextState=bestNextState.state, bestFinalState=besttrajectory.state) return (bestNextState=bestNextState.state, bestFinalState=besttrajectory.state)
end end

View File

@@ -1,6 +1,6 @@
module mcts module mcts
export selectBestNextState, selectBestTrajectory, backpropagate, isleaf, isroot, selectChildNode, export selectBestNextNode, selectBestTrajectoryNode, backpropagate, isleaf, isroot, selectChildNode,
expand, simulate, makeNewState expand, simulate, makeNewState
using GeneralUtils using GeneralUtils
@@ -20,18 +20,9 @@ using ..type
- `childNode::MCTSNode` - `childNode::MCTSNode`
the highest value child node the highest value child node
# Example
```jldoctest
julia>
```
# TODO
- [] update docs
- [x] implement the function
# Signature # Signature
""" """
function selectBestNextState(node::MCTSNode)::MCTSNode function selectBestNextNode(node::MCTSNode)::MCTSNode
highestProgressValue = -1 highestProgressValue = -1
nodekey = nothing nodekey = nothing
@@ -72,20 +63,11 @@ end
- `childNode::MCTSNode` - `childNode::MCTSNode`
the highest value child node the highest value child node
# Example
```jldoctest
julia>
```
# TODO
- [] update docs
- [x] implement the function
# Signature # Signature
""" """
function selectBestTrajectory(node::MCTSNode)::MCTSNode function selectBestTrajectoryNode(node::MCTSNode)::MCTSNode
while !isleaf(node) while !isleaf(node)
node = selectBestNextState(node) node = selectBestNextNode(node)
end end
return node return node
@@ -99,14 +81,12 @@ end
leaf node of a search tree leaf node of a search tree
- `simTrajectoryReward::T` - `simTrajectoryReward::T`
total reward from trajectory simulation total reward from trajectory simulation
- `discountRewardCoeff::AbstractFloat`
A discount reward coefficient to reduce future reward. The futher in the future the lower
reward it is now.
# Return # Return
- `No return` - `None`
# Example
```jldoctest
julia>
```
# Signature # Signature
""" """
@@ -166,11 +146,6 @@ isleaf(node::MCTSNode)::Bool = isempty(node.children)
- `isrootnode::Bool` - `isrootnode::Bool`
true if the given node is root node, false otherwise true if the given node is root node, false otherwise
# Example
```jldoctest
julia>
```
# Signature # Signature
""" """
isroot(node::MCTSNode)::Bool = node.nodekey == "root" ? true : false isroot(node::MCTSNode)::Bool = node.nodekey == "root" ? true : false
@@ -187,11 +162,6 @@ isroot(node::MCTSNode)::Bool = node.nodekey == "root" ? true : false
- `childNode::MCTSNode` - `childNode::MCTSNode`
the highest value child node the highest value child node
# Example
```jldoctest
julia>
```
# Signature # Signature
""" """
function selectChildNode(node::MCTSNode)::MCTSNode function selectChildNode(node::MCTSNode)::MCTSNode
@@ -201,9 +171,6 @@ function selectChildNode(node::MCTSNode)::MCTSNode
# loop thought node children dictionary to find the highest progress value # loop thought node children dictionary to find the highest progress value
for (k, childNode) in node.children for (k, childNode) in node.children
potential = childNode.progressvalue + childNode.reward potential = childNode.progressvalue + childNode.reward
if childNode.reward > 0 #XXX for testing. remove when done.
println("")
end
if potential > highestProgressValue if potential > highestProgressValue
highestProgressValue = potential highestProgressValue = potential
nodekey = childNode.nodekey nodekey = childNode.nodekey
@@ -214,34 +181,26 @@ function selectChildNode(node::MCTSNode)::MCTSNode
end end
""" Expand selected node """ Expand selected node.
# Arguments # Arguments
- `a::T1`
One of YiemAgent's agent
- `node::MCTSNode` - `node::MCTSNode`
MCTS node MCTS node
- `state::T2` - `transition::Function`
a state of a game. Can be a Dict or something else. A function that handles state transition.
- `decisionMaker::Function` - `transitionargs::NamedTuple`
a function that output Thought and Action Arguments for transition()
- `evaluator::Function` - `totalsample::Integer`
a function that output trajectory progress score Total number to sample from the current node (i.e. expand new node horizontally)
# Return # Return
- None
# Example # Example
```jldoctest ```jldoctest
julia> julia>
``` ```
# TODO
[] update docstring
[] try loop should limit to 3 times. if not succeed, skip
[] newNodeKey ∉ keys(node.children). New state may have semantic vector close enought to one of existing child state. Which can be assume that they are the same state semantically-wise.
[x] store feedback -> state -> agent.
# Signature # Signature
""" """
function expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple; function expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple;
@@ -255,6 +214,14 @@ function expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple;
newNodeKey::AbstractString = result[:newNodeKey] newNodeKey::AbstractString = result[:newNodeKey]
newstate::AbstractDict = result[:newstate] newstate::AbstractDict = result[:newstate]
progressvalue::Integer = result[:progressvalue] progressvalue::Integer = result[:progressvalue]
"""
[] newNodeKey ∉ keys(node.children).
New state may have semantic vector close enought to
one of existing child state. Which can be assume that they are the same state
semantically-wise i.e. De javu. This could be used to recall lessons for this
similar situation to improve decisionMaker and evaluator.
"""
if newNodeKey keys(node.children) if newNodeKey keys(node.children)
node.children[newNodeKey] = node.children[newNodeKey] =
MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward], MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward],
@@ -270,30 +237,25 @@ end
""" Simulate interactions between agent and environment """ Simulate interactions between agent and environment
# Arguments # Arguments
- `a::T`
one of YiemAgent's agent
- `node::MCTSNode` - `node::MCTSNode`
node that will be a simulation starting point. node that will be a simulation starting point.
- `decisionMaker::Function` - `transition::Function`
function that receive state return Thought and Action A user function that handles how state transition.
- `transitionargs::NamedTuple`
Arguments for everything the user will use within transition().
- `maxdepth::Integer`
maximum depth level MCTS goes vertically.
- totalsample::Integer
Total number to sample from the current node (i.e. expand new node horizontally)
# Return # Return
- `simTrajectoryReward::Number` - `(simTrajectoryReward, terminalstate)::Union{Tuple{Number, Dict{Symbol, <:Any}}, Tuple{Number, Nothing}}`
# Example
```jldoctest
julia>
```
# TODO
- [] update docs
# Signature # Signature
""" """
function simulate(node::MCTSNode, transition::Function, transitionargs::NamedTuple; function simulate(node::MCTSNode, transition::Function, transitionargs::NamedTuple;
maxdepth::Integer=3, totalsample::Integer=3 maxdepth::Integer=3, totalsample::Integer=3
)::Union{Tuple{Number, Dict{Symbol, <:Any}}, Tuple{Number, Nothing}} )::Union{Tuple{Number, Dict{Symbol, <:Any}}, Tuple{Number, Nothing}}
simTrajectoryReward = 0.0 simTrajectoryReward = 0.0
terminalstate = nothing terminalstate = nothing