diff --git a/src/mcts.jl b/src/mcts.jl index 37efc01..d3edab1 100644 --- a/src/mcts.jl +++ b/src/mcts.jl @@ -513,13 +513,17 @@ function runMCTS( end end - best_child_state = argmax([child.statevalue / child.visits for child in values(root.children)]) + avgStateValue = 0 + selectedChildKey = nothing + for (k, v) in root.children + k_avgStateValue = v.statevalue / v.visits + if k_avgStateValue > avgStateValue + avgStateValue = k_avgStateValue + selectedChildKey = k + end + end - - - - - return best_child_state + return root.children[selectedChildKey] end