From 452262d3d6cb7942d60599a43e675ed5f89de268 Mon Sep 17 00:00:00 2001 From: narawat lamaiin Date: Fri, 31 May 2024 11:47:51 +0700 Subject: [PATCH] update --- .vscode/settings.json | 1 + src/LLMMCTS.jl | 25 ++- src/interface.jl | 179 +++++++++++++++++ src/mcts.jl | 432 ++++++++++++++++++++++++++++++++++++++++++ src/type.jl | 116 ++++++++++++ src/util.jl | 139 ++++++++++++++ 6 files changed, 891 insertions(+), 1 deletion(-) create mode 100644 .vscode/settings.json create mode 100644 src/interface.jl create mode 100644 src/mcts.jl create mode 100644 src/type.jl create mode 100644 src/util.jl diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..9e26dfe --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1 @@ +{} \ No newline at end of file diff --git a/src/LLMMCTS.jl b/src/LLMMCTS.jl index 38c63dc..f5f6e76 100644 --- a/src/LLMMCTS.jl +++ b/src/LLMMCTS.jl @@ -1,5 +1,28 @@ module LLMMCTS -greet() = print("Hello World!") + # export agent + + + """ Order by dependencies of each file. The 1st included file must not depend on any other + files and each file can only depend on the file included before it. + """ + + include("type.jl") + using .type + + include("util.jl") + using .util + + include("mcts.jl") + using .mcts + + include("interface.jl") + using .interface + + +# ---------------------------------------------- 100 --------------------------------------------- # + + + end # module LLMMCTS diff --git a/src/interface.jl b/src/interface.jl new file mode 100644 index 0000000..b508dc1 --- /dev/null +++ b/src/interface.jl @@ -0,0 +1,179 @@ +module interface + +export runMCTS + +using ..type, ..mcts + + +# ---------------------------------------------- 100 --------------------------------------------- # + + + +""" Search the best action to take for a given state and task + +# Arguments + - `a::agent` + one of Yiem's agents + - `initial state` + initial state + - `decisionMaker::Function` + decide what action to take + - `evaluator::Function` + assess the value of the state + - `reflector::Function` + generate lesson from trajectory and reward + - `isterminal::Function` + determine whether a given state is a terminal state + - `n::Integer` + how many times action will be sampled from decisionMaker + - `w::Float64` + exploration weight. Value is usually between 1 to 2. + Value 1.0 makes MCTS balance between exploration and exploitation like 50%-50% + Value 2.0 makes MCTS aggressively search the tree + +# Return + - `plan::Vector{Dict}` + best plan + +# Example +```jldoctest +julia> +``` + +# TODO + [] update docstring + [] return best action + +# Signature +""" +function runMCTS( + workDict::Dict{Symbol, Any}, + initialState, + decisionMaker::Function, + evaluator::Function, + reflector::Function, + transition::Function, + ; + totalsample::Integer=3, + maxDepth::Integer=3, + maxiterations::Integer=10, + explorationweight::Number=1.0, + ) + + root = MCTSNode("root", initialState, 0, 0, 0, 0, false, nothing, Dict{String, MCTSNode}()) + + for nth in 1:maxiterations + node = root + node.visits += 1 + + while !isleaf(node) + node = UCTselect(node, explorationweight) + end + if node.isterminal + # MCTS arrive at the leaf node that is also a terminal state, + # do nothing then go directly to backpropagation + backpropagate(leafNode, node.reward) + else + expand(workDict, node, decisionMaker, evaluator, reflector; totalsample=totalsample) + leafNode = selectChildNode(node) + simTrajectoryReward, terminalstate = simulate(workDict, leafNode, decisionMaker, evaluator, + reflector; maxDepth=maxDepth, totalsample=totalsample) + if terminalstate !== nothing #XXX not sure why I need this + terminalstate[:totalTrajectoryReward] = simTrajectoryReward + end + + #[] write best state to file if it has higher simTrajectoryReward. Use to improve evaluation + # open("trajectory.json", "w") do io + # JSON3.pretty(io, terminalstate) + # end + + backpropagate(leafNode, simTrajectoryReward) + end + end + + bestNextState = selectBestNextState(root) + besttrajectory = selectBestTrajectory(root) + + return (bestNextState.state, besttrajectory.state) +end + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +end # module interface \ No newline at end of file diff --git a/src/mcts.jl b/src/mcts.jl new file mode 100644 index 0000000..a5bbe97 --- /dev/null +++ b/src/mcts.jl @@ -0,0 +1,432 @@ +module mcts + +export selectBestNextState, selectBestTrajectory, backpropagate, isleaf, isroot, selectChildNode, + expand, mctstransition + +using ..type + +# ---------------------------------------------- 100 --------------------------------------------- # + + +""" + +# Arguments + - `node::MCTSNode` + node of a search tree + +# Return + - `childNode::MCTSNode` + the highest value child node + +# Example +```jldoctest +julia> +``` + +# TODO + - [] update docs + - [x] implement the function + +# Signature +""" +function selectBestNextState(node::MCTSNode)::MCTSNode + highestProgressValue = 0 + nodekey = nothing + + # if all childnode has statevalue == 0, use progressvalue + reward to select the best node + stateValueSum = sum([v.statevalue for (k, v) in node.children]) + + if stateValueSum != 0 + for (k, childnode) in node.children + potential = childnode.statevalue / childnode.visits + + if potential > highestProgressValue + highestProgressValue = potential + nodekey = childnode.nodekey + end + end + else + for (k, childnode) in node.children + potential = childnode.progressvalue + childnode.reward + + if potential > highestProgressValue + highestProgressValue = potential + nodekey = childnode.nodekey + end + end + end + + return node.children[nodekey] +end + + +""" + +# Arguments + - `node::MCTSNode` + node of a search tree + +# Return + - `childNode::MCTSNode` + the highest value child node + +# Example +```jldoctest +julia> +``` + +# TODO + - [] update docs + - [x] implement the function + +# Signature +""" +function selectBestTrajectory(node::MCTSNode)::MCTSNode + while !isleaf(node) + node = selectBestNextState(node) + end + + return node +end + + +""" Backpropagate reward along the simulation chain + +# Arguments + - `node::MCTSNode` + leaf node of a search tree + - `simTrajectoryReward::T` + total reward from trajectory simulation + +# Return + - `No return` + +# Example +```jldoctest +julia> +``` + +# Signature +""" +function backpropagate(node::MCTSNode, simTrajectoryReward::T; + discountRewardCoeff::AbstractFloat=0.9) where {T<:Number} + while !isroot(node) + # Update the statistics of the current node based on the result of the playout + node.visits += 1 + node.statevalue += ((node.statevalue * (node.visits-1)) + simTrajectoryReward) / node.visits + simTrajectoryReward *= discountRewardCoeff # discount because future reward is uncertain + node = node.parent + end +end + + +""" Determine whether a node is a leaf node of a search tree. + +# Arguments + - `node::MCTSNode` + a search tree node +# Return + - `result::Bool` + true if it is a leaf node, false otherwise. +# Example +```jldoctest +julia> using Revise +julia> using YiemAgent, DataStructures +julia> initialState = Dict{Symbol, Any}( + :customerinfo=> Dict{Symbol, Any}(), + :storeinfo=> Dict{Symbol, Any}(), + + :thoughtHistory=> OrderedDict{Symbol, Any}( + :question=> "How are you?", + ) + ) +julia> statetype = typeof(initialState) +julia> root = YiemAgent.MCTSNode(initialState, 0, 0.0, Dict{statetype, YiemAgent.MCTSNode}()) +julia> YiemAgent.isleaf(root) +true +``` + +# TODO + [] update docs + +# Signature +""" +isleaf(node::MCTSNode)::Bool = isempty(node.children) + + +""" Determine wheter a given node is a root node + +# Arguments + - `node::MCTSNode` + node of a search tree + +# Return + - `isrootnode::Bool` + true if the given node is root node, false otherwise + +# Example +```jldoctest +julia> +``` + +# Signature +""" +isroot(node::MCTSNode)::Bool = node.nodekey == "root" ? true : false + + + +""" Select child node based on the highest statevalue + +# Arguments + - `node::MCTSNode` + node of a search tree + +# Return + - `childNode::MCTSNode` + the highest value child node + +# Example +```jldoctest +julia> +``` + +# Signature +""" +function selectChildNode(node::MCTSNode)::MCTSNode + highestProgressValue = 0 + nodekey = nothing + + # loop thought node children dictionary to find the highest progress value + for (k, childNode) in node.children + potential = childNode.progressvalue + childNode.reward + if childNode.reward > 0 #XXX for testing. remove when done. + println("") + end + if potential > highestProgressValue + highestProgressValue = potential + nodekey = childNode.nodekey + end + end + + return node.children[nodekey] +end + + +""" Expand selected node + +# Arguments + - `a::T1` + One of YiemAgent's agent + - `node::MCTSNode` + MCTS node + - `state::T2` + a state of a game. Can be a Dict or something else. + - `decisionMaker::Function` + a function that output Thought and Action + - `evaluator::Function` + a function that output trajectory progress score + +# Return + +# Example +```jldoctest +julia> +``` + +# TODO + [] update docstring + [] try loop should limit to 3 times. if not succeed, skip + [] newNodeKey ∉ keys(node.children). New state may have semantic vector close enought to one of existing child state. Which can be assume that they are the same state semantically-wise. + [x] store feedback -> state -> agent. + + +# Signature +""" +function expand(workDict::T1, node::MCTSNode, decisionMaker::Function, evaluator::Function, + reflector::Function, transition::Function; totalsample::Integer=3 + ) where {T1<:AbstractDict} + + nthSample = 0 + while true + nthSample += 1 + if nthSample <= totalsample + thoughtDict = decisionMaker(a, node.state) + println("---> expand() sample $nthSample") + pprintln(node.state[:thoughtHistory]) + pprintln(thoughtDict) + newNodeKey, newstate = mctstransition(workDict, transition, node.state, thoughtDict) + + stateevaluation, progressvalue = evaluator(workDict, newstate) + + if newstate[:reward] < 0 + pprint(newstate[:thoughtHistory]) + newstate[:evaluation] = stateevaluation + newstate[:lesson] = reflector(a, newstate) + + # store new lesson for later use + lessonDict = copy(JSON3.read("lesson.json")) + latestLessonKey, latestLessonIndice = + GeneralUtils.findHighestIndexKey(lessonDict, "lesson") + nextIndice = latestLessonKey == :NA ? 1 : latestLessonIndice + 1 + newLessonKey = Symbol("lesson_$(nextIndice)") + lessonDict[newLessonKey] = newstate + open("lesson.json", "w") do io + JSON3.pretty(io, lessonDict) + end + print("---> reflector()") + end + + if newNodeKey ∉ keys(node.children) + node.children[newNodeKey] = + MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward], + newstate[:isterminal], node, Dict{String, MCTSNode}()) + end + else + break + end + end +end + + +""" Get a new state + +# Arguments + - `a::T1` + one of YiemAgent's agent + - `state::T2` + current game state + - `thoughtDict::T3` + contain Thought, Action, Observation + - `isterminal::Function` + a function to determine terminal state + +# Return + - `(newNodeKey, newstate, isterminalstate, reward)::Tuple{String, Dict{Symbol, <:Any}, Bool, <:Number}` + +# Example +```jldoctest +julia> state = Dict{Symbol, Dict{Symbol, Any}}( + :thoughtHistory => Dict(:question => "Hello, I want to buy a bottle of wine."), + :storeinfo => Dict(), + :customerinfo => Dict() + ) +julia> thoughtDict = Dict( + :question=> "I want to buy a bottle of wine.", + :thought_1=> "The customer wants to buy a bottle of wine.", + :action_1=> Dict{Symbol, Any}( + :name=>"Chatbox", + :input=>"What occasion are you buying the wine for?", + ), + :observation_1 => "" + ) +``` + +# TODO + - [] add other actions + - [WORKING] add embedding of newstate and store in newstate[:embedding] + +# Signature +""" +function mctstransition(workDict::T1, transition::Function, state::T2, thoughtDict::T2 + )::Tuple{String, Dict{Symbol, <:Any}} where {T1<:AbstractDict, T2<:AbstractDict} + + # actionname = thoughtDict[:action][:name] + # actioninput = thoughtDict[:action][:input] + + # # map action and input() to llm function + # response, select, reward, isterminal = + # if actionname == "chatbox" + # # deepcopy(state[:virtualCustomerChatHistory]) because I want to keep it clean + # # so that other simulation start from this same node is not contaminated with actioninput + # virtualWineUserChatbox(workDict, actioninput, deepcopy(state[:virtualCustomerChatHistory])) # virtual customer + # elseif actionname == "winestock" + # winestock(a, actioninput) + # elseif actionname == "recommendbox" + # virtualWineUserRecommendbox(workDict, actioninput) + # else + # error("undefined LLM function. Requesting $actionname") + # end + + # newNodeKey, newstate = makeNewState(state, thoughtDict, response, select, reward, isterminal) + # if actionname == "chatbox" + # push!(newstate[:virtualCustomerChatHistory], Dict(:name=>"assistant", :text=> actioninput) ) + # push!(newstate[:virtualCustomerChatHistory], Dict(:name=>"user", :text=> response)) + # end + + return (newNodeKey, newstate) +end + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +end # module mcts \ No newline at end of file diff --git a/src/type.jl b/src/type.jl new file mode 100644 index 0000000..bb6cdce --- /dev/null +++ b/src/type.jl @@ -0,0 +1,116 @@ +module type + +export MCTSNode + +# ---------------------------------------------- 100 --------------------------------------------- # + + +""" a node for MCTS search tree + +# Arguments + - `state::T` + a state of a game. Can be a Dict or something else. + - `visits::Integer ` + number of time the game visits this state + - `stateValue::Float64` + state value + - `children::Dict{T, MCTSNode}` + children node + +# Return + - `nothing` +# Example +```jldoctest +julia> state = Dict( + :info=> Dict(), # keyword info + :thoughtHistory=> Dict( + :question=> _, + :thought_1=> _, + :action_1=> _, + :observation_1=> _, + :thought_2=> _, + ... + ) + ) +``` + +# TODO + [] update docstring + +# Signature +""" +mutable struct MCTSNode{T1<:AbstractDict, T2<:AbstractString} + nodekey::T2 + state::T1 + visits::Integer + progressvalue::Number # estimate value by LLM's reasoning + statevalue::Number # store discounted commulative reward (gather from its child node) + reward::Number # this node's own reward + isterminal::Bool + parent::Union{MCTSNode, Nothing} + children::Dict{String, MCTSNode} +end + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +end # module type \ No newline at end of file diff --git a/src/util.jl b/src/util.jl new file mode 100644 index 0000000..559af39 --- /dev/null +++ b/src/util.jl @@ -0,0 +1,139 @@ +module util + +export UCTselect + +using ..type + +# ---------------------------------------------- 100 --------------------------------------------- # + +""" Select a node based on UCT score + +# Arguments + - `node::MCTSNode` + mcts node + - `w::T` + exploration weight. Value is usually between 1 to 2. + Value 1.0 makes MCTS balance between exploration and exploitation like 50%-50%. + Value 2.0 makes MCTS aggressively search the tree. +# Return + - `selectedNode::MCTSNode` + +# Example +```jldoctest +julia> +``` + +# Signature +""" +function UCTselect(node::MCTSNode, w::T)::MCTSNode where {T<:AbstractFloat} + maxUCT = -Inf + selectedNode = nothing + + for (childState, childNode) in node.children + UCTvalue = + if childNode.visits != 0 + weightedterm = w * sqrt(log(node.visits) / childNode.visits) # explore term + childNode.statevalue + weightedterm + else # node.visits == 0 makes sqrt() in explore term error + childNode.progressvalue # exploit term + end + + if UCTvalue > maxUCT + maxUCT = UCTvalue + selectedNode = childNode + end + end + + return selectedNode +end + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +end # module util \ No newline at end of file