This commit is contained in:
narawat lamaiin
2024-12-09 20:28:02 +07:00
parent 3338085567
commit cae94e5690
15 changed files with 2942 additions and 2942 deletions

View File

@@ -1,228 +1,228 @@
module interface
export runMCTS
using ..type, ..mcts, ..util
# ---------------------------------------------- 100 --------------------------------------------- #
""" Search the best action to take for a given state and task
# Arguments
- `initialstate::T`
initial state
- `transition::Function`
a function that define how the state transitions
- `transitionargs::NamedTuple`
arguments for transition function
# Keyword Arguments
- `totalsample::Integer`
a number of child state MCTS sample at each node during expansion phase
- `maxdepth::Integer`
a number of levels MCTS goes during simulation phase
- `maxiterations::Integer`
a number of iteration MCTS goes thru expansion -> simulation -> backpropagation cycle
- `explorationweight::Number`
exploration weight controls how much MCTS should explore new state instead of exploit
a known state. 1.0 balance between exploration and exploitation like 50%-50%. 2.0 makes MCTS
aggressively explore new state.
# Return
- `NamedTuple{(:bestNextState, :bestFinalState), Tuple{T, T}}`
the best next state and the best final state
# Example
Refers to SQLLLM package
# Signature
"""
function runMCTS(
initialstate::T,
transition::Function,
transitionargs::NamedTuple,
;
totalsample::Integer=3,
maxdepth::Integer=3,
maxiterations::Integer=10,
explorationweight::Number=1.0,
earlystop::Union{Function,Nothing}=nothing
)::NamedTuple{(:bestNextState, :bestFinalState),Tuple{T,T}} where {T<:Any}
root = MCTSNode("root", initialstate, 0, 0, 0, 0, false, nothing, Dict{String,MCTSNode}())
for nth in 1:maxiterations
node = root
node.visits += 1
while !isleaf(node)
node = UCTselect(node, explorationweight)
end
if node.isterminal
# MCTS arrive at the leaf node that is also a terminal state,
# do nothing then go directly to backpropagation. It means the end of this iteration
backpropagate(node, node.reward)
else
expand(node, transition, transitionargs;
totalsample=totalsample)
leafNode = selectChildNode(node)
simTrajectoryReward, terminalstate = simulate(leafNode, transition, transitionargs;
maxdepth=maxdepth, totalsample=totalsample)
# if terminalstate !== nothing #XXX not sure why I need this
# terminalstate[:totalTrajectoryReward] = simTrajectoryReward
# end
#[] write best state to file if it has higher simTrajectoryReward. Use to improve evaluation
# open("trajectory.json", "w") do io
# JSON3.pretty(io, terminalstate)
# end
backpropagate(leafNode, simTrajectoryReward)
end
# stop if the early stop condition is met
if typeof(earlystop) <: Function && earlystop(node.state)
break
end
end
bestNextState = selectBestNextNode(root)
besttrajectory = selectBestTrajectoryNode(root)
return (bestNextState=bestNextState.state, bestFinalState=besttrajectory.state)
end
# function runMCTS(
# initialstate::T,
# transition::Function,
# transitionargs::NamedTuple,
# ;
# totalsample::Integer=3,
# maxdepth::Integer=3,
# maxiterations::Integer=10,
# explorationweight::Number=1.0,
# )::NamedTuple{(:bestNextState, :bestFinalState),Tuple{T,T}} where {T<:Any}
# root = MCTSNode("root", initialstate, 0, 0, 0, 0, false, nothing, Dict{String,MCTSNode}())
# for nth in 1:maxiterations
# node = root
# node.visits += 1
# while !isleaf(node)
# node = UCTselect(node, explorationweight)
# end
# if node.isterminal
# # MCTS arrive at the leaf node that is also a terminal state,
# # do nothing then go directly to backpropagation. It means the end of this iteration
# backpropagate(leafNode, node.reward)
# else
# expand(node, transition, transitionargs;
# totalsample=totalsample)
# leafNode = selectChildNode(node)
# simTrajectoryReward, terminalstate = simulate(leafNode, transition, transitionargs;
# maxdepth=maxdepth, totalsample=totalsample)
# # if terminalstate !== nothing #XXX not sure why I need this
# # terminalstate[:totalTrajectoryReward] = simTrajectoryReward
# # end
# #[] write best state to file if it has higher simTrajectoryReward. Use to improve evaluation
# # open("trajectory.json", "w") do io
# # JSON3.pretty(io, terminalstate)
# # end
# backpropagate(leafNode, simTrajectoryReward)
# end
# end
# bestNextState = selectBestNextNode(root)
# besttrajectory = selectBestTrajectoryNode(root)
# return (bestNextState=bestNextState.state, bestFinalState=besttrajectory.state)
# end
module interface
export runMCTS
using ..type, ..mcts, ..util
# ---------------------------------------------- 100 --------------------------------------------- #
""" Search the best action to take for a given state and task
# Arguments
- `initialstate::T`
initial state
- `transition::Function`
a function that define how the state transitions
- `transitionargs::NamedTuple`
arguments for transition function
# Keyword Arguments
- `totalsample::Integer`
a number of child state MCTS sample at each node during expansion phase
- `maxdepth::Integer`
a number of levels MCTS goes during simulation phase
- `maxiterations::Integer`
a number of iteration MCTS goes thru expansion -> simulation -> backpropagation cycle
- `explorationweight::Number`
exploration weight controls how much MCTS should explore new state instead of exploit
a known state. 1.0 balance between exploration and exploitation like 50%-50%. 2.0 makes MCTS
aggressively explore new state.
# Return
- `NamedTuple{(:bestNextState, :bestFinalState), Tuple{T, T}}`
the best next state and the best final state
# Example
Refers to SQLLLM package
# Signature
"""
function runMCTS(
initialstate::T,
transition::Function,
transitionargs::NamedTuple,
;
totalsample::Integer=3,
maxdepth::Integer=3,
maxiterations::Integer=10,
explorationweight::Number=1.0,
earlystop::Union{Function,Nothing}=nothing
)::NamedTuple{(:bestNextState, :bestFinalState),Tuple{T,T}} where {T<:Any}
root = MCTSNode("root", initialstate, 0, 0, 0, 0, false, nothing, Dict{String,MCTSNode}())
for nth in 1:maxiterations
node = root
node.visits += 1
while !isleaf(node)
node = UCTselect(node, explorationweight)
end
if node.isterminal
# MCTS arrive at the leaf node that is also a terminal state,
# do nothing then go directly to backpropagation. It means the end of this iteration
backpropagate(node, node.reward)
else
expand(node, transition, transitionargs;
totalsample=totalsample)
leafNode = selectChildNode(node)
simTrajectoryReward, terminalstate = simulate(leafNode, transition, transitionargs;
maxdepth=maxdepth, totalsample=totalsample)
# if terminalstate !== nothing #XXX not sure why I need this
# terminalstate[:totalTrajectoryReward] = simTrajectoryReward
# end
#[] write best state to file if it has higher simTrajectoryReward. Use to improve evaluation
# open("trajectory.json", "w") do io
# JSON3.pretty(io, terminalstate)
# end
backpropagate(leafNode, simTrajectoryReward)
end
# stop if the early stop condition is met
if typeof(earlystop) <: Function && earlystop(node.state)
break
end
end
bestNextState = selectBestNextNode(root)
besttrajectory = selectBestTrajectoryNode(root)
return (bestNextState=bestNextState.state, bestFinalState=besttrajectory.state)
end
# function runMCTS(
# initialstate::T,
# transition::Function,
# transitionargs::NamedTuple,
# ;
# totalsample::Integer=3,
# maxdepth::Integer=3,
# maxiterations::Integer=10,
# explorationweight::Number=1.0,
# )::NamedTuple{(:bestNextState, :bestFinalState),Tuple{T,T}} where {T<:Any}
# root = MCTSNode("root", initialstate, 0, 0, 0, 0, false, nothing, Dict{String,MCTSNode}())
# for nth in 1:maxiterations
# node = root
# node.visits += 1
# while !isleaf(node)
# node = UCTselect(node, explorationweight)
# end
# if node.isterminal
# # MCTS arrive at the leaf node that is also a terminal state,
# # do nothing then go directly to backpropagation. It means the end of this iteration
# backpropagate(leafNode, node.reward)
# else
# expand(node, transition, transitionargs;
# totalsample=totalsample)
# leafNode = selectChildNode(node)
# simTrajectoryReward, terminalstate = simulate(leafNode, transition, transitionargs;
# maxdepth=maxdepth, totalsample=totalsample)
# # if terminalstate !== nothing #XXX not sure why I need this
# # terminalstate[:totalTrajectoryReward] = simTrajectoryReward
# # end
# #[] write best state to file if it has higher simTrajectoryReward. Use to improve evaluation
# # open("trajectory.json", "w") do io
# # JSON3.pretty(io, terminalstate)
# # end
# backpropagate(leafNode, simTrajectoryReward)
# end
# end
# bestNextState = selectBestNextNode(root)
# besttrajectory = selectBestTrajectoryNode(root)
# return (bestNextState=bestNextState.state, bestFinalState=besttrajectory.state)
# end
end # module interface