diff --git a/src/mcts.jl b/src/mcts.jl index 9548e71..95e698e 100644 --- a/src/mcts.jl +++ b/src/mcts.jl @@ -182,6 +182,7 @@ function simulate(a, node::MCTSNode, decisionMaker::Function, progressValueEstim for depth in 1:maxDepth if node.isterminal + simTrajectoryReward += node.reward break else simTrajectoryReward += node.reward @@ -448,12 +449,15 @@ function runMCTS( while !isleaf(node) node = UCTselect(node, w) end - - expand(a, node, decisionMaker, progressValueEstimator, isterminal, n=n) - - leafNode = UCTselect(node, w) - simTrajectoryReward = simulate(a, leafNode, decisionMaker, progressValueEstimator, - isterminal, maxDepth, n=n) + if node.isterminal + # MCTS arrive at the leaf node that is also a terminal state, + # do nothing then go directly to backpropagation + else + expand(a, node, decisionMaker, progressValueEstimator, isterminal, n=n) + leafNode = UCTselect(node, w) + simTrajectoryReward = simulate(a, leafNode, decisionMaker, progressValueEstimator, + isterminal, maxDepth, n=n) + end backpropagate(leafNode, simTrajectoryReward) end