update
This commit is contained in:
16
src/mcts.jl
16
src/mcts.jl
@@ -182,6 +182,7 @@ function simulate(a, node::MCTSNode, decisionMaker::Function, progressValueEstim
|
||||
|
||||
for depth in 1:maxDepth
|
||||
if node.isterminal
|
||||
simTrajectoryReward += node.reward
|
||||
break
|
||||
else
|
||||
simTrajectoryReward += node.reward
|
||||
@@ -448,12 +449,15 @@ function runMCTS(
|
||||
while !isleaf(node)
|
||||
node = UCTselect(node, w)
|
||||
end
|
||||
|
||||
expand(a, node, decisionMaker, progressValueEstimator, isterminal, n=n)
|
||||
|
||||
leafNode = UCTselect(node, w)
|
||||
simTrajectoryReward = simulate(a, leafNode, decisionMaker, progressValueEstimator,
|
||||
isterminal, maxDepth, n=n)
|
||||
if node.isterminal
|
||||
# MCTS arrive at the leaf node that is also a terminal state,
|
||||
# do nothing then go directly to backpropagation
|
||||
else
|
||||
expand(a, node, decisionMaker, progressValueEstimator, isterminal, n=n)
|
||||
leafNode = UCTselect(node, w)
|
||||
simTrajectoryReward = simulate(a, leafNode, decisionMaker, progressValueEstimator,
|
||||
isterminal, maxDepth, n=n)
|
||||
end
|
||||
backpropagate(leafNode, simTrajectoryReward)
|
||||
end
|
||||
|
||||
|
||||
Reference in New Issue
Block a user