Files
LLMMCTS/src/interface.jl
narawat lamaiin 093290a33b update
2025-03-22 21:33:20 +07:00

267 lines
7.0 KiB
Julia

module interface
export runMCTS
using Base.Threads, PrettyPrinting
using ..type, ..mcts, ..util
# ---------------------------------------------- 100 --------------------------------------------- #
""" Search the best action to take for a given state and task
# Arguments
- `initialstate::T`
initial state
- `transition::Function`
a function that define how the state transitions
- `transitionargs::NamedTuple`
arguments for transition function
# Keyword Arguments
- `horizontalSampleExpansionPhase::Integer`
a number of child state MCTS sample at each node during expansion phase (default: 3)
- `horizontalSampleSimulationPhase::Integer`
a number of child state MCTS sample at each node during simulation's expansion phase (default: 3)
- `maxSimulationDepth::Integer`
a number of levels MCTS goes during simulation phase (default: 3)
- `maxiterations::Integer`
a number of iteration MCTS goes thru expansion -> simulation -> backpropagation cycle (default: 10)
- `explorationweight::Number`
exploration weight controls how much MCTS should explore new state instead of exploit
a known state. 1.0 balance between exploration and exploitation like 50%-50%. 2.0 makes MCTS
aggressively explore new state (default: 1.0)
- `earlystop::Union{Function,Nothing}`
optional function to check early stopping condition (default: nothing)
- `saveSimulatedNode::Bool`
whether to save nodes created during simulation phase (default: false)
- `multithread::Bool`
whether to use multithreading during simulation (default: false)
# Returns
- `NamedTuple{(:root, :bestNextState, :bestFinalState), Tuple{MCTSNode, T, T}}`
- root: the complete MCTS tree with root node
- bestNextState: the best immediate next state
- bestFinalState: the best final state along the best trajectory
# Example
Refers to SQLLLM package
# Signature
"""
function runMCTS(
initialstate::T,
transition::Function,
transitionargs::NamedTuple,
;
horizontalSampleExpansionPhase::Integer=3,
horizontalSampleSimulationPhase::Integer=3,
maxSimulationDepth::Integer=3,
maxiterations::Integer=10,
explorationweight::Number=1.0,
earlystop::Union{Function,Nothing}=nothing,
saveSimulatedNode::Bool=false,
multithread=false,
)::NamedTuple{(:root, :bestNextState, :bestTerminalState, :highValueStateList),
Tuple{MCTSNode,T,T,Vector{Dict{Symbol,Any}}}} where {T<:Any}
root = MCTSNode("root", initialstate, 0, 0, 0, 0, false, nothing, Dict{String,MCTSNode}(),
Dict{Symbol,Any}())
# storage for holding all high reward terminal nodes
highValueState = Channel{Any}(100)
for nth in 1:maxiterations
node = root
node.visits += 1
while !isleaf(node)
node = UCTselect(node, explorationweight)
end
if node.isterminal
if node.state[:reward] >= 8
put!(highrewardNode, deepcopy(node.state))
end
# MCTS arrive at the leaf node that is also a terminal state,
# do nothing then go directly to backpropagation. It means the end of this iteration
backpropagate(node, node.reward)
else
_ = expand(node, transition, transitionargs;
horizontalSample=horizontalSampleExpansionPhase,
multithread=multithread)
if multithread
@sync for (leafNodeKey, leafNode) in node.children
@spawn simulateThenBackpropagate(leafNode, transition, transitionargs;
maxSimulationDepth=maxSimulationDepth,
horizontalSampleSimulationPhase=horizontalSampleSimulationPhase,
saveSimulatedNode=saveSimulatedNode,
multithread=multithread,
highValueState=highValueState,
)
end
else
for (leafNodeKey, leafNode) in node.children
simulateThenBackpropagate(leafNode, transition, transitionargs;
maxSimulationDepth=maxSimulationDepth,
horizontalSampleSimulationPhase=horizontalSampleSimulationPhase,
saveSimulatedNode=saveSimulatedNode,
multithread=multithread,
highValueState=highValueState)
end
end
end
# stop if the early stop condition is met
if typeof(earlystop) <: Function && earlystop(node.state)
break
end
end
# select the best next state and the best terminal state along the best trajectory
bestNextState = selectBestNextNode(root)
bestTerminalState = selectBestTrajectoryNode(root)
# take all high value state from highValueState channel and put it in a list
highValueStateList = Vector{Dict{Symbol, Any}}()
while !isempty(highValueState)
push!(highValueStateList, take!(highValueState))
end
result = (
root=root,
bestNextState=bestNextState.state,
bestTerminalState=bestTerminalState.state,
highValueStateList=highValueStateList
)
return result
end
""" Search the best action to take for a given state and task
# Arguments
- `node::MCTSNode`
current node to simulate from
- `transition::Function`
a function that defines how the state transitions
- `transitionargs::NamedTuple`
arguments for transition function
# Keyword Arguments
- `maxSimulationDepth::Integer`
a number of levels MCTS goes during simulation phase (default: 3)
- `horizontalSampleSimulationPhase::Integer`
a number of child states MCTS samples at each node during simulation phase (default: 3)
- `saveSimulatedNode::Bool`
whether to save nodes created during simulation phase (default: false)
- `multithread::Bool`
whether to use multithreading during simulation (default: false)
# Returns
Nothing, but updates the node's reward and visit count through backpropagation
"""
function simulateThenBackpropagate(node::MCTSNode, transition::Function, transitionargs::NamedTuple;
maxSimulationDepth::Integer=3, horizontalSampleSimulationPhase::Integer=3,
saveSimulatedNode::Bool=false,
multithread=false,
highValueState=Union{Nothing,Any}=nothing)
simTrajectoryReward, terminalstate =
simulate(node, transition, transitionargs;
maxSimulationDepth=maxSimulationDepth,
horizontalSample=horizontalSampleSimulationPhase,
multithread=multithread)
# if a node has state value >= 8, store it in highValueState
if highValueState !== nothing &&
terminalstate !== nothing &&
terminalstate[:reward] >= 8
put!(highValueState, deepcopy(terminalstate))
end
backpropagate(node, simTrajectoryReward)
# check if the user wants to keep the simulated node
if saveSimulatedNode == false
node.children = Dict{String, MCTSNode}()
end
end
end # module interface