This commit is contained in:
narawat lamaiin
2024-05-06 16:13:01 +07:00
parent 89c6af780f
commit ebacc16922
4 changed files with 28 additions and 10 deletions

View File

@@ -348,8 +348,9 @@ function selectChildNode(node::MCTSNode)::MCTSNode
# loop thought node children dictionary to find the highest progress value
for (k, childNode) in node.children
thisNodeProgressValue = childNode.statevalue + childNode.reward
if childNode.statevalue > highestProgressValue
highestProgressValue = childNode.statevalue + childNode.reward
highestProgressValue = thisNodeProgressValue
nodekey = childNode.nodekey
end
end
@@ -443,10 +444,10 @@ function runMCTS(
expand(a, node, decisionMaker, progressValueEstimator, isterminal, n=n)
leaf_node = UCTselect(node, w)
simTrajectoryReward = simulate(a, leaf_node, decisionMaker, progressValueEstimator,
isterminal, maxDepth, n=n)
backpropagate(leaf_node, simTrajectoryReward)
leafNode = UCTselect(node, w)
simTrajectoryReward = simulate(a, leafNode, decisionMaker, progressValueEstimator,
isterminal, maxDepth, n=n)
backpropagate(leafNode, simTrajectoryReward)
end
best_child_state = argmax([child.total_reward / child.visits for child in values(root.children)])