From ae9d80df5e4743c68cfedfe488d078ed3e632d9b Mon Sep 17 00:00:00 2001 From: narawat lamaiin Date: Sat, 4 May 2024 17:41:31 +0700 Subject: [PATCH] update --- Manifest.toml | 11 +++-- Project.toml | 1 + src/interface.jl | 52 ++++++++++++++------- src/llmfunction.jl | 111 ++++++++++++++++++++++++++++++++++++++++++++- 4 files changed, 154 insertions(+), 21 deletions(-) diff --git a/Manifest.toml b/Manifest.toml index 3992a95..767e318 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -1,8 +1,8 @@ # This file is machine-generated - editing it directly is not advised -julia_version = "1.10.2" +julia_version = "1.10.3" manifest_format = "2.0" -project_hash = "60ddc268a63725d93580a5caeda2cac7b1579c68" +project_hash = "c6233f8bf690740dd830d1f0927bd3afed93b8d2" [[deps.ArgTools]] uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" @@ -50,7 +50,7 @@ weakdeps = ["Dates", "LinearAlgebra"] [[deps.CompilerSupportLibraries_jll]] deps = ["Artifacts", "Libdl"] uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" -version = "1.1.0+0" +version = "1.1.1+0" [[deps.ConcurrentUtilities]] deps = ["Serialization", "Sockets"] @@ -392,6 +392,11 @@ git-tree-sha1 = "9306f6085165d270f7e3db02af26a400d580f5c6" uuid = "21216c6a-2e73-6563-6e65-726566657250" version = "1.4.3" +[[deps.PrettyPrinting]] +git-tree-sha1 = "142ee93724a9c5d04d78df7006670a93ed1b244e" +uuid = "54e16d92-306c-5ea0-a30b-337be88ac337" +version = "0.4.2" + [[deps.Printf]] deps = ["Unicode"] uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" diff --git a/Project.toml b/Project.toml index e2afee9..b3951f5 100644 --- a/Project.toml +++ b/Project.toml @@ -11,6 +11,7 @@ GeneralUtils = "c6c72f09-b708-4ac8-ac7c-2084d70108fe" HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3" JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" MQTTClient = "985f35cc-2c3d-4943-b8c1-f0931d5f0959" +PrettyPrinting = "54e16d92-306c-5ea0-a30b-337be88ac337" PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" URIs = "5c2747f8-b7ea-4ff2-ba2e-563bfd36b1d4" diff --git a/src/interface.jl b/src/interface.jl index 3dc6d86..e03e89f 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -2,7 +2,7 @@ module interface export addNewMessage, conversation, decisionMaker, progressValueEstimator, isterminal -using JSON3, DataStructures, Dates, UUIDs, HTTP, Random, MQTTClient +using JSON3, DataStructures, Dates, UUIDs, HTTP, Random, MQTTClient, PrettyPrinting using GeneralUtils using ..type, ..util, ..llmfunction, ..mcts @@ -98,6 +98,16 @@ function decisionMaker(a::T1, state::T2)::Dict{Symbol, Any} where {T1<:agent, T2 # (trajectories) # """ + responseformat = + """ + You should only respond in JSON format as describe below: + { + "Thought": "your reasoning", + "Action": {"name": "action to take", "input": "Action input"}, + "Observation": "result of the action" + } + """ + _prompt = """ You are a helpful sommelier working for a wine store. @@ -119,12 +129,7 @@ function decisionMaker(a::T1, state::T2)::Dict{Symbol, Any} where {T1<:agent, T2 2) chatbox[text], which you can use to interact with the user. 3) recommendation[answer], which returns your wine reccommendation to the user. - You should only respond in JSON format as describe below: - { - "Thought": "your reasoning", - "Action": {"name": "action to take", "input": "Action input"}, - "Observation": "result of the action" - } + $responseformat Here are some examples: { @@ -164,12 +169,18 @@ function decisionMaker(a::T1, state::T2)::Dict{Symbol, Any} where {T1<:agent, T2 :msgMeta=> msgMeta, :payload=> Dict( :text=> prompt, + :kwargs=> Dict( + :max_tokens=> 512, + :stop=> ["<|eot_id|>"], + ) ) ) - @show outgoingMsg + _response = GeneralUtils.sendReceiveMqttMsg(outgoingMsg) - thoughtJsonStr = _response[:response][:text] + _thoughtJsonStr = _response[:response][:text] + thoughtJsonStr = jsoncorrection(a, _thoughtJsonStr, "") thoughtDict = copy(JSON3.read(thoughtJsonStr)) + pprint(thoughtDict) return thoughtDict end @@ -196,6 +207,17 @@ julia> # Signature """ function progressValueEstimator(a::T1, state::T2)::Tuple{String, Integer} where {T1<:agent, T2<:AbstractDict} + responseformat = + """ + You should only respond in JSON format as describe below: + { + "Thought_1": "reasoning 1", + "Action_1": {"name": "action to take", "input": "Action input"}, + "Observation_1": "result of the action", + "Evaluation_1": {"evaluation": "your evaluation", "score": your evaluation score} + } + """ + _prompt = """ Analyze the trajectories of a solution to a question answering task. The trajectories are @@ -211,13 +233,7 @@ function progressValueEstimator(a::T1, state::T2)::Tuple{String, Integer} where yet. Do not generate additional thoughts or actions. Then ending with the correctness score s where s is an integer from 1 to 10. - You should only respond in JSON format as describe below: - { - "Thought_1": "reasoning 1", - "Action_1": {"name": "action to take", "input": "Action input"}, - "Observation_1": "result of the action", - "Evaluation_1": {"evaluation": "your evaluation", "score": your evaluation score} - } + $responseformat Here are some examples: { @@ -232,6 +248,7 @@ function progressValueEstimator(a::T1, state::T2)::Tuple{String, Integer} where "score": 10} } + Let's begin!: $(JSON3.write(state[:thoughtHistory])) """ @@ -254,7 +271,8 @@ function progressValueEstimator(a::T1, state::T2)::Tuple{String, Integer} where ) _response = GeneralUtils.sendReceiveMqttMsg(outgoingMsg) - thoughtJsonStr = _response[:response][:text] + _thoughtJsonStr = _response[:response][:text] + thoughtJsonStr = jsoncorrection(a, _thoughtJsonStr, responseformat) thoughtDict = copy(JSON3.read(thoughtJsonStr)) latestEvaluationKey, _ = GeneralUtils.findHighestIndexKey(thoughtDict, "Evaluation") diff --git a/src/llmfunction.jl b/src/llmfunction.jl index 4542bac..505a347 100644 --- a/src/llmfunction.jl +++ b/src/llmfunction.jl @@ -1,6 +1,6 @@ module llmfunction -export virtualWineCustomerChatbox +export virtualWineCustomerChatbox, jsoncorrection using HTTP, JSON3, URIs, Random using GeneralUtils @@ -362,6 +362,115 @@ end # return result # end + +""" Attemp to correct LLM response's incorrect JSON response. + +# Arguments + - `a::T1` + one of Yiem's agent + - `input::T2` + text to be send to virtual wine customer + +# Return + - `response::String` + response of virtual wine customer +# Example +```jldoctest +julia> +``` + +# TODO + - [] update docstring + - [TESTING] implement the function + +# Signature +""" +function jsoncorrection(a::T1, input::T2, + correctJsonExample::T3) where {T1<:agent, T2<:AbstractString, T3<:AbstractString} + + attemptround = 0 + incorrectjson = input + correctjson = nothing + while true + attemptround += 1 + if attemptround <= 5 + try + JSON3.read(incorrectjson) + correctjson = incorrectjson + break + catch + println("Attempting correct JSON string. $attempting") + _prompt = + """ + Your goal is to correct a given incorrect JSON string. + + $correctJsonExample + + Incorrect JSON: + $incorrectjson + Corrention: + """ + + externalService = a.config[:externalservice][:text2textinstruct] + llminfo = externalService[:llminfo] + prompt = + if llminfo[:name] == "llama3instruct" + formatLLMtext_llama3instruct("system", _prompt) + else + error("llm model name is not defied yet $(@__LINE__)") + end + + # send formatted input to user using GeneralUtils.sendReceiveMqttMsg + msgMeta = GeneralUtils.generate_msgMeta( + externalService[:mqtttopic], + senderName= "jsoncorrection", + senderId= a.id, + receiverName= "text2textinstruct", + mqttBroker= a.config[:mqttServerInfo][:broker], + mqttBrokerPort= a.config[:mqttServerInfo][:port], + ) + + outgoingMsg = Dict( + :msgMeta=> msgMeta, + :payload=> Dict( + :text=> prompt, + :kwargs=> Dict( + :max_tokens=> 512, + :stop=> ["<|eot_id|>"], + ) + ) + ) + result = GeneralUtils.sendReceiveMqttMsg(outgoingMsg) + incorrectjson = result[:response][:text] + end + else + error("Can't fix JSON string") + break + end + end + @show correctjson + return correctjson +end + + + + + + + + + + + + + + + + + + + +