update
This commit is contained in:
@@ -26,27 +26,23 @@ using ..type, ..mcts, ..util
|
||||
- `horizontalSampleSimulationPhase::Integer`
|
||||
a number of child state MCTS sample at each node during simulation's expansion phase (default: 3)
|
||||
- `maxSimulationDepth::Integer`
|
||||
|
||||
a number of levels MCTS goes during simulation phase (default: 3)
|
||||
- `maxiterations::Integer`
|
||||
|
||||
a number of iteration MCTS goes thru expansion -> simulation -> backpropagation cycle (default: 10)
|
||||
- `explorationweight::Number`
|
||||
exploration weight controls how much MCTS should explore new state instead of exploit
|
||||
a known state. 1.0 balance between exploration and exploitation like 50%-50%. 2.0 makes MCTS
|
||||
|
||||
aggressively explore new state (default: 1.0)
|
||||
- `earlystop::Union{Function,Nothing}`
|
||||
optional function to check early stopping condition (default: nothing)
|
||||
- `saveSimulatedNode::Bool`
|
||||
whether to save nodes created during simulation phase (default: false)
|
||||
|
||||
|
||||
|
||||
- `multithread::Bool`
|
||||
whether to use multithreading during simulation (default: false)
|
||||
|
||||
# Returns
|
||||
- `NamedTuple{(:mctstree, :bestNextState, :bestFinalState), Tuple{MCTSNode, T, T}}`
|
||||
- mctstree: the complete MCTS tree with root node
|
||||
- `NamedTuple{(:root, :bestNextState, :bestFinalState), Tuple{MCTSNode, T, T}}`
|
||||
- root: the complete MCTS tree with root node
|
||||
- bestNextState: the best immediate next state
|
||||
- bestFinalState: the best final state along the best trajectory
|
||||
|
||||
@@ -67,8 +63,8 @@ function runMCTS(
|
||||
explorationweight::Number=1.0,
|
||||
earlystop::Union{Function,Nothing}=nothing,
|
||||
saveSimulatedNode::Bool=false,
|
||||
multithread=false) where {T<:Any}
|
||||
# )::NamedTuple{(:bestNextState, :bestFinalState),Tuple{T,T}} where {T<:Any}
|
||||
multithread=false
|
||||
)::NamedTuple{(:root, :bestNextState, :bestFinalState),Tuple{MCTSNode,T,T}} where {T<:Any}
|
||||
|
||||
root = MCTSNode("root", initialstate, 0, 0, 0, 0, false, nothing, Dict{String,MCTSNode}(),
|
||||
Dict{Symbol,Any}())
|
||||
@@ -121,7 +117,29 @@ function runMCTS(
|
||||
return (root=root, bestNextState=bestNextState.state, bestFinalState=besttrajectory.state)
|
||||
end
|
||||
|
||||
""" Search the best action to take for a given state and task
|
||||
|
||||
# Arguments
|
||||
- `node::MCTSNode`
|
||||
current node to simulate from
|
||||
- `transition::Function`
|
||||
a function that defines how the state transitions
|
||||
- `transitionargs::NamedTuple`
|
||||
arguments for transition function
|
||||
|
||||
# Keyword Arguments
|
||||
- `maxSimulationDepth::Integer`
|
||||
a number of levels MCTS goes during simulation phase (default: 3)
|
||||
- `horizontalSampleSimulationPhase::Integer`
|
||||
a number of child states MCTS samples at each node during simulation phase (default: 3)
|
||||
- `saveSimulatedNode::Bool`
|
||||
whether to save nodes created during simulation phase (default: false)
|
||||
- `multithread::Bool`
|
||||
whether to use multithreading during simulation (default: false)
|
||||
|
||||
# Returns
|
||||
Nothing, but updates the node's reward and visit count through backpropagation
|
||||
"""
|
||||
function simulateThenBackpropagate(node::MCTSNode, transition::Function, transitionargs::NamedTuple;
|
||||
maxSimulationDepth::Integer=3, horizontalSampleSimulationPhase::Integer=3,
|
||||
saveSimulatedNode::Bool=false,
|
||||
@@ -209,9 +227,6 @@ end
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
68
src/mcts.jl
68
src/mcts.jl
@@ -10,27 +10,29 @@ using ..type
|
||||
# ---------------------------------------------- 100 --------------------------------------------- #
|
||||
|
||||
|
||||
"""
|
||||
""" Select the best next node based on the highest value metric
|
||||
|
||||
# Arguments
|
||||
- `node::MCTSNode`
|
||||
node of a search tree
|
||||
node of a search tree to evaluate
|
||||
|
||||
# Return
|
||||
- `childNode::MCTSNode`
|
||||
the highest value child node
|
||||
|
||||
# Signature
|
||||
the child node with highest value based on either:
|
||||
- statevalue/visits ratio if any nodes have non-zero statevalue
|
||||
- progressvalue + reward otherwise
|
||||
"""
|
||||
function selectBestNextNode(node::MCTSNode)::MCTSNode
|
||||
highestProgressValue = -1
|
||||
nodekey = nothing
|
||||
|
||||
# if all childnode has statevalue == 0, use progressvalue + reward to select the best node
|
||||
# Calculate sum of statevalues across all child nodes
|
||||
stateValueSum = sum([v.statevalue for (k, v) in node.children])
|
||||
|
||||
# If any nodes have non-zero statevalue, use statevalue/visits as selection metric
|
||||
if stateValueSum != 0
|
||||
for (k, childnode) in node.children
|
||||
# Calculate average statevalue per visit
|
||||
potential = childnode.statevalue / childnode.visits
|
||||
|
||||
if potential > highestProgressValue
|
||||
@@ -39,6 +41,7 @@ function selectBestNextNode(node::MCTSNode)::MCTSNode
|
||||
end
|
||||
end
|
||||
else
|
||||
# Otherwise use progressvalue + reward as selection metric
|
||||
for (k, childnode) in node.children
|
||||
potential = childnode.progressvalue + childnode.reward
|
||||
|
||||
@@ -53,15 +56,16 @@ function selectBestNextNode(node::MCTSNode)::MCTSNode
|
||||
end
|
||||
|
||||
|
||||
"""
|
||||
""" Select the best trajectory node based on the highest reward
|
||||
|
||||
# Arguments
|
||||
- `node::MCTSNode`
|
||||
node of a search tree
|
||||
node of a search tree to evaluate
|
||||
|
||||
# Return
|
||||
- `childNode::MCTSNode`
|
||||
the highest value child node
|
||||
the highest value child node found by traversing down the tree using selectBestNextNode
|
||||
until reaching a leaf node
|
||||
|
||||
# Signature
|
||||
"""
|
||||
@@ -86,7 +90,8 @@ end
|
||||
reward it is now.
|
||||
|
||||
# Return
|
||||
- `None`
|
||||
- `Nothing`
|
||||
This function modifies the nodes in place and returns nothing
|
||||
|
||||
# Signature
|
||||
"""
|
||||
@@ -94,22 +99,23 @@ function backpropagate(node::MCTSNode, simTrajectoryReward::T;
|
||||
discountRewardCoeff::AbstractFloat=0.9) where {T<:Number}
|
||||
while !isroot(node)
|
||||
# Update the statistics of the current node based on the result of the playout
|
||||
node.visits += 1
|
||||
node.statevalue += ((node.statevalue * (node.visits-1)) + simTrajectoryReward) / node.visits
|
||||
node.visits += 1 # Increment visit count for this node
|
||||
node.statevalue += ((node.statevalue * (node.visits-1)) + simTrajectoryReward) / node.visits # Update running average of state value
|
||||
simTrajectoryReward *= discountRewardCoeff # discount because future reward is uncertain
|
||||
node = node.parent
|
||||
node = node.parent # Move up to parent node for next iteration
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
""" Determine whether a node is a leaf node of a search tree.
|
||||
|
||||
# Arguments
|
||||
- `node::MCTSNode`
|
||||
a search tree node
|
||||
|
||||
# Return
|
||||
- `result::Bool`
|
||||
true if it is a leaf node, false otherwise.
|
||||
true if it is a leaf node (has no children), false otherwise.
|
||||
|
||||
# Example
|
||||
```jldoctest
|
||||
julia> using Revise
|
||||
@@ -128,14 +134,10 @@ julia> YiemAgent.isleaf(root)
|
||||
true
|
||||
```
|
||||
|
||||
# TODO
|
||||
[] update docs
|
||||
|
||||
# Signature
|
||||
"""
|
||||
isleaf(node::MCTSNode)::Bool = isempty(node.children)
|
||||
|
||||
|
||||
""" Determine wheter a given node is a root node
|
||||
|
||||
# Arguments
|
||||
@@ -185,14 +187,18 @@ end
|
||||
|
||||
# Arguments
|
||||
- `node::MCTSNode`
|
||||
MCTS node
|
||||
MCTS node to expand
|
||||
- `transition::Function`
|
||||
A function that handles state transition.
|
||||
- `transitionargs::NamedTuple`
|
||||
Arguments for transition()
|
||||
- `totalsample::Integer`
|
||||
Total number to sample from the current node (i.e. expand new node horizontally)
|
||||
|
||||
|
||||
# Keyword Arguments
|
||||
- `horizontalSample::Integer`
|
||||
Total number to sample from the current node (i.e. expand new node horizontally). Defaults to 3.
|
||||
- `multithread::Bool`
|
||||
Whether to run expansion in parallel using multiple threads. Defaults to false.
|
||||
|
||||
# Return
|
||||
- None
|
||||
|
||||
@@ -211,6 +217,21 @@ function expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple;
|
||||
end
|
||||
end
|
||||
|
||||
""" Helper function to expand a single child node.
|
||||
|
||||
# Arguments
|
||||
- `node::MCTSNode`
|
||||
Parent MCTS node to expand from
|
||||
- `transition::Function`
|
||||
A function that handles state transition
|
||||
- `transitionargs::NamedTuple`
|
||||
Arguments for transition()
|
||||
|
||||
# Return
|
||||
- None
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function _expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple)
|
||||
result = transition(node.state, transitionargs)
|
||||
newNodeKey::AbstractString = result[:newNodeKey]
|
||||
@@ -231,7 +252,6 @@ function _expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
""" Simulate interactions between agent and environment
|
||||
|
||||
# Arguments
|
||||
|
||||
34
src/type.jl
34
src/type.jl
@@ -10,17 +10,27 @@ using GeneralUtils
|
||||
""" a node for MCTS search tree
|
||||
|
||||
# Arguments
|
||||
- `state::T`
|
||||
a state of a game. Can be a Dict or something else.
|
||||
- `visits::Integer `
|
||||
number of time the game visits this state
|
||||
- `stateValue::Float64`
|
||||
state value
|
||||
- `children::Dict{T, MCTSNode}`
|
||||
children node
|
||||
- `nodekey::AbstractString`
|
||||
unique identifier for the node
|
||||
- `state::AbstractDict`
|
||||
a state of a game represented as a dictionary
|
||||
- `visits::Integer`
|
||||
number of times the game visits this state
|
||||
- `progressvalue::Number`
|
||||
estimated value by LLM's reasoning
|
||||
- `statevalue::Number`
|
||||
current state value, stores node's immediate reward and future discounted rewards
|
||||
- `reward::Number`
|
||||
immediate reward for this node
|
||||
- `isterminal::Bool`
|
||||
whether this node represents a terminal state
|
||||
- `parent::Union{MCTSNode, Nothing}`
|
||||
reference to parent node, Nothing for root
|
||||
- `children::Dict{String, MCTSNode}`
|
||||
mapping of child nodes
|
||||
- `etc::Dict{Symbol, Any}`
|
||||
additional storage for arbitrary data
|
||||
|
||||
# Return
|
||||
- `nothing`
|
||||
# Example
|
||||
```jldoctest
|
||||
julia> state = Dict(
|
||||
@@ -36,9 +46,6 @@ julia> state = Dict(
|
||||
)
|
||||
```
|
||||
|
||||
# TODO
|
||||
[] update docstring
|
||||
|
||||
# Signature
|
||||
"""
|
||||
mutable struct MCTSNode{T1<:AbstractDict, T2<:AbstractString}
|
||||
@@ -113,7 +120,6 @@ end
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
end # module type
|
||||
@@ -17,6 +17,8 @@ using ..type
|
||||
Value 2.0 makes MCTS aggressively search the tree.
|
||||
# Return
|
||||
- `selectedNode::MCTSNode`
|
||||
child node with highest UCT score. UCT score balances between exploitation (state value)
|
||||
and exploration (visit count) based on the exploration weight w.
|
||||
|
||||
# Example
|
||||
```jldoctest
|
||||
@@ -133,7 +135,6 @@ end
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
end # module util
|
||||
Reference in New Issue
Block a user