diff --git a/src/interface.jl b/src/interface.jl index 0110ff2..4b3e7cc 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -753,6 +753,7 @@ function conversation(a::T, userinput::Dict) where {T<:agent} addNewMessage(a, "user", userinput[:text]) if isempty(a.plan[:currenttrajectory]) + a.plan[:currenttrajectory] = Dict{Symbol, Any}( # deepcopy the info to prevent modifying the info unintentionally during MCTS planning :customerinfo=> deepcopy(a.keywordinfo[:customerinfo]), @@ -769,7 +770,10 @@ function conversation(a::T, userinput::Dict) where {T<:agent} :thoughtHistory=> OrderedDict{Symbol, Any}( #[] :recap=>, :question=> userinput[:text], - ) + ), + :virtualCustomerChatHistory=> Vector{Dict{Symbol, Any}}( + [Dict(:name=> "user", :text=> userinput[:text])] + ), ) else _, a.plan[:currenttrajectory] = makeNewState(a.plan[:currenttrajectory], @@ -780,7 +784,7 @@ function conversation(a::T, userinput::Dict) where {T<:agent} while true bestNextState, besttrajectory = runMCTS(a, a.plan[:currenttrajectory], decisionMaker, - evaluator, reflector, totalsample=2, maxDepth=2, maxiterations=1, explorationweight=1.0) + evaluator, reflector, totalsample=2, maxDepth=2, maxiterations=2, explorationweight=1.0) a.plan[:activeplan] = bestNextState latestActionKey, latestActionIndice = diff --git a/src/llmfunction.jl b/src/llmfunction.jl index 98df02f..b176daa 100644 --- a/src/llmfunction.jl +++ b/src/llmfunction.jl @@ -3,7 +3,7 @@ module llmfunction export virtualWineUserChatbox, jsoncorrection, winestock, virtualWineUserRecommendbox, userChatbox, userRecommendbox -using HTTP, JSON3, URIs, Random +using HTTP, JSON3, URIs, Random, PrettyPrinting using GeneralUtils using ..type, ..util @@ -164,28 +164,119 @@ julia> # TODO - [] update docs + - [x] write a prompt for virtual customer # Signature """ -function virtualWineUserChatbox(a::T1, input::T2 +function virtualWineUserChatbox(a::T1, input::T2, virtualCustomerChatHistory )::Union{Tuple{String, Number, Number, Bool}, Tuple{String, Nothing, Number, Bool}} where {T1<:agent, T2<:AbstractString} - # put in model format - virtualWineCustomer = a.config[:externalservice][:virtualWineCustomer_1] - llminfo = virtualWineCustomer[:llminfo] - prompt = - if llminfo[:name] == "llama3instruct" - formatLLMtext_llama3instruct("assistant", input) - else - error("llm model name is not defied yet $(@__LINE__)") + previouswines = + """ + You have the following wines previously: + + """ + + systemmsg = + """ + You find yourself in a well-stocked wine store, engaged in a conversation with the store's knowledgeable sommelier. + You're on a quest to find a bottle of wine that aligns with your specific preferences and requirements. + + The ideal wine you're seeking should meet the following criteria: + 1. It should fit within your budget. + 2. It should be suitable for the occasion you're planning. + 3. It should pair well with the food you intend to serve. + 4. It should be of a particular type of wine you prefer. + 5. It should possess certain characteristics, including: + - The level of sweetness. + - The intensity of its flavor. + - The amount of tannin it contains. + - Its acidity level. + + Here's the criteria details: + { + "budget": 50, + "occasion": "graduation ceremony", + "food pairing": "Thai food", + "type of wine": "red", + "wine sweetness level": "dry", + "wine intensity level": "full-bodied", + "wine tannin level": "low", + "wine acidity level": "medium", + } + + You should only respond with "text", "select", "reward", "isterminal" steps. + "text" is your conversation. + "select" is an integer. Choose an option when presented with choices, or leave it null if none of the options satisfy you or if no choices are available. + "reward" is an integer, it can be three number: + 1) 1 if you find the right wine. + 2) 0 if you don’t find the ideal wine. + 3) -1 if you’re dissatisfied with the sommelier’s response. + "isterminal" can be false if you still want to talk with the sommelier, true otherwise. + + You should only respond in JSON format as describe below: + { + "text": "your conversation", + "select": null, + "reward": 0, + "isterminal": false + } + + Here are some examples: + { + "text": "My budget is 30 USD.", + "select": null, + "reward": 0, + "isterminal": false + } + { + "text": "I like the 2nd option.", + "select": 2, + "reward": 1, + "isterminal": true + } + + Let's begin! + """ + + pushfirst!(virtualCustomerChatHistory, Dict(:name=> "system", :text=> systemmsg)) + + # replace the :user key in chathistory to allow the virtual wine customer AI roleplay + chathistory::Vector{Dict{Symbol, Any}} = Vector{Dict{Symbol, Any}}() + for i in virtualCustomerChatHistory + newdict = Dict() + newdict[:name] = + if i[:name] == "user" + "you" + elseif i[:name] == "assistant" + "sommelier" + else + i[:name] + end + + newdict[:text] = i[:text] + push!(chathistory, newdict) end + push!(chathistory, Dict(:name=> "assistant", :text=> input)) + + # put in model format + prompt = formatLLMtext(chathistory, "llama3instruct") + prompt *= + """ + <|start_header_id|>you<|end_header_id|> + {"text" + """ + + pprint(prompt) + externalService = a.config[:externalservice][:text2textinstruct] + # send formatted input to user using GeneralUtils.sendReceiveMqttMsg msgMeta = GeneralUtils.generate_msgMeta( - virtualWineCustomer[:mqtttopic], + externalService[:mqtttopic], senderName= "virtualWineUserChatbox", senderId= a.id, - receiverName= "virtualWineCustomer", + receiverName= "text2textinstruct", mqttBroker= a.config[:mqttServerInfo][:broker], mqttBrokerPort= a.config[:mqttServerInfo][:port], msgId = "dummyid" #CHANGE remove after testing finished @@ -201,10 +292,33 @@ function virtualWineUserChatbox(a::T1, input::T2 attempt = 0 for attempt in 1:5 try - result = GeneralUtils.sendReceiveMqttMsg(outgoingMsg; timeout=120) - response = result[:response] + response = GeneralUtils.sendReceiveMqttMsg(outgoingMsg; timeout=120) + _responseJsonStr = response[:response][:text] + expectedJsonExample = + """ + Here is an expected JSON format: + { + "text": "...", + "select": "...", + "reward": "...", + "isterminal": "..." + } + """ + responseJsonStr = jsoncorrection(a, _responseJsonStr, expectedJsonExample) + responseDict = copy(JSON3.read(responseJsonStr)) - return (response[:text], response[:select], response[:reward], response[:isterminal]) + text = responseDict[:text] + select = responseDict[:select] == "null" ? nothing : responseDict[:select] + reward = responseDict[:reward] + isterminal = responseDict[:isterminal] + + if text != "" && select != "" && reward != "" && isterminal != "" + # pass test + else + error("virtual customer not answer correctly") + end + + return (text, select, reward, isterminal) catch e io = IOBuffer() showerror(io, e) @@ -218,6 +332,57 @@ function virtualWineUserChatbox(a::T1, input::T2 error("virtualWineUserChatbox failed to get a response") end +# function virtualWineUserChatbox(a::T1, input::T2 +# )::Union{Tuple{String, Number, Number, Bool}, Tuple{String, Nothing, Number, Bool}} where {T1<:agent, T2<:AbstractString} + +# # put in model format +# virtualWineCustomer = a.config[:externalservice][:virtualWineCustomer_1] +# llminfo = virtualWineCustomer[:llminfo] +# prompt = +# if llminfo[:name] == "llama3instruct" +# formatLLMtext_llama3instruct("assistant", input) +# else +# error("llm model name is not defied yet $(@__LINE__)") +# end + +# # send formatted input to user using GeneralUtils.sendReceiveMqttMsg +# msgMeta = GeneralUtils.generate_msgMeta( +# virtualWineCustomer[:mqtttopic], +# senderName= "virtualWineUserChatbox", +# senderId= a.id, +# receiverName= "virtualWineCustomer", +# mqttBroker= a.config[:mqttServerInfo][:broker], +# mqttBrokerPort= a.config[:mqttServerInfo][:port], +# msgId = "dummyid" #CHANGE remove after testing finished +# ) + +# outgoingMsg = Dict( +# :msgMeta=> msgMeta, +# :payload=> Dict( +# :text=> prompt, +# ) +# ) + +# attempt = 0 +# for attempt in 1:5 +# try +# result = GeneralUtils.sendReceiveMqttMsg(outgoingMsg; timeout=120) +# response = result[:response] + +# return (response[:text], response[:select], response[:reward], response[:isterminal]) +# catch e +# io = IOBuffer() +# showerror(io, e) +# errorMsg = String(take!(io)) +# st = sprint((io, v) -> show(io, "text/plain", v), stacktrace(catch_backtrace())) +# println("") +# @warn "Error occurred: $errorMsg\n$st" +# println("") +# end +# end +# error("virtualWineUserChatbox failed to get a response") +# end + """ Search wine in stock. @@ -239,7 +404,7 @@ julia> result = winestock(agent, input) # TODO [] update docs - [PENDING] implement the function + [WORKING] implement the function # Signature """ @@ -302,8 +467,8 @@ function jsoncorrection(a::T1, input::T2, _prompt = """ Your goal are: - 1) Use the info why the given JSON string failed to load and provide a corrected version that can be loaded by Python's json.load function. - 2) The user need Corrected JSON string only. Do not provide any other info. + 1) Use the expected JSON format as a guideline to check why the given JSON string failed to load and provide a corrected version that can be loaded by Python's json.load function. + 2) Provide Corrected JSON string only. Do not provide any other info. $correctJsonExample diff --git a/src/mcts.jl b/src/mcts.jl index 5632351..536fc33 100644 --- a/src/mcts.jl +++ b/src/mcts.jl @@ -146,7 +146,6 @@ function expand(a::T1, node::MCTSNode, decisionMaker::Function, pprintln(thoughtDict) newNodeKey, newstate = MCTStransition(a, node.state, thoughtDict) - stateevaluation, progressvalue = evaluator(a, newstate) if newstate[:reward] < 0 @@ -302,7 +301,9 @@ function MCTStransition(a::T1, state::T2, thoughtDict::T2 # map action and input() to llm function response, select, reward, isterminal = if actionname == "chatbox" - virtualWineUserChatbox(a, actioninput) # virtual customer + # deepcopy(state[:virtualCustomerChatHistory]) because I want to keep it clean + # so that other simulation start from this same node is not contaminated with actioninput + virtualWineUserChatbox(a, actioninput, deepcopy(state[:virtualCustomerChatHistory])) # virtual customer elseif actionname == "winestock" winestock(a, actioninput) elseif actionname == "recommendbox" @@ -311,7 +312,13 @@ function MCTStransition(a::T1, state::T2, thoughtDict::T2 error("undefined LLM function. Requesting $actionname") end - return makeNewState(state, thoughtDict, response, select, reward, isterminal) + newNodeKey, newstate = makeNewState(state, thoughtDict, response, select, reward, isterminal) + if actionname == "chatbox" + push!(newstate[:virtualCustomerChatHistory], Dict(:name=>"assistant", :text=> actioninput) ) + push!(newstate[:virtualCustomerChatHistory], Dict(:name=>"user", :text=> response)) + end + + return (newNodeKey, newstate) end @@ -386,7 +393,7 @@ julia> # TODO - [] update docstring - - [TESTING] implement the function + - [x] implement the function # Signature """ @@ -520,7 +527,7 @@ julia> # TODO - [] update docs - - [TESTING] implement the function + - [x] implement the function # Signature """ @@ -573,7 +580,7 @@ julia> # TODO - [] update docs - - [TESTING] implement the function + - [x] implement the function # Signature """ @@ -675,7 +682,7 @@ function runMCTS( leafNode = selectChildNode(node) simTrajectoryReward, terminalstate = simulate(a, leafNode, decisionMaker, evaluator, reflector; maxDepth=maxDepth, totalsample=totalsample) - if terminalstate !== nothing + if terminalstate !== nothing #XXX not sure why I need this terminalstate[:totalTrajectoryReward] = simTrajectoryReward end diff --git a/src/util.jl b/src/util.jl index 80138c8..fa28a7e 100644 --- a/src/util.jl +++ b/src/util.jl @@ -201,13 +201,13 @@ function formatLLMtext_llama3instruct(name::T, text::T) where {T<:AbstractString <|begin_of_text|> <|start_header_id|>$name<|end_header_id|> $text - <|eot_id|>\n + <|eot_id|> """ else """ <|start_header_id|>$name<|end_header_id|> $text - <|eot_id|>\n + <|eot_id|> """ end @@ -286,7 +286,7 @@ end TODO\n ----- [] update docstring - [TESTING] implement the function + [PENDING] implement the function Signature\n ----- diff --git a/test/runtest.jl b/test/runtest.jl index dd617e9..3e4fc6a 100644 --- a/test/runtest.jl +++ b/test/runtest.jl @@ -51,8 +51,6 @@ tools=Dict( # update input format ) a = YiemAgent.sommelier( - receiveUserMsgChannel, - receiveInternalMsgChannel, agentConfig, name="assistant", id="testingSessionID", # agent instance id @@ -68,7 +66,7 @@ response = YiemAgent.conversation(a, Dict(:text=> "Hello, I would like a get a b ) ) println("---> YiemAgent: ", response) -response = YiemAgent.conversation(a, Dict(:text=> "I'm having a graduation party this evening", +response = YiemAgent.conversation(a, Dict(:text=> "I'm having a graduation party this evening. I'll pay at most 30 bucks.", :select=> nothing, :reward=> 0, :isterminal=> false,