267 lines
7.0 KiB
Julia
267 lines
7.0 KiB
Julia
module interface
|
|
|
|
export runMCTS
|
|
|
|
using Base.Threads, PrettyPrinting
|
|
using ..type, ..mcts, ..util
|
|
|
|
|
|
# ---------------------------------------------- 100 --------------------------------------------- #
|
|
|
|
|
|
""" Search the best action to take for a given state and task
|
|
|
|
# Arguments
|
|
- `initialstate::T`
|
|
initial state
|
|
- `transition::Function`
|
|
a function that define how the state transitions
|
|
- `transitionargs::NamedTuple`
|
|
arguments for transition function
|
|
|
|
# Keyword Arguments
|
|
- `horizontalSampleExpansionPhase::Integer`
|
|
a number of child state MCTS sample at each node during expansion phase (default: 3)
|
|
- `horizontalSampleSimulationPhase::Integer`
|
|
a number of child state MCTS sample at each node during simulation's expansion phase (default: 3)
|
|
- `maxSimulationDepth::Integer`
|
|
a number of levels MCTS goes during simulation phase (default: 3)
|
|
- `maxiterations::Integer`
|
|
a number of iteration MCTS goes thru expansion -> simulation -> backpropagation cycle (default: 10)
|
|
- `explorationweight::Number`
|
|
exploration weight controls how much MCTS should explore new state instead of exploit
|
|
a known state. 1.0 balance between exploration and exploitation like 50%-50%. 2.0 makes MCTS
|
|
aggressively explore new state (default: 1.0)
|
|
- `earlystop::Union{Function,Nothing}`
|
|
optional function to check early stopping condition (default: nothing)
|
|
- `saveSimulatedNode::Bool`
|
|
whether to save nodes created during simulation phase (default: false)
|
|
- `multithread::Bool`
|
|
whether to use multithreading during simulation (default: false)
|
|
|
|
# Returns
|
|
- `NamedTuple{(:root, :bestNextState, :bestFinalState), Tuple{MCTSNode, T, T}}`
|
|
- root: the complete MCTS tree with root node
|
|
- bestNextState: the best immediate next state
|
|
- bestFinalState: the best final state along the best trajectory
|
|
|
|
# Example
|
|
Refers to SQLLLM package
|
|
|
|
# Signature
|
|
"""
|
|
function runMCTS(
|
|
initialstate::T,
|
|
transition::Function,
|
|
transitionargs::NamedTuple,
|
|
;
|
|
horizontalSampleExpansionPhase::Integer=3,
|
|
horizontalSampleSimulationPhase::Integer=3,
|
|
maxSimulationDepth::Integer=3,
|
|
maxiterations::Integer=10,
|
|
explorationweight::Number=1.0,
|
|
earlystop::Union{Function,Nothing}=nothing,
|
|
saveSimulatedNode::Bool=false,
|
|
multithread=false,
|
|
)::NamedTuple{(:root, :bestNextState, :bestTerminalState, :highValueStateList),
|
|
Tuple{MCTSNode,T,T,Vector{Dict{Symbol,Any}}}} where {T<:Any}
|
|
|
|
root = MCTSNode("root", initialstate, 0, 0, 0, 0, false, nothing, Dict{String,MCTSNode}(),
|
|
Dict{Symbol,Any}())
|
|
|
|
# storage for holding all high reward terminal nodes
|
|
highValueState = Channel{Any}(100)
|
|
|
|
for nth in 1:maxiterations
|
|
node = root
|
|
node.visits += 1
|
|
|
|
while !isleaf(node)
|
|
node = UCTselect(node, explorationweight)
|
|
end
|
|
|
|
if node.isterminal
|
|
if node.state[:reward] >= 8
|
|
put!(highrewardNode, deepcopy(node.state))
|
|
end
|
|
|
|
# MCTS arrive at the leaf node that is also a terminal state,
|
|
# do nothing then go directly to backpropagation. It means the end of this iteration
|
|
backpropagate(node, node.reward)
|
|
else
|
|
_ = expand(node, transition, transitionargs;
|
|
horizontalSample=horizontalSampleExpansionPhase,
|
|
multithread=multithread)
|
|
if multithread
|
|
@sync for (leafNodeKey, leafNode) in node.children
|
|
@spawn simulateThenBackpropagate(leafNode, transition, transitionargs;
|
|
maxSimulationDepth=maxSimulationDepth,
|
|
horizontalSampleSimulationPhase=horizontalSampleSimulationPhase,
|
|
saveSimulatedNode=saveSimulatedNode,
|
|
multithread=multithread,
|
|
highValueState=highValueState,
|
|
)
|
|
end
|
|
else
|
|
for (leafNodeKey, leafNode) in node.children
|
|
simulateThenBackpropagate(leafNode, transition, transitionargs;
|
|
maxSimulationDepth=maxSimulationDepth,
|
|
horizontalSampleSimulationPhase=horizontalSampleSimulationPhase,
|
|
saveSimulatedNode=saveSimulatedNode,
|
|
multithread=multithread,
|
|
highValueState=highValueState)
|
|
end
|
|
end
|
|
end
|
|
|
|
# stop if the early stop condition is met
|
|
if typeof(earlystop) <: Function && earlystop(node.state)
|
|
break
|
|
end
|
|
end
|
|
|
|
# select the best next state and the best terminal state along the best trajectory
|
|
bestNextState = selectBestNextNode(root)
|
|
bestTerminalState = selectBestTrajectoryNode(root)
|
|
|
|
# take all high value state from highValueState channel and put it in a list
|
|
highValueStateList = Vector{Dict{Symbol, Any}}()
|
|
while !isempty(highValueState)
|
|
push!(highValueStateList, take!(highValueState))
|
|
end
|
|
|
|
result = (
|
|
root=root,
|
|
bestNextState=bestNextState.state,
|
|
bestTerminalState=bestTerminalState.state,
|
|
highValueStateList=highValueStateList
|
|
)
|
|
|
|
return result
|
|
end
|
|
|
|
""" Search the best action to take for a given state and task
|
|
|
|
# Arguments
|
|
- `node::MCTSNode`
|
|
current node to simulate from
|
|
- `transition::Function`
|
|
a function that defines how the state transitions
|
|
- `transitionargs::NamedTuple`
|
|
arguments for transition function
|
|
|
|
# Keyword Arguments
|
|
- `maxSimulationDepth::Integer`
|
|
a number of levels MCTS goes during simulation phase (default: 3)
|
|
- `horizontalSampleSimulationPhase::Integer`
|
|
a number of child states MCTS samples at each node during simulation phase (default: 3)
|
|
- `saveSimulatedNode::Bool`
|
|
whether to save nodes created during simulation phase (default: false)
|
|
- `multithread::Bool`
|
|
whether to use multithreading during simulation (default: false)
|
|
|
|
# Returns
|
|
Nothing, but updates the node's reward and visit count through backpropagation
|
|
"""
|
|
function simulateThenBackpropagate(node::MCTSNode, transition::Function, transitionargs::NamedTuple;
|
|
maxSimulationDepth::Integer=3, horizontalSampleSimulationPhase::Integer=3,
|
|
saveSimulatedNode::Bool=false,
|
|
multithread=false,
|
|
highValueState=Union{Nothing,Any}=nothing)
|
|
simTrajectoryReward, terminalstate =
|
|
simulate(node, transition, transitionargs;
|
|
maxSimulationDepth=maxSimulationDepth,
|
|
horizontalSample=horizontalSampleSimulationPhase,
|
|
multithread=multithread)
|
|
# if a node has state value >= 8, store it in highValueState
|
|
if highValueState !== nothing &&
|
|
terminalstate !== nothing &&
|
|
terminalstate[:reward] >= 8
|
|
|
|
put!(highValueState, deepcopy(terminalstate))
|
|
end
|
|
|
|
backpropagate(node, simTrajectoryReward)
|
|
|
|
# check if the user wants to keep the simulated node
|
|
if saveSimulatedNode == false
|
|
node.children = Dict{String, MCTSNode}()
|
|
end
|
|
end
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
end # module interface |