495 lines
8.3 KiB
Julia
495 lines
8.3 KiB
Julia
""" To implement a Monte Carlo Tree Search (MCTS) algorithm in Julia with the UCT (Upper Confidence
|
|
Bound for Trees) selection function, you can follow the steps below: Define the necessary types
|
|
and functions for the MCTS algorithm:
|
|
"""
|
|
|
|
module mcts
|
|
|
|
export MCTSNode, runMCTS
|
|
|
|
using Dates, UUIDs, DataStructures, JSON3, Random
|
|
using GeneralUtils
|
|
using ..type
|
|
|
|
# ---------------------------------------------- 100 --------------------------------------------- #
|
|
|
|
""" a node for MCTS search tree
|
|
|
|
Arguments\n
|
|
-----
|
|
state::T
|
|
a state of a game. Can be a Dict or something else.
|
|
For example:
|
|
state = Dict(
|
|
:info=> Dict(), # keyword info
|
|
:thoughtHistory=> Dict(
|
|
:question=> _,
|
|
:thought_1=> _,
|
|
:action_1=> _,
|
|
:observation_1=> _,
|
|
:thought_2=> _,
|
|
...
|
|
)
|
|
)
|
|
visits::Integer
|
|
number of time the game visits this state
|
|
stateValue::Float64
|
|
state value
|
|
|
|
Return\n
|
|
-----
|
|
|
|
Example\n
|
|
-----
|
|
```jldoctest
|
|
julia>
|
|
```
|
|
|
|
TODO\n
|
|
-----
|
|
[] update docstring
|
|
[DONE] implement the function
|
|
|
|
Signature\n
|
|
-----
|
|
"""
|
|
struct MCTSNode{T<:AbstractDict}
|
|
state::T
|
|
visits::Integer
|
|
stateValue::AbstractFloat
|
|
children::Dict{T, MCTSNode}
|
|
end
|
|
|
|
""" Select a node based on UCT score
|
|
|
|
Arguments\n
|
|
-----
|
|
node::MCTSNode
|
|
mcts node
|
|
w::Float64
|
|
exploration weight
|
|
Return\n
|
|
-----
|
|
|
|
Example\n
|
|
-----
|
|
```jldoctest
|
|
julia>
|
|
```
|
|
|
|
TODO\n
|
|
-----
|
|
[] update docstring
|
|
[DONE] check childNode.total_reward w/ LATS paper. Which value total_reward representing
|
|
|
|
Signature\n
|
|
-----
|
|
"""
|
|
function select(node::MCTSNode, w::Float64)
|
|
max_uct = -Inf
|
|
selectedNode = nothing
|
|
|
|
for (childState, childNode) in node.children
|
|
uctValue = childNode.stateValue +
|
|
w * sqrt(log(node.visits) / childNode.visits)
|
|
if uctValue > max_uct
|
|
max_uct = uctValue
|
|
selectedNode = childNode
|
|
end
|
|
end
|
|
|
|
return selectedNode
|
|
end
|
|
|
|
""" Expand selected node
|
|
|
|
Arguments\n
|
|
-----
|
|
node::MCTSNode
|
|
MCTS node
|
|
state::T
|
|
a state of a game. Can be a Dict or something else.
|
|
decisionMaker::Function
|
|
|
|
|
|
Return\n
|
|
-----
|
|
|
|
Example\n
|
|
-----
|
|
```jldoctest
|
|
julia>
|
|
```
|
|
|
|
TODO\n
|
|
-----
|
|
[] update docstring
|
|
[x] implement the function
|
|
|
|
Signature\n
|
|
-----
|
|
"""
|
|
function expand(a::T1, node::MCTSNode, state::T2, decisionMaker::Function, stateValueEstimator::Function;
|
|
n::Integer=3) where {T1<:agent, T2<:AbstractDict}
|
|
|
|
# sampling action from decisionMaker
|
|
for sample in 1:n
|
|
thought = decisionMaker(a, state)
|
|
error("--> expand")
|
|
newState = transition(node.state, action) #[] Implement your transition function
|
|
if newState ∉ keys(node.children)
|
|
node.children[newState] = MCTSNode(newState, 0, 0.0, Dict{T, MCTSNode}())
|
|
end
|
|
end
|
|
end
|
|
|
|
"""
|
|
|
|
Arguments\n
|
|
-----
|
|
|
|
Return\n
|
|
-----
|
|
|
|
Example\n
|
|
-----
|
|
```jldoctest
|
|
julia>
|
|
```
|
|
|
|
TODO\n
|
|
-----
|
|
[] update docstring
|
|
[] implement the function
|
|
[] reward only comes at terminal state
|
|
|
|
Signature\n
|
|
-----
|
|
"""
|
|
function simulate(state::T, max_depth::Int) where {T<:AbstractDict}
|
|
total_reward = 0.0
|
|
for _ in 1:max_depth
|
|
#[] Implement your action selection function based on highest stateValue
|
|
action = select_action(state) # current state
|
|
state, reward = transition(state, action) # Implement transition function to a new state
|
|
|
|
#[] check for the terminal state
|
|
|
|
total_reward += reward
|
|
end
|
|
return total_reward
|
|
end
|
|
|
|
"""
|
|
|
|
Arguments\n
|
|
-----
|
|
|
|
Return\n
|
|
-----
|
|
|
|
Example\n
|
|
-----
|
|
```jldoctest
|
|
julia>
|
|
```
|
|
|
|
TODO\n
|
|
-----
|
|
[] update docstring
|
|
[] implement the function
|
|
|
|
Signature\n
|
|
-----
|
|
"""
|
|
function backpropagate(node::MCTSNode, reward::Float64)
|
|
node.visits += 1
|
|
|
|
# [] there is no total_reward in the paper, buy they use stateValue
|
|
node.total_reward += reward
|
|
if !isempty(node.children)
|
|
best_child = argmax([child.total_reward / child.visits for child in values(node.children)])
|
|
backpropagate(node.children[best_child], -reward)
|
|
end
|
|
end
|
|
|
|
"""
|
|
|
|
Arguments\n
|
|
-----
|
|
|
|
Return\n
|
|
-----
|
|
|
|
Example\n
|
|
-----
|
|
```jldoctest
|
|
julia>
|
|
```
|
|
|
|
TODO\n
|
|
-----
|
|
[] update docstring
|
|
[] implement the function
|
|
|
|
Signature\n
|
|
-----
|
|
"""
|
|
function transition(state, action)
|
|
|
|
end
|
|
|
|
""" Check whether a node is a leaf node of a tree
|
|
|
|
Arguments\n
|
|
-----
|
|
node::MCTSNode
|
|
node of a tree
|
|
|
|
Return\n
|
|
-----
|
|
result::Bool
|
|
true if the node is a leaf node of a tree otherwise false
|
|
|
|
Example\n
|
|
-----
|
|
```jldoctest
|
|
julia> using
|
|
```
|
|
|
|
TODO\n
|
|
-----
|
|
[] update docstring
|
|
|
|
Signature\n
|
|
-----
|
|
"""
|
|
isLeaf(node::MCTSNode)::Bool = isempty(node.children)
|
|
|
|
|
|
|
|
"""
|
|
|
|
Arguments\n
|
|
-----
|
|
|
|
Return\n
|
|
-----
|
|
|
|
Example\n
|
|
-----
|
|
```jldoctest
|
|
julia>
|
|
```
|
|
|
|
TODO\n
|
|
-----
|
|
[] update docstring
|
|
[WORKING] implement the function
|
|
|
|
Signature\n
|
|
-----
|
|
"""
|
|
function iterativeprompting(a::T, prompt::String, verification::Function) where {T<:agent}
|
|
|
|
msgMeta = GeneralUtils.generate_msgMeta(
|
|
a.config[:thirdPartyService][:text2textinstruct],
|
|
senderName= "iterativeprompting",
|
|
senderId= a.id,
|
|
receiverName= "text2textinstruct",
|
|
)
|
|
|
|
outgoingMsg = Dict(
|
|
:msgMeta,
|
|
:payload=> Dict(
|
|
:text=> prompt,
|
|
)
|
|
)
|
|
|
|
# iteration loop
|
|
while true
|
|
# send prompt to LLM
|
|
response = GeneralUtils.sendReceiveMqttMsg()
|
|
|
|
# check for correctness and get feedback
|
|
|
|
if correct
|
|
|
|
break
|
|
else
|
|
# get LLM critique
|
|
|
|
# add critique to prompt
|
|
end
|
|
|
|
|
|
|
|
end
|
|
return
|
|
end
|
|
|
|
|
|
""" Check JSON format correctness, provide
|
|
|
|
Arguments\n
|
|
-----
|
|
|
|
Return\n
|
|
-----
|
|
|
|
Example\n
|
|
-----
|
|
```jldoctest
|
|
julia>
|
|
```
|
|
|
|
TODO\n
|
|
-----
|
|
[] update docstring
|
|
[WORKING] implement the function
|
|
|
|
Signature\n
|
|
-----
|
|
"""
|
|
function syntaxcheck_json(jsonstring::String)::NamedTuple
|
|
success, result, errormsg, st = GeneralUtils.showstracktrace(JSON3.read, jsonstring)
|
|
if !success # gives feedback
|
|
|
|
|
|
|
|
else
|
|
|
|
|
|
end
|
|
|
|
|
|
return (success, result, critique)
|
|
end
|
|
|
|
|
|
"""
|
|
|
|
Arguments\n
|
|
-----
|
|
|
|
Return\n
|
|
-----
|
|
|
|
Example\n
|
|
-----
|
|
```jldoctest
|
|
julia>
|
|
```
|
|
|
|
TODO\n
|
|
-----
|
|
[] update docstring
|
|
[] implement the function
|
|
[] implement RAG to pull similar experience
|
|
|
|
Signature\n
|
|
-----
|
|
"""
|
|
function isTerminal()
|
|
|
|
end
|
|
|
|
"""
|
|
|
|
Arguments\n
|
|
-----
|
|
|
|
Return\n
|
|
-----
|
|
|
|
Example\n
|
|
-----
|
|
```jldoctest
|
|
julia>
|
|
```
|
|
|
|
TODO\n
|
|
-----
|
|
[] update docstring
|
|
[] implement the function
|
|
|
|
Signature\n
|
|
-----
|
|
"""
|
|
function executeLLMFunction()
|
|
|
|
end
|
|
|
|
|
|
# ------------------------------------------------------------------------------------------------ #
|
|
# Create a complete example using the defined MCTS functions #
|
|
# ------------------------------------------------------------------------------------------------ #
|
|
""" Search for best action
|
|
|
|
Arguments\n
|
|
-----
|
|
initial state
|
|
initial state
|
|
decisionMaker::Function
|
|
decide what action to take
|
|
stateValueEstimator::Function
|
|
assess the value of the state
|
|
reflector::Function
|
|
generate lesson from trajectory and reward
|
|
n::Integer
|
|
how many times action will be sampled from decisionMaker
|
|
w::Float64
|
|
exploration weight
|
|
|
|
Return\n
|
|
-----
|
|
plan::Vector{Dict}
|
|
best plan
|
|
|
|
Example\n
|
|
-----
|
|
```jldoctest
|
|
julia>
|
|
```
|
|
|
|
TODO\n
|
|
-----
|
|
[] update docstring
|
|
|
|
Signature\n
|
|
-----
|
|
"""
|
|
function runMCTS(a::T, initialState, decisionMaker::Function, stateValueEstimator::Function,
|
|
reflector::Function, n::Integer, maxDepth::Integer,
|
|
maxIterations::Integer, w::Float64) where {T<:agent}
|
|
statetype = typeof(initialState)
|
|
root = MCTSNode(initialState, 0, 0.0, Dict{statetype, MCTSNode}())
|
|
|
|
for _ in 1:maxIterations
|
|
node = root
|
|
while !isLeaf(node)
|
|
node = select(node, w)
|
|
end
|
|
|
|
expand(a, node, node.state, decisionMaker, stateValueEstimator, n=n)
|
|
error("---> runMCTS")
|
|
leaf_node = node.children[node.state] # mark leaf node
|
|
reward = simulate(leaf_node.state, maxDepth)
|
|
backpropagate(leaf_node, reward)
|
|
end
|
|
|
|
best_child_state = argmax([child.total_reward / child.visits for child in values(root.children)])
|
|
return best_child_state
|
|
end
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
end |