This commit is contained in:
2025-03-15 08:28:13 +07:00
parent 2eff443f70
commit b2c53ffa45
4 changed files with 94 additions and 52 deletions

View File

@@ -26,27 +26,23 @@ using ..type, ..mcts, ..util
- `horizontalSampleSimulationPhase::Integer` - `horizontalSampleSimulationPhase::Integer`
a number of child state MCTS sample at each node during simulation's expansion phase (default: 3) a number of child state MCTS sample at each node during simulation's expansion phase (default: 3)
- `maxSimulationDepth::Integer` - `maxSimulationDepth::Integer`
a number of levels MCTS goes during simulation phase (default: 3) a number of levels MCTS goes during simulation phase (default: 3)
- `maxiterations::Integer` - `maxiterations::Integer`
a number of iteration MCTS goes thru expansion -> simulation -> backpropagation cycle (default: 10) a number of iteration MCTS goes thru expansion -> simulation -> backpropagation cycle (default: 10)
- `explorationweight::Number` - `explorationweight::Number`
exploration weight controls how much MCTS should explore new state instead of exploit exploration weight controls how much MCTS should explore new state instead of exploit
a known state. 1.0 balance between exploration and exploitation like 50%-50%. 2.0 makes MCTS a known state. 1.0 balance between exploration and exploitation like 50%-50%. 2.0 makes MCTS
aggressively explore new state (default: 1.0) aggressively explore new state (default: 1.0)
- `earlystop::Union{Function,Nothing}` - `earlystop::Union{Function,Nothing}`
optional function to check early stopping condition (default: nothing) optional function to check early stopping condition (default: nothing)
- `saveSimulatedNode::Bool` - `saveSimulatedNode::Bool`
whether to save nodes created during simulation phase (default: false) whether to save nodes created during simulation phase (default: false)
- `multithread::Bool`
whether to use multithreading during simulation (default: false)
# Returns # Returns
- `NamedTuple{(:mctstree, :bestNextState, :bestFinalState), Tuple{MCTSNode, T, T}}` - `NamedTuple{(:root, :bestNextState, :bestFinalState), Tuple{MCTSNode, T, T}}`
- mctstree: the complete MCTS tree with root node - root: the complete MCTS tree with root node
- bestNextState: the best immediate next state - bestNextState: the best immediate next state
- bestFinalState: the best final state along the best trajectory - bestFinalState: the best final state along the best trajectory
@@ -67,8 +63,8 @@ function runMCTS(
explorationweight::Number=1.0, explorationweight::Number=1.0,
earlystop::Union{Function,Nothing}=nothing, earlystop::Union{Function,Nothing}=nothing,
saveSimulatedNode::Bool=false, saveSimulatedNode::Bool=false,
multithread=false) where {T<:Any} multithread=false
# )::NamedTuple{(:bestNextState, :bestFinalState),Tuple{T,T}} where {T<:Any} )::NamedTuple{(:root, :bestNextState, :bestFinalState),Tuple{MCTSNode,T,T}} where {T<:Any}
root = MCTSNode("root", initialstate, 0, 0, 0, 0, false, nothing, Dict{String,MCTSNode}(), root = MCTSNode("root", initialstate, 0, 0, 0, 0, false, nothing, Dict{String,MCTSNode}(),
Dict{Symbol,Any}()) Dict{Symbol,Any}())
@@ -121,7 +117,29 @@ function runMCTS(
return (root=root, bestNextState=bestNextState.state, bestFinalState=besttrajectory.state) return (root=root, bestNextState=bestNextState.state, bestFinalState=besttrajectory.state)
end end
""" Search the best action to take for a given state and task
# Arguments
- `node::MCTSNode`
current node to simulate from
- `transition::Function`
a function that defines how the state transitions
- `transitionargs::NamedTuple`
arguments for transition function
# Keyword Arguments
- `maxSimulationDepth::Integer`
a number of levels MCTS goes during simulation phase (default: 3)
- `horizontalSampleSimulationPhase::Integer`
a number of child states MCTS samples at each node during simulation phase (default: 3)
- `saveSimulatedNode::Bool`
whether to save nodes created during simulation phase (default: false)
- `multithread::Bool`
whether to use multithreading during simulation (default: false)
# Returns
Nothing, but updates the node's reward and visit count through backpropagation
"""
function simulateThenBackpropagate(node::MCTSNode, transition::Function, transitionargs::NamedTuple; function simulateThenBackpropagate(node::MCTSNode, transition::Function, transitionargs::NamedTuple;
maxSimulationDepth::Integer=3, horizontalSampleSimulationPhase::Integer=3, maxSimulationDepth::Integer=3, horizontalSampleSimulationPhase::Integer=3,
saveSimulatedNode::Bool=false, saveSimulatedNode::Bool=false,
@@ -209,9 +227,6 @@ end

View File

@@ -10,27 +10,29 @@ using ..type
# ---------------------------------------------- 100 --------------------------------------------- # # ---------------------------------------------- 100 --------------------------------------------- #
""" """ Select the best next node based on the highest value metric
# Arguments # Arguments
- `node::MCTSNode` - `node::MCTSNode`
node of a search tree node of a search tree to evaluate
# Return # Return
- `childNode::MCTSNode` - `childNode::MCTSNode`
the highest value child node the child node with highest value based on either:
- statevalue/visits ratio if any nodes have non-zero statevalue
# Signature - progressvalue + reward otherwise
""" """
function selectBestNextNode(node::MCTSNode)::MCTSNode function selectBestNextNode(node::MCTSNode)::MCTSNode
highestProgressValue = -1 highestProgressValue = -1
nodekey = nothing nodekey = nothing
# if all childnode has statevalue == 0, use progressvalue + reward to select the best node # Calculate sum of statevalues across all child nodes
stateValueSum = sum([v.statevalue for (k, v) in node.children]) stateValueSum = sum([v.statevalue for (k, v) in node.children])
# If any nodes have non-zero statevalue, use statevalue/visits as selection metric
if stateValueSum != 0 if stateValueSum != 0
for (k, childnode) in node.children for (k, childnode) in node.children
# Calculate average statevalue per visit
potential = childnode.statevalue / childnode.visits potential = childnode.statevalue / childnode.visits
if potential > highestProgressValue if potential > highestProgressValue
@@ -39,6 +41,7 @@ function selectBestNextNode(node::MCTSNode)::MCTSNode
end end
end end
else else
# Otherwise use progressvalue + reward as selection metric
for (k, childnode) in node.children for (k, childnode) in node.children
potential = childnode.progressvalue + childnode.reward potential = childnode.progressvalue + childnode.reward
@@ -53,15 +56,16 @@ function selectBestNextNode(node::MCTSNode)::MCTSNode
end end
""" """ Select the best trajectory node based on the highest reward
# Arguments # Arguments
- `node::MCTSNode` - `node::MCTSNode`
node of a search tree node of a search tree to evaluate
# Return # Return
- `childNode::MCTSNode` - `childNode::MCTSNode`
the highest value child node the highest value child node found by traversing down the tree using selectBestNextNode
until reaching a leaf node
# Signature # Signature
""" """
@@ -86,7 +90,8 @@ end
reward it is now. reward it is now.
# Return # Return
- `None` - `Nothing`
This function modifies the nodes in place and returns nothing
# Signature # Signature
""" """
@@ -94,22 +99,23 @@ function backpropagate(node::MCTSNode, simTrajectoryReward::T;
discountRewardCoeff::AbstractFloat=0.9) where {T<:Number} discountRewardCoeff::AbstractFloat=0.9) where {T<:Number}
while !isroot(node) while !isroot(node)
# Update the statistics of the current node based on the result of the playout # Update the statistics of the current node based on the result of the playout
node.visits += 1 node.visits += 1 # Increment visit count for this node
node.statevalue += ((node.statevalue * (node.visits-1)) + simTrajectoryReward) / node.visits node.statevalue += ((node.statevalue * (node.visits-1)) + simTrajectoryReward) / node.visits # Update running average of state value
simTrajectoryReward *= discountRewardCoeff # discount because future reward is uncertain simTrajectoryReward *= discountRewardCoeff # discount because future reward is uncertain
node = node.parent node = node.parent # Move up to parent node for next iteration
end end
end end
""" Determine whether a node is a leaf node of a search tree. """ Determine whether a node is a leaf node of a search tree.
# Arguments # Arguments
- `node::MCTSNode` - `node::MCTSNode`
a search tree node a search tree node
# Return # Return
- `result::Bool` - `result::Bool`
true if it is a leaf node, false otherwise. true if it is a leaf node (has no children), false otherwise.
# Example # Example
```jldoctest ```jldoctest
julia> using Revise julia> using Revise
@@ -128,14 +134,10 @@ julia> YiemAgent.isleaf(root)
true true
``` ```
# TODO
[] update docs
# Signature # Signature
""" """
isleaf(node::MCTSNode)::Bool = isempty(node.children) isleaf(node::MCTSNode)::Bool = isempty(node.children)
""" Determine wheter a given node is a root node """ Determine wheter a given node is a root node
# Arguments # Arguments
@@ -185,13 +187,17 @@ end
# Arguments # Arguments
- `node::MCTSNode` - `node::MCTSNode`
MCTS node MCTS node to expand
- `transition::Function` - `transition::Function`
A function that handles state transition. A function that handles state transition.
- `transitionargs::NamedTuple` - `transitionargs::NamedTuple`
Arguments for transition() Arguments for transition()
- `totalsample::Integer`
Total number to sample from the current node (i.e. expand new node horizontally) # Keyword Arguments
- `horizontalSample::Integer`
Total number to sample from the current node (i.e. expand new node horizontally). Defaults to 3.
- `multithread::Bool`
Whether to run expansion in parallel using multiple threads. Defaults to false.
# Return # Return
- None - None
@@ -211,6 +217,21 @@ function expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple;
end end
end end
""" Helper function to expand a single child node.
# Arguments
- `node::MCTSNode`
Parent MCTS node to expand from
- `transition::Function`
A function that handles state transition
- `transitionargs::NamedTuple`
Arguments for transition()
# Return
- None
# Signature
"""
function _expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple) function _expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple)
result = transition(node.state, transitionargs) result = transition(node.state, transitionargs)
newNodeKey::AbstractString = result[:newNodeKey] newNodeKey::AbstractString = result[:newNodeKey]
@@ -231,7 +252,6 @@ function _expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple
end end
end end
""" Simulate interactions between agent and environment """ Simulate interactions between agent and environment
# Arguments # Arguments

View File

@@ -10,17 +10,27 @@ using GeneralUtils
""" a node for MCTS search tree """ a node for MCTS search tree
# Arguments # Arguments
- `state::T` - `nodekey::AbstractString`
a state of a game. Can be a Dict or something else. unique identifier for the node
- `visits::Integer ` - `state::AbstractDict`
number of time the game visits this state a state of a game represented as a dictionary
- `stateValue::Float64` - `visits::Integer`
state value number of times the game visits this state
- `children::Dict{T, MCTSNode}` - `progressvalue::Number`
children node estimated value by LLM's reasoning
- `statevalue::Number`
current state value, stores node's immediate reward and future discounted rewards
- `reward::Number`
immediate reward for this node
- `isterminal::Bool`
whether this node represents a terminal state
- `parent::Union{MCTSNode, Nothing}`
reference to parent node, Nothing for root
- `children::Dict{String, MCTSNode}`
mapping of child nodes
- `etc::Dict{Symbol, Any}`
additional storage for arbitrary data
# Return
- `nothing`
# Example # Example
```jldoctest ```jldoctest
julia> state = Dict( julia> state = Dict(
@@ -36,9 +46,6 @@ julia> state = Dict(
) )
``` ```
# TODO
[] update docstring
# Signature # Signature
""" """
mutable struct MCTSNode{T1<:AbstractDict, T2<:AbstractString} mutable struct MCTSNode{T1<:AbstractDict, T2<:AbstractString}
@@ -113,7 +120,6 @@ end
end # module type end # module type

View File

@@ -17,6 +17,8 @@ using ..type
Value 2.0 makes MCTS aggressively search the tree. Value 2.0 makes MCTS aggressively search the tree.
# Return # Return
- `selectedNode::MCTSNode` - `selectedNode::MCTSNode`
child node with highest UCT score. UCT score balances between exploitation (state value)
and exploration (visit count) based on the exploration weight w.
# Example # Example
```jldoctest ```jldoctest
@@ -133,7 +135,6 @@ end
end # module util end # module util