update
This commit is contained in:
12
src/mcts.jl
12
src/mcts.jl
@@ -407,6 +407,9 @@ function selectChildNode(node::MCTSNode)::MCTSNode
|
||||
# loop thought node children dictionary to find the highest progress value
|
||||
for (k, childNode) in node.children
|
||||
potential = childNode.progressvalue + childNode.reward
|
||||
if childNode.reward > 0 #XXX for testing. remove when done.
|
||||
println("")
|
||||
end
|
||||
if potential > highestProgressValue
|
||||
highestProgressValue = potential
|
||||
nodekey = childNode.nodekey
|
||||
@@ -485,7 +488,8 @@ function runMCTS(
|
||||
n::Integer,
|
||||
maxDepth::Integer,
|
||||
maxIterations::Integer,
|
||||
w::Float64) where {T1<:agent}
|
||||
w::Float64
|
||||
) where {T1<:agent}
|
||||
|
||||
root = MCTSNode("root", initialState, 0, 0, 0, 0, false, nothing, Dict{String, MCTSNode}())
|
||||
|
||||
@@ -510,7 +514,11 @@ function runMCTS(
|
||||
end
|
||||
|
||||
best_child_state = argmax([child.statevalue / child.visits for child in values(root.children)])
|
||||
error("---> runMCTS")
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
return best_child_state
|
||||
end
|
||||
|
||||
|
||||
Reference in New Issue
Block a user