module interface export runMCTS using Base.Threads, PrettyPrinting using ..type, ..mcts, ..util # ---------------------------------------------- 100 --------------------------------------------- # """ Search the best action to take for a given state and task # Arguments - `initialstate::T` initial state - `transition::Function` a function that define how the state transitions - `transitionargs::NamedTuple` arguments for transition function # Keyword Arguments - `horizontalSampleExpansionPhase::Integer` a number of child state MCTS sample at each node during expansion phase (default: 3) - `horizontalSampleSimulationPhase::Integer` a number of child state MCTS sample at each node during simulation's expansion phase (default: 3) - `maxSimulationDepth::Integer` a number of levels MCTS goes during simulation phase (default: 3) - `maxiterations::Integer` a number of iteration MCTS goes thru expansion -> simulation -> backpropagation cycle (default: 10) - `explorationweight::Number` exploration weight controls how much MCTS should explore new state instead of exploit a known state. 1.0 balance between exploration and exploitation like 50%-50%. 2.0 makes MCTS aggressively explore new state (default: 1.0) - `earlystop::Union{Function,Nothing}` optional function to check early stopping condition (default: nothing) - `saveSimulatedNode::Bool` whether to save nodes created during simulation phase (default: false) - `multithread::Bool` whether to use multithreading during simulation (default: false) # Returns - `NamedTuple{(:root, :bestNextState, :bestFinalState), Tuple{MCTSNode, T, T}}` - root: the complete MCTS tree with root node - bestNextState: the best immediate next state - bestFinalState: the best final state along the best trajectory # Example Refers to SQLLLM package # Signature """ function runMCTS( initialstate::T, transition::Function, transitionargs::NamedTuple, ; horizontalSampleExpansionPhase::Integer=3, horizontalSampleSimulationPhase::Integer=3, maxSimulationDepth::Integer=3, maxiterations::Integer=10, explorationweight::Number=1.0, earlystop::Union{Function,Nothing}=nothing, saveSimulatedNode::Bool=false, multithread=false, )::NamedTuple{(:root, :bestNextState, :bestTerminalState, :highValueStateList), Tuple{MCTSNode,T,T,Vector{Dict{Symbol,Any}}}} where {T<:Any} root = MCTSNode("root", initialstate, 0, 0, 0, 0, false, nothing, Dict{String,MCTSNode}(), Dict{Symbol,Any}()) # storage for holding all high reward terminal nodes highValueState = Channel{Any}(100) for nth in 1:maxiterations node = root node.visits += 1 while !isleaf(node) node = UCTselect(node, explorationweight) end if node.isterminal if node.state[:reward] >= 8 put!(highrewardNode, deepcopy(node.state)) end # MCTS arrive at the leaf node that is also a terminal state, # do nothing then go directly to backpropagation. It means the end of this iteration backpropagate(node, node.reward) else _ = expand(node, transition, transitionargs; horizontalSample=horizontalSampleExpansionPhase, multithread=multithread) if multithread @sync for (leafNodeKey, leafNode) in node.children @spawn simulateThenBackpropagate(leafNode, transition, transitionargs; maxSimulationDepth=maxSimulationDepth, horizontalSampleSimulationPhase=horizontalSampleSimulationPhase, saveSimulatedNode=saveSimulatedNode, multithread=multithread, highValueState=highValueState, ) end else for (leafNodeKey, leafNode) in node.children simulateThenBackpropagate(leafNode, transition, transitionargs; maxSimulationDepth=maxSimulationDepth, horizontalSampleSimulationPhase=horizontalSampleSimulationPhase, saveSimulatedNode=saveSimulatedNode, multithread=multithread, highValueState=highValueState) end end end # stop if the early stop condition is met if typeof(earlystop) <: Function && earlystop(node.state) break end end # select the best next state and the best terminal state along the best trajectory bestNextState = selectBestNextNode(root) bestTerminalState = selectBestTrajectoryNode(root) # take all high value state from highValueState channel and put it in a list highValueStateList = Vector{Dict{Symbol, Any}}() while !isempty(highValueState) push!(highValueStateList, take!(highValueState)) end result = ( root=root, bestNextState=bestNextState.state, bestTerminalState=bestTerminalState.state, highValueStateList=highValueStateList ) return result end """ Search the best action to take for a given state and task # Arguments - `node::MCTSNode` current node to simulate from - `transition::Function` a function that defines how the state transitions - `transitionargs::NamedTuple` arguments for transition function # Keyword Arguments - `maxSimulationDepth::Integer` a number of levels MCTS goes during simulation phase (default: 3) - `horizontalSampleSimulationPhase::Integer` a number of child states MCTS samples at each node during simulation phase (default: 3) - `saveSimulatedNode::Bool` whether to save nodes created during simulation phase (default: false) - `multithread::Bool` whether to use multithreading during simulation (default: false) # Returns Nothing, but updates the node's reward and visit count through backpropagation """ function simulateThenBackpropagate(node::MCTSNode, transition::Function, transitionargs::NamedTuple; maxSimulationDepth::Integer=3, horizontalSampleSimulationPhase::Integer=3, saveSimulatedNode::Bool=false, multithread=false, highValueState=Union{Nothing,Any}=nothing) simTrajectoryReward, terminalstate = simulate(node, transition, transitionargs; maxSimulationDepth=maxSimulationDepth, horizontalSample=horizontalSampleSimulationPhase, multithread=multithread) # if a node has state value >= 8, store it in highValueState if highValueState !== nothing && terminalstate !== nothing && terminalstate[:reward] >= 8 put!(highValueState, deepcopy(terminalstate)) end backpropagate(node, simTrajectoryReward) # check if the user wants to keep the simulated node if saveSimulatedNode == false node.children = Dict{String, MCTSNode}() end end end # module interface