From 62c6ce90ed44f66ba78fcb0c56537fd8349dbae0 Mon Sep 17 00:00:00 2001 From: narawat lamaiin Date: Mon, 13 May 2024 17:37:44 +0700 Subject: [PATCH] update --- src/interface.jl | 198 ++++++++++++++++++++++++++++------ src/llmfunction.jl | 59 +++++++++-- src/mcts.jl | 259 +++++++++++++++++++++++++++++++++++++-------- test/test_1.jl | 2 +- 4 files changed, 432 insertions(+), 86 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index b280b8c..8eba29b 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -295,7 +295,7 @@ function evaluator(a::T1, state::T2)::Tuple{String, Integer} where {T1<:agent, T } {"evaluation": "This trajectory is correct as it is reasonable to check an inventory for info provided in the question. It is also better to have simple searches corresponding to a single entity, making this the best action.", - "score": 7 + "score": 10 } { @@ -309,7 +309,7 @@ function evaluator(a::T1, state::T2)::Tuple{String, Integer} where {T1<:agent, T } {"evaluation": "This trajectory is incorrect as my search term should be related to a 4-colors pen with a pencil in it, not a pen and a pencil seperately. A better search term should have been a 4-colors pen with a pencil, all-in-one.", - "score": 3 + "score": 0 } Let's begin!: @@ -378,6 +378,143 @@ function evaluator(a::T1, state::T2)::Tuple{String, Integer} where {T1<:agent, T end +# """ + +# # Arguments +# - `a::T1` +# one of Yiem's agent +# - `state::T2` +# a game state + +# # Return +# - `evaluation::Tuple{String, Integer}` +# evaluation and score + +# # Example +# ```jldoctest +# julia> +# ``` + +# # TODO +# - [] update docs +# - [] implement the function + +# # Signature +# """ +# function comparer(a::T1, state::T2)::Tuple{String, Integer} where {T1<:agent, T2<:AbstractDict} + +# _prompt = +# """ +# Analyze the trajectories of a solution to a question answering task. The trajectories are +# labeled by environmental observations about the situation, thoughts that can reason about +# the current situation and actions that can be three types: +# 1) winestock[query], which you can use to find wine in your inventory. +# 2) chatbox[text], which you can use to interact with the user. +# 3) recommendbox[answer], which returns your wine recommendation to the user. + +# Given a question and a trajectory, evaluate its correctness and provide your reasoning and +# analysis in detail. Focus on the latest thought, action, and observation. Incomplete trajectories +# can be correct if the thoughts and actions so far are correct, even if the answer is not found +# yet. Do not generate additional thoughts or actions. Then ending with the correctness score s +# where s is an integer from 0 to 10. + +# You should only respond in JSON format as describe below: +# {"evaluation": "your evaluation", "score": "your evaluation score"} + +# Here are some examples: +# { +# "question": "I'm looking for a sedan with an automatic driving feature.", +# "thought_1": "I have many types of sedans in my inventory, each with diverse features.", +# "thought_2": "But there is only 1 model that has the feature customer wanted.", +# "thought_3": "I should check our inventory first to see if we have it.", +# "action_1": {"name": "inventory", "input": "Yiem model A"}, +# "observation_1": "Yiem model A is in stock." +# } +# {"evaluation": "This trajectory is correct as it is reasonable to check an inventory for info provided in the question. +# It is also better to have simple searches corresponding to a single entity, making this the best action.", +# "score": 10 +# } + +# { +# "question": "Do you have an all-in-one pen with 4 colors and a pencil for sale?", +# "thought_1": "Let me check our inventory first to see if I have it.", +# "action_1": {"name": "inventory", "input": "pen with 4 color and a pencil."}, +# "observation_1": "I found {1: "Pilot Dr. grip 4-in-1 pen", 2: "Rotting pencil"}", +# "thought_2": "Ok, I have what the user is asking. Let's tell the user.", +# "action_2": {"name": "chatbox", "input": "Yes, we do have a Pilot Dr. grip 4-in-1 pen and a Rotting pencil"}, +# "observation_1": "This is not what I wanted." +# } +# {"evaluation": "This trajectory is incorrect as my search term should be related to a 4-colors pen with a pencil in it, +# not a pen and a pencil seperately. A better search term should have been a 4-colors pen with a pencil, all-in-one.", +# "score": 0 +# } + +# Let's begin!: +# $(JSON3.write(state[:thoughtHistory])) +# {"evaluation" +# """ + +# # apply LLM specific instruct format +# externalService = a.config[:externalservice][:text2textinstruct] +# llminfo = externalService[:llminfo] +# prompt = +# if llminfo[:name] == "llama3instruct" +# formatLLMtext_llama3instruct("system", _prompt) +# else +# error("llm model name is not defied yet $(@__LINE__)") +# end + +# msgMeta = GeneralUtils.generate_msgMeta( +# a.config[:externalservice][:text2textinstruct][:mqtttopic], +# senderName= "evaluator", +# senderId= a.id, +# receiverName= "text2textinstruct", +# mqttBroker= a.config[:mqttServerInfo][:broker], +# mqttBrokerPort= a.config[:mqttServerInfo][:port], +# ) + +# outgoingMsg = Dict( +# :msgMeta=> msgMeta, +# :payload=> Dict( +# :text=> prompt, +# :kwargs=> Dict( +# :max_tokens=> 512, +# :stop=> ["<|eot_id|>"], +# ) +# ) +# ) + +# for attempt in 1:5 +# try +# response = GeneralUtils.sendReceiveMqttMsg(outgoingMsg) +# _responseJsonStr = response[:response][:text] +# expectedJsonExample = +# """ +# Here is an expected JSON format: +# {"evaluation": "...", "score": "..."} +# """ +# responseJsonStr = jsoncorrection(a, _responseJsonStr, expectedJsonExample) +# evaluationDict = copy(JSON3.read(responseJsonStr)) + +# # check if dict has all required value +# dummya::AbstractString = evaluationDict[:evaluation] +# dummyb::Integer = evaluationDict[:score] + +# return (evaluationDict[:evaluation], evaluationDict[:score]) +# catch e +# io = IOBuffer() +# showerror(io, e) +# errorMsg = String(take!(io)) +# st = sprint((io, v) -> show(io, "text/plain", v), stacktrace(catch_backtrace())) +# println("") +# @warn "Attempt $attempt. Error occurred: $errorMsg\n$st" +# println("") +# end +# end +# error("evaluator failed to generate an evaluation") +# end + + """ # Arguments @@ -600,8 +737,9 @@ julia> response = ChatAgent.conversation(newAgent, "Hi! how are you?") # TODO - [] update docstring - - [WORKING] MCTS() for planning + - [x] MCTS() for planning - [] add recap to initialState for earlier completed question + - [WORKING] conversation loop # Signature """ @@ -617,36 +755,36 @@ function conversation(a::T, userinput::Dict) where {T<:agent} # add usermsg to a.chathistory addNewMessage(a, "user", userinput[:text]) - #[] if the last used tool is a chatbox, put usermsg -> observation and continue actor loop as planned - if !isempty(a.plan[:currenttrajectory]) && - a.plan[:currenttrajectory][end][:action] == "chatbox" - - - - + currentstate = + if isempty(a.plan[:currenttrajectory]) + # set up initial state + Dict{Symbol, Any}( + # deepcopy the info to prevent modifying the info unintentionally during MCTS planning + :customerinfo=> deepcopy(a.keywordinfo[:customerinfo]), + :storeinfo=> deepcopy(a.keywordinfo[:storeinfo]), + :userselect=> nothing, + :reward=> 0, + :isterminal=> false, + :evaluation=> nothing, + :lesson=> nothing, + :thoughtDict=> nothing, + :totalTrajectoryReward=> nothing, + :thoughtHistory=> OrderedDict{Symbol, Any}( # contain question, thought_1, action_1, observation_1, thought_2, ... + # :recap=>, + :question=> userinput[:text], + ) + ) else - initialState = Dict{Symbol, Any}( + a.plan[:currenttrajectory] + end - # deepcopy the info to prevent modifying the info unintentionally during MCTS planning - :customerinfo=> deepcopy(a.keywordinfo[:customerinfo]), - :storeinfo=> deepcopy(a.keywordinfo[:storeinfo]), - :select=> nothing, - :reward=> 0, - :isterminal=> false, - :evaluation=> nothing, - :lesson=> nothing, - :thoughtHistory=> OrderedDict{Symbol, Any}( # contain question, thought_1, action_1, observation_1, thought_2, ... - # :recap=>, - :question=> userinput[:text], - ) - ) - bestplan = runMCTS(a, initialState, decisionMaker, evaluator, reflector, - 2, 3, 4, 1.0) - error("---> bestplan") + bestNextState, besttrajectory = runMCTS(a, currentstate, decisionMaker, evaluator, reflector, + totalsample=3, maxDepth=2, maxiterations=1, explorationweight=1.0) + + # transition + newstate = transition(a, bestNextState) + a.plan[:currenttrajectory] = newstate - # actor loop(bestplan) - - end end end diff --git a/src/llmfunction.jl b/src/llmfunction.jl index 7b9bfd3..523244a 100644 --- a/src/llmfunction.jl +++ b/src/llmfunction.jl @@ -1,7 +1,7 @@ module llmfunction -export virtualWineCustomerChatbox, jsoncorrection, winestock, - virtualWineCustomerReccommendbox +export virtualWineUserChatbox, jsoncorrection, winestock, + virtualWineUserRecommendbox, userChatbox, userRecommendbox using HTTP, JSON3, URIs, Random using GeneralUtils @@ -26,8 +26,46 @@ julia> # Signature """ -function chatbox(a::T1, input::T2) where {T1<:agent, T2<:AbstractString} - error("--> chatbox") +function userChatbox(a::T1, input::T2) where {T1<:agent, T2<:AbstractString} + error("--> userChatbox") + + # put in model format + virtualWineCustomer = a.config[:externalservice][:virtualWineCustomer_1] + llminfo = virtualWineCustomer[:llminfo] + formattedinput = + if llminfo[:name] == "llama3instruct" + formatLLMtext_llama3instruct("assistant", input) + else + error("llm model name is not defied yet $(@__LINE__)") + end + + # send formatted input to user using GeneralUtils.sendReceiveMqttMsg + + + # return response + +end + + +""" + +# Arguments + +# Return + +# Example +```jldoctest +julia> +``` + +# TODO + - [] update docstring + - [PENDING] implement the function + +# Signature +""" +function userRecommendbox(a::T1, input::T2) where {T1<:agent, T2<:AbstractString} + error("--> userRecommendbox") # put in model format virtualWineCustomer = a.config[:externalservice][:virtualWineCustomer_1] @@ -69,7 +107,7 @@ julia> # Signature """ -function virtualWineCustomerReccommendbox(a::T1, input +function virtualWineUserRecommendbox(a::T1, input )::Union{Tuple{String, Number, Number, Bool}, Tuple{String, Nothing, Number, Bool}} where {T1<:agent} # put in model format @@ -85,7 +123,7 @@ function virtualWineCustomerReccommendbox(a::T1, input # send formatted input to user using GeneralUtils.sendReceiveMqttMsg msgMeta = GeneralUtils.generate_msgMeta( virtualWineCustomer[:mqtttopic], - senderName= "virtualWineCustomerReccommendbox", + senderName= "virtualWineUserRecommendbox", senderId= a.id, receiverName= "virtualWineCustomer", mqttBroker= a.config[:mqttServerInfo][:broker], @@ -125,12 +163,11 @@ julia> ``` # TODO - - [] update docs - - [] add to remove <<< user option select >>> and <<| reward |>> + - [] update docs # Signature """ -function virtualWineCustomerChatbox(a::T1, input::T2 +function virtualWineUserChatbox(a::T1, input::T2 )::Union{Tuple{String, Number, Number, Bool}, Tuple{String, Nothing, Number, Bool}} where {T1<:agent, T2<:AbstractString} # put in model format @@ -146,7 +183,7 @@ function virtualWineCustomerChatbox(a::T1, input::T2 # send formatted input to user using GeneralUtils.sendReceiveMqttMsg msgMeta = GeneralUtils.generate_msgMeta( virtualWineCustomer[:mqtttopic], - senderName= "virtualWineCustomerChatbox", + senderName= "virtualWineUserChatbox", senderId= a.id, receiverName= "virtualWineCustomer", mqttBroker= a.config[:mqttServerInfo][:broker], @@ -178,7 +215,7 @@ function virtualWineCustomerChatbox(a::T1, input::T2 println("") end end - error("virtualWineCustomerChatbox failed to get a response") + error("virtualWineUserChatbox failed to get a response") end diff --git a/src/mcts.jl b/src/mcts.jl index 30d120f..c9e2df1 100644 --- a/src/mcts.jl +++ b/src/mcts.jl @@ -5,7 +5,8 @@ module mcts -export MCTSNode, runMCTS, isleaf +export MCTSNode, runMCTS, isleaf, selectBestNextState, selectBestTrajectory, transition, + userChatbox using Dates, UUIDs, DataStructures, JSON3, Random, PrettyPrinting using GeneralUtils @@ -51,9 +52,9 @@ mutable struct MCTSNode{T1<:AbstractDict, T2<:AbstractString} nodekey::T2 state::T1 visits::Integer - progressvalue::Number - statevalue::Number - reward::Number + progressvalue::Number # estimate value by LLM's reasoning + statevalue::Number # store discounted commulative reward (gather from its child node) + reward::Number # this node's own reward isterminal::Bool parent::Union{MCTSNode, Nothing} children::Dict{String, MCTSNode} @@ -132,23 +133,24 @@ julia> # Signature """ function expand(a::T1, node::MCTSNode, decisionMaker::Function, - evaluator::Function, reflector::Function; n::Integer=3) where {T1<:agent} + evaluator::Function, reflector::Function; totalsample::Integer=3 + ) where {T1<:agent} nthSample = 0 while true nthSample += 1 - if nthSample <= n + if nthSample <= totalsample thoughtDict = decisionMaker(a, node.state) println("---> expand() sample $nthSample") pprintln(node.state[:thoughtHistory]) pprintln(thoughtDict) - newNodeKey, newstate, reward, isterminalstate = - MCTStransition(a, node.state, thoughtDict) + node.state[:thoughtDict] = thoughtDict + newNodeKey, newstate = MCTStransition(a, node.state) # add evaluator stateevaluation, progressvalue = evaluator(a, newstate) - if reward < 0 + if newstate[:reward] < 0 pprint(newstate[:thoughtHistory]) newstate[:evaluation] = stateevaluation newstate[:lesson] = reflector(a, newstate) @@ -167,8 +169,9 @@ function expand(a::T1, node::MCTSNode, decisionMaker::Function, end if newNodeKey ∉ keys(node.children) - node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, - reward, isterminalstate, node, Dict{String, MCTSNode}()) + node.children[newNodeKey] = + MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward], + newstate[:isterminal], node, Dict{String, MCTSNode}()) end else break @@ -196,24 +199,30 @@ end julia> ``` +# TODO + - [] update docs + # Signature """ function simulate(a::T, node::MCTSNode, decisionMaker::Function, evaluator::Function, - reflector::Function; maxDepth::Integer=3, n::Integer=3)::Number where {T<:agent} + reflector::Function; maxDepth::Integer=3, totalsample::Integer=3 + )::Union{Tuple{Number, Dict{Symbol, <:Any}}, Tuple{Number, Nothing}} where {T<:agent} simTrajectoryReward = 0.0 + terminalstate = nothing for depth in 1:maxDepth simTrajectoryReward += node.reward if node.isterminal + terminalstate = node.state break else - expand(a, node, decisionMaker, evaluator, reflector; n=n) + expand(a, node, decisionMaker, evaluator, reflector; totalsample=totalsample) node = selectChildNode(node) end end - return simTrajectoryReward + return (simTrajectoryReward, terminalstate) end """ Backpropagate reward along the simulation chain @@ -285,20 +294,21 @@ julia> thoughtDict = Dict( # Signature """ -function MCTStransition(a::T1, state::T2, thoughtDict::T3 - )::Tuple{String, Dict{Symbol, <:Any}, <:Number, Bool} where {T1<:agent, T2<:AbstractDict, T3<:AbstractDict} +function MCTStransition(a::T1, state::T2 + )::Tuple{String, Dict{Symbol, <:Any}} where {T1<:agent, T2<:AbstractDict} + thoughtDict = state[:thoughtDict] actionname = thoughtDict[:action][:name] actioninput = thoughtDict[:action][:input] # map action and input() to llm function response, select, reward, isterminal = if actionname == "chatbox" - virtualWineCustomerChatbox(a, actioninput) # virtual customer + virtualWineUserChatbox(a, actioninput) # virtual customer elseif actionname == "winestock" winestock(a, actioninput) elseif actionname == "recommendbox" - virtualWineCustomerReccommendbox(a, actioninput) + virtualWineUserRecommendbox(a, actioninput) else error("undefined LLM function. Requesting $actionname") end @@ -321,7 +331,85 @@ function MCTStransition(a::T1, state::T2, thoughtDict::T3 newNodeKey = GeneralUtils.uuid4snakecase() - return (newNodeKey, newstate, reward, isterminal) + return (newNodeKey, newstate) +end + + +""" Get a new state + +# Arguments + - `a::T1` + one of YiemAgent's agent + - `state::T2` + current game state + - `thoughtDict::T3` + contain Thought, Action, Observation + - `isterminal::Function` + a function to determine terminal state + +# Return + - `(newNodeKey, newstate, isterminalstate, reward)::Tuple{String, Dict{Symbol, <:Any}, Bool, <:Number}` + +# Example +```jldoctest +julia> state = Dict{Symbol, Dict{Symbol, Any}}( + :thoughtHistory => Dict(:question => "Hello, I want to buy a bottle of wine."), + :storeinfo => Dict(), + :customerinfo => Dict() + ) +julia> thoughtDict = Dict( + :question=> "I want to buy a bottle of wine.", + :thought_1=> "The customer wants to buy a bottle of wine.", + :action_1=> Dict{Symbol, Any}( + :name=>"Chatbox", + :input=>"What occasion are you buying the wine for?", + ), + :observation_1 => "" + ) +``` + +# TODO + - [x] add other actions + - [] add embedding of newstate and store in newstate[:embedding] + +# Signature +""" +function transition(a::T1, state::T2 + )::Dict{Symbol, <:Any} where {T1<:agent, T2<:AbstractDict} + + thoughtDict = state[:thoughtDict] + actionname = thoughtDict[:action][:name] + actioninput = thoughtDict[:action][:input] + + # map action and input() to llm function + response, select, reward, isterminal = + if actionname == "chatbox" + userChatbox(a, actioninput) # virtual customer + elseif actionname == "winestock" + winestock(a, actioninput) + elseif actionname == "recommendbox" + userRecommendbox(a, actioninput) + else + error("undefined LLM function. Requesting $actionname") + end + + latestThoughtKey, latestThoughtIndice = + GeneralUtils.findHighestIndexKey(state[:thoughtHistory], "thought") + nextIndice = latestThoughtKey == :NA ? 1 : latestThoughtIndice + 1 + latestThoughtKey = Symbol("thought_$nextIndice") + latestActionKey = Symbol("action_$nextIndice") + + # add Thought, action, observation to thoughtHistory + newstate = deepcopy(state) + newstate[:thoughtHistory][latestThoughtKey] = thoughtDict[:thought] + newstate[:thoughtHistory][latestActionKey] = thoughtDict[:action] + newObservationKey = Symbol("observation_$(nextIndice)") + newstate[:thoughtHistory][newObservationKey] = response + newstate[:reward] = reward + newstate[:select] = select + newstate[:isterminal] = isterminal + + return newstate end @@ -396,6 +484,90 @@ function selectChildNode(node::MCTSNode)::MCTSNode end + +""" + +# Arguments + - `node::MCTSNode` + node of a search tree + +# Return + - `childNode::MCTSNode` + the highest value child node + +# Example +```jldoctest +julia> +``` + +# TODO + - [] update docs + - [TESTING] implement the function + +# Signature +""" +function selectBestNextState(node::MCTSNode)::MCTSNode + highestProgressValue = 0 + nodekey = nothing + + # if all childnode has statevalue == 0, use progressvalue + reward to select the best node + stateValueSum = sum([v.statevalue for (k, v) in node.children]) + + if stateValueSum != 0 + for (k, childnode) in node.children + potential = childnode.statevalue / childnode.visits + + if potential > highestProgressValue + highestProgressValue = potential + nodekey = childnode.nodekey + end + end + else + for (k, childnode) in node.children + potential = childnode.progressvalue + childnode.reward + + if potential > highestProgressValue + highestProgressValue = potential + nodekey = childnode.nodekey + end + end + end + + return node.children[nodekey] +end + + + +""" + +# Arguments + - `node::MCTSNode` + node of a search tree + +# Return + - `childNode::MCTSNode` + the highest value child node + +# Example +```jldoctest +julia> +``` + +# TODO + - [] update docs + - [TESTING] implement the function + +# Signature +""" +function selectBestTrajectory(node::MCTSNode)::MCTSNode + while !isleaf(node) + node = selectBestNextState(node) + end + + return node +end + + """ Determine wheter a given node is a root node # Arguments @@ -451,7 +623,7 @@ julia> # TODO [] update docstring - [PENDING] return best plan + [x] return best action # Signature """ @@ -460,46 +632,48 @@ function runMCTS( initialState, decisionMaker::Function, evaluator::Function, - reflector::Function, - n::Integer, - maxDepth::Integer, - maxIterations::Integer, - w::Float64 + reflector::Function; + totalsample::Integer=3, + maxDepth::Integer=3, + maxiterations::Integer=10, + explorationweight::Number=1.0, ) where {T1<:agent} root = MCTSNode("root", initialState, 0, 0, 0, 0, false, nothing, Dict{String, MCTSNode}()) - - for nth in 1:maxIterations + + for nth in 1:maxiterations node = root node.visits += 1 while !isleaf(node) - node = UCTselect(node, w) + node = UCTselect(node, explorationweight) end if node.isterminal # MCTS arrive at the leaf node that is also a terminal state, # do nothing then go directly to backpropagation backpropagate(leafNode, node.reward) else - expand(a, node, decisionMaker, evaluator, reflector; n=n) + expand(a, node, decisionMaker, evaluator, reflector; totalsample=totalsample) leafNode = selectChildNode(node) - simTrajectoryReward = simulate(a, leafNode, decisionMaker, evaluator, - reflector; maxDepth=maxDepth, n=n) + simTrajectoryReward, terminalstate = simulate(a, leafNode, decisionMaker, evaluator, + reflector; maxDepth=maxDepth, totalsample=totalsample) + if terminalstate !== nothing + terminalstate[:totalTrajectoryReward] = simTrajectoryReward + end + + #[] write best state to file if it has higher simTrajectoryReward. Use to improve evaluation + # open("trajectory.json", "w") do io + # JSON3.pretty(io, terminalstate) + # end + backpropagate(leafNode, simTrajectoryReward) end end - avgStateValue = 0 - selectedChildKey = nothing - for (k, v) in root.children - k_avgStateValue = v.statevalue / v.visits - if k_avgStateValue > avgStateValue - avgStateValue = k_avgStateValue - selectedChildKey = k - end - end + bestNextState = selectBestNextState(root) + besttrajectory = selectBestTrajectory(root) - return root.children[selectedChildKey] + return (bestNextState.state, besttrajectory.state) end @@ -525,9 +699,6 @@ end - - - diff --git a/test/test_1.jl b/test/test_1.jl index 3e292b5..3e02670 100644 --- a/test/test_1.jl +++ b/test/test_1.jl @@ -134,7 +134,7 @@ result = GeneralUtils.sendMqttMsg(outgoingMsg) outgoingMsg = Dict( :msgMeta=> msgMeta, :payload=> Dict( - :text=> "I like it dry.", + :text=> "I like dry wine with fruity flavors.", :select=> nothing, :reward=> 0, :isterminal=> false,