From abefbaa743b43e23fe27118650845081735a01da Mon Sep 17 00:00:00 2001 From: narawat lamaiin Date: Sat, 1 Jun 2024 08:20:39 +0700 Subject: [PATCH] update --- src/mcts.jl | 729 ---------------------------------------------------- 1 file changed, 729 deletions(-) delete mode 100644 src/mcts.jl diff --git a/src/mcts.jl b/src/mcts.jl deleted file mode 100644 index ed7fc1a..0000000 --- a/src/mcts.jl +++ /dev/null @@ -1,729 +0,0 @@ -""" https://www.harrycodes.com/blog/monte-carlo-tree-search -""" - -module mcts - -export MCTSNode, runMCTS, isleaf, selectBestNextState, selectBestTrajectory, transition, - userChatbox, makeNewState - -using Dates, UUIDs, DataStructures, JSON3, Random, PrettyPrinting -using GeneralUtils -using ..type, ..llmfunction - -# ---------------------------------------------- 100 --------------------------------------------- # - -""" a node for MCTS search tree - -# Arguments - - `state::T` - a state of a game. Can be a Dict or something else. - - `visits::Integer ` - number of time the game visits this state - - `stateValue::Float64` - state value - - `children::Dict{T, MCTSNode}` - children node - -# Return - - `nothing` -# Example -```jldoctest -julia> state = Dict( - :info=> Dict(), # keyword info - :thoughtHistory=> Dict( - :question=> _, - :thought_1=> _, - :action_1=> _, - :observation_1=> _, - :thought_2=> _, - ... - ) - ) -``` - -# TODO - [] update docstring - -# Signature -""" -mutable struct MCTSNode{T1<:AbstractDict, T2<:AbstractString} - nodekey::T2 - state::T1 - visits::Integer - progressvalue::Number # estimate value by LLM's reasoning - statevalue::Number # store discounted commulative reward (gather from its child node) - reward::Number # this node's own reward - isterminal::Bool - parent::Union{MCTSNode, Nothing} - children::Dict{String, MCTSNode} -end - -""" Select a node based on UCT score - -# Arguments - - `node::MCTSNode` - mcts node - - `w::T` - exploration weight. Value is usually between 1 to 2. - Value 1.0 makes MCTS balance between exploration and exploitation like 50%-50%. - Value 2.0 makes MCTS aggressively search the tree. -# Return - - `selectedNode::MCTSNode` - -# Example -```jldoctest -julia> -``` - -# Signature -""" -function UCTselect(node::MCTSNode, w::T)::MCTSNode where {T<:AbstractFloat} - maxUCT = -Inf - selectedNode = nothing - - for (childState, childNode) in node.children - UCTvalue = - if childNode.visits != 0 - weightedterm = w * sqrt(log(node.visits) / childNode.visits) # explore term - childNode.statevalue + weightedterm - else # node.visits == 0 makes sqrt() in explore term error - childNode.progressvalue # exploit term - end - - if UCTvalue > maxUCT - maxUCT = UCTvalue - selectedNode = childNode - end - end - - return selectedNode -end - - -""" Expand selected node - -# Arguments - - `a::T1` - One of YiemAgent's agent - - `node::MCTSNode` - MCTS node - - `state::T2` - a state of a game. Can be a Dict or something else. - - `decisionMaker::Function` - a function that output Thought and Action - - `evaluator::Function` - a function that output trajectory progress score - -# Return - -# Example -```jldoctest -julia> -``` - -# TODO - [] update docstring - [] try loop should limit to 3 times. if not succeed, skip - [] newNodeKey ∉ keys(node.children). New state may have semantic vector close enought to one of existing child state. Which can be assume that they are the same state semantically-wise. - [x] store feedback -> state -> agent. - - -# Signature -""" -function expand(a::T1, node::MCTSNode, decisionMaker::Function, - evaluator::Function, reflector::Function; totalsample::Integer=3 - ) where {T1<:agent} - - nthSample = 0 - while true - nthSample += 1 - if nthSample <= totalsample - thoughtDict = decisionMaker(a, node.state) - println("---> expand() sample $nthSample") - pprintln(node.state[:thoughtHistory]) - pprintln(thoughtDict) - newNodeKey, newstate = MCTStransition(a, node.state, thoughtDict) - - stateevaluation, progressvalue = evaluator(a, newstate) - - if newstate[:reward] < 0 - pprint(newstate[:thoughtHistory]) - newstate[:evaluation] = stateevaluation - newstate[:lesson] = reflector(a, newstate) - - # store new lesson for later use - lessonDict = copy(JSON3.read("lesson.json")) - latestLessonKey, latestLessonIndice = - GeneralUtils.findHighestIndexKey(lessonDict, "lesson") - nextIndice = latestLessonKey == :NA ? 1 : latestLessonIndice + 1 - newLessonKey = Symbol("lesson_$(nextIndice)") - lessonDict[newLessonKey] = newstate - open("lesson.json", "w") do io - JSON3.pretty(io, lessonDict) - end - print("---> reflector()") - end - - if newNodeKey ∉ keys(node.children) - node.children[newNodeKey] = - MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward], - newstate[:isterminal], node, Dict{String, MCTSNode}()) - end - else - break - end - end -end - - - -""" Simulate interactions between agent and environment - -# Arguments - - `a::T` - one of YiemAgent's agent - - `node::MCTSNode` - node that will be a simulation starting point. - - `decisionMaker::Function` - function that receive state return Thought and Action - -# Return - - `simTrajectoryReward::Number` - -# Example -```jldoctest -julia> -``` - -# TODO - - [] update docs - -# Signature -""" -function simulate(a::T, node::MCTSNode, decisionMaker::Function, evaluator::Function, - reflector::Function; maxDepth::Integer=3, totalsample::Integer=3 - )::Union{Tuple{Number, Dict{Symbol, <:Any}}, Tuple{Number, Nothing}} where {T<:agent} - - simTrajectoryReward = 0.0 - terminalstate = nothing - - for depth in 1:maxDepth - simTrajectoryReward += node.reward - if node.isterminal - terminalstate = node.state - break - else - expand(a, node, decisionMaker, evaluator, reflector; totalsample=totalsample) - node = selectChildNode(node) - end - end - - return (simTrajectoryReward, terminalstate) -end - -""" Backpropagate reward along the simulation chain - -# Arguments - - `node::MCTSNode` - leaf node of a search tree - - `simTrajectoryReward::T` - total reward from trajectory simulation - -# Return - - `No return` - -# Example -```jldoctest -julia> -``` - -# Signature -""" -function backpropagate(node::MCTSNode, simTrajectoryReward::T; - discountRewardCoeff::AbstractFloat=0.9) where {T<:Number} - while !isroot(node) - # Update the statistics of the current node based on the result of the playout - node.visits += 1 - node.statevalue += ((node.statevalue * (node.visits-1)) + simTrajectoryReward) / node.visits - simTrajectoryReward *= discountRewardCoeff # discount because future reward is uncertain - node = node.parent - end -end - - -""" Get a new state - -# Arguments - - `a::T1` - one of YiemAgent's agent - - `state::T2` - current game state - - `thoughtDict::T3` - contain Thought, Action, Observation - - `isterminal::Function` - a function to determine terminal state - -# Return - - `(newNodeKey, newstate, isterminalstate, reward)::Tuple{String, Dict{Symbol, <:Any}, Bool, <:Number}` - -# Example -```jldoctest -julia> state = Dict{Symbol, Dict{Symbol, Any}}( - :thoughtHistory => Dict(:question => "Hello, I want to buy a bottle of wine."), - :storeinfo => Dict(), - :customerinfo => Dict() - ) -julia> thoughtDict = Dict( - :question=> "I want to buy a bottle of wine.", - :thought_1=> "The customer wants to buy a bottle of wine.", - :action_1=> Dict{Symbol, Any}( - :name=>"Chatbox", - :input=>"What occasion are you buying the wine for?", - ), - :observation_1 => "" - ) -``` - -# TODO - - [x] add other actions - - [] add embedding of newstate and store in newstate[:embedding] - -# Signature -""" -function MCTStransition(a::T1, state::T2, thoughtDict::T2 - )::Tuple{String, Dict{Symbol, <:Any}} where {T1<:agent, T2<:AbstractDict} - - actionname = thoughtDict[:action][:name] - actioninput = thoughtDict[:action][:input] - - # map action and input() to llm function - response, select, reward, isterminal = - if actionname == "chatbox" - # deepcopy(state[:virtualCustomerChatHistory]) because I want to keep it clean - # so that other simulation start from this same node is not contaminated with actioninput - virtualWineUserChatbox(a, actioninput, deepcopy(state[:virtualCustomerChatHistory])) # virtual customer - elseif actionname == "winestock" - winestock(a, actioninput) - elseif actionname == "recommendbox" - virtualWineUserRecommendbox(a, actioninput) - else - error("undefined LLM function. Requesting $actionname") - end - - newNodeKey, newstate = makeNewState(state, thoughtDict, response, select, reward, isterminal) - if actionname == "chatbox" - push!(newstate[:virtualCustomerChatHistory], Dict(:name=>"assistant", :text=> actioninput) ) - push!(newstate[:virtualCustomerChatHistory], Dict(:name=>"user", :text=> response)) - end - - return (newNodeKey, newstate) -end - - -""" Get a new state - -# Arguments - - `a::T1` - one of YiemAgent's agent - - `state::T2` - current game state - - `thoughtDict::T3` - contain Thought, Action, Observation - - `isterminal::Function` - a function to determine terminal state - -# Return - - `(newNodeKey, newstate, isterminalstate, reward)::Tuple{String, Dict{Symbol, <:Any}, Bool, <:Number}` - -# Example -```jldoctest -julia> state = Dict{Symbol, Dict{Symbol, Any}}( - :thoughtHistory => Dict(:question => "Hello, I want to buy a bottle of wine."), - :storeinfo => Dict(), - :customerinfo => Dict() - ) -julia> thoughtDict = Dict( - :question=> "I want to buy a bottle of wine.", - :thought_1=> "The customer wants to buy a bottle of wine.", - :action_1=> Dict{Symbol, Any}( - :name=>"Chatbox", - :input=>"What occasion are you buying the wine for?", - ), - :observation_1 => "" - ) -``` - -# TODO - - [x] add other actions - - [] add embedding of newstate and store in newstate[:embedding] - -# Signature -""" -function transition(a::T1, state::T2, thoughtDict::T2 - )::Dict{Symbol, <:Any} where {T1<:agent, T2<:AbstractDict} - - thoughtDict = state[:thoughtDict] - actionname = thoughtDict[:action][:name] - actioninput = thoughtDict[:action][:input] - - # map action and input() to llm function - response, select, reward, isterminal = - if actionname == "winestock" - winestock(a, actioninput) - else - error("undefined LLM function. Requesting $actionname") - end - - return makeNewState(state, thoughtDict, response, select, reward, isterminal) -end - - -""" - -# Arguments - -# Return - -# Example -```jldoctest -julia> -``` - -# TODO - - [] update docstring - - [x] implement the function - -# Signature -""" -function makeNewState(currentstate::T1, thoughtDict::T4, response::T2, select::Union{T3, Nothing}, - reward::T3, isterminal::Bool - )::Tuple{String, Dict{Symbol, <:Any}} where {T1<:AbstractDict, T2<:AbstractString, T3<:Number, T4<:AbstractDict} - - currentstate_latestThoughtKey, currentstate_latestThoughtIndice = - GeneralUtils.findHighestIndexKey(currentstate[:thoughtHistory], "thought") - currentstate_nextIndice = currentstate_latestThoughtKey == :NA ? 1 : currentstate_latestThoughtIndice + 1 - currentstate_latestThoughtKey = Symbol("thought_$currentstate_nextIndice") - latestActionKey = Symbol("action_$currentstate_nextIndice") - - _, thoughtDict_latestThoughtIndice = - GeneralUtils.findHighestIndexKey(thoughtDict, "thought") - - thoughtDict_latestThoughtKey, thoughtDict_latestActionKey = - if thoughtDict_latestThoughtIndice == -1 - (:thought, :action) - else - ( - Symbol("thought_$thoughtDict_latestThoughtIndice"), - Symbol("action_$thoughtDict_latestThoughtIndice"), - ) - end - - # add Thought, action, observation to thoughtHistory - newstate = deepcopy(currentstate) - newstate[:thoughtHistory][currentstate_latestThoughtKey] = - thoughtDict[thoughtDict_latestThoughtKey] - newstate[:thoughtHistory][latestActionKey] = thoughtDict[thoughtDict_latestActionKey] - newObservationKey = Symbol("observation_$(currentstate_nextIndice)") - newstate[:thoughtHistory][newObservationKey] = response - newstate[:reward] = reward - newstate[:select] = select - newstate[:isterminal] = isterminal - - newNodeKey = GeneralUtils.uuid4snakecase() - - return (newNodeKey, newstate) -end - - - -""" Determine whether a node is a leaf node of a search tree. - -# Arguments - - `node::MCTSNode` - a search tree node -# Return - - `result::Bool` - true if it is a leaf node, false otherwise. -# Example -```jldoctest -julia> using Revise -julia> using YiemAgent, DataStructures -julia> initialState = Dict{Symbol, Any}( - :customerinfo=> Dict{Symbol, Any}(), - :storeinfo=> Dict{Symbol, Any}(), - - :thoughtHistory=> OrderedDict{Symbol, Any}( - :question=> "How are you?", - ) - ) -julia> statetype = typeof(initialState) -julia> root = YiemAgent.MCTSNode(initialState, 0, 0.0, Dict{statetype, YiemAgent.MCTSNode}()) -julia> YiemAgent.isleaf(root) -true -``` - -# TODO - [] update docs - -# Signature -""" -isleaf(node::MCTSNode)::Bool = isempty(node.children) - - -""" Select child node based on the highest statevalue - -# Arguments - - `node::MCTSNode` - node of a search tree - -# Return - - `childNode::MCTSNode` - the highest value child node - -# Example -```jldoctest -julia> -``` - -# Signature -""" -function selectChildNode(node::MCTSNode)::MCTSNode - highestProgressValue = 0 - nodekey = nothing - - # loop thought node children dictionary to find the highest progress value - for (k, childNode) in node.children - potential = childNode.progressvalue + childNode.reward - if childNode.reward > 0 #XXX for testing. remove when done. - println("") - end - if potential > highestProgressValue - highestProgressValue = potential - nodekey = childNode.nodekey - end - end - - return node.children[nodekey] -end - - - -""" - -# Arguments - - `node::MCTSNode` - node of a search tree - -# Return - - `childNode::MCTSNode` - the highest value child node - -# Example -```jldoctest -julia> -``` - -# TODO - - [] update docs - - [x] implement the function - -# Signature -""" -function selectBestNextState(node::MCTSNode)::MCTSNode - highestProgressValue = 0 - nodekey = nothing - - # if all childnode has statevalue == 0, use progressvalue + reward to select the best node - stateValueSum = sum([v.statevalue for (k, v) in node.children]) - - if stateValueSum != 0 - for (k, childnode) in node.children - potential = childnode.statevalue / childnode.visits - - if potential > highestProgressValue - highestProgressValue = potential - nodekey = childnode.nodekey - end - end - else - for (k, childnode) in node.children - potential = childnode.progressvalue + childnode.reward - - if potential > highestProgressValue - highestProgressValue = potential - nodekey = childnode.nodekey - end - end - end - - return node.children[nodekey] -end - - - -""" - -# Arguments - - `node::MCTSNode` - node of a search tree - -# Return - - `childNode::MCTSNode` - the highest value child node - -# Example -```jldoctest -julia> -``` - -# TODO - - [] update docs - - [x] implement the function - -# Signature -""" -function selectBestTrajectory(node::MCTSNode)::MCTSNode - while !isleaf(node) - node = selectBestNextState(node) - end - - return node -end - - -""" Determine wheter a given node is a root node - -# Arguments - - `node::MCTSNode` - node of a search tree - -# Return - - `isrootnode::Bool` - true if the given node is root node, false otherwise - -# Example -```jldoctest -julia> -``` - -# Signature -""" -isroot(node::MCTSNode)::Bool = node.nodekey == "root" ? true : false - -# ------------------------------------------------------------------------------------------------ # -# Create a complete example using the defined MCTS functions # -# ------------------------------------------------------------------------------------------------ # -""" Search the best action to take for a given state and task - -# Arguments - - `a::agent` - one of Yiem's agents - - `initial state` - initial state - - `decisionMaker::Function` - decide what action to take - - `evaluator::Function` - assess the value of the state - - `reflector::Function` - generate lesson from trajectory and reward - - `isterminal::Function` - determine whether a given state is a terminal state - - `n::Integer` - how many times action will be sampled from decisionMaker - - `w::Float64` - exploration weight. Value is usually between 1 to 2. - Value 1.0 makes MCTS balance between exploration and exploitation like 50%-50% - Value 2.0 makes MCTS aggressively search the tree - -# Return - - `plan::Vector{Dict}` - best plan - -# Example -```jldoctest -julia> -``` - -# TODO - [] update docstring - [x] return best action - -# Signature -""" -function runMCTS( - a::T1, - initialState, - decisionMaker::Function, - evaluator::Function, - reflector::Function; - totalsample::Integer=3, - maxDepth::Integer=3, - maxiterations::Integer=10, - explorationweight::Number=1.0, - ) where {T1<:agent} - - root = MCTSNode("root", initialState, 0, 0, 0, 0, false, nothing, Dict{String, MCTSNode}()) - - for nth in 1:maxiterations - node = root - node.visits += 1 - - while !isleaf(node) - node = UCTselect(node, explorationweight) - end - if node.isterminal - # MCTS arrive at the leaf node that is also a terminal state, - # do nothing then go directly to backpropagation - backpropagate(leafNode, node.reward) - else - expand(a, node, decisionMaker, evaluator, reflector; totalsample=totalsample) - leafNode = selectChildNode(node) - simTrajectoryReward, terminalstate = simulate(a, leafNode, decisionMaker, evaluator, - reflector; maxDepth=maxDepth, totalsample=totalsample) - if terminalstate !== nothing #XXX not sure why I need this - terminalstate[:totalTrajectoryReward] = simTrajectoryReward - end - - #[] write best state to file if it has higher simTrajectoryReward. Use to improve evaluation - # open("trajectory.json", "w") do io - # JSON3.pretty(io, terminalstate) - # end - - backpropagate(leafNode, simTrajectoryReward) - end - end - - bestNextState = selectBestNextState(root) - besttrajectory = selectBestTrajectory(root) - - return (bestNextState.state, besttrajectory.state) -end - - - - - - - - - - - - - - - - - - - - - - - - - - - - -end # module mcts \ No newline at end of file