This commit is contained in:
2025-03-15 08:28:13 +07:00
parent 2eff443f70
commit b2c53ffa45
4 changed files with 94 additions and 52 deletions

View File

@@ -26,27 +26,23 @@ using ..type, ..mcts, ..util
- `horizontalSampleSimulationPhase::Integer`
a number of child state MCTS sample at each node during simulation's expansion phase (default: 3)
- `maxSimulationDepth::Integer`
a number of levels MCTS goes during simulation phase (default: 3)
- `maxiterations::Integer`
a number of iteration MCTS goes thru expansion -> simulation -> backpropagation cycle (default: 10)
- `explorationweight::Number`
exploration weight controls how much MCTS should explore new state instead of exploit
a known state. 1.0 balance between exploration and exploitation like 50%-50%. 2.0 makes MCTS
aggressively explore new state (default: 1.0)
- `earlystop::Union{Function,Nothing}`
optional function to check early stopping condition (default: nothing)
- `saveSimulatedNode::Bool`
whether to save nodes created during simulation phase (default: false)
- `multithread::Bool`
whether to use multithreading during simulation (default: false)
# Returns
- `NamedTuple{(:mctstree, :bestNextState, :bestFinalState), Tuple{MCTSNode, T, T}}`
- mctstree: the complete MCTS tree with root node
- `NamedTuple{(:root, :bestNextState, :bestFinalState), Tuple{MCTSNode, T, T}}`
- root: the complete MCTS tree with root node
- bestNextState: the best immediate next state
- bestFinalState: the best final state along the best trajectory
@@ -67,8 +63,8 @@ function runMCTS(
explorationweight::Number=1.0,
earlystop::Union{Function,Nothing}=nothing,
saveSimulatedNode::Bool=false,
multithread=false) where {T<:Any}
# )::NamedTuple{(:bestNextState, :bestFinalState),Tuple{T,T}} where {T<:Any}
multithread=false
)::NamedTuple{(:root, :bestNextState, :bestFinalState),Tuple{MCTSNode,T,T}} where {T<:Any}
root = MCTSNode("root", initialstate, 0, 0, 0, 0, false, nothing, Dict{String,MCTSNode}(),
Dict{Symbol,Any}())
@@ -121,7 +117,29 @@ function runMCTS(
return (root=root, bestNextState=bestNextState.state, bestFinalState=besttrajectory.state)
end
""" Search the best action to take for a given state and task
# Arguments
- `node::MCTSNode`
current node to simulate from
- `transition::Function`
a function that defines how the state transitions
- `transitionargs::NamedTuple`
arguments for transition function
# Keyword Arguments
- `maxSimulationDepth::Integer`
a number of levels MCTS goes during simulation phase (default: 3)
- `horizontalSampleSimulationPhase::Integer`
a number of child states MCTS samples at each node during simulation phase (default: 3)
- `saveSimulatedNode::Bool`
whether to save nodes created during simulation phase (default: false)
- `multithread::Bool`
whether to use multithreading during simulation (default: false)
# Returns
Nothing, but updates the node's reward and visit count through backpropagation
"""
function simulateThenBackpropagate(node::MCTSNode, transition::Function, transitionargs::NamedTuple;
maxSimulationDepth::Integer=3, horizontalSampleSimulationPhase::Integer=3,
saveSimulatedNode::Bool=false,
@@ -209,9 +227,6 @@ end

View File

@@ -10,27 +10,29 @@ using ..type
# ---------------------------------------------- 100 --------------------------------------------- #
"""
""" Select the best next node based on the highest value metric
# Arguments
- `node::MCTSNode`
node of a search tree
node of a search tree to evaluate
# Return
- `childNode::MCTSNode`
the highest value child node
# Signature
the child node with highest value based on either:
- statevalue/visits ratio if any nodes have non-zero statevalue
- progressvalue + reward otherwise
"""
function selectBestNextNode(node::MCTSNode)::MCTSNode
highestProgressValue = -1
nodekey = nothing
# if all childnode has statevalue == 0, use progressvalue + reward to select the best node
# Calculate sum of statevalues across all child nodes
stateValueSum = sum([v.statevalue for (k, v) in node.children])
# If any nodes have non-zero statevalue, use statevalue/visits as selection metric
if stateValueSum != 0
for (k, childnode) in node.children
# Calculate average statevalue per visit
potential = childnode.statevalue / childnode.visits
if potential > highestProgressValue
@@ -39,6 +41,7 @@ function selectBestNextNode(node::MCTSNode)::MCTSNode
end
end
else
# Otherwise use progressvalue + reward as selection metric
for (k, childnode) in node.children
potential = childnode.progressvalue + childnode.reward
@@ -53,15 +56,16 @@ function selectBestNextNode(node::MCTSNode)::MCTSNode
end
"""
""" Select the best trajectory node based on the highest reward
# Arguments
- `node::MCTSNode`
node of a search tree
node of a search tree to evaluate
# Return
- `childNode::MCTSNode`
the highest value child node
the highest value child node found by traversing down the tree using selectBestNextNode
until reaching a leaf node
# Signature
"""
@@ -86,7 +90,8 @@ end
reward it is now.
# Return
- `None`
- `Nothing`
This function modifies the nodes in place and returns nothing
# Signature
"""
@@ -94,22 +99,23 @@ function backpropagate(node::MCTSNode, simTrajectoryReward::T;
discountRewardCoeff::AbstractFloat=0.9) where {T<:Number}
while !isroot(node)
# Update the statistics of the current node based on the result of the playout
node.visits += 1
node.statevalue += ((node.statevalue * (node.visits-1)) + simTrajectoryReward) / node.visits
node.visits += 1 # Increment visit count for this node
node.statevalue += ((node.statevalue * (node.visits-1)) + simTrajectoryReward) / node.visits # Update running average of state value
simTrajectoryReward *= discountRewardCoeff # discount because future reward is uncertain
node = node.parent
node = node.parent # Move up to parent node for next iteration
end
end
""" Determine whether a node is a leaf node of a search tree.
# Arguments
- `node::MCTSNode`
a search tree node
# Return
- `result::Bool`
true if it is a leaf node, false otherwise.
true if it is a leaf node (has no children), false otherwise.
# Example
```jldoctest
julia> using Revise
@@ -128,14 +134,10 @@ julia> YiemAgent.isleaf(root)
true
```
# TODO
[] update docs
# Signature
"""
isleaf(node::MCTSNode)::Bool = isempty(node.children)
""" Determine wheter a given node is a root node
# Arguments
@@ -185,13 +187,17 @@ end
# Arguments
- `node::MCTSNode`
MCTS node
MCTS node to expand
- `transition::Function`
A function that handles state transition.
- `transitionargs::NamedTuple`
Arguments for transition()
- `totalsample::Integer`
Total number to sample from the current node (i.e. expand new node horizontally)
# Keyword Arguments
- `horizontalSample::Integer`
Total number to sample from the current node (i.e. expand new node horizontally). Defaults to 3.
- `multithread::Bool`
Whether to run expansion in parallel using multiple threads. Defaults to false.
# Return
- None
@@ -211,6 +217,21 @@ function expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple;
end
end
""" Helper function to expand a single child node.
# Arguments
- `node::MCTSNode`
Parent MCTS node to expand from
- `transition::Function`
A function that handles state transition
- `transitionargs::NamedTuple`
Arguments for transition()
# Return
- None
# Signature
"""
function _expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple)
result = transition(node.state, transitionargs)
newNodeKey::AbstractString = result[:newNodeKey]
@@ -231,7 +252,6 @@ function _expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple
end
end
""" Simulate interactions between agent and environment
# Arguments

View File

@@ -10,17 +10,27 @@ using GeneralUtils
""" a node for MCTS search tree
# Arguments
- `state::T`
a state of a game. Can be a Dict or something else.
- `visits::Integer `
number of time the game visits this state
- `stateValue::Float64`
state value
- `children::Dict{T, MCTSNode}`
children node
- `nodekey::AbstractString`
unique identifier for the node
- `state::AbstractDict`
a state of a game represented as a dictionary
- `visits::Integer`
number of times the game visits this state
- `progressvalue::Number`
estimated value by LLM's reasoning
- `statevalue::Number`
current state value, stores node's immediate reward and future discounted rewards
- `reward::Number`
immediate reward for this node
- `isterminal::Bool`
whether this node represents a terminal state
- `parent::Union{MCTSNode, Nothing}`
reference to parent node, Nothing for root
- `children::Dict{String, MCTSNode}`
mapping of child nodes
- `etc::Dict{Symbol, Any}`
additional storage for arbitrary data
# Return
- `nothing`
# Example
```jldoctest
julia> state = Dict(
@@ -36,9 +46,6 @@ julia> state = Dict(
)
```
# TODO
[] update docstring
# Signature
"""
mutable struct MCTSNode{T1<:AbstractDict, T2<:AbstractString}
@@ -113,7 +120,6 @@ end
end # module type

View File

@@ -17,6 +17,8 @@ using ..type
Value 2.0 makes MCTS aggressively search the tree.
# Return
- `selectedNode::MCTSNode`
child node with highest UCT score. UCT score balances between exploitation (state value)
and exploration (visit count) based on the exploration weight w.
# Example
```jldoctest
@@ -133,7 +135,6 @@ end
end # module util