diff --git a/src/interface.jl b/src/interface.jl index 5b4e88e..e142230 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -47,17 +47,17 @@ julia> # Signature """ function runMCTS( - initialState, + initialstate, transition::Function, - args..., + transitionargs::NamedTuple, ; totalsample::Integer=3, - maxDepth::Integer=3, + maxdepth::Integer=3, maxiterations::Integer=10, explorationweight::Number=1.0, ) - root = MCTSNode("root", initialState, 0, 0, 0, 0, false, nothing, Dict{String, MCTSNode}()) + root = MCTSNode("root", initialstate, 0, 0, 0, 0, false, nothing, Dict{String, MCTSNode}()) for nth in 1:maxiterations node = root @@ -71,11 +71,11 @@ function runMCTS( # do nothing then go directly to backpropagation backpropagate(leafNode, node.reward) else - expand(node, transition, args...; + expand(node, transition, transitionargs; totalsample=totalsample) leafNode = selectChildNode(node) - simTrajectoryReward, terminalstate = simulate(leafNode, transition, args...; - maxDepth=maxDepth, totalsample=totalsample) + simTrajectoryReward, terminalstate = simulate(leafNode, transition, transitionargs; + maxdepth=maxdepth, totalsample=totalsample) if terminalstate !== nothing #XXX not sure why I need this terminalstate[:totalTrajectoryReward] = simTrajectoryReward end diff --git a/src/mcts.jl b/src/mcts.jl index 47f7d34..1896444 100644 --- a/src/mcts.jl +++ b/src/mcts.jl @@ -244,13 +244,13 @@ julia> # Signature """ -function expand(node::MCTSNode,transition::Function, args...; totalsample::Integer=3) +function expand(node::MCTSNode,transition::Function, args::NamedTuple; totalsample::Integer=3) nthSample = 0 while true nthSample += 1 if nthSample <= totalsample - newNodeKey, newstate, progressvalue = transition(node.state, args...) + newNodeKey, newstate, progressvalue = transition(node.state, args) if newNodeKey ∉ keys(node.children) node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward], @@ -287,19 +287,19 @@ julia> # Signature """ function simulate(node::MCTSNode, transition::Function, args...; - maxDepth::Integer=3, totalsample::Integer=3 + maxdepth::Integer=3, totalsample::Integer=3 )::Union{Tuple{Number, Dict{Symbol, <:Any}}, Tuple{Number, Nothing}} simTrajectoryReward = 0.0 terminalstate = nothing - for depth in 1:maxDepth + for depth in 1:maxdepth simTrajectoryReward += node.reward if node.isterminal terminalstate = node.state break else - expand(node, transition, args...; + expand(node, transition, args; totalsample=totalsample) node = selectChildNode(node) end