update
This commit is contained in:
11
src/mcts.jl
11
src/mcts.jl
@@ -348,8 +348,9 @@ function selectChildNode(node::MCTSNode)::MCTSNode
|
||||
|
||||
# loop thought node children dictionary to find the highest progress value
|
||||
for (k, childNode) in node.children
|
||||
thisNodeProgressValue = childNode.statevalue + childNode.reward
|
||||
if childNode.statevalue > highestProgressValue
|
||||
highestProgressValue = childNode.statevalue + childNode.reward
|
||||
highestProgressValue = thisNodeProgressValue
|
||||
nodekey = childNode.nodekey
|
||||
end
|
||||
end
|
||||
@@ -443,10 +444,10 @@ function runMCTS(
|
||||
|
||||
expand(a, node, decisionMaker, progressValueEstimator, isterminal, n=n)
|
||||
|
||||
leaf_node = UCTselect(node, w)
|
||||
simTrajectoryReward = simulate(a, leaf_node, decisionMaker, progressValueEstimator,
|
||||
isterminal, maxDepth, n=n)
|
||||
backpropagate(leaf_node, simTrajectoryReward)
|
||||
leafNode = UCTselect(node, w)
|
||||
simTrajectoryReward = simulate(a, leafNode, decisionMaker, progressValueEstimator,
|
||||
isterminal, maxDepth, n=n)
|
||||
backpropagate(leafNode, simTrajectoryReward)
|
||||
end
|
||||
|
||||
best_child_state = argmax([child.total_reward / child.visits for child in values(root.children)])
|
||||
|
||||
Reference in New Issue
Block a user