Files
YiemAgent/src/mcts.jl
narawat lamaiin 1fae63126f update
2024-05-06 17:01:47 +07:00

493 lines
11 KiB
Julia

""" To implement a Monte Carlo Tree Search (MCTS) algorithm in Julia with the UCT (Upper Confidence
Bound for Trees) selection function, you can follow the steps below: Define the necessary types
and functions for the MCTS algorithm:
"""
module mcts
export MCTSNode, runMCTS, isleaf
using Dates, UUIDs, DataStructures, JSON3, Random, PrettyPrinting
using GeneralUtils
using ..type, ..llmfunction
# ---------------------------------------------- 100 --------------------------------------------- #
""" a node for MCTS search tree
# Arguments
- `state::T`
a state of a game. Can be a Dict or something else.
- `visits::Integer `
number of time the game visits this state
- `stateValue::Float64`
state value
- `children::Dict{T, MCTSNode}`
children node
# Return
- `nothing`
# Example
```jldoctest
julia> state = Dict(
:info=> Dict(), # keyword info
:thoughtHistory=> Dict(
:question=> _,
:thought_1=> _,
:action_1=> _,
:observation_1=> _,
:thought_2=> _,
...
)
)
```
# TODO
[] update docstring
# Signature
"""
mutable struct MCTSNode{T1<:AbstractDict, T2<:AbstractString}
nodekey::T2
state::T1
visits::Integer
stateevaluation::T2
statevalue::Number
reward::Number
isterminal::Bool
parent::Union{MCTSNode, Nothing}
children::Dict{String, MCTSNode}
end
""" Select a node based on UCT score
# Arguments
- `node::MCTSNode`
mcts node
- `w::Float64`
exploration weight
# Return
# Example
```jldoctest
julia>
```
# TODO
[] update docstring
[x] check childNode.total_reward w/ LATS paper. Which value total_reward representing
# Signature
"""
function UCTselect(node::MCTSNode, w::Float64)
max_uct = -Inf
selectedNode = nothing
for (childState, childNode) in node.children
weightedterm =
if node.visits == 0 || childNode.visits == 0
0
else
w * sqrt(log(node.visits) / childNode.visits)
end
uctValue = childNode.statevalue + weightedterm
if uctValue > max_uct
max_uct = uctValue
selectedNode = childNode
end
end
return selectedNode
end
""" Expand selected node
# Arguments
- `a::T1`
One of YiemAgent's agent
- `node::MCTSNode`
MCTS node
- `state::T2`
a state of a game. Can be a Dict or something else.
- `decisionMaker::Function`
a function that output Thought and Action
- `progressValueEstimator::Function`
a function that output trajectory progress score
# Return
# Example
```jldoctest
julia>
```
# TODO
- [] update docstring
# Signature
"""
function expand(a::T1, node::MCTSNode, decisionMaker::Function,
progressValueEstimator::Function, isterminal::Function; n::Integer=3) where {T1<:agent}
# sampling action from decisionMaker
for sample in 1:n
thoughtDict = decisionMaker(a, node.state)
newNodeKey, newstate, isterminalstate, reward = MCTStransition(a, node.state, thoughtDict,
isterminal)
# add progressValueEstimator
stateevaluation, statevalue = progressValueEstimator(a, newstate)
if newNodeKey keys(node.children)
node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, stateevaluation, statevalue,
reward, isterminalstate, node, Dict{String, MCTSNode}())
end
end
end
"""
# Arguments
- `node::MCTSNode`
node that will be a simulation starting point.
# Return
# Example
```jldoctest
julia>
```
# TODO
- [] update docstring
- [x] implement the function
- [] check for the terminal state (node.reward != 0), break if it is terminal state
# Signature
"""
function simulate(a, node::MCTSNode, decisionMaker::Function, progressValueEstimator::Function,
isterminal::Function, maxDepth::Int; n=3)::Number
simTrajectoryReward = 0.0
for depth in 1:maxDepth
if node.isterminal
break
else
try
simTrajectoryReward += node.reward
expand(a, node, decisionMaker, progressValueEstimator, isterminal, n=n)
node = selectChildNode(node)
catch
# if error occurs, break and try again later
break
end
end
end
return simTrajectoryReward
end
"""
# Arguments
# Return
# Example
```jldoctest
julia>
```
# TODO
- [] update docstring
- [WORKING] implement the function
# Signature
"""
function backpropagate(node, simTrajectoryReward; discountRewardCoeff=0.9)
while !isroot(node)
# Update the statistics of the current node based on the result of the playout
node.visits += 1
node.statevalue += ((node.statevalue * (node.visits-1)) + simTrajectoryReward) / node.visits
simTrajectoryReward *= discountRewardCoeff # discount because future reward is uncertain
node = node.parent
end
end
""" Get a new state
# Arguments
- `a::T1`
one of YiemAgent's agent
- `state::T2`
current game state
- `thoughtDict::T3`
contain Thought, Action, Observation
- `isterminal::Function`
a function to determine terminal state
# Return
- `(newNodeKey, newstate, isterminalstate, reward)::Tuple{String, Dict{Symbol, <:Any}, Bool, <:Number}`
# Example
```jldoctest
julia> state = Dict{Symbol, Dict{Symbol, Any}}(
:thoughtHistory => Dict(:question => "Hello, I want to buy a bottle of wine."),
:storeinfo => Dict(),
:customerinfo => Dict()
)
julia> thoughtDict = Dict(
:question=> "I want to buy a bottle of wine.",
:thought_1=> "The customer wants to buy a bottle of wine.",
:action_1=> Dict{Symbol, Any}(
:name=>"Chatbox",
:input=>"What occasion are you buying the wine for?",
),
:observation_1 => ""
)
```
# TODO
- [x] add other actions
- [] add embedding of newstate and store in newstate[:embedding]
# Signature
"""
function MCTStransition(a::T1, state::T2, thoughtDict::T3, isterminal::Function
)::Tuple{String, Dict{Symbol, <:Any}, Bool, <:Number} where {T1<:agent, T2<:AbstractDict, T3<:AbstractDict}
actionname = thoughtDict[:action][:name]
actioninput = thoughtDict[:action][:input]
# map action and input() to llm function
response =
if actionname == "chatbox"
virtualWineCustomerChatbox(a, actioninput) # virtual customer
elseif actionname == "winestock"
winestock(a, actioninput)
elseif actionname == "recommendbox"
virtualWineCustomerReccommendbox(a, actioninput)
else
error("undefined LLM function. Requesting $actionname")
end
latestThoughtKey, latestThoughtIndice = GeneralUtils.findHighestIndexKey(state[:thoughtHistory],
"thought")
nextIndice = latestThoughtKey == :NA ? 1 : latestThoughtIndice + 1
latestThoughtKey = Symbol("thought_$nextIndice")
latestActionKey = Symbol("action_$nextIndice")
# add Thought, action, observation to thoughtHistory
newstate = deepcopy(state)
newstate[:thoughtHistory][latestThoughtKey] = thoughtDict[:thought]
newstate[:thoughtHistory][latestActionKey] = thoughtDict[:action]
latestObservationKey = Symbol("observation_$(nextIndice)")
newstate[:thoughtHistory][latestObservationKey] = response
newNodeKey = GeneralUtils.uuid4snakecase()
isterminalstate, reward = isterminal(newstate)
return (newNodeKey, newstate, isterminalstate, reward)
end
""" Determine whether a node is a leaf node of a search tree.
# Arguments
- `node::MCTSNode`
a search tree node
# Return
- `result::Bool`
true if it is a leaf node, false otherwise.
# Example
```jldoctest
julia> using Revise
julia> using YiemAgent, DataStructures
julia> initialState = Dict{Symbol, Any}(
:customerinfo=> Dict{Symbol, Any}(),
:storeinfo=> Dict{Symbol, Any}(),
:thoughtHistory=> OrderedDict{Symbol, Any}(
:question=> "How are you?",
)
)
julia> statetype = typeof(initialState)
julia> root = YiemAgent.MCTSNode(initialState, 0, 0.0, Dict{statetype, YiemAgent.MCTSNode}())
julia> YiemAgent.isleaf(root)
true
```
# TODO
[] update docs
# Signature
"""
isleaf(node::MCTSNode)::Bool = isempty(node.children)
""" Select child node based on the highest statevalue
# Arguments
- `node::MCTSNode`
node of a search tree
# Return
- `childNode::MCTSNode`
the highest value child node
# Example
```jldoctest
julia>
```
# Signature
"""
function selectChildNode(node::MCTSNode)::MCTSNode
highestProgressValue = 0
nodekey = nothing
# loop thought node children dictionary to find the highest progress value
for (k, childNode) in node.children
thisNodeProgressValue = childNode.statevalue + childNode.reward
if childNode.statevalue > highestProgressValue
highestProgressValue = thisNodeProgressValue
nodekey = childNode.nodekey
end
end
return node.children[nodekey]
end
""" Determine wheter a given node is a root node
# Arguments
- `node::MCTSNode`
node of a search tree
# Return
- `isrootnode::Bool`
true if the given node is root node, false otherwise
# Example
```jldoctest
julia>
```
# TODO
[] update docs
[TESTING] implement the function
# Signature
"""
isroot(node::MCTSNode)::Bool = node.nodekey == "root" ? true : false
# ------------------------------------------------------------------------------------------------ #
# Create a complete example using the defined MCTS functions #
# ------------------------------------------------------------------------------------------------ #
""" Search the best action to take for a given state and task
# Arguments
- `a::agent`
one of Yiem's agents
- `initial state`
initial state
- `decisionMaker::Function`
decide what action to take
- `progressValueEstimator::Function`
assess the value of the state
- `reflector::Function`
generate lesson from trajectory and reward
- `isterminal::Function`
determine whether a given state is a terminal state
- `n::Integer`
how many times action will be sampled from decisionMaker
- `w::Float64`
exploration weight. Value is usually between 1 to 2.
Value 1.0 makes MCTS balance between exploration and exploitation like 50%-50%
Value 2.0 makes MCTS aggressively search the tree
# Return
- `plan::Vector{Dict}`
best plan
# Example
```jldoctest
julia>
```
# TODO
[] update docstring
[PENDING] return best plan
# Signature
"""
function runMCTS(
a::T1,
initialState,
decisionMaker::Function,
progressValueEstimator::Function,
reflector::Function,
isterminal::Function,
n::Integer,
maxDepth::Integer,
maxIterations::Integer,
w::Float64) where {T1<:agent}
root = MCTSNode("root", initialState, 0, "N/A", 0, 0, false, nothing, Dict{String, MCTSNode}())
for nth in 1:maxIterations
node = root
while !isleaf(node)
node = UCTselect(node, w)
end
expand(a, node, decisionMaker, progressValueEstimator, isterminal, n=n)
leafNode = UCTselect(node, w)
simTrajectoryReward = simulate(a, leafNode, decisionMaker, progressValueEstimator,
isterminal, maxDepth, n=n)
backpropagate(leafNode, simTrajectoryReward)
end
best_child_state = argmax([child.total_reward / child.visits for child in values(root.children)])
error("---> runMCTS")
return best_child_state
end
end # module mcts