This commit is contained in:
2025-03-15 08:28:13 +07:00
parent 2eff443f70
commit b2c53ffa45
4 changed files with 94 additions and 52 deletions

View File

@@ -26,27 +26,23 @@ using ..type, ..mcts, ..util
- `horizontalSampleSimulationPhase::Integer`
a number of child state MCTS sample at each node during simulation's expansion phase (default: 3)
- `maxSimulationDepth::Integer`
a number of levels MCTS goes during simulation phase (default: 3)
- `maxiterations::Integer`
a number of iteration MCTS goes thru expansion -> simulation -> backpropagation cycle (default: 10)
- `explorationweight::Number`
exploration weight controls how much MCTS should explore new state instead of exploit
a known state. 1.0 balance between exploration and exploitation like 50%-50%. 2.0 makes MCTS
aggressively explore new state (default: 1.0)
- `earlystop::Union{Function,Nothing}`
optional function to check early stopping condition (default: nothing)
- `saveSimulatedNode::Bool`
whether to save nodes created during simulation phase (default: false)
- `multithread::Bool`
whether to use multithreading during simulation (default: false)
# Returns
- `NamedTuple{(:mctstree, :bestNextState, :bestFinalState), Tuple{MCTSNode, T, T}}`
- mctstree: the complete MCTS tree with root node
- `NamedTuple{(:root, :bestNextState, :bestFinalState), Tuple{MCTSNode, T, T}}`
- root: the complete MCTS tree with root node
- bestNextState: the best immediate next state
- bestFinalState: the best final state along the best trajectory
@@ -67,8 +63,8 @@ function runMCTS(
explorationweight::Number=1.0,
earlystop::Union{Function,Nothing}=nothing,
saveSimulatedNode::Bool=false,
multithread=false) where {T<:Any}
# )::NamedTuple{(:bestNextState, :bestFinalState),Tuple{T,T}} where {T<:Any}
multithread=false
)::NamedTuple{(:root, :bestNextState, :bestFinalState),Tuple{MCTSNode,T,T}} where {T<:Any}
root = MCTSNode("root", initialstate, 0, 0, 0, 0, false, nothing, Dict{String,MCTSNode}(),
Dict{Symbol,Any}())
@@ -121,7 +117,29 @@ function runMCTS(
return (root=root, bestNextState=bestNextState.state, bestFinalState=besttrajectory.state)
end
""" Search the best action to take for a given state and task
# Arguments
- `node::MCTSNode`
current node to simulate from
- `transition::Function`
a function that defines how the state transitions
- `transitionargs::NamedTuple`
arguments for transition function
# Keyword Arguments
- `maxSimulationDepth::Integer`
a number of levels MCTS goes during simulation phase (default: 3)
- `horizontalSampleSimulationPhase::Integer`
a number of child states MCTS samples at each node during simulation phase (default: 3)
- `saveSimulatedNode::Bool`
whether to save nodes created during simulation phase (default: false)
- `multithread::Bool`
whether to use multithreading during simulation (default: false)
# Returns
Nothing, but updates the node's reward and visit count through backpropagation
"""
function simulateThenBackpropagate(node::MCTSNode, transition::Function, transitionargs::NamedTuple;
maxSimulationDepth::Integer=3, horizontalSampleSimulationPhase::Integer=3,
saveSimulatedNode::Bool=false,
@@ -209,9 +227,6 @@ end