140 lines
1.5 KiB
Julia
140 lines
1.5 KiB
Julia
module util
|
|
|
|
export UCTselect
|
|
|
|
using ..type
|
|
|
|
# ---------------------------------------------- 100 --------------------------------------------- #
|
|
|
|
""" Select a node based on UCT score
|
|
|
|
# Arguments
|
|
- `node::MCTSNode`
|
|
mcts node
|
|
- `w::T`
|
|
exploration weight. Value is usually between 1 to 2.
|
|
Value 1.0 makes MCTS balance between exploration and exploitation like 50%-50%.
|
|
Value 2.0 makes MCTS aggressively search the tree.
|
|
# Return
|
|
- `selectedNode::MCTSNode`
|
|
child node with highest UCT score. UCT score balances between exploitation (state value)
|
|
and exploration (visit count) based on the exploration weight w.
|
|
|
|
# Example
|
|
```jldoctest
|
|
julia>
|
|
```
|
|
|
|
# Signature
|
|
"""
|
|
function UCTselect(node::MCTSNode, w::T)::MCTSNode where {T<:AbstractFloat}
|
|
maxUCT = -Inf
|
|
selectedNode = nothing
|
|
|
|
for (childState, childNode) in node.children
|
|
UCTvalue =
|
|
if childNode.visits != 0
|
|
weightedterm = w * sqrt(log(node.visits) / childNode.visits) # explore term
|
|
childNode.statevalue + weightedterm
|
|
else # node.visits == 0 makes sqrt() in explore term error
|
|
childNode.progressvalue # exploit term
|
|
end
|
|
|
|
if UCTvalue > maxUCT
|
|
maxUCT = UCTvalue
|
|
selectedNode = childNode
|
|
end
|
|
end
|
|
|
|
return selectedNode
|
|
end
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
end # module util |