This commit is contained in:
narawat lamaiin
2024-06-20 16:56:57 +07:00
parent 59a2b009c6
commit b9458f5b05

View File

@@ -244,13 +244,14 @@ julia>
# Signature # Signature
""" """
function expand(node::MCTSNode,transition::Function, args::NamedTuple; totalsample::Integer=3) function expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple;
totalsample::Integer=3)
nthSample = 0 nthSample = 0
while true while true
nthSample += 1 nthSample += 1
if nthSample <= totalsample if nthSample <= totalsample
newNodeKey, newstate, progressvalue = transition(node.state, args) newNodeKey, newstate, progressvalue = transition(node.state, transitionargs)
if newNodeKey keys(node.children) if newNodeKey keys(node.children)
node.children[newNodeKey] = node.children[newNodeKey] =
MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward], MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward],
@@ -286,7 +287,7 @@ julia>
# Signature # Signature
""" """
function simulate(node::MCTSNode, transition::Function, args::NamedTuple; function simulate(node::MCTSNode, transition::Function, transitionargs::NamedTuple;
maxdepth::Integer=3, totalsample::Integer=3 maxdepth::Integer=3, totalsample::Integer=3
)::Union{Tuple{Number, Dict{Symbol, <:Any}}, Tuple{Number, Nothing}} )::Union{Tuple{Number, Dict{Symbol, <:Any}}, Tuple{Number, Nothing}}
@@ -299,7 +300,7 @@ function simulate(node::MCTSNode, transition::Function, args::NamedTuple;
terminalstate = node.state terminalstate = node.state
break break
else else
expand(node, transition, args; expand(node, transition, transitionargs;
totalsample=totalsample) totalsample=totalsample)
node = selectChildNode(node) node = selectChildNode(node)
end end