update
This commit is contained in:
454
src/interface.jl
454
src/interface.jl
@@ -1,228 +1,228 @@
|
||||
module interface
|
||||
|
||||
export runMCTS
|
||||
|
||||
using ..type, ..mcts, ..util
|
||||
|
||||
|
||||
# ---------------------------------------------- 100 --------------------------------------------- #
|
||||
|
||||
|
||||
|
||||
""" Search the best action to take for a given state and task
|
||||
|
||||
# Arguments
|
||||
- `initialstate::T`
|
||||
initial state
|
||||
- `transition::Function`
|
||||
a function that define how the state transitions
|
||||
- `transitionargs::NamedTuple`
|
||||
arguments for transition function
|
||||
|
||||
# Keyword Arguments
|
||||
- `totalsample::Integer`
|
||||
a number of child state MCTS sample at each node during expansion phase
|
||||
- `maxdepth::Integer`
|
||||
a number of levels MCTS goes during simulation phase
|
||||
- `maxiterations::Integer`
|
||||
a number of iteration MCTS goes thru expansion -> simulation -> backpropagation cycle
|
||||
- `explorationweight::Number`
|
||||
exploration weight controls how much MCTS should explore new state instead of exploit
|
||||
a known state. 1.0 balance between exploration and exploitation like 50%-50%. 2.0 makes MCTS
|
||||
aggressively explore new state.
|
||||
|
||||
# Return
|
||||
- `NamedTuple{(:bestNextState, :bestFinalState), Tuple{T, T}}`
|
||||
the best next state and the best final state
|
||||
|
||||
# Example
|
||||
Refers to SQLLLM package
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function runMCTS(
|
||||
initialstate::T,
|
||||
transition::Function,
|
||||
transitionargs::NamedTuple,
|
||||
;
|
||||
totalsample::Integer=3,
|
||||
maxdepth::Integer=3,
|
||||
maxiterations::Integer=10,
|
||||
explorationweight::Number=1.0,
|
||||
earlystop::Union{Function,Nothing}=nothing
|
||||
)::NamedTuple{(:bestNextState, :bestFinalState),Tuple{T,T}} where {T<:Any}
|
||||
|
||||
root = MCTSNode("root", initialstate, 0, 0, 0, 0, false, nothing, Dict{String,MCTSNode}())
|
||||
|
||||
for nth in 1:maxiterations
|
||||
node = root
|
||||
node.visits += 1
|
||||
|
||||
while !isleaf(node)
|
||||
node = UCTselect(node, explorationweight)
|
||||
end
|
||||
|
||||
if node.isterminal
|
||||
# MCTS arrive at the leaf node that is also a terminal state,
|
||||
# do nothing then go directly to backpropagation. It means the end of this iteration
|
||||
backpropagate(node, node.reward)
|
||||
else
|
||||
expand(node, transition, transitionargs;
|
||||
totalsample=totalsample)
|
||||
leafNode = selectChildNode(node)
|
||||
simTrajectoryReward, terminalstate = simulate(leafNode, transition, transitionargs;
|
||||
maxdepth=maxdepth, totalsample=totalsample)
|
||||
# if terminalstate !== nothing #XXX not sure why I need this
|
||||
# terminalstate[:totalTrajectoryReward] = simTrajectoryReward
|
||||
# end
|
||||
|
||||
#[] write best state to file if it has higher simTrajectoryReward. Use to improve evaluation
|
||||
# open("trajectory.json", "w") do io
|
||||
# JSON3.pretty(io, terminalstate)
|
||||
# end
|
||||
|
||||
backpropagate(leafNode, simTrajectoryReward)
|
||||
end
|
||||
|
||||
# stop if the early stop condition is met
|
||||
if typeof(earlystop) <: Function && earlystop(node.state)
|
||||
break
|
||||
end
|
||||
end
|
||||
|
||||
bestNextState = selectBestNextNode(root)
|
||||
besttrajectory = selectBestTrajectoryNode(root)
|
||||
|
||||
return (bestNextState=bestNextState.state, bestFinalState=besttrajectory.state)
|
||||
end
|
||||
|
||||
|
||||
# function runMCTS(
|
||||
# initialstate::T,
|
||||
# transition::Function,
|
||||
# transitionargs::NamedTuple,
|
||||
# ;
|
||||
# totalsample::Integer=3,
|
||||
# maxdepth::Integer=3,
|
||||
# maxiterations::Integer=10,
|
||||
# explorationweight::Number=1.0,
|
||||
# )::NamedTuple{(:bestNextState, :bestFinalState),Tuple{T,T}} where {T<:Any}
|
||||
|
||||
# root = MCTSNode("root", initialstate, 0, 0, 0, 0, false, nothing, Dict{String,MCTSNode}())
|
||||
|
||||
# for nth in 1:maxiterations
|
||||
# node = root
|
||||
# node.visits += 1
|
||||
|
||||
# while !isleaf(node)
|
||||
# node = UCTselect(node, explorationweight)
|
||||
# end
|
||||
# if node.isterminal
|
||||
# # MCTS arrive at the leaf node that is also a terminal state,
|
||||
# # do nothing then go directly to backpropagation. It means the end of this iteration
|
||||
# backpropagate(leafNode, node.reward)
|
||||
# else
|
||||
# expand(node, transition, transitionargs;
|
||||
# totalsample=totalsample)
|
||||
# leafNode = selectChildNode(node)
|
||||
# simTrajectoryReward, terminalstate = simulate(leafNode, transition, transitionargs;
|
||||
# maxdepth=maxdepth, totalsample=totalsample)
|
||||
# # if terminalstate !== nothing #XXX not sure why I need this
|
||||
# # terminalstate[:totalTrajectoryReward] = simTrajectoryReward
|
||||
# # end
|
||||
|
||||
# #[] write best state to file if it has higher simTrajectoryReward. Use to improve evaluation
|
||||
# # open("trajectory.json", "w") do io
|
||||
# # JSON3.pretty(io, terminalstate)
|
||||
# # end
|
||||
|
||||
# backpropagate(leafNode, simTrajectoryReward)
|
||||
# end
|
||||
# end
|
||||
|
||||
# bestNextState = selectBestNextNode(root)
|
||||
# besttrajectory = selectBestTrajectoryNode(root)
|
||||
|
||||
# return (bestNextState=bestNextState.state, bestFinalState=besttrajectory.state)
|
||||
# end
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
module interface
|
||||
|
||||
export runMCTS
|
||||
|
||||
using ..type, ..mcts, ..util
|
||||
|
||||
|
||||
# ---------------------------------------------- 100 --------------------------------------------- #
|
||||
|
||||
|
||||
|
||||
""" Search the best action to take for a given state and task
|
||||
|
||||
# Arguments
|
||||
- `initialstate::T`
|
||||
initial state
|
||||
- `transition::Function`
|
||||
a function that define how the state transitions
|
||||
- `transitionargs::NamedTuple`
|
||||
arguments for transition function
|
||||
|
||||
# Keyword Arguments
|
||||
- `totalsample::Integer`
|
||||
a number of child state MCTS sample at each node during expansion phase
|
||||
- `maxdepth::Integer`
|
||||
a number of levels MCTS goes during simulation phase
|
||||
- `maxiterations::Integer`
|
||||
a number of iteration MCTS goes thru expansion -> simulation -> backpropagation cycle
|
||||
- `explorationweight::Number`
|
||||
exploration weight controls how much MCTS should explore new state instead of exploit
|
||||
a known state. 1.0 balance between exploration and exploitation like 50%-50%. 2.0 makes MCTS
|
||||
aggressively explore new state.
|
||||
|
||||
# Return
|
||||
- `NamedTuple{(:bestNextState, :bestFinalState), Tuple{T, T}}`
|
||||
the best next state and the best final state
|
||||
|
||||
# Example
|
||||
Refers to SQLLLM package
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function runMCTS(
|
||||
initialstate::T,
|
||||
transition::Function,
|
||||
transitionargs::NamedTuple,
|
||||
;
|
||||
totalsample::Integer=3,
|
||||
maxdepth::Integer=3,
|
||||
maxiterations::Integer=10,
|
||||
explorationweight::Number=1.0,
|
||||
earlystop::Union{Function,Nothing}=nothing
|
||||
)::NamedTuple{(:bestNextState, :bestFinalState),Tuple{T,T}} where {T<:Any}
|
||||
|
||||
root = MCTSNode("root", initialstate, 0, 0, 0, 0, false, nothing, Dict{String,MCTSNode}())
|
||||
|
||||
for nth in 1:maxiterations
|
||||
node = root
|
||||
node.visits += 1
|
||||
|
||||
while !isleaf(node)
|
||||
node = UCTselect(node, explorationweight)
|
||||
end
|
||||
|
||||
if node.isterminal
|
||||
# MCTS arrive at the leaf node that is also a terminal state,
|
||||
# do nothing then go directly to backpropagation. It means the end of this iteration
|
||||
backpropagate(node, node.reward)
|
||||
else
|
||||
expand(node, transition, transitionargs;
|
||||
totalsample=totalsample)
|
||||
leafNode = selectChildNode(node)
|
||||
simTrajectoryReward, terminalstate = simulate(leafNode, transition, transitionargs;
|
||||
maxdepth=maxdepth, totalsample=totalsample)
|
||||
# if terminalstate !== nothing #XXX not sure why I need this
|
||||
# terminalstate[:totalTrajectoryReward] = simTrajectoryReward
|
||||
# end
|
||||
|
||||
#[] write best state to file if it has higher simTrajectoryReward. Use to improve evaluation
|
||||
# open("trajectory.json", "w") do io
|
||||
# JSON3.pretty(io, terminalstate)
|
||||
# end
|
||||
|
||||
backpropagate(leafNode, simTrajectoryReward)
|
||||
end
|
||||
|
||||
# stop if the early stop condition is met
|
||||
if typeof(earlystop) <: Function && earlystop(node.state)
|
||||
break
|
||||
end
|
||||
end
|
||||
|
||||
bestNextState = selectBestNextNode(root)
|
||||
besttrajectory = selectBestTrajectoryNode(root)
|
||||
|
||||
return (bestNextState=bestNextState.state, bestFinalState=besttrajectory.state)
|
||||
end
|
||||
|
||||
|
||||
# function runMCTS(
|
||||
# initialstate::T,
|
||||
# transition::Function,
|
||||
# transitionargs::NamedTuple,
|
||||
# ;
|
||||
# totalsample::Integer=3,
|
||||
# maxdepth::Integer=3,
|
||||
# maxiterations::Integer=10,
|
||||
# explorationweight::Number=1.0,
|
||||
# )::NamedTuple{(:bestNextState, :bestFinalState),Tuple{T,T}} where {T<:Any}
|
||||
|
||||
# root = MCTSNode("root", initialstate, 0, 0, 0, 0, false, nothing, Dict{String,MCTSNode}())
|
||||
|
||||
# for nth in 1:maxiterations
|
||||
# node = root
|
||||
# node.visits += 1
|
||||
|
||||
# while !isleaf(node)
|
||||
# node = UCTselect(node, explorationweight)
|
||||
# end
|
||||
# if node.isterminal
|
||||
# # MCTS arrive at the leaf node that is also a terminal state,
|
||||
# # do nothing then go directly to backpropagation. It means the end of this iteration
|
||||
# backpropagate(leafNode, node.reward)
|
||||
# else
|
||||
# expand(node, transition, transitionargs;
|
||||
# totalsample=totalsample)
|
||||
# leafNode = selectChildNode(node)
|
||||
# simTrajectoryReward, terminalstate = simulate(leafNode, transition, transitionargs;
|
||||
# maxdepth=maxdepth, totalsample=totalsample)
|
||||
# # if terminalstate !== nothing #XXX not sure why I need this
|
||||
# # terminalstate[:totalTrajectoryReward] = simTrajectoryReward
|
||||
# # end
|
||||
|
||||
# #[] write best state to file if it has higher simTrajectoryReward. Use to improve evaluation
|
||||
# # open("trajectory.json", "w") do io
|
||||
# # JSON3.pretty(io, terminalstate)
|
||||
# # end
|
||||
|
||||
# backpropagate(leafNode, simTrajectoryReward)
|
||||
# end
|
||||
# end
|
||||
|
||||
# bestNextState = selectBestNextNode(root)
|
||||
# besttrajectory = selectBestTrajectoryNode(root)
|
||||
|
||||
# return (bestNextState=bestNextState.state, bestFinalState=besttrajectory.state)
|
||||
# end
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
end # module interface
|
||||
Reference in New Issue
Block a user