update
This commit is contained in:
@@ -26,27 +26,23 @@ using ..type, ..mcts, ..util
|
|||||||
- `horizontalSampleSimulationPhase::Integer`
|
- `horizontalSampleSimulationPhase::Integer`
|
||||||
a number of child state MCTS sample at each node during simulation's expansion phase (default: 3)
|
a number of child state MCTS sample at each node during simulation's expansion phase (default: 3)
|
||||||
- `maxSimulationDepth::Integer`
|
- `maxSimulationDepth::Integer`
|
||||||
|
|
||||||
a number of levels MCTS goes during simulation phase (default: 3)
|
a number of levels MCTS goes during simulation phase (default: 3)
|
||||||
- `maxiterations::Integer`
|
- `maxiterations::Integer`
|
||||||
|
|
||||||
a number of iteration MCTS goes thru expansion -> simulation -> backpropagation cycle (default: 10)
|
a number of iteration MCTS goes thru expansion -> simulation -> backpropagation cycle (default: 10)
|
||||||
- `explorationweight::Number`
|
- `explorationweight::Number`
|
||||||
exploration weight controls how much MCTS should explore new state instead of exploit
|
exploration weight controls how much MCTS should explore new state instead of exploit
|
||||||
a known state. 1.0 balance between exploration and exploitation like 50%-50%. 2.0 makes MCTS
|
a known state. 1.0 balance between exploration and exploitation like 50%-50%. 2.0 makes MCTS
|
||||||
|
|
||||||
aggressively explore new state (default: 1.0)
|
aggressively explore new state (default: 1.0)
|
||||||
- `earlystop::Union{Function,Nothing}`
|
- `earlystop::Union{Function,Nothing}`
|
||||||
optional function to check early stopping condition (default: nothing)
|
optional function to check early stopping condition (default: nothing)
|
||||||
- `saveSimulatedNode::Bool`
|
- `saveSimulatedNode::Bool`
|
||||||
whether to save nodes created during simulation phase (default: false)
|
whether to save nodes created during simulation phase (default: false)
|
||||||
|
- `multithread::Bool`
|
||||||
|
whether to use multithreading during simulation (default: false)
|
||||||
|
|
||||||
|
|
||||||
# Returns
|
# Returns
|
||||||
- `NamedTuple{(:mctstree, :bestNextState, :bestFinalState), Tuple{MCTSNode, T, T}}`
|
- `NamedTuple{(:root, :bestNextState, :bestFinalState), Tuple{MCTSNode, T, T}}`
|
||||||
- mctstree: the complete MCTS tree with root node
|
- root: the complete MCTS tree with root node
|
||||||
- bestNextState: the best immediate next state
|
- bestNextState: the best immediate next state
|
||||||
- bestFinalState: the best final state along the best trajectory
|
- bestFinalState: the best final state along the best trajectory
|
||||||
|
|
||||||
@@ -67,8 +63,8 @@ function runMCTS(
|
|||||||
explorationweight::Number=1.0,
|
explorationweight::Number=1.0,
|
||||||
earlystop::Union{Function,Nothing}=nothing,
|
earlystop::Union{Function,Nothing}=nothing,
|
||||||
saveSimulatedNode::Bool=false,
|
saveSimulatedNode::Bool=false,
|
||||||
multithread=false) where {T<:Any}
|
multithread=false
|
||||||
# )::NamedTuple{(:bestNextState, :bestFinalState),Tuple{T,T}} where {T<:Any}
|
)::NamedTuple{(:root, :bestNextState, :bestFinalState),Tuple{MCTSNode,T,T}} where {T<:Any}
|
||||||
|
|
||||||
root = MCTSNode("root", initialstate, 0, 0, 0, 0, false, nothing, Dict{String,MCTSNode}(),
|
root = MCTSNode("root", initialstate, 0, 0, 0, 0, false, nothing, Dict{String,MCTSNode}(),
|
||||||
Dict{Symbol,Any}())
|
Dict{Symbol,Any}())
|
||||||
@@ -121,7 +117,29 @@ function runMCTS(
|
|||||||
return (root=root, bestNextState=bestNextState.state, bestFinalState=besttrajectory.state)
|
return (root=root, bestNextState=bestNextState.state, bestFinalState=besttrajectory.state)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
""" Search the best action to take for a given state and task
|
||||||
|
|
||||||
|
# Arguments
|
||||||
|
- `node::MCTSNode`
|
||||||
|
current node to simulate from
|
||||||
|
- `transition::Function`
|
||||||
|
a function that defines how the state transitions
|
||||||
|
- `transitionargs::NamedTuple`
|
||||||
|
arguments for transition function
|
||||||
|
|
||||||
|
# Keyword Arguments
|
||||||
|
- `maxSimulationDepth::Integer`
|
||||||
|
a number of levels MCTS goes during simulation phase (default: 3)
|
||||||
|
- `horizontalSampleSimulationPhase::Integer`
|
||||||
|
a number of child states MCTS samples at each node during simulation phase (default: 3)
|
||||||
|
- `saveSimulatedNode::Bool`
|
||||||
|
whether to save nodes created during simulation phase (default: false)
|
||||||
|
- `multithread::Bool`
|
||||||
|
whether to use multithreading during simulation (default: false)
|
||||||
|
|
||||||
|
# Returns
|
||||||
|
Nothing, but updates the node's reward and visit count through backpropagation
|
||||||
|
"""
|
||||||
function simulateThenBackpropagate(node::MCTSNode, transition::Function, transitionargs::NamedTuple;
|
function simulateThenBackpropagate(node::MCTSNode, transition::Function, transitionargs::NamedTuple;
|
||||||
maxSimulationDepth::Integer=3, horizontalSampleSimulationPhase::Integer=3,
|
maxSimulationDepth::Integer=3, horizontalSampleSimulationPhase::Integer=3,
|
||||||
saveSimulatedNode::Bool=false,
|
saveSimulatedNode::Bool=false,
|
||||||
@@ -209,9 +227,6 @@ end
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
66
src/mcts.jl
66
src/mcts.jl
@@ -10,27 +10,29 @@ using ..type
|
|||||||
# ---------------------------------------------- 100 --------------------------------------------- #
|
# ---------------------------------------------- 100 --------------------------------------------- #
|
||||||
|
|
||||||
|
|
||||||
"""
|
""" Select the best next node based on the highest value metric
|
||||||
|
|
||||||
# Arguments
|
# Arguments
|
||||||
- `node::MCTSNode`
|
- `node::MCTSNode`
|
||||||
node of a search tree
|
node of a search tree to evaluate
|
||||||
|
|
||||||
# Return
|
# Return
|
||||||
- `childNode::MCTSNode`
|
- `childNode::MCTSNode`
|
||||||
the highest value child node
|
the child node with highest value based on either:
|
||||||
|
- statevalue/visits ratio if any nodes have non-zero statevalue
|
||||||
# Signature
|
- progressvalue + reward otherwise
|
||||||
"""
|
"""
|
||||||
function selectBestNextNode(node::MCTSNode)::MCTSNode
|
function selectBestNextNode(node::MCTSNode)::MCTSNode
|
||||||
highestProgressValue = -1
|
highestProgressValue = -1
|
||||||
nodekey = nothing
|
nodekey = nothing
|
||||||
|
|
||||||
# if all childnode has statevalue == 0, use progressvalue + reward to select the best node
|
# Calculate sum of statevalues across all child nodes
|
||||||
stateValueSum = sum([v.statevalue for (k, v) in node.children])
|
stateValueSum = sum([v.statevalue for (k, v) in node.children])
|
||||||
|
|
||||||
|
# If any nodes have non-zero statevalue, use statevalue/visits as selection metric
|
||||||
if stateValueSum != 0
|
if stateValueSum != 0
|
||||||
for (k, childnode) in node.children
|
for (k, childnode) in node.children
|
||||||
|
# Calculate average statevalue per visit
|
||||||
potential = childnode.statevalue / childnode.visits
|
potential = childnode.statevalue / childnode.visits
|
||||||
|
|
||||||
if potential > highestProgressValue
|
if potential > highestProgressValue
|
||||||
@@ -39,6 +41,7 @@ function selectBestNextNode(node::MCTSNode)::MCTSNode
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
else
|
else
|
||||||
|
# Otherwise use progressvalue + reward as selection metric
|
||||||
for (k, childnode) in node.children
|
for (k, childnode) in node.children
|
||||||
potential = childnode.progressvalue + childnode.reward
|
potential = childnode.progressvalue + childnode.reward
|
||||||
|
|
||||||
@@ -53,15 +56,16 @@ function selectBestNextNode(node::MCTSNode)::MCTSNode
|
|||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
"""
|
""" Select the best trajectory node based on the highest reward
|
||||||
|
|
||||||
# Arguments
|
# Arguments
|
||||||
- `node::MCTSNode`
|
- `node::MCTSNode`
|
||||||
node of a search tree
|
node of a search tree to evaluate
|
||||||
|
|
||||||
# Return
|
# Return
|
||||||
- `childNode::MCTSNode`
|
- `childNode::MCTSNode`
|
||||||
the highest value child node
|
the highest value child node found by traversing down the tree using selectBestNextNode
|
||||||
|
until reaching a leaf node
|
||||||
|
|
||||||
# Signature
|
# Signature
|
||||||
"""
|
"""
|
||||||
@@ -86,7 +90,8 @@ end
|
|||||||
reward it is now.
|
reward it is now.
|
||||||
|
|
||||||
# Return
|
# Return
|
||||||
- `None`
|
- `Nothing`
|
||||||
|
This function modifies the nodes in place and returns nothing
|
||||||
|
|
||||||
# Signature
|
# Signature
|
||||||
"""
|
"""
|
||||||
@@ -94,22 +99,23 @@ function backpropagate(node::MCTSNode, simTrajectoryReward::T;
|
|||||||
discountRewardCoeff::AbstractFloat=0.9) where {T<:Number}
|
discountRewardCoeff::AbstractFloat=0.9) where {T<:Number}
|
||||||
while !isroot(node)
|
while !isroot(node)
|
||||||
# Update the statistics of the current node based on the result of the playout
|
# Update the statistics of the current node based on the result of the playout
|
||||||
node.visits += 1
|
node.visits += 1 # Increment visit count for this node
|
||||||
node.statevalue += ((node.statevalue * (node.visits-1)) + simTrajectoryReward) / node.visits
|
node.statevalue += ((node.statevalue * (node.visits-1)) + simTrajectoryReward) / node.visits # Update running average of state value
|
||||||
simTrajectoryReward *= discountRewardCoeff # discount because future reward is uncertain
|
simTrajectoryReward *= discountRewardCoeff # discount because future reward is uncertain
|
||||||
node = node.parent
|
node = node.parent # Move up to parent node for next iteration
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
""" Determine whether a node is a leaf node of a search tree.
|
""" Determine whether a node is a leaf node of a search tree.
|
||||||
|
|
||||||
# Arguments
|
# Arguments
|
||||||
- `node::MCTSNode`
|
- `node::MCTSNode`
|
||||||
a search tree node
|
a search tree node
|
||||||
|
|
||||||
# Return
|
# Return
|
||||||
- `result::Bool`
|
- `result::Bool`
|
||||||
true if it is a leaf node, false otherwise.
|
true if it is a leaf node (has no children), false otherwise.
|
||||||
|
|
||||||
# Example
|
# Example
|
||||||
```jldoctest
|
```jldoctest
|
||||||
julia> using Revise
|
julia> using Revise
|
||||||
@@ -128,14 +134,10 @@ julia> YiemAgent.isleaf(root)
|
|||||||
true
|
true
|
||||||
```
|
```
|
||||||
|
|
||||||
# TODO
|
|
||||||
[] update docs
|
|
||||||
|
|
||||||
# Signature
|
# Signature
|
||||||
"""
|
"""
|
||||||
isleaf(node::MCTSNode)::Bool = isempty(node.children)
|
isleaf(node::MCTSNode)::Bool = isempty(node.children)
|
||||||
|
|
||||||
|
|
||||||
""" Determine wheter a given node is a root node
|
""" Determine wheter a given node is a root node
|
||||||
|
|
||||||
# Arguments
|
# Arguments
|
||||||
@@ -185,13 +187,17 @@ end
|
|||||||
|
|
||||||
# Arguments
|
# Arguments
|
||||||
- `node::MCTSNode`
|
- `node::MCTSNode`
|
||||||
MCTS node
|
MCTS node to expand
|
||||||
- `transition::Function`
|
- `transition::Function`
|
||||||
A function that handles state transition.
|
A function that handles state transition.
|
||||||
- `transitionargs::NamedTuple`
|
- `transitionargs::NamedTuple`
|
||||||
Arguments for transition()
|
Arguments for transition()
|
||||||
- `totalsample::Integer`
|
|
||||||
Total number to sample from the current node (i.e. expand new node horizontally)
|
# Keyword Arguments
|
||||||
|
- `horizontalSample::Integer`
|
||||||
|
Total number to sample from the current node (i.e. expand new node horizontally). Defaults to 3.
|
||||||
|
- `multithread::Bool`
|
||||||
|
Whether to run expansion in parallel using multiple threads. Defaults to false.
|
||||||
|
|
||||||
# Return
|
# Return
|
||||||
- None
|
- None
|
||||||
@@ -211,6 +217,21 @@ function expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple;
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
""" Helper function to expand a single child node.
|
||||||
|
|
||||||
|
# Arguments
|
||||||
|
- `node::MCTSNode`
|
||||||
|
Parent MCTS node to expand from
|
||||||
|
- `transition::Function`
|
||||||
|
A function that handles state transition
|
||||||
|
- `transitionargs::NamedTuple`
|
||||||
|
Arguments for transition()
|
||||||
|
|
||||||
|
# Return
|
||||||
|
- None
|
||||||
|
|
||||||
|
# Signature
|
||||||
|
"""
|
||||||
function _expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple)
|
function _expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple)
|
||||||
result = transition(node.state, transitionargs)
|
result = transition(node.state, transitionargs)
|
||||||
newNodeKey::AbstractString = result[:newNodeKey]
|
newNodeKey::AbstractString = result[:newNodeKey]
|
||||||
@@ -231,7 +252,6 @@ function _expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
""" Simulate interactions between agent and environment
|
""" Simulate interactions between agent and environment
|
||||||
|
|
||||||
# Arguments
|
# Arguments
|
||||||
|
|||||||
32
src/type.jl
32
src/type.jl
@@ -10,17 +10,27 @@ using GeneralUtils
|
|||||||
""" a node for MCTS search tree
|
""" a node for MCTS search tree
|
||||||
|
|
||||||
# Arguments
|
# Arguments
|
||||||
- `state::T`
|
- `nodekey::AbstractString`
|
||||||
a state of a game. Can be a Dict or something else.
|
unique identifier for the node
|
||||||
|
- `state::AbstractDict`
|
||||||
|
a state of a game represented as a dictionary
|
||||||
- `visits::Integer`
|
- `visits::Integer`
|
||||||
number of time the game visits this state
|
number of times the game visits this state
|
||||||
- `stateValue::Float64`
|
- `progressvalue::Number`
|
||||||
state value
|
estimated value by LLM's reasoning
|
||||||
- `children::Dict{T, MCTSNode}`
|
- `statevalue::Number`
|
||||||
children node
|
current state value, stores node's immediate reward and future discounted rewards
|
||||||
|
- `reward::Number`
|
||||||
|
immediate reward for this node
|
||||||
|
- `isterminal::Bool`
|
||||||
|
whether this node represents a terminal state
|
||||||
|
- `parent::Union{MCTSNode, Nothing}`
|
||||||
|
reference to parent node, Nothing for root
|
||||||
|
- `children::Dict{String, MCTSNode}`
|
||||||
|
mapping of child nodes
|
||||||
|
- `etc::Dict{Symbol, Any}`
|
||||||
|
additional storage for arbitrary data
|
||||||
|
|
||||||
# Return
|
|
||||||
- `nothing`
|
|
||||||
# Example
|
# Example
|
||||||
```jldoctest
|
```jldoctest
|
||||||
julia> state = Dict(
|
julia> state = Dict(
|
||||||
@@ -36,9 +46,6 @@ julia> state = Dict(
|
|||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
# TODO
|
|
||||||
[] update docstring
|
|
||||||
|
|
||||||
# Signature
|
# Signature
|
||||||
"""
|
"""
|
||||||
mutable struct MCTSNode{T1<:AbstractDict, T2<:AbstractString}
|
mutable struct MCTSNode{T1<:AbstractDict, T2<:AbstractString}
|
||||||
@@ -113,7 +120,6 @@ end
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
end # module type
|
end # module type
|
||||||
@@ -17,6 +17,8 @@ using ..type
|
|||||||
Value 2.0 makes MCTS aggressively search the tree.
|
Value 2.0 makes MCTS aggressively search the tree.
|
||||||
# Return
|
# Return
|
||||||
- `selectedNode::MCTSNode`
|
- `selectedNode::MCTSNode`
|
||||||
|
child node with highest UCT score. UCT score balances between exploitation (state value)
|
||||||
|
and exploration (visit count) based on the exploration weight w.
|
||||||
|
|
||||||
# Example
|
# Example
|
||||||
```jldoctest
|
```jldoctest
|
||||||
@@ -133,7 +135,6 @@ end
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
end # module util
|
end # module util
|
||||||
Reference in New Issue
Block a user