Files
YiemAgent/src/mcts.jl
narawat lamaiin aacae344c2 update
2024-04-26 21:36:32 +07:00

495 lines
8.3 KiB
Julia

""" To implement a Monte Carlo Tree Search (MCTS) algorithm in Julia with the UCT (Upper Confidence
Bound for Trees) selection function, you can follow the steps below: Define the necessary types
and functions for the MCTS algorithm:
"""
module mcts
export MCTSNode, runMCTS
using Dates, UUIDs, DataStructures, JSON3, Random
using GeneralUtils
using ..type
# ---------------------------------------------- 100 --------------------------------------------- #
""" a node for MCTS search tree
Arguments\n
-----
state::T
a state of a game. Can be a Dict or something else.
For example:
state = Dict(
:info=> Dict(), # keyword info
:thoughtHistory=> Dict(
:question=> _,
:thought_1=> _,
:action_1=> _,
:observation_1=> _,
:thought_2=> _,
...
)
)
visits::Integer
number of time the game visits this state
stateValue::Float64
state value
Return\n
-----
Example\n
-----
```jldoctest
julia>
```
TODO\n
-----
[] update docstring
[DONE] implement the function
Signature\n
-----
"""
struct MCTSNode{T<:AbstractDict}
state::T
visits::Integer
stateValue::AbstractFloat
children::Dict{T, MCTSNode}
end
""" Select a node based on UCT score
Arguments\n
-----
node::MCTSNode
mcts node
w::Float64
exploration weight
Return\n
-----
Example\n
-----
```jldoctest
julia>
```
TODO\n
-----
[] update docstring
[DONE] check childNode.total_reward w/ LATS paper. Which value total_reward representing
Signature\n
-----
"""
function select(node::MCTSNode, w::Float64)
max_uct = -Inf
selectedNode = nothing
for (childState, childNode) in node.children
uctValue = childNode.stateValue +
w * sqrt(log(node.visits) / childNode.visits)
if uctValue > max_uct
max_uct = uctValue
selectedNode = childNode
end
end
return selectedNode
end
""" Expand selected node
Arguments\n
-----
node::MCTSNode
MCTS node
state::T
a state of a game. Can be a Dict or something else.
decisionMaker::Function
Return\n
-----
Example\n
-----
```jldoctest
julia>
```
TODO\n
-----
[] update docstring
[x] implement the function
Signature\n
-----
"""
function expand(a::T1, node::MCTSNode, state::T2, decisionMaker::Function, stateValueEstimator::Function;
n::Integer=3) where {T1<:agent, T2<:AbstractDict}
# sampling action from decisionMaker
for sample in 1:n
thought = decisionMaker(a, state)
error("--> expand")
newState = transition(node.state, action) #[] Implement your transition function
if newState keys(node.children)
node.children[newState] = MCTSNode(newState, 0, 0.0, Dict{T, MCTSNode}())
end
end
end
"""
Arguments\n
-----
Return\n
-----
Example\n
-----
```jldoctest
julia>
```
TODO\n
-----
[] update docstring
[] implement the function
[] reward only comes at terminal state
Signature\n
-----
"""
function simulate(state::T, max_depth::Int) where {T<:AbstractDict}
total_reward = 0.0
for _ in 1:max_depth
#[] Implement your action selection function based on highest stateValue
action = select_action(state) # current state
state, reward = transition(state, action) # Implement transition function to a new state
#[] check for the terminal state
total_reward += reward
end
return total_reward
end
"""
Arguments\n
-----
Return\n
-----
Example\n
-----
```jldoctest
julia>
```
TODO\n
-----
[] update docstring
[] implement the function
Signature\n
-----
"""
function backpropagate(node::MCTSNode, reward::Float64)
node.visits += 1
# [] there is no total_reward in the paper, buy they use stateValue
node.total_reward += reward
if !isempty(node.children)
best_child = argmax([child.total_reward / child.visits for child in values(node.children)])
backpropagate(node.children[best_child], -reward)
end
end
"""
Arguments\n
-----
Return\n
-----
Example\n
-----
```jldoctest
julia>
```
TODO\n
-----
[] update docstring
[] implement the function
Signature\n
-----
"""
function transition(state, action)
end
""" Check whether a node is a leaf node of a tree
Arguments\n
-----
node::MCTSNode
node of a tree
Return\n
-----
result::Bool
true if the node is a leaf node of a tree otherwise false
Example\n
-----
```jldoctest
julia> using
```
TODO\n
-----
[] update docstring
Signature\n
-----
"""
isLeaf(node::MCTSNode)::Bool = isempty(node.children)
"""
Arguments\n
-----
Return\n
-----
Example\n
-----
```jldoctest
julia>
```
TODO\n
-----
[] update docstring
[WORKING] implement the function
Signature\n
-----
"""
function iterativeprompting(a::T, prompt::String, verification::Function) where {T<:agent}
msgMeta = GeneralUtils.generate_msgMeta(
a.config[:thirdPartyService][:text2textinstruct],
senderName= "iterativeprompting",
senderId= a.id,
receiverName= "text2textinstruct",
)
outgoingMsg = Dict(
:msgMeta,
:payload=> Dict(
:text=> prompt,
)
)
# iteration loop
while true
# send prompt to LLM
response = GeneralUtils.sendReceiveMqttMsg()
# check for correctness and get feedback
if correct
break
else
# get LLM critique
# add critique to prompt
end
end
return
end
""" Check JSON format correctness, provide
Arguments\n
-----
Return\n
-----
Example\n
-----
```jldoctest
julia>
```
TODO\n
-----
[] update docstring
[WORKING] implement the function
Signature\n
-----
"""
function syntaxcheck_json(jsonstring::String)::NamedTuple
success, result, errormsg, st = GeneralUtils.showstracktrace(JSON3.read, jsonstring)
if !success # gives feedback
else
end
return (success, result, critique)
end
"""
Arguments\n
-----
Return\n
-----
Example\n
-----
```jldoctest
julia>
```
TODO\n
-----
[] update docstring
[] implement the function
[] implement RAG to pull similar experience
Signature\n
-----
"""
function isTerminal()
end
"""
Arguments\n
-----
Return\n
-----
Example\n
-----
```jldoctest
julia>
```
TODO\n
-----
[] update docstring
[] implement the function
Signature\n
-----
"""
function executeLLMFunction()
end
# ------------------------------------------------------------------------------------------------ #
# Create a complete example using the defined MCTS functions #
# ------------------------------------------------------------------------------------------------ #
""" Search for best action
Arguments\n
-----
initial state
initial state
decisionMaker::Function
decide what action to take
stateValueEstimator::Function
assess the value of the state
reflector::Function
generate lesson from trajectory and reward
n::Integer
how many times action will be sampled from decisionMaker
w::Float64
exploration weight
Return\n
-----
plan::Vector{Dict}
best plan
Example\n
-----
```jldoctest
julia>
```
TODO\n
-----
[] update docstring
Signature\n
-----
"""
function runMCTS(a::T, initialState, decisionMaker::Function, stateValueEstimator::Function,
reflector::Function, n::Integer, maxDepth::Integer,
maxIterations::Integer, w::Float64) where {T<:agent}
statetype = typeof(initialState)
root = MCTSNode(initialState, 0, 0.0, Dict{statetype, MCTSNode}())
for _ in 1:maxIterations
node = root
while !isLeaf(node)
node = select(node, w)
end
expand(a, node, node.state, decisionMaker, stateValueEstimator, n=n)
error("---> runMCTS")
leaf_node = node.children[node.state] # mark leaf node
reward = simulate(leaf_node.state, maxDepth)
backpropagate(leaf_node, reward)
end
best_child_state = argmax([child.total_reward / child.visits for child in values(root.children)])
return best_child_state
end
end