425 lines
9.2 KiB
Julia
425 lines
9.2 KiB
Julia
""" To implement a Monte Carlo Tree Search (MCTS) algorithm in Julia with the UCT (Upper Confidence
|
|
Bound for Trees) selection function, you can follow the steps below: Define the necessary types
|
|
and functions for the MCTS algorithm:
|
|
"""
|
|
|
|
module mcts
|
|
|
|
export MCTSNode, runMCTS, isleaf
|
|
|
|
using Dates, UUIDs, DataStructures, JSON3, Random
|
|
using GeneralUtils
|
|
using ..type, ..llmfunction
|
|
|
|
# ---------------------------------------------- 100 --------------------------------------------- #
|
|
|
|
""" a node for MCTS search tree
|
|
|
|
# Arguments
|
|
- `state::T`
|
|
a state of a game. Can be a Dict or something else.
|
|
- `visits::Integer `
|
|
number of time the game visits this state
|
|
- `stateValue::Float64`
|
|
state value
|
|
- `children::Dict{T, MCTSNode}`
|
|
children node
|
|
|
|
# Return
|
|
- `nothing`
|
|
# Example
|
|
```jldoctest
|
|
julia> state = Dict(
|
|
:info=> Dict(), # keyword info
|
|
:thoughtHistory=> Dict(
|
|
:question=> _,
|
|
:thought_1=> _,
|
|
:action_1=> _,
|
|
:observation_1=> _,
|
|
:thought_2=> _,
|
|
...
|
|
)
|
|
)
|
|
```
|
|
|
|
# TODO
|
|
[] update docstring
|
|
|
|
# Signature
|
|
"""
|
|
struct MCTSNode{T<:AbstractDict}
|
|
statekey::String
|
|
state::T
|
|
visits::Integer
|
|
progressValue::Number
|
|
children::Dict{String, MCTSNode}
|
|
end
|
|
|
|
""" Select a node based on UCT score
|
|
|
|
# Arguments
|
|
- `node::MCTSNode`
|
|
mcts node
|
|
- `w::Float64`
|
|
exploration weight
|
|
# Return
|
|
|
|
# Example
|
|
```jldoctest
|
|
julia>
|
|
```
|
|
|
|
# TODO
|
|
[] update docstring
|
|
[TESTING] check childNode.total_reward w/ LATS paper. Which value total_reward representing
|
|
|
|
# Signature
|
|
"""
|
|
function select(node::MCTSNode, w::Float64)
|
|
max_uct = -Inf
|
|
selectedNode = nothing
|
|
|
|
for (childState, childNode) in node.children
|
|
uctValue = childNode.stateValue +
|
|
w * sqrt(log(node.visits) / childNode.visits)
|
|
if uctValue > max_uct
|
|
max_uct = uctValue
|
|
selectedNode = childNode
|
|
end
|
|
end
|
|
|
|
return selectedNode
|
|
end
|
|
|
|
""" Expand selected node
|
|
|
|
# Arguments
|
|
- `a::T1`
|
|
One of YiemAgent's agent
|
|
- `node::MCTSNode`
|
|
MCTS node
|
|
- `state::T2`
|
|
a state of a game. Can be a Dict or something else.
|
|
- `decisionMaker::Function`
|
|
a function that output Thought and Action
|
|
- `progressValueEstimator::Function`
|
|
a function that output trajectory progress score
|
|
|
|
# Return
|
|
|
|
# Example
|
|
```jldoctest
|
|
julia>
|
|
```
|
|
|
|
# Signature
|
|
"""
|
|
function expand(a::T1, node::MCTSNode, state::T2, decisionMaker::Function,
|
|
progressValueEstimator::Function; n::Integer=3) where {T1<:agent, T2<:AbstractDict}
|
|
|
|
# sampling action from decisionMaker
|
|
for sample in 1:n
|
|
thoughtDict = decisionMaker(a, state)
|
|
@show state
|
|
@show thoughtDict
|
|
newStatekey, newstate = MCTStransition(a, node.state, thoughtDict) #[] Implement your transition function
|
|
|
|
# add progressValueEstimator
|
|
_, progressValue = progressValueEstimator(a, newstate)
|
|
|
|
if newStatekey ∉ keys(node.children)
|
|
node.children[newStatekey] = MCTSNode(newStatekey, newstate, 0, progressValue, Dict{String, MCTSNode}())
|
|
end
|
|
end
|
|
end
|
|
|
|
"""
|
|
|
|
# Arguments
|
|
|
|
# Return
|
|
|
|
# Example
|
|
```jldoctest
|
|
julia>
|
|
```
|
|
|
|
# TODO
|
|
- [] update docstring
|
|
- [WORKING] implement the function
|
|
- [] reward only comes at terminal state
|
|
|
|
# Signature
|
|
"""
|
|
function simulate(state::T, max_depth::Int) where {T<:AbstractDict}
|
|
error("--> simulate")
|
|
total_reward = 0.0
|
|
for _ in 1:max_depth
|
|
#[] Implement your action selection function based on highest stateValue
|
|
action = select_action(state) # current state
|
|
state, reward = transition(state, action) # Implement transition function to a new state
|
|
|
|
#[] check for the terminal state
|
|
|
|
total_reward += reward
|
|
end
|
|
return total_reward
|
|
end
|
|
|
|
"""
|
|
|
|
# Arguments
|
|
|
|
# Return
|
|
|
|
# Example
|
|
```jldoctest
|
|
julia>
|
|
```
|
|
|
|
# TODO
|
|
- [] update docstring
|
|
- [] implement the function
|
|
|
|
# Signature
|
|
"""
|
|
function backpropagate(node::MCTSNode, reward::Float64)
|
|
node.visits += 1
|
|
|
|
# [] there is no total_reward in the paper, buy they use stateValue
|
|
node.total_reward += reward
|
|
if !isempty(node.children)
|
|
best_child = argmax([child.total_reward / child.visits for child in values(node.children)])
|
|
backpropagate(node.children[best_child], -reward)
|
|
end
|
|
end
|
|
|
|
""" Get a new state
|
|
|
|
# Arguments
|
|
- `a::T1`
|
|
one of YiemAgent's agent
|
|
- `state::T2`
|
|
current game state
|
|
- `thoughtDict::T3`
|
|
contain Thought, Action, Observation
|
|
|
|
# Return
|
|
- (newStatekey, )
|
|
- `newStatekey::String`
|
|
key for newstate
|
|
- `newstate::Dict{Symbol, Any}`
|
|
next game state
|
|
|
|
# Example
|
|
```jldoctest
|
|
julia> state = Dict{Symbol, Dict{Symbol, Any}}(
|
|
:thoughtHistory => Dict(:Question => "Hello, I want to buy a bottle of wine."),
|
|
:storeinfo => Dict(),
|
|
:customerinfo => Dict()
|
|
)
|
|
julia> thoughtDict = Dict(
|
|
:Question=> "I want to buy a bottle of wine.",
|
|
:Thought_1=> "The customer wants to buy a bottle of wine.",
|
|
:Action_1=> Dict{Symbol, Any}(
|
|
:name=>"Chatbox",
|
|
:input=>"What occasion are you buying the wine for?",
|
|
),
|
|
:Observation_1 => ""
|
|
)
|
|
```
|
|
|
|
# TODO
|
|
- [] update docstring
|
|
- [PENDING] add other actions
|
|
- [] add embedding of newstate and store in newstate[:embedding]
|
|
|
|
# Signature
|
|
"""
|
|
function MCTStransition(a::T1, state::T2,
|
|
thoughtDict::T3)::Tuple{String, Dict{Symbol, Any}} where {T1<:agent, T2<:AbstractDict, T3<:AbstractDict}
|
|
latestThoughtKey, _ = GeneralUtils.findHighestIndexKey(thoughtDict, "Thought")
|
|
latestActionKey, latestActionIndice = GeneralUtils.findHighestIndexKey(thoughtDict, "Action")
|
|
_action = thoughtDict[latestActionKey]
|
|
actionname = _action[:name]
|
|
actioninput = _action[:input]
|
|
|
|
# map action and input() to llm function
|
|
response =
|
|
if actionname == "chatbox"
|
|
virtualWineCustomerChatbox(a, actioninput) # virtual customer
|
|
elseif actionname == "winestock"
|
|
|
|
elseif actionname == "finish"
|
|
|
|
else
|
|
|
|
end
|
|
|
|
# add Thought, action, observation to thoughtHistory
|
|
newstate = deepcopy(state)
|
|
newstate[:thoughtHistory][latestThoughtKey] = thoughtDict[latestThoughtKey]
|
|
newstate[:thoughtHistory][latestActionKey] = thoughtDict[latestActionKey]
|
|
latestObservationKey = Symbol("Observation_$(latestActionIndice)")
|
|
newstate[:thoughtHistory][latestObservationKey] = response
|
|
|
|
newStatekey = GeneralUtils.uuid4snakecase()
|
|
|
|
return newStatekey, newstate
|
|
end
|
|
|
|
|
|
""" Determine whether a node is a leaf node of a search tree.
|
|
|
|
# Arguments
|
|
- `node::MCTSNode`
|
|
a search tree node
|
|
# Return
|
|
- `result::Bool`
|
|
true if it is a leaf node, false otherwise.
|
|
# Example
|
|
```jldoctest
|
|
julia> using Revise
|
|
julia> using YiemAgent, DataStructures
|
|
julia> initialState = Dict{Symbol, Any}(
|
|
:customerinfo=> Dict{Symbol, Any}(),
|
|
:storeinfo=> Dict{Symbol, Any}(),
|
|
|
|
:thoughtHistory=> OrderedDict{Symbol, Any}(
|
|
:Question=> "How are you?",
|
|
)
|
|
)
|
|
julia> statetype = typeof(initialState)
|
|
julia> root = YiemAgent.MCTSNode(initialState, 0, 0.0, Dict{statetype, YiemAgent.MCTSNode}())
|
|
julia> YiemAgent.isleaf(root)
|
|
true
|
|
```
|
|
|
|
# Signature
|
|
"""
|
|
isleaf(node::MCTSNode)::Bool = isempty(node.children)
|
|
|
|
|
|
"""
|
|
|
|
# Arguments
|
|
|
|
# Return
|
|
|
|
# Example
|
|
```jldoctest
|
|
julia>
|
|
```
|
|
|
|
# TODO
|
|
- [] update docstring
|
|
- [] implement the function
|
|
|
|
# Signature
|
|
"""
|
|
function executeLLMFunction()
|
|
|
|
end
|
|
|
|
|
|
# ------------------------------------------------------------------------------------------------ #
|
|
# Create a complete example using the defined MCTS functions #
|
|
# ------------------------------------------------------------------------------------------------ #
|
|
""" Search the best action to take for a given state and task
|
|
|
|
# Arguments
|
|
- `a::agent`
|
|
one of Yiem's agents
|
|
- `initial state`
|
|
initial state
|
|
- `decisionMaker::Function`
|
|
decide what action to take
|
|
- `progressValueEstimator::Function`
|
|
assess the value of the state
|
|
- `reflector::Function`
|
|
generate lesson from trajectory and reward
|
|
- `isterminal::Function`
|
|
determine whether a given state is a terminal state
|
|
- `n::Integer`
|
|
how many times action will be sampled from decisionMaker
|
|
- `w::Float64`
|
|
exploration weight
|
|
|
|
# Return
|
|
- `plan::Vector{Dict}`
|
|
best plan
|
|
|
|
# Example
|
|
```jldoctest
|
|
julia>
|
|
```
|
|
|
|
# TODO
|
|
[] update docstring
|
|
|
|
# Signature
|
|
"""
|
|
function runMCTS(
|
|
a::T1,
|
|
initialState,
|
|
decisionMaker::Function,
|
|
progressValueEstimator::Function,
|
|
reflector::Function,
|
|
isterminal::Function,
|
|
n::Integer,
|
|
maxDepth::Integer,
|
|
maxIterations::Integer,
|
|
w::Float64) where {T1<:agent}
|
|
|
|
root = MCTSNode("root", initialState, 0, 0.0, Dict{String, MCTSNode}())
|
|
|
|
for _ in 1:maxIterations
|
|
node = root
|
|
while !isleaf(node)
|
|
node = select(node, w)
|
|
end
|
|
|
|
expand(a, node, node.state, decisionMaker, progressValueEstimator, n=n)
|
|
|
|
# from paper, just start simulation at this node. Not the node that newly expanded
|
|
leaf_node = node
|
|
reward = simulate(leaf_node.state, maxDepth)
|
|
backpropagate(leaf_node, reward)
|
|
end
|
|
|
|
best_child_state = argmax([child.total_reward / child.visits for child in values(root.children)])
|
|
error("---> runMCTS")
|
|
return best_child_state
|
|
end
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
end # module mcts |