From b9458f5b05532aef371d9195a23bb09b9cd73205 Mon Sep 17 00:00:00 2001 From: narawat lamaiin Date: Thu, 20 Jun 2024 16:56:57 +0700 Subject: [PATCH] update --- src/mcts.jl | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/mcts.jl b/src/mcts.jl index c37d9f2..49558c0 100644 --- a/src/mcts.jl +++ b/src/mcts.jl @@ -244,13 +244,14 @@ julia> # Signature """ -function expand(node::MCTSNode,transition::Function, args::NamedTuple; totalsample::Integer=3) +function expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple; + totalsample::Integer=3) nthSample = 0 while true nthSample += 1 if nthSample <= totalsample - newNodeKey, newstate, progressvalue = transition(node.state, args) + newNodeKey, newstate, progressvalue = transition(node.state, transitionargs) if newNodeKey ∉ keys(node.children) node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward], @@ -286,7 +287,7 @@ julia> # Signature """ -function simulate(node::MCTSNode, transition::Function, args::NamedTuple; +function simulate(node::MCTSNode, transition::Function, transitionargs::NamedTuple; maxdepth::Integer=3, totalsample::Integer=3 )::Union{Tuple{Number, Dict{Symbol, <:Any}}, Tuple{Number, Nothing}} @@ -299,7 +300,7 @@ function simulate(node::MCTSNode, transition::Function, args::NamedTuple; terminalstate = node.state break else - expand(node, transition, args; + expand(node, transition, transitionargs; totalsample=totalsample) node = selectChildNode(node) end