""" To implement a Monte Carlo Tree Search (MCTS) algorithm in Julia with the UCT (Upper Confidence Bound for Trees) selection function, you can follow the steps below: Define the necessary types and functions for the MCTS algorithm: """ module mcts export MCTSNode, runMCTS, isleaf using Dates, UUIDs, DataStructures, JSON3, Random, PrettyPrinting using GeneralUtils using ..type, ..llmfunction # ---------------------------------------------- 100 --------------------------------------------- # """ a node for MCTS search tree # Arguments - `state::T` a state of a game. Can be a Dict or something else. - `visits::Integer ` number of time the game visits this state - `stateValue::Float64` state value - `children::Dict{T, MCTSNode}` children node # Return - `nothing` # Example ```jldoctest julia> state = Dict( :info=> Dict(), # keyword info :thoughtHistory=> Dict( :question=> _, :thought_1=> _, :action_1=> _, :observation_1=> _, :thought_2=> _, ... ) ) ``` # TODO [] update docstring # Signature """ struct MCTSNode{T<:AbstractDict} nodekey::String state::T visits::Integer progressValue::Number reward::Number isterminal::Bool parent::Union{MCTSNode, Nothing} children::Dict{String, MCTSNode} end """ Select a node based on UCT score # Arguments - `node::MCTSNode` mcts node - `w::Float64` exploration weight # Return # Example ```jldoctest julia> ``` # TODO [] update docstring [TESTING] check childNode.total_reward w/ LATS paper. Which value total_reward representing # Signature """ function UCTselect(node::MCTSNode, w::Float64) max_uct = -Inf selectedNode = nothing for (childState, childNode) in node.children uctValue = childNode.stateValue + w * sqrt(log(node.visits) / childNode.visits) if uctValue > max_uct max_uct = uctValue selectedNode = childNode end end return selectedNode end """ Expand selected node # Arguments - `a::T1` One of YiemAgent's agent - `node::MCTSNode` MCTS node - `state::T2` a state of a game. Can be a Dict or something else. - `decisionMaker::Function` a function that output Thought and Action - `progressValueEstimator::Function` a function that output trajectory progress score # Return # Example ```jldoctest julia> ``` # TODO - [] update docstring # Signature """ function expand(a::T1, node::MCTSNode, decisionMaker::Function, progressValueEstimator::Function, isterminal::Function; n::Integer=3) where {T1<:agent} # sampling action from decisionMaker for sample in 1:n thoughtDict = decisionMaker(a, node.state) newNodeKey, newstate, isterminalstate, reward = MCTStransition(a, node.state, thoughtDict, isterminal) # add progressValueEstimator progressRationale, progressValue = progressValueEstimator(a, newstate) if newNodeKey ∉ keys(node.children) node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, progressValue, reward, isterminalstate, node, Dict{String, MCTSNode}()) end end end """ # Arguments # Return # Example ```jldoctest julia> ``` # TODO - [] update docstring - [WORKING] implement the function - [] check for the terminal state (node.reward != 0), break if it is terminal state # Signature """ function simulate(a, node::MCTSNode, decisionMaker::Function, progressValueEstimator::Function, isterminal::Function, max_depth::Int; n=3) for _ in 1:max_depth node = selectChildNode(node) expand(a, node, decisionMaker, progressValueEstimator, isterminal, n=n) # if isterminal (use for loop over node to look for childNode.reward != 0) end error("--> simulate") return total_reward end """ # Arguments # Return # Example ```jldoctest julia> ``` # TODO - [] update docstring - [] implement the function # Signature """ function backpropagate(node::MCTSNode, reward::Float64) node.visits += 1 # [] there is no total_reward in the paper, buy they use stateValue node.total_reward += reward if !isempty(node.children) best_child = argmax([child.total_reward / child.visits for child in values(node.children)]) backpropagate(node.children[best_child], -reward) end end """ Get a new state # Arguments - `a::T1` one of YiemAgent's agent - `state::T2` current game state - `thoughtDict::T3` contain Thought, Action, Observation - `isterminal::Function` a function to determine terminal state # Return - `(newNodeKey, newstate, isterminalstate, reward)::Tuple{String, Dict{Symbol, <:Any}, Bool, <:Number}` # Example ```jldoctest julia> state = Dict{Symbol, Dict{Symbol, Any}}( :thoughtHistory => Dict(:Question => "Hello, I want to buy a bottle of wine."), :storeinfo => Dict(), :customerinfo => Dict() ) julia> thoughtDict = Dict( :Question=> "I want to buy a bottle of wine.", :Thought_1=> "The customer wants to buy a bottle of wine.", :Action_1=> Dict{Symbol, Any}( :name=>"Chatbox", :input=>"What occasion are you buying the wine for?", ), :Observation_1 => "" ) ``` # TODO - [PENDING] add other actions - [] add embedding of newstate and store in newstate[:embedding] # Signature """ function MCTStransition(a::T1, state::T2, thoughtDict::T3, isterminal::Function )::Tuple{String, Dict{Symbol, <:Any}, Bool, <:Number} where {T1<:agent, T2<:AbstractDict, T3<:AbstractDict} pprint(thoughtDict) actionname = thoughtDict[:Action][:name] actioninput = thoughtDict[:Action][:input] # map action and input() to llm function response = if actionname == "chatbox" virtualWineCustomerChatbox(a, actioninput) # virtual customer elseif actionname == "winestock" elseif actionname == "finish" else end latestThoughtKey, latestThoughtIndice = GeneralUtils.findHighestIndexKey(state[:thoughtHistory], "Thought") nextIndice = latestThoughtKey == :NA ? 1 : latestThoughtIndice + 1 latestThoughtKey = Symbol("Thought_$nextIndice") latestActionKey = Symbol("Action_$nextIndice") # add Thought, action, observation to thoughtHistory newstate = deepcopy(state) newstate[:thoughtHistory][latestThoughtKey] = thoughtDict[:Thought] newstate[:thoughtHistory][latestActionKey] = thoughtDict[:Action] latestObservationKey = Symbol("Observation_$(nextIndice)") newstate[:thoughtHistory][latestObservationKey] = response newNodeKey = GeneralUtils.uuid4snakecase() isterminalstate, reward = isterminal(newstate) return (newNodeKey, newstate, isterminalstate, reward) end """ Determine whether a node is a leaf node of a search tree. # Arguments - `node::MCTSNode` a search tree node # Return - `result::Bool` true if it is a leaf node, false otherwise. # Example ```jldoctest julia> using Revise julia> using YiemAgent, DataStructures julia> initialState = Dict{Symbol, Any}( :customerinfo=> Dict{Symbol, Any}(), :storeinfo=> Dict{Symbol, Any}(), :thoughtHistory=> OrderedDict{Symbol, Any}( :Question=> "How are you?", ) ) julia> statetype = typeof(initialState) julia> root = YiemAgent.MCTSNode(initialState, 0, 0.0, Dict{statetype, YiemAgent.MCTSNode}()) julia> YiemAgent.isleaf(root) true ``` # Signature """ isleaf(node::MCTSNode)::Bool = isempty(node.children) """ Select child node based on the highest progressValue # Arguments - `node::MCTSNode` node of a search tree # Return - `childNode::MCTSNode` the highest value child node # Example ```jldoctest julia> ``` # Signature """ function selectChildNode(node::MCTSNode)::MCTSNode highestProgressValue = 0 nodekey = nothing # loop thought node children dictionary to find the highest progress value for (k, childNode) in node.children if childNode.progressValue > highestProgressValue highestProgressValue = childNode.progressValue nodekey = childNode.nodekey end end return node.children[nodekey] end # ------------------------------------------------------------------------------------------------ # # Create a complete example using the defined MCTS functions # # ------------------------------------------------------------------------------------------------ # """ Search the best action to take for a given state and task # Arguments - `a::agent` one of Yiem's agents - `initial state` initial state - `decisionMaker::Function` decide what action to take - `progressValueEstimator::Function` assess the value of the state - `reflector::Function` generate lesson from trajectory and reward - `isterminal::Function` determine whether a given state is a terminal state - `n::Integer` how many times action will be sampled from decisionMaker - `w::Float64` exploration weight # Return - `plan::Vector{Dict}` best plan # Example ```jldoctest julia> ``` # TODO [] update docstring # Signature """ function runMCTS( a::T1, initialState, decisionMaker::Function, progressValueEstimator::Function, reflector::Function, isterminal::Function, n::Integer, maxDepth::Integer, maxIterations::Integer, w::Float64) where {T1<:agent} root = MCTSNode("root", initialState, 0, 0, 0, false, nothing, Dict{String, MCTSNode}()) for _ in 1:maxIterations node = root while !isleaf(node) node = UCTselect(node, w) end expand(a, node, decisionMaker, progressValueEstimator, isterminal, n=n) # from paper, just start simulation at this node. Not the node that newly expanded startsim_node = node reward = simulate(a, startsim_node, decisionMaker, progressValueEstimator, isterminal, maxDepth, n=n) backpropagate(leaf_node, reward) end best_child_state = argmax([child.total_reward / child.visits for child in values(root.children)]) error("---> runMCTS") return best_child_state end end # module mcts