diff --git a/Manifest.toml b/Manifest.toml index ee06871..40d6d1f 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -266,15 +266,15 @@ version = "1.0.3" [[deps.LoweredCodeUtils]] deps = ["JuliaInterpreter"] -git-tree-sha1 = "31e27f0b0bf0df3e3e951bfcc43fe8c730a219f6" +git-tree-sha1 = "c6a36b22d2cca0e1a903f00f600991f97bf5f426" uuid = "6f1432cf-f94c-5a45-995e-cdbf5db27b0b" -version = "2.4.5" +version = "2.4.6" [[deps.MQTTClient]] deps = ["Distributed", "Random", "Sockets"] -git-tree-sha1 = "7d6a1042b8c330d20e4dfbd941f510f92b457624" +git-tree-sha1 = "c58ba9d6ae121f58494fa1e5164213f5b4e3e2c7" uuid = "985f35cc-2c3d-4943-b8c1-f0931d5f0959" -version = "0.2.1" +version = "0.3.0" weakdeps = ["PrecompileTools"] [deps.MQTTClient.extensions] @@ -408,9 +408,9 @@ deps = ["Unicode"] uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" [[deps.PtrArrays]] -git-tree-sha1 = "077664975d750757f30e739c870fbbdc01db7913" +git-tree-sha1 = "f011fbb92c4d401059b2212c05c0601b70f8b759" uuid = "43287f4e-b6f4-7ad1-bb20-aadabca52c3d" -version = "1.1.0" +version = "1.2.0" [[deps.PythonCall]] deps = ["CondaPkg", "Dates", "Libdl", "MacroTools", "Markdown", "Pkg", "REPL", "Requires", "Serialization", "Tables", "UnsafePointers"] @@ -456,10 +456,10 @@ uuid = "79098fc4-a85e-5d69-aa6a-4863f24498fa" version = "0.7.1" [[deps.Rmath_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "6ed52fdd3382cf21947b15e8870ac0ddbff736da" +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "d483cd324ce5cf5d61b77930f0bbd6cb61927d21" uuid = "f50d1b31-88e8-58de-be2c-1cc44531875f" -version = "0.4.0+0" +version = "0.4.2+0" [[deps.SHA]] uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" diff --git a/src/interface.jl b/src/interface.jl index 4b3e7cc..130b43b 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -194,9 +194,9 @@ function decisionMaker(a::T1, state::T2)::Dict{Symbol, Any} where {T1<:agent, T2 :kwargs=> Dict( :max_tokens=> 512, :stop=> ["<|eot_id|>"], - ) ) ) + ) @show outgoingMsg for attempt in 1:5 diff --git a/src/llmfunction.jl b/src/llmfunction.jl index b176daa..86fdb40 100644 --- a/src/llmfunction.jl +++ b/src/llmfunction.jl @@ -223,12 +223,18 @@ function virtualWineUserChatbox(a::T1, input::T2, virtualCustomerChatHistory } Here are some examples: + + sommelier: "What's your budget? + you: { "text": "My budget is 30 USD.", "select": null, "reward": 0, "isterminal": false } + + sommelier: "The first option is Zena Crown and the second one is Buano Red." + you: { "text": "I like the 2nd option.", "select": 2, @@ -307,12 +313,12 @@ function virtualWineUserChatbox(a::T1, input::T2, virtualCustomerChatHistory responseJsonStr = jsoncorrection(a, _responseJsonStr, expectedJsonExample) responseDict = copy(JSON3.read(responseJsonStr)) - text = responseDict[:text] - select = responseDict[:select] == "null" ? nothing : responseDict[:select] - reward = responseDict[:reward] - isterminal = responseDict[:isterminal] + text::AbstractString = responseDict[:text] + select::Union{Nothing, Number} = responseDict[:select] == "null" ? nothing : responseDict[:select] + reward::Number = responseDict[:reward] + isterminal::Bool = responseDict[:isterminal] - if text != "" && select != "" && reward != "" && isterminal != "" + if text != "" # pass test else error("virtual customer not answer correctly") @@ -332,58 +338,6 @@ function virtualWineUserChatbox(a::T1, input::T2, virtualCustomerChatHistory error("virtualWineUserChatbox failed to get a response") end -# function virtualWineUserChatbox(a::T1, input::T2 -# )::Union{Tuple{String, Number, Number, Bool}, Tuple{String, Nothing, Number, Bool}} where {T1<:agent, T2<:AbstractString} - -# # put in model format -# virtualWineCustomer = a.config[:externalservice][:virtualWineCustomer_1] -# llminfo = virtualWineCustomer[:llminfo] -# prompt = -# if llminfo[:name] == "llama3instruct" -# formatLLMtext_llama3instruct("assistant", input) -# else -# error("llm model name is not defied yet $(@__LINE__)") -# end - -# # send formatted input to user using GeneralUtils.sendReceiveMqttMsg -# msgMeta = GeneralUtils.generate_msgMeta( -# virtualWineCustomer[:mqtttopic], -# senderName= "virtualWineUserChatbox", -# senderId= a.id, -# receiverName= "virtualWineCustomer", -# mqttBroker= a.config[:mqttServerInfo][:broker], -# mqttBrokerPort= a.config[:mqttServerInfo][:port], -# msgId = "dummyid" #CHANGE remove after testing finished -# ) - -# outgoingMsg = Dict( -# :msgMeta=> msgMeta, -# :payload=> Dict( -# :text=> prompt, -# ) -# ) - -# attempt = 0 -# for attempt in 1:5 -# try -# result = GeneralUtils.sendReceiveMqttMsg(outgoingMsg; timeout=120) -# response = result[:response] - -# return (response[:text], response[:select], response[:reward], response[:isterminal]) -# catch e -# io = IOBuffer() -# showerror(io, e) -# errorMsg = String(take!(io)) -# st = sprint((io, v) -> show(io, "text/plain", v), stacktrace(catch_backtrace())) -# println("") -# @warn "Error occurred: $errorMsg\n$st" -# println("") -# end -# end -# error("virtualWineUserChatbox failed to get a response") -# end - - """ Search wine in stock. # Arguments @@ -411,6 +365,150 @@ julia> result = winestock(agent, input) function winestock(a::T1, input::T2 )::Union{Tuple{String, Number, Number, Bool}, Tuple{String, Nothing, Number, Bool}} where {T1<:agent, T2<:AbstractString} + systemmsg = + """ + As an attentive sommelier, your mission is to determine the user's preferred levels of sweetness, intensity, tannin, and acidity for a wine based on their input. + You'll achieve this by referring to the provided conversion table. + + Conversion Table: + Intensity level: + Level 1: May correspond to "light-bodied" or a similar description. + Level 2: May correspond to "med-light" or a similar description. + Level 3: May correspond to "medium" or a similar description. + Level 4: May correspond to "med-full" or a similar description. + Level 5: May correspond to "full" or a similar description. + Sweetness level: + Level 1: May correspond to "dry", "no-sweet" or a similar description. + Level 2: May correspond to "off-dry", "less-sweet" or a similar description. + Level 3: May correspond to "semi-sweet" or a similar description. + Level 4: May correspond to "sweet" or a similar description. + Level 5: May correspond to "very sweet" or a similar description. + Tannin level: + Level 1: May correspond to "low tannin" or a similar description. + Level 2: May correspond to "semi-low tannin" or a similar description. + Level 3: May correspond to "medium tannin" or a similar description. + Level 4: May correspond to "semi-high tannin" or a similar description. + Level 5: May correspond to "high tannin" or a similar description. + Acidity level: + Level 1: May correspond to "low acidity" or a similar description. + Level 2: May correspond to "semi-low acidity" or a similar description. + Level 3: May correspond to "medium acidity" or a similar description. + Level 4: May correspond to "semi-high acidity" or a similar description. + Level 5: May correspond to "high acidity" or a similar description. + + You should only respond in JSON format as describe below: + { + "sweetness": "sweetness level", + "acidity": "acidity level", + "tannin": "tannin level", + "intensity": "intensity level" + } + + Here are some examples: + + user: red wines, price < 50, body=full-bodied, tannins=1, off dry, acidity=medium, intensity=intense, Thai dishes + assistant: + { + "wine_attributes": + { + "sweetness": 2, + "acidity": 3, + "tannin": 1, + "intensity": 5 + } + } + + Let's begin! + """ + + usermsg = + """ + $input + """ + + chathistory = + [ + Dict(:name=> "system", :text=> systemmsg), + Dict(:name=> "user", :text=> usermsg) + ] + + # put in model format + prompt = formatLLMtext(chathistory, "llama3instruct") + prompt *= + """ + <|start_header_id|>assistant<|end_header_id|> + { + """ + + pprint(prompt) + externalService = a.config[:externalservice][:text2textinstruct] + + # send formatted input to user using GeneralUtils.sendReceiveMqttMsg + msgMeta = GeneralUtils.generate_msgMeta( + externalService[:mqtttopic], + senderName= "virtualWineUserChatbox", + senderId= a.id, + receiverName= "text2textinstruct", + mqttBroker= a.config[:mqttServerInfo][:broker], + mqttBrokerPort= a.config[:mqttServerInfo][:port], + msgId = "dummyid" #CHANGE remove after testing finished + ) + + outgoingMsg = Dict( + :msgMeta=> msgMeta, + :payload=> Dict( + :text=> prompt, + ) + ) + + attempt = 0 + for attempt in 1:5 + try + response = GeneralUtils.sendReceiveMqttMsg(outgoingMsg; timeout=120) + _responseJsonStr = response[:response][:text] + expectedJsonExample = + """ + Here is an expected JSON format: + { + "wine_attributes": + { + "...": "...", + "...": "...", + } + } + """ + responseJsonStr = jsoncorrection(a, _responseJsonStr, expectedJsonExample) + responseDict = copy(JSON3.read(responseJsonStr)) + + return (text, select, reward, isterminal) + catch e + io = IOBuffer() + showerror(io, e) + errorMsg = String(take!(io)) + st = sprint((io, v) -> show(io, "text/plain", v), stacktrace(catch_backtrace())) + println("") + @warn "Error occurred: $errorMsg\n$st" + println("") + end + end + error("virtualWineUserChatbox failed to get a response") + + + + + + + + + + + + + + + + + winesStr = """ 1: El Enemigo Cabernet Franc 2019 @@ -425,6 +523,23 @@ function winestock(a::T1, input::T2 """ return result, nothing, 0, false end +# function winestock(a::T1, input::T2 +# )::Union{Tuple{String, Number, Number, Bool}, Tuple{String, Nothing, Number, Bool}} where {T1<:agent, T2<:AbstractString} + +# winesStr = +# """ +# 1: El Enemigo Cabernet Franc 2019 +# 2: Tantara Chardonnay 2017 +# """ +# result = +# """ +# I found the following wines in our stock: +# { +# $winesStr +# } +# """ +# return result, nothing, 0, false +# end """ Attemp to correct LLM response's incorrect JSON response. @@ -446,13 +561,14 @@ julia> # Signature """ -function jsoncorrection(a::T1, input::T2, - correctJsonExample::T3) where {T1<:agent, T2<:AbstractString, T3<:AbstractString} +function jsoncorrection(a::T1, input::T2, correctJsonExample::T3; + maxattempt::Integer=3 + ) where {T1<:agent, T2<:AbstractString, T3<:AbstractString} incorrectjson = deepcopy(input) correctjson = nothing - for attempt in 1:5 + for attempt in 1:maxattempt try d = copy(JSON3.read(incorrectjson)) correctjson = incorrectjson diff --git a/src/mcts copy.jl b/src/mcts copy.jl deleted file mode 100644 index df65a5c..0000000 --- a/src/mcts copy.jl +++ /dev/null @@ -1,138 +0,0 @@ -""" To implement a Monte Carlo Tree Search (MCTS) algorithm in Julia with the UCT (Upper Confidence - Bound for Trees) selection function, you can follow the steps below: Define the necessary types - and functions for the MCTS algorithm: -""" - -module MCTS - -# export - -using Dates, UUIDs, DataStructures, JSON3, Random -using GeneralUtils - -# ---------------------------------------------- 100 --------------------------------------------- # - -""" - TODO\n - [] update docstring -""" -struct MCTSNode{T} - state::T - visits::Int - total_reward::Float64 - children::Dict{T, MCTSNode} -end - -""" - TODO\n - [] update docstring -""" -function select(node::MCTSNode, c::Float64) - max_uct = -Inf - selected_node = nothing - - for (child_state, child_node) in node.children - uct_value = child_node.total_reward / child_node.visits + - c * sqrt(log(node.visits) / child_node.visits) - if uct_value > max_uct - max_uct = uct_value - selected_node = child_node - end - end - - return selected_node -end - -""" - TODO\n - [] update docstring -""" -function expand(node::MCTSNode, state::T, actions::Vector{T}) - for action in actions - new_state = transition(node.state, action) # Implement your transition function - if new_state ∉ keys(node.children) - node.children[new_state] = MCTSNode(new_state, 0, 0.0, Dict{T, MCTSNode}()) - end - end -end - -""" - TODO\n - [] update docstring -""" -function simulate(state::T, max_depth::Int) - total_reward = 0.0 - for _ in 1:max_depth - action = select_action(state) # Implement your action selection function - state, reward = transition(state, action) # Implement your transition function - total_reward += reward - end - return total_reward -end - -""" - TODO\n - [] update docstring -""" -function backpropagate(node::MCTSNode, reward::Float64) - node.visits += 1 - node.total_reward += reward - if !isempty(node.children) - best_child = argmax([child.total_reward / child.visits for child in values(node.children)]) - backpropagate(node.children[best_child], -reward) - end -end - -""" - TODO\n - [] update docstring - [] implement transition() -""" -function transition(state, action) - -end - -# ------------------------------------------------------------------------------------------------ # -# Create a complete example using the defined MCTS functions # -# ------------------------------------------------------------------------------------------------ # -""" - TODO\n - [] update docstring -""" -function run_mcts(initial_state, actions, max_iterations::Int, max_depth::Int, w::Float64) - root = MCTSNode(initial_state, 0, 0.0, Dict()) - - for _ in 1:max_iterations - node = root - while !is_leaf(node) - node = select(node, w) - end - - expand(node, node.state, actions) - - leaf_node = node.children[node.state] - reward = simulate(leaf_node.state, max_depth) - backpropagate(leaf_node, reward) - end - - best_child_state = argmax([child.total_reward / child.visits for child in values(root.children)]) - return best_child_state -end - -# Define your transition function and action selection function here - -# Example usage -initial_state = 0 -actions = [-1, 0, 1] -best_action = run_mcts(initial_state, actions, 1000, 10, 1.0) -println("Best action to take: ", best_action) - - - - - - - - - -end \ No newline at end of file diff --git a/src/mcts.jl b/src/mcts.jl index 536fc33..ed7fc1a 100644 --- a/src/mcts.jl +++ b/src/mcts.jl @@ -1,6 +1,4 @@ -""" To implement a Monte Carlo Tree Search (MCTS) algorithm in Julia with the UCT (Upper Confidence - Bound for Trees) selection function, you can follow the steps below: Define the necessary types - and functions for the MCTS algorithm: +""" https://www.harrycodes.com/blog/monte-carlo-tree-search """ module mcts diff --git a/src/type copy.jl b/src/type copy.jl deleted file mode 100644 index faf5b52..0000000 --- a/src/type copy.jl +++ /dev/null @@ -1,226 +0,0 @@ -module type - -export agent, sommelier - -using Dates, UUIDs, DataStructures, JSON3 -using GeneralUtils - -# ---------------------------------------------- 100 --------------------------------------------- # - -abstract type agent end - - -""" A sommelier agent. - -# Arguments - - `mqttClient::Client` - MQTTClient's client - - `msgMeta::Dict{Symbol, Any}` - A dict contain info about a message. - - `config::Dict{Symbol, Any}` - Config info for an agent. Contain mqtt topic for internal use and other info. - -# Keyword Arguments - - `name::String` - Agent's name - - `id::String` - Agent's ID - - `tools::Dict{Symbol, Any}` - Agent's tools - - `maxHistoryMsg::Integer` - max history message - -# Return - - `nothing` - -# Example -```jldoctest -julia> using YiemAgent, MQTTClient, GeneralUtils -julia> msgMeta = GeneralUtils.generate_msgMeta( - "N/A", - replyTopic = "/testtopic/prompt" - ) -julia> tools= Dict( - :chatbox=>Dict( - :name => "chatbox", - :description => "Useful only for when you need to ask the user for more info or context. Do not ask the user their own question.", - :input => "Input should be a text.", - :output => "" , - :func => nothing, - ), - ) -julia> agentConfig = Dict( - :receiveprompt=>Dict( - :mqtttopic=> "/testtopic/prompt", # topic to receive prompt i.e. frontend send msg to this topic - ), - :receiveinternal=>Dict( - :mqtttopic=> "/testtopic/internal", # receive topic for model's internal - ), - :text2text=>Dict( - :mqtttopic=> "/text2text/receive", - ), - ) -julia> client, connection = MakeConnection("test.mosquitto.org", 1883) -julia> agent = YiemAgent.bsommelier( - client, - msgMeta, - agentConfig, - name= "assistant", - id= "555", # agent instance id - tools=tools, - ) -``` - -# TODO -- [] update docstring -- [x] implement the function - -# Signature -""" -@kwdef mutable struct sommelier <: agent - name::String # agent name - id::String # agent id - config::Dict # agent config - tools::Dict - thinkinglimit::Integer # thinking round limit - thinkingcount::Integer # used to count attempted round of a task - - """ Memory - Ref: Chat prompt format https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGML/discussions/3 - NO "system" message in chathistory because I want to add it at the inference time - chathistory= [ - Dict(:name=>"user", :text=> "Wassup!", :timestamp=> Dates.now()), - Dict(:name=>"assistant", :text=> "Hi I'm your assistant.", :timestamp=> Dates.now()), - ] - - """ - chathistory::Vector{Dict{Symbol, Any}} = Vector{Dict{Symbol, Any}}() - - maxHistoryMsg::Integer # 21th and earlier messages will get summarized - keywordinfo::Dict{Symbol, Any} = Dict{Symbol, Any}( - :customerinfo => Dict{Symbol, Any}(), - :storeinfo => Dict{Symbol, Any}(), - ) - mctstree::Dict{Symbol, Any} = Dict{Symbol, Any}() - - # 1-historyPoint is in Dict{Symbol, Any} and compose of: - # state, statevalue, thought, action, observation - plan::Dict{Symbol, Any} = Dict{Symbol, Any}( - - # store 3 to 5 best plan AI frequently used to avoid having to search MCTS all the time - # each plan is in [historyPoint_1, historyPoint_2, ...] format - :existingplan => Vector(), - - :activeplan => Dict{Symbol, Any}(), # current using plan - :currenttrajectory=> Dict{Symbol, Any}(), # store question, thought, action, observation, ... - ) - -end - -function sommelier( - config::Dict = Dict( - :mqttServerInfo=> Dict( - :broker=> nothing, - :port=> nothing, - ), - :receivemsg=> Dict( - :prompt=> nothing, # topic to receive prompt i.e. frontend send msg to this topic - :internal=> nothing, - ), - :thirdPartyService=> Dict( - :text2textinstruct=> nothing, - :text2textchat=> nothing, - ), - ) - ; - name::String= "Assistant", - id::String= string(uuid4()), - tools::Dict= Dict( - :chatbox=> Dict( - :name => "chatbox", - :description => "Useful for when you need to communicate with the user.", - :input => "Input should be a conversation to the user.", - :output => "" , - :func => nothing, - ), - ), - maxHistoryMsg::Integer= 20, - thinkinglimit::Integer= 5, - thinkingcount::Integer= 0, - ) - - #[NEXTVERSION] publish to a.config[:configtopic] to get a config. - #[NEXTVERSION] get a config message in a.mqttMsg_internal - #[NEXTVERSION] set agent according to config - - newAgent = sommelier( - name= name, - id= id, - config= config, - maxHistoryMsg= maxHistoryMsg, - tools= tools, - thinkinglimit= thinkinglimit, - thinkingcount= thinkingcount, - ) - - return newAgent -end - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -end # module type \ No newline at end of file diff --git a/test/runtest.jl b/test/runtest.jl index 3e4fc6a..70d117d 100644 --- a/test/runtest.jl +++ b/test/runtest.jl @@ -59,26 +59,26 @@ tools=Dict( # update input format # response = YiemAgent.conversation(a, Dict(:text=> "newtopic",) ) -response = YiemAgent.conversation(a, Dict(:text=> "Hello, I would like a get a bottle of wine", - :select=> nothing, - :reward=> 0, - :isterminal=> false, - ) ) -println("---> YiemAgent: ", response) +# response = YiemAgent.conversation(a, Dict(:text=> "Hello, I would like a get a bottle of wine", +# :select=> nothing, +# :reward=> 0, +# :isterminal=> false, +# ) ) +# println("---> YiemAgent: ", response) -response = YiemAgent.conversation(a, Dict(:text=> "I'm having a graduation party this evening. I'll pay at most 30 bucks.", - :select=> nothing, - :reward=> 0, - :isterminal=> false, - ) ) -println("---> YiemAgent: ", response) +# #BUG mcts do not start at current chat history +# response = YiemAgent.conversation(a, Dict(:text=> "I'm having a graduation party this evening. I'll pay at most 30 bucks.", +# :select=> nothing, +# :reward=> 0, +# :isterminal=> false, +# ) ) +# println("---> YiemAgent: ", response) - -"It will be Thai dishes." -"I like medium-bodied with low tannin." +dummyinput = "price < 50, full-bodied red wine with sweetness level 2, low tannin level and medium acidity level, Thai dishes" +response = YiemAgent.winestock(a, dummyinput)