update
This commit is contained in:
@@ -2,7 +2,7 @@ module interface
|
||||
|
||||
export runMCTS
|
||||
|
||||
using Base.Threads
|
||||
using Base.Threads, PrettyPrinting
|
||||
using ..type, ..mcts, ..util
|
||||
|
||||
|
||||
@@ -21,20 +21,34 @@ using ..type, ..mcts, ..util
|
||||
arguments for transition function
|
||||
|
||||
# Keyword Arguments
|
||||
- `totalsample::Integer`
|
||||
a number of child state MCTS sample at each node during expansion phase
|
||||
- `maxdepth::Integer`
|
||||
a number of levels MCTS goes during simulation phase
|
||||
- `horizontalSampleExpansionPhase::Integer`
|
||||
a number of child state MCTS sample at each node during expansion phase (default: 3)
|
||||
- `horizontalSampleSimulationPhase::Integer`
|
||||
a number of child state MCTS sample at each node during simulation's expansion phase (default: 3)
|
||||
- `maxSimulationDepth::Integer`
|
||||
|
||||
a number of levels MCTS goes during simulation phase (default: 3)
|
||||
- `maxiterations::Integer`
|
||||
a number of iteration MCTS goes thru expansion -> simulation -> backpropagation cycle
|
||||
|
||||
a number of iteration MCTS goes thru expansion -> simulation -> backpropagation cycle (default: 10)
|
||||
- `explorationweight::Number`
|
||||
exploration weight controls how much MCTS should explore new state instead of exploit
|
||||
a known state. 1.0 balance between exploration and exploitation like 50%-50%. 2.0 makes MCTS
|
||||
aggressively explore new state.
|
||||
|
||||
# Return
|
||||
- `NamedTuple{(:bestNextState, :bestFinalState), Tuple{T, T}}`
|
||||
the best next state and the best final state
|
||||
aggressively explore new state (default: 1.0)
|
||||
- `earlystop::Union{Function,Nothing}`
|
||||
optional function to check early stopping condition (default: nothing)
|
||||
- `saveSimulatedNode::Bool`
|
||||
whether to save nodes created during simulation phase (default: false)
|
||||
|
||||
|
||||
|
||||
|
||||
# Returns
|
||||
- `NamedTuple{(:mctstree, :bestNextState, :bestFinalState), Tuple{MCTSNode, T, T}}`
|
||||
- mctstree: the complete MCTS tree with root node
|
||||
- bestNextState: the best immediate next state
|
||||
- bestFinalState: the best final state along the best trajectory
|
||||
|
||||
# Example
|
||||
Refers to SQLLLM package
|
||||
@@ -48,11 +62,12 @@ function runMCTS(
|
||||
;
|
||||
horizontalSampleExpansionPhase::Integer=3,
|
||||
horizontalSampleSimulationPhase::Integer=3,
|
||||
maxdepth::Integer=3,
|
||||
maxSimulationDepth::Integer=3,
|
||||
maxiterations::Integer=10,
|
||||
explorationweight::Number=1.0,
|
||||
earlystop::Union{Function,Nothing}=nothing
|
||||
)::NamedTuple{(:bestNextState, :bestFinalState),Tuple{T,T}} where {T<:Any}
|
||||
earlystop::Union{Function,Nothing}=nothing,
|
||||
saveSimulatedNode::Bool=false) where {T<:Any}
|
||||
# )::NamedTuple{(:bestNextState, :bestFinalState),Tuple{T,T}} where {T<:Any}
|
||||
|
||||
root = MCTSNode("root", initialstate, 0, 0, 0, 0, false, nothing, Dict{String,MCTSNode}(),
|
||||
Dict{Symbol,Any}())
|
||||
@@ -72,31 +87,13 @@ function runMCTS(
|
||||
else
|
||||
_ = expand(node, transition, transitionargs;
|
||||
horizontalSample=horizontalSampleExpansionPhase)
|
||||
#[WORKING] make simulation parallel, leafNodes must be newly expanded nodes
|
||||
|
||||
leafNode = selectChildNode(node)
|
||||
|
||||
|
||||
|
||||
|
||||
# outputch = Channel(8)
|
||||
|
||||
#[WORKING] test whether multiple spawn retain result leafNode's child node
|
||||
|
||||
@spawn simulate(outputch, leafNode, transition, transitionargs;
|
||||
maxdepth=maxdepth, horizontalSample=horizontalSampleSimulationPhase)
|
||||
# if terminalstate !== nothing #XXX not sure why I need this
|
||||
# terminalstate[:totalTrajectoryReward] = simTrajectoryReward
|
||||
# end
|
||||
|
||||
#[] write best state to file if it has higher simTrajectoryReward. Use to improve evaluation
|
||||
# open("trajectory.json", "w") do io
|
||||
# JSON3.pretty(io, terminalstate)
|
||||
# end
|
||||
|
||||
# result = take!(outputch)
|
||||
|
||||
backpropagate(leafNode, simTrajectoryReward)
|
||||
@sync for (leafNodeKey, leafNode) in node.children
|
||||
@spawn simulateThenBackpropagate(leafNode, transition, transitionargs;
|
||||
maxSimulationDepth=maxSimulationDepth,
|
||||
horizontalSampleSimulationPhase=horizontalSampleSimulationPhase,
|
||||
saveSimulatedNode=saveSimulatedNode)
|
||||
end
|
||||
end
|
||||
|
||||
# stop if the early stop condition is met
|
||||
@@ -105,13 +102,30 @@ function runMCTS(
|
||||
end
|
||||
end
|
||||
|
||||
# select the best next state and the best final state
|
||||
bestNextState = selectBestNextNode(root)
|
||||
besttrajectory = selectBestTrajectoryNode(root)
|
||||
|
||||
return (bestNextState=bestNextState.state, bestFinalState=besttrajectory.state)
|
||||
return (root=root, bestNextState=bestNextState.state, bestFinalState=besttrajectory.state)
|
||||
end
|
||||
|
||||
|
||||
function simulateThenBackpropagate(node::MCTSNode, transition::Function, transitionargs::NamedTuple;
|
||||
maxSimulationDepth::Integer=3, horizontalSampleSimulationPhase::Integer=3,
|
||||
saveSimulatedNode::Bool=false)
|
||||
simTrajectoryReward, terminalstate = simulate(node, transition, transitionargs;
|
||||
maxSimulationDepth=maxSimulationDepth,
|
||||
horizontalSample=horizontalSampleSimulationPhase)
|
||||
backpropagate(node, simTrajectoryReward)
|
||||
|
||||
# check if the user wants to keep the simulated node
|
||||
if saveSimulatedNode == false
|
||||
node.children = Dict{String, MCTSNode}()
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
|
||||
# function runMCTS(
|
||||
# initialstate::T,
|
||||
# transition::Function,
|
||||
|
||||
Reference in New Issue
Block a user