Files
YiemAgent/src/mcts.jl
narawat lamaiin e558107284 update
2024-04-30 17:39:45 +07:00

405 lines
9.2 KiB
Julia

""" To implement a Monte Carlo Tree Search (MCTS) algorithm in Julia with the UCT (Upper Confidence
Bound for Trees) selection function, you can follow the steps below: Define the necessary types
and functions for the MCTS algorithm:
"""
module mcts
export MCTSNode, runMCTS, isleaf
using Dates, UUIDs, DataStructures, JSON3, Random
using GeneralUtils
using ..type, ..llmfunction
# ---------------------------------------------- 100 --------------------------------------------- #
""" a node for MCTS search tree
# Arguments
- `state::T`
a state of a game. Can be a Dict or something else.
- `visits::Integer `
number of time the game visits this state
- `stateValue::Float64`
state value
- `children::Dict{T, MCTSNode}`
children node
# Return
- `nothing`
# Example
```jldoctest
julia> state = Dict(
:info=> Dict(), # keyword info
:thoughtHistory=> Dict(
:question=> _,
:thought_1=> _,
:action_1=> _,
:observation_1=> _,
:thought_2=> _,
...
)
)
```
# Signature
"""
struct MCTSNode{T<:AbstractDict}
state::T
visits::Integer
stateValue::AbstractFloat
children::Dict{T, MCTSNode}
end
""" Select a node based on UCT score
# Arguments
- `node::MCTSNode`
mcts node
- `w::Float64`
exploration weight
# Return
# Example
```jldoctest
julia>
```
# TODO
[] update docstring
[DONE] check childNode.total_reward w/ LATS paper. Which value total_reward representing
# Signature
"""
function select(node::MCTSNode, w::Float64)
max_uct = -Inf
selectedNode = nothing
for (childState, childNode) in node.children
uctValue = childNode.stateValue +
w * sqrt(log(node.visits) / childNode.visits)
if uctValue > max_uct
max_uct = uctValue
selectedNode = childNode
end
end
return selectedNode
end
""" Expand selected node
# Arguments
- `node::MCTSNode`
MCTS node
- `state::T`
a state of a game. Can be a Dict or something else.
- `decisionMaker::Function`
# Return
# Example
```jldoctest
julia>
```
# TODO
- [] update docstring
- [WORKING] implement the function
# Signature
"""
function expand(a::T1, node::MCTSNode, state::T2, decisionMaker::Function, stateValueEstimator::Function;
n::Integer=3) where {T1<:agent, T2<:AbstractDict}
# sampling action from decisionMaker
for sample in 1:n
thoughtJstr = decisionMaker(a, state)
thoughtDict = copy(JSON3.read(thoughtJstr))
""" Example of thoughtDict
Dict{Symbol, Any} with 3 entries:
:Thought_1 => "The customer wants to buy a bottle of wine. This is a good start!"
:Action_1 => Dict{Symbol, Any}(
:action=>"Chatbox",
:input=>"What occasion are you buying the wine for?"
)
:Observation_1 => ""
"""
@show state
@show thoughtDict
newstate = MCTStransition(a, node.state, thoughtDict) #[] Implement your transition function
if newstate keys(node.children) # BUG should be "key of the newstate" here not newstate
statetype = typeof(state)
# BUG should be node.children[key of newstate] here not newstate. may be a uuid
node.children[newstate] = MCTSNode(newstate, 0, 0.0, Dict{statetype, MCTSNode}())
end
end
end
"""
# Arguments
# Return
# Example
```jldoctest
julia>
```
# TODO
- [] update docstring
- [] implement the function
- [] reward only comes at terminal state
# Signature
"""
function simulate(state::T, max_depth::Int) where {T<:AbstractDict}
total_reward = 0.0
for _ in 1:max_depth
#[] Implement your action selection function based on highest stateValue
action = select_action(state) # current state
state, reward = transition(state, action) # Implement transition function to a new state
#[] check for the terminal state
total_reward += reward
end
return total_reward
end
"""
# Arguments
# Return
# Example
```jldoctest
julia>
```
# TODO
- [] update docstring
- [] implement the function
# Signature
"""
function backpropagate(node::MCTSNode, reward::Float64)
node.visits += 1
# [] there is no total_reward in the paper, buy they use stateValue
node.total_reward += reward
if !isempty(node.children)
best_child = argmax([child.total_reward / child.visits for child in values(node.children)])
backpropagate(node.children[best_child], -reward)
end
end
""" Get a new state
# Arguments
- `a::T1`
one of YiemAgent's agent
- `state::T2`
current game state
- `thoughtDict::T3`
contain Thought, Action, Observation
# Return
- `newstate::AbstractDict`
next game state
# Example
```jldoctest
julia> state = Dict{Symbol, Dict{Symbol, Any}}(
:thoughtHistory => Dict(:Question => "Hello, I want to buy a bottle of wine."),
:storeinfo => Dict(),
:customerinfo => Dict()
)
julia> thoughtDict = Dict(
:Question=> "I want to buy a bottle of wine.",
:Thought_1=> "The customer wants to buy a bottle of wine.",
:Action_1=> Dict{Symbol, Any}(
:name=>"Chatbox",
:input=>"What occasion are you buying the wine for?",
),
:Observation_1 => ""
)
```
# TODO
- [] update docstring
- [PENDING] add other actions
# Signature
"""
function MCTStransition(a::T1, state::T2,
thoughtDict::T3)::AbstractDict where {T1<:agent, T2<:AbstractDict, T3<:AbstractDict}
latestThoughtKey, _ = GeneralUtils.findHighestIndexKey(thoughtDict, "Thought")
latestActionKey, latestActionIndice = GeneralUtils.findHighestIndexKey(thoughtDict, "Action")
_action = thoughtDict[latestActionKey]
actionname = _action[:name]
actioninput = _action[:input]
# map action and input() to llm function
response =
if actionname == "chatbox"
virtualWineCustomerChatbox(a, actioninput) # virtual customer
elseif actionname == "winestock"
elseif actionname == "finish"
else
end
# add Thought, action, observation to thoughtHistory
newstate = deepcopy(state)
newstate[:thoughtHistory][latestThoughtKey] = thoughtDict[latestThoughtKey]
newstate[:thoughtHistory][latestActionKey] = thoughtDict[latestActionKey]
latestObservationKey = Symbol("Observation_$(latestActionIndice)")
newstate[:thoughtHistory][latestObservationKey] = response
return newstate
end
""" Determine whether a node is a leaf node of a search tree.
# Arguments
- `node::MCTSNode`
a search tree node
# Return
- `result::Bool`
true if it is a leaf node, false otherwise.
# Example
```jldoctest
julia> using Revise
julia> using YiemAgent, DataStructures
julia> initialState = Dict{Symbol, Any}(
:customerinfo=> Dict{Symbol, Any}(),
:storeinfo=> Dict{Symbol, Any}(),
:thoughtHistory=> OrderedDict{Symbol, Any}(
:Question=> "How are you?",
)
)
julia> statetype = typeof(initialState)
julia> root = YiemAgent.MCTSNode(initialState, 0, 0.0, Dict{statetype, YiemAgent.MCTSNode}())
julia> YiemAgent.isleaf(root)
true
```
# Signature
"""
isleaf(node::MCTSNode)::Bool = isempty(node.children)
"""
# Arguments
# Return
# Example
```jldoctest
julia>
```
# TODO
- [] update docstring
- [] implement the function
# Signature
"""
function executeLLMFunction()
end
# ------------------------------------------------------------------------------------------------ #
# Create a complete example using the defined MCTS functions #
# ------------------------------------------------------------------------------------------------ #
""" Search the best action to take for a given state and task
# Arguments
- `a::agent`
one of Yiem's agents
- `initial state`
initial state
- `decisionMaker::Function`
decide what action to take
- `stateValueEstimator::Function`
assess the value of the state
- `reflector::Function`
generate lesson from trajectory and reward
- `isterminal::Function`
determine whether a given state is a terminal state
- `n::Integer`
how many times action will be sampled from decisionMaker
- `w::Float64`
exploration weight
# Return
- `plan::Vector{Dict}`
best plan
# Example
```jldoctest
julia>
```
# TODO
[] update docstring
# Signature
"""
function runMCTS(
a::T1,
initialState,
decisionMaker::Function,
stateValueEstimator::Function,
reflector::Function,
isterminal::Function,
n::Integer,
maxDepth::Integer,
maxIterations::Integer,
w::Float64) where {T1<:agent}
statetype = typeof(initialState)
root = MCTSNode(initialState, 0, 0.0, Dict{statetype, MCTSNode}())
for _ in 1:maxIterations
node = root
while !isleaf(node)
node = select(node, w)
end
expand(a, node, node.state, decisionMaker, stateValueEstimator, n=n)
leaf_node = node.children[node.state] # mark leaf node
reward = simulate(leaf_node.state, maxDepth)
backpropagate(leaf_node, reward)
end
best_child_state = argmax([child.total_reward / child.visits for child in values(root.children)])
error("---> runMCTS")
return best_child_state
end
end