update
This commit is contained in:
@@ -26,27 +26,23 @@ using ..type, ..mcts, ..util
|
||||
- `horizontalSampleSimulationPhase::Integer`
|
||||
a number of child state MCTS sample at each node during simulation's expansion phase (default: 3)
|
||||
- `maxSimulationDepth::Integer`
|
||||
|
||||
a number of levels MCTS goes during simulation phase (default: 3)
|
||||
- `maxiterations::Integer`
|
||||
|
||||
a number of iteration MCTS goes thru expansion -> simulation -> backpropagation cycle (default: 10)
|
||||
- `explorationweight::Number`
|
||||
exploration weight controls how much MCTS should explore new state instead of exploit
|
||||
a known state. 1.0 balance between exploration and exploitation like 50%-50%. 2.0 makes MCTS
|
||||
|
||||
aggressively explore new state (default: 1.0)
|
||||
- `earlystop::Union{Function,Nothing}`
|
||||
optional function to check early stopping condition (default: nothing)
|
||||
- `saveSimulatedNode::Bool`
|
||||
whether to save nodes created during simulation phase (default: false)
|
||||
|
||||
|
||||
|
||||
- `multithread::Bool`
|
||||
whether to use multithreading during simulation (default: false)
|
||||
|
||||
# Returns
|
||||
- `NamedTuple{(:mctstree, :bestNextState, :bestFinalState), Tuple{MCTSNode, T, T}}`
|
||||
- mctstree: the complete MCTS tree with root node
|
||||
- `NamedTuple{(:root, :bestNextState, :bestFinalState), Tuple{MCTSNode, T, T}}`
|
||||
- root: the complete MCTS tree with root node
|
||||
- bestNextState: the best immediate next state
|
||||
- bestFinalState: the best final state along the best trajectory
|
||||
|
||||
@@ -67,8 +63,8 @@ function runMCTS(
|
||||
explorationweight::Number=1.0,
|
||||
earlystop::Union{Function,Nothing}=nothing,
|
||||
saveSimulatedNode::Bool=false,
|
||||
multithread=false) where {T<:Any}
|
||||
# )::NamedTuple{(:bestNextState, :bestFinalState),Tuple{T,T}} where {T<:Any}
|
||||
multithread=false
|
||||
)::NamedTuple{(:root, :bestNextState, :bestFinalState),Tuple{MCTSNode,T,T}} where {T<:Any}
|
||||
|
||||
root = MCTSNode("root", initialstate, 0, 0, 0, 0, false, nothing, Dict{String,MCTSNode}(),
|
||||
Dict{Symbol,Any}())
|
||||
@@ -121,7 +117,29 @@ function runMCTS(
|
||||
return (root=root, bestNextState=bestNextState.state, bestFinalState=besttrajectory.state)
|
||||
end
|
||||
|
||||
""" Search the best action to take for a given state and task
|
||||
|
||||
# Arguments
|
||||
- `node::MCTSNode`
|
||||
current node to simulate from
|
||||
- `transition::Function`
|
||||
a function that defines how the state transitions
|
||||
- `transitionargs::NamedTuple`
|
||||
arguments for transition function
|
||||
|
||||
# Keyword Arguments
|
||||
- `maxSimulationDepth::Integer`
|
||||
a number of levels MCTS goes during simulation phase (default: 3)
|
||||
- `horizontalSampleSimulationPhase::Integer`
|
||||
a number of child states MCTS samples at each node during simulation phase (default: 3)
|
||||
- `saveSimulatedNode::Bool`
|
||||
whether to save nodes created during simulation phase (default: false)
|
||||
- `multithread::Bool`
|
||||
whether to use multithreading during simulation (default: false)
|
||||
|
||||
# Returns
|
||||
Nothing, but updates the node's reward and visit count through backpropagation
|
||||
"""
|
||||
function simulateThenBackpropagate(node::MCTSNode, transition::Function, transitionargs::NamedTuple;
|
||||
maxSimulationDepth::Integer=3, horizontalSampleSimulationPhase::Integer=3,
|
||||
saveSimulatedNode::Bool=false,
|
||||
@@ -209,9 +227,6 @@ end
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user