From b2c53ffa450535c6d13aefbc36fff8d33a5d06a8 Mon Sep 17 00:00:00 2001 From: tonaerospace Date: Sat, 15 Mar 2025 08:28:13 +0700 Subject: [PATCH] update --- src/interface.jl | 41 ++++++++++++++++++++--------- src/mcts.jl | 68 +++++++++++++++++++++++++++++++----------------- src/type.jl | 34 ++++++++++++++---------- src/util.jl | 3 ++- 4 files changed, 94 insertions(+), 52 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index d5974b7..32a8b3f 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -26,27 +26,23 @@ using ..type, ..mcts, ..util - `horizontalSampleSimulationPhase::Integer` a number of child state MCTS sample at each node during simulation's expansion phase (default: 3) - `maxSimulationDepth::Integer` - a number of levels MCTS goes during simulation phase (default: 3) - `maxiterations::Integer` - a number of iteration MCTS goes thru expansion -> simulation -> backpropagation cycle (default: 10) - `explorationweight::Number` exploration weight controls how much MCTS should explore new state instead of exploit a known state. 1.0 balance between exploration and exploitation like 50%-50%. 2.0 makes MCTS - aggressively explore new state (default: 1.0) - `earlystop::Union{Function,Nothing}` optional function to check early stopping condition (default: nothing) - `saveSimulatedNode::Bool` whether to save nodes created during simulation phase (default: false) - - - + - `multithread::Bool` + whether to use multithreading during simulation (default: false) # Returns - - `NamedTuple{(:mctstree, :bestNextState, :bestFinalState), Tuple{MCTSNode, T, T}}` - - mctstree: the complete MCTS tree with root node + - `NamedTuple{(:root, :bestNextState, :bestFinalState), Tuple{MCTSNode, T, T}}` + - root: the complete MCTS tree with root node - bestNextState: the best immediate next state - bestFinalState: the best final state along the best trajectory @@ -67,8 +63,8 @@ function runMCTS( explorationweight::Number=1.0, earlystop::Union{Function,Nothing}=nothing, saveSimulatedNode::Bool=false, - multithread=false) where {T<:Any} -# )::NamedTuple{(:bestNextState, :bestFinalState),Tuple{T,T}} where {T<:Any} + multithread=false + )::NamedTuple{(:root, :bestNextState, :bestFinalState),Tuple{MCTSNode,T,T}} where {T<:Any} root = MCTSNode("root", initialstate, 0, 0, 0, 0, false, nothing, Dict{String,MCTSNode}(), Dict{Symbol,Any}()) @@ -121,7 +117,29 @@ function runMCTS( return (root=root, bestNextState=bestNextState.state, bestFinalState=besttrajectory.state) end +""" Search the best action to take for a given state and task +# Arguments +- `node::MCTSNode` + current node to simulate from +- `transition::Function` + a function that defines how the state transitions +- `transitionargs::NamedTuple` + arguments for transition function + +# Keyword Arguments +- `maxSimulationDepth::Integer` + a number of levels MCTS goes during simulation phase (default: 3) +- `horizontalSampleSimulationPhase::Integer` + a number of child states MCTS samples at each node during simulation phase (default: 3) +- `saveSimulatedNode::Bool` + whether to save nodes created during simulation phase (default: false) +- `multithread::Bool` + whether to use multithreading during simulation (default: false) + +# Returns + Nothing, but updates the node's reward and visit count through backpropagation +""" function simulateThenBackpropagate(node::MCTSNode, transition::Function, transitionargs::NamedTuple; maxSimulationDepth::Integer=3, horizontalSampleSimulationPhase::Integer=3, saveSimulatedNode::Bool=false, @@ -209,9 +227,6 @@ end - - - diff --git a/src/mcts.jl b/src/mcts.jl index 81839ab..2715baa 100644 --- a/src/mcts.jl +++ b/src/mcts.jl @@ -10,27 +10,29 @@ using ..type # ---------------------------------------------- 100 --------------------------------------------- # -""" +""" Select the best next node based on the highest value metric # Arguments - `node::MCTSNode` - node of a search tree + node of a search tree to evaluate # Return - `childNode::MCTSNode` - the highest value child node - -# Signature + the child node with highest value based on either: + - statevalue/visits ratio if any nodes have non-zero statevalue + - progressvalue + reward otherwise """ function selectBestNextNode(node::MCTSNode)::MCTSNode highestProgressValue = -1 nodekey = nothing - # if all childnode has statevalue == 0, use progressvalue + reward to select the best node + # Calculate sum of statevalues across all child nodes stateValueSum = sum([v.statevalue for (k, v) in node.children]) + # If any nodes have non-zero statevalue, use statevalue/visits as selection metric if stateValueSum != 0 for (k, childnode) in node.children + # Calculate average statevalue per visit potential = childnode.statevalue / childnode.visits if potential > highestProgressValue @@ -39,6 +41,7 @@ function selectBestNextNode(node::MCTSNode)::MCTSNode end end else + # Otherwise use progressvalue + reward as selection metric for (k, childnode) in node.children potential = childnode.progressvalue + childnode.reward @@ -53,15 +56,16 @@ function selectBestNextNode(node::MCTSNode)::MCTSNode end -""" +""" Select the best trajectory node based on the highest reward # Arguments - `node::MCTSNode` - node of a search tree + node of a search tree to evaluate # Return - `childNode::MCTSNode` - the highest value child node + the highest value child node found by traversing down the tree using selectBestNextNode + until reaching a leaf node # Signature """ @@ -86,7 +90,8 @@ end reward it is now. # Return - - `None` + - `Nothing` + This function modifies the nodes in place and returns nothing # Signature """ @@ -94,22 +99,23 @@ function backpropagate(node::MCTSNode, simTrajectoryReward::T; discountRewardCoeff::AbstractFloat=0.9) where {T<:Number} while !isroot(node) # Update the statistics of the current node based on the result of the playout - node.visits += 1 - node.statevalue += ((node.statevalue * (node.visits-1)) + simTrajectoryReward) / node.visits + node.visits += 1 # Increment visit count for this node + node.statevalue += ((node.statevalue * (node.visits-1)) + simTrajectoryReward) / node.visits # Update running average of state value simTrajectoryReward *= discountRewardCoeff # discount because future reward is uncertain - node = node.parent + node = node.parent # Move up to parent node for next iteration end end - """ Determine whether a node is a leaf node of a search tree. # Arguments - `node::MCTSNode` a search tree node + # Return - `result::Bool` - true if it is a leaf node, false otherwise. + true if it is a leaf node (has no children), false otherwise. + # Example ```jldoctest julia> using Revise @@ -128,14 +134,10 @@ julia> YiemAgent.isleaf(root) true ``` -# TODO - [] update docs - # Signature """ isleaf(node::MCTSNode)::Bool = isempty(node.children) - """ Determine wheter a given node is a root node # Arguments @@ -185,14 +187,18 @@ end # Arguments - `node::MCTSNode` - MCTS node + MCTS node to expand - `transition::Function` A function that handles state transition. - `transitionargs::NamedTuple` Arguments for transition() - - `totalsample::Integer` - Total number to sample from the current node (i.e. expand new node horizontally) - + +# Keyword Arguments + - `horizontalSample::Integer` + Total number to sample from the current node (i.e. expand new node horizontally). Defaults to 3. + - `multithread::Bool` + Whether to run expansion in parallel using multiple threads. Defaults to false. + # Return - None @@ -211,6 +217,21 @@ function expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple; end end +""" Helper function to expand a single child node. + +# Arguments + - `node::MCTSNode` + Parent MCTS node to expand from + - `transition::Function` + A function that handles state transition + - `transitionargs::NamedTuple` + Arguments for transition() + +# Return + - None + +# Signature +""" function _expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple) result = transition(node.state, transitionargs) newNodeKey::AbstractString = result[:newNodeKey] @@ -231,7 +252,6 @@ function _expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple end end - """ Simulate interactions between agent and environment # Arguments diff --git a/src/type.jl b/src/type.jl index cc1d7dc..0b8f46e 100644 --- a/src/type.jl +++ b/src/type.jl @@ -10,17 +10,27 @@ using GeneralUtils """ a node for MCTS search tree # Arguments - - `state::T` - a state of a game. Can be a Dict or something else. - - `visits::Integer ` - number of time the game visits this state - - `stateValue::Float64` - state value - - `children::Dict{T, MCTSNode}` - children node + - `nodekey::AbstractString` + unique identifier for the node + - `state::AbstractDict` + a state of a game represented as a dictionary + - `visits::Integer` + number of times the game visits this state + - `progressvalue::Number` + estimated value by LLM's reasoning + - `statevalue::Number` + current state value, stores node's immediate reward and future discounted rewards + - `reward::Number` + immediate reward for this node + - `isterminal::Bool` + whether this node represents a terminal state + - `parent::Union{MCTSNode, Nothing}` + reference to parent node, Nothing for root + - `children::Dict{String, MCTSNode}` + mapping of child nodes + - `etc::Dict{Symbol, Any}` + additional storage for arbitrary data -# Return - - `nothing` # Example ```jldoctest julia> state = Dict( @@ -36,9 +46,6 @@ julia> state = Dict( ) ``` -# TODO - [] update docstring - # Signature """ mutable struct MCTSNode{T1<:AbstractDict, T2<:AbstractString} @@ -113,7 +120,6 @@ end - end # module type \ No newline at end of file diff --git a/src/util.jl b/src/util.jl index fcdd9d5..512556d 100644 --- a/src/util.jl +++ b/src/util.jl @@ -17,6 +17,8 @@ using ..type Value 2.0 makes MCTS aggressively search the tree. # Return - `selectedNode::MCTSNode` + child node with highest UCT score. UCT score balances between exploitation (state value) + and exploration (visit count) based on the exploration weight w. # Example ```jldoctest @@ -133,7 +135,6 @@ end - end # module util \ No newline at end of file