""" To implement a Monte Carlo Tree Search (MCTS) algorithm in Julia with the UCT (Upper Confidence Bound for Trees) selection function, you can follow the steps below: Define the necessary types and functions for the MCTS algorithm: """ module mcts export MCTSNode, runMCTS, isleaf, selectBestNextState, selectBestTrajectory, transition, userChatbox using Dates, UUIDs, DataStructures, JSON3, Random, PrettyPrinting using GeneralUtils using ..type, ..llmfunction # ---------------------------------------------- 100 --------------------------------------------- # """ a node for MCTS search tree # Arguments - `state::T` a state of a game. Can be a Dict or something else. - `visits::Integer ` number of time the game visits this state - `stateValue::Float64` state value - `children::Dict{T, MCTSNode}` children node # Return - `nothing` # Example ```jldoctest julia> state = Dict( :info=> Dict(), # keyword info :thoughtHistory=> Dict( :question=> _, :thought_1=> _, :action_1=> _, :observation_1=> _, :thought_2=> _, ... ) ) ``` # TODO [] update docstring # Signature """ mutable struct MCTSNode{T1<:AbstractDict, T2<:AbstractString} nodekey::T2 state::T1 visits::Integer progressvalue::Number # estimate value by LLM's reasoning statevalue::Number # store discounted commulative reward (gather from its child node) reward::Number # this node's own reward isterminal::Bool parent::Union{MCTSNode, Nothing} children::Dict{String, MCTSNode} end """ Select a node based on UCT score # Arguments - `node::MCTSNode` mcts node - `w::T` exploration weight. Value is usually between 1 to 2. Value 1.0 makes MCTS balance between exploration and exploitation like 50%-50%. Value 2.0 makes MCTS aggressively search the tree. # Return - `selectedNode::MCTSNode` # Example ```jldoctest julia> ``` # Signature """ function UCTselect(node::MCTSNode, w::T)::MCTSNode where {T<:AbstractFloat} maxUCT = -Inf selectedNode = nothing for (childState, childNode) in node.children UCTvalue = if childNode.visits != 0 weightedterm = w * sqrt(log(node.visits) / childNode.visits) # explore term childNode.statevalue + weightedterm else # node.visits == 0 makes sqrt() in explore term error childNode.progressvalue # exploit term end if UCTvalue > maxUCT maxUCT = UCTvalue selectedNode = childNode end end return selectedNode end """ Expand selected node # Arguments - `a::T1` One of YiemAgent's agent - `node::MCTSNode` MCTS node - `state::T2` a state of a game. Can be a Dict or something else. - `decisionMaker::Function` a function that output Thought and Action - `evaluator::Function` a function that output trajectory progress score # Return # Example ```jldoctest julia> ``` # TODO [] update docstring [] try loop should limit to 3 times. if not succeed, skip [] newNodeKey ∉ keys(node.children). New state may have semantic vector close enought to one of existing child state. Which can be assume that they are the same state semantically-wise. [x] store feedback -> state -> agent. # Signature """ function expand(a::T1, node::MCTSNode, decisionMaker::Function, evaluator::Function, reflector::Function; totalsample::Integer=3 ) where {T1<:agent} nthSample = 0 while true nthSample += 1 if nthSample <= totalsample thoughtDict = decisionMaker(a, node.state) println("---> expand() sample $nthSample") pprintln(node.state[:thoughtHistory]) pprintln(thoughtDict) node.state[:thoughtDict] = thoughtDict newNodeKey, newstate = MCTStransition(a, node.state) # add evaluator stateevaluation, progressvalue = evaluator(a, newstate) if newstate[:reward] < 0 pprint(newstate[:thoughtHistory]) newstate[:evaluation] = stateevaluation newstate[:lesson] = reflector(a, newstate) # store new lesson for later use lessonDict = copy(JSON3.read("lesson.json")) latestLessonKey, latestLessonIndice = GeneralUtils.findHighestIndexKey(lessonDict, "lesson") nextIndice = latestLessonKey == :NA ? 1 : latestLessonIndice + 1 newLessonKey = Symbol("lesson_$(nextIndice)") lessonDict[newLessonKey] = newstate open("lesson.json", "w") do io JSON3.pretty(io, lessonDict) end print("---> reflector()") end if newNodeKey ∉ keys(node.children) node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward], newstate[:isterminal], node, Dict{String, MCTSNode}()) end else break end end end """ Simulate interactions between agent and environment # Arguments - `a::T` one of YiemAgent's agent - `node::MCTSNode` node that will be a simulation starting point. - `decisionMaker::Function` function that receive state return Thought and Action # Return - `simTrajectoryReward::Number` # Example ```jldoctest julia> ``` # TODO - [] update docs # Signature """ function simulate(a::T, node::MCTSNode, decisionMaker::Function, evaluator::Function, reflector::Function; maxDepth::Integer=3, totalsample::Integer=3 )::Union{Tuple{Number, Dict{Symbol, <:Any}}, Tuple{Number, Nothing}} where {T<:agent} simTrajectoryReward = 0.0 terminalstate = nothing for depth in 1:maxDepth simTrajectoryReward += node.reward if node.isterminal terminalstate = node.state break else expand(a, node, decisionMaker, evaluator, reflector; totalsample=totalsample) node = selectChildNode(node) end end return (simTrajectoryReward, terminalstate) end """ Backpropagate reward along the simulation chain # Arguments - `node::MCTSNode` leaf node of a search tree - `simTrajectoryReward::T` total reward from trajectory simulation # Return - `No return` # Example ```jldoctest julia> ``` # Signature """ function backpropagate(node::MCTSNode, simTrajectoryReward::T; discountRewardCoeff::AbstractFloat=0.9) where {T<:Number} while !isroot(node) # Update the statistics of the current node based on the result of the playout node.visits += 1 node.statevalue += ((node.statevalue * (node.visits-1)) + simTrajectoryReward) / node.visits simTrajectoryReward *= discountRewardCoeff # discount because future reward is uncertain node = node.parent end end """ Get a new state # Arguments - `a::T1` one of YiemAgent's agent - `state::T2` current game state - `thoughtDict::T3` contain Thought, Action, Observation - `isterminal::Function` a function to determine terminal state # Return - `(newNodeKey, newstate, isterminalstate, reward)::Tuple{String, Dict{Symbol, <:Any}, Bool, <:Number}` # Example ```jldoctest julia> state = Dict{Symbol, Dict{Symbol, Any}}( :thoughtHistory => Dict(:question => "Hello, I want to buy a bottle of wine."), :storeinfo => Dict(), :customerinfo => Dict() ) julia> thoughtDict = Dict( :question=> "I want to buy a bottle of wine.", :thought_1=> "The customer wants to buy a bottle of wine.", :action_1=> Dict{Symbol, Any}( :name=>"Chatbox", :input=>"What occasion are you buying the wine for?", ), :observation_1 => "" ) ``` # TODO - [x] add other actions - [] add embedding of newstate and store in newstate[:embedding] # Signature """ function MCTStransition(a::T1, state::T2 )::Tuple{String, Dict{Symbol, <:Any}} where {T1<:agent, T2<:AbstractDict} thoughtDict = state[:thoughtDict] actionname = thoughtDict[:action][:name] actioninput = thoughtDict[:action][:input] # map action and input() to llm function response, select, reward, isterminal = if actionname == "chatbox" virtualWineUserChatbox(a, actioninput) # virtual customer elseif actionname == "winestock" winestock(a, actioninput) elseif actionname == "recommendbox" virtualWineUserRecommendbox(a, actioninput) else error("undefined LLM function. Requesting $actionname") end latestThoughtKey, latestThoughtIndice = GeneralUtils.findHighestIndexKey(state[:thoughtHistory], "thought") nextIndice = latestThoughtKey == :NA ? 1 : latestThoughtIndice + 1 latestThoughtKey = Symbol("thought_$nextIndice") latestActionKey = Symbol("action_$nextIndice") # add Thought, action, observation to thoughtHistory newstate = deepcopy(state) newstate[:thoughtHistory][latestThoughtKey] = thoughtDict[:thought] newstate[:thoughtHistory][latestActionKey] = thoughtDict[:action] newObservationKey = Symbol("observation_$(nextIndice)") newstate[:thoughtHistory][newObservationKey] = response newstate[:reward] = reward newstate[:select] = select newstate[:isterminal] = isterminal newNodeKey = GeneralUtils.uuid4snakecase() return (newNodeKey, newstate) end """ Get a new state # Arguments - `a::T1` one of YiemAgent's agent - `state::T2` current game state - `thoughtDict::T3` contain Thought, Action, Observation - `isterminal::Function` a function to determine terminal state # Return - `(newNodeKey, newstate, isterminalstate, reward)::Tuple{String, Dict{Symbol, <:Any}, Bool, <:Number}` # Example ```jldoctest julia> state = Dict{Symbol, Dict{Symbol, Any}}( :thoughtHistory => Dict(:question => "Hello, I want to buy a bottle of wine."), :storeinfo => Dict(), :customerinfo => Dict() ) julia> thoughtDict = Dict( :question=> "I want to buy a bottle of wine.", :thought_1=> "The customer wants to buy a bottle of wine.", :action_1=> Dict{Symbol, Any}( :name=>"Chatbox", :input=>"What occasion are you buying the wine for?", ), :observation_1 => "" ) ``` # TODO - [x] add other actions - [] add embedding of newstate and store in newstate[:embedding] # Signature """ function transition(a::T1, state::T2 )::Dict{Symbol, <:Any} where {T1<:agent, T2<:AbstractDict} thoughtDict = state[:thoughtDict] actionname = thoughtDict[:action][:name] actioninput = thoughtDict[:action][:input] # map action and input() to llm function response, select, reward, isterminal = if actionname == "chatbox" userChatbox(a, actioninput) # virtual customer elseif actionname == "winestock" winestock(a, actioninput) elseif actionname == "recommendbox" userRecommendbox(a, actioninput) else error("undefined LLM function. Requesting $actionname") end latestThoughtKey, latestThoughtIndice = GeneralUtils.findHighestIndexKey(state[:thoughtHistory], "thought") nextIndice = latestThoughtKey == :NA ? 1 : latestThoughtIndice + 1 latestThoughtKey = Symbol("thought_$nextIndice") latestActionKey = Symbol("action_$nextIndice") # add Thought, action, observation to thoughtHistory newstate = deepcopy(state) newstate[:thoughtHistory][latestThoughtKey] = thoughtDict[:thought] newstate[:thoughtHistory][latestActionKey] = thoughtDict[:action] newObservationKey = Symbol("observation_$(nextIndice)") newstate[:thoughtHistory][newObservationKey] = response newstate[:reward] = reward newstate[:select] = select newstate[:isterminal] = isterminal return newstate end """ Determine whether a node is a leaf node of a search tree. # Arguments - `node::MCTSNode` a search tree node # Return - `result::Bool` true if it is a leaf node, false otherwise. # Example ```jldoctest julia> using Revise julia> using YiemAgent, DataStructures julia> initialState = Dict{Symbol, Any}( :customerinfo=> Dict{Symbol, Any}(), :storeinfo=> Dict{Symbol, Any}(), :thoughtHistory=> OrderedDict{Symbol, Any}( :question=> "How are you?", ) ) julia> statetype = typeof(initialState) julia> root = YiemAgent.MCTSNode(initialState, 0, 0.0, Dict{statetype, YiemAgent.MCTSNode}()) julia> YiemAgent.isleaf(root) true ``` # TODO [] update docs # Signature """ isleaf(node::MCTSNode)::Bool = isempty(node.children) """ Select child node based on the highest statevalue # Arguments - `node::MCTSNode` node of a search tree # Return - `childNode::MCTSNode` the highest value child node # Example ```jldoctest julia> ``` # Signature """ function selectChildNode(node::MCTSNode)::MCTSNode highestProgressValue = 0 nodekey = nothing # loop thought node children dictionary to find the highest progress value for (k, childNode) in node.children potential = childNode.progressvalue + childNode.reward if childNode.reward > 0 #XXX for testing. remove when done. println("") end if potential > highestProgressValue highestProgressValue = potential nodekey = childNode.nodekey end end return node.children[nodekey] end """ # Arguments - `node::MCTSNode` node of a search tree # Return - `childNode::MCTSNode` the highest value child node # Example ```jldoctest julia> ``` # TODO - [] update docs - [TESTING] implement the function # Signature """ function selectBestNextState(node::MCTSNode)::MCTSNode highestProgressValue = 0 nodekey = nothing # if all childnode has statevalue == 0, use progressvalue + reward to select the best node stateValueSum = sum([v.statevalue for (k, v) in node.children]) if stateValueSum != 0 for (k, childnode) in node.children potential = childnode.statevalue / childnode.visits if potential > highestProgressValue highestProgressValue = potential nodekey = childnode.nodekey end end else for (k, childnode) in node.children potential = childnode.progressvalue + childnode.reward if potential > highestProgressValue highestProgressValue = potential nodekey = childnode.nodekey end end end return node.children[nodekey] end """ # Arguments - `node::MCTSNode` node of a search tree # Return - `childNode::MCTSNode` the highest value child node # Example ```jldoctest julia> ``` # TODO - [] update docs - [TESTING] implement the function # Signature """ function selectBestTrajectory(node::MCTSNode)::MCTSNode while !isleaf(node) node = selectBestNextState(node) end return node end """ Determine wheter a given node is a root node # Arguments - `node::MCTSNode` node of a search tree # Return - `isrootnode::Bool` true if the given node is root node, false otherwise # Example ```jldoctest julia> ``` # Signature """ isroot(node::MCTSNode)::Bool = node.nodekey == "root" ? true : false # ------------------------------------------------------------------------------------------------ # # Create a complete example using the defined MCTS functions # # ------------------------------------------------------------------------------------------------ # """ Search the best action to take for a given state and task # Arguments - `a::agent` one of Yiem's agents - `initial state` initial state - `decisionMaker::Function` decide what action to take - `evaluator::Function` assess the value of the state - `reflector::Function` generate lesson from trajectory and reward - `isterminal::Function` determine whether a given state is a terminal state - `n::Integer` how many times action will be sampled from decisionMaker - `w::Float64` exploration weight. Value is usually between 1 to 2. Value 1.0 makes MCTS balance between exploration and exploitation like 50%-50% Value 2.0 makes MCTS aggressively search the tree # Return - `plan::Vector{Dict}` best plan # Example ```jldoctest julia> ``` # TODO [] update docstring [x] return best action # Signature """ function runMCTS( a::T1, initialState, decisionMaker::Function, evaluator::Function, reflector::Function; totalsample::Integer=3, maxDepth::Integer=3, maxiterations::Integer=10, explorationweight::Number=1.0, ) where {T1<:agent} root = MCTSNode("root", initialState, 0, 0, 0, 0, false, nothing, Dict{String, MCTSNode}()) for nth in 1:maxiterations node = root node.visits += 1 while !isleaf(node) node = UCTselect(node, explorationweight) end if node.isterminal # MCTS arrive at the leaf node that is also a terminal state, # do nothing then go directly to backpropagation backpropagate(leafNode, node.reward) else expand(a, node, decisionMaker, evaluator, reflector; totalsample=totalsample) leafNode = selectChildNode(node) simTrajectoryReward, terminalstate = simulate(a, leafNode, decisionMaker, evaluator, reflector; maxDepth=maxDepth, totalsample=totalsample) if terminalstate !== nothing terminalstate[:totalTrajectoryReward] = simTrajectoryReward end #[] write best state to file if it has higher simTrajectoryReward. Use to improve evaluation # open("trajectory.json", "w") do io # JSON3.pretty(io, terminalstate) # end backpropagate(leafNode, simTrajectoryReward) end end bestNextState = selectBestNextState(root) besttrajectory = selectBestTrajectory(root) return (bestNextState.state, besttrajectory.state) end end # module mcts