diff --git a/src/interface.jl b/src/interface.jl index d44d09e..407736a 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -102,8 +102,8 @@ Dict( # Signature """ -function decisionMaker(state::T1, context, text2textInstructLLM::Function, - ; querySQLVectorDBF::Union{T2, Nothing}=nothing +function decisionMaker(state::T1, context, text2textInstructLLM::Function, llmFormatName::String + ; querySQLVectorDBF::Union{T2, Nothing}=nothing, maxattempt=10 )::Dict{Symbol, Any} where {T1<:AbstractDict, T2<:Function} # lessonDict = @@ -203,13 +203,10 @@ function decisionMaker(state::T1, context, text2textInstructLLM::Function, :temperature => 0.1, ) - for attempt in 1:10 - if attempt > 1 - println("\nERROR SQLLLM decisionMaker() attempt $attempt/10 ", @__FILE__, ":", @__LINE__, " $(Dates.now())") - llmkwargs[:temperature] = 0.1 * attempt - end + for attempt in 1:maxattempt + attempt > 1 ? llmkwargs[:temperature] += 0.1 : nothing - QandA = generatequestion(state, context, text2textInstructLLM; similarSQL=similarSQL_) + QandA = generatequestion(state, context, text2textInstructLLM, llmFormatName; similarSQL=similarSQL_) usermsg = """ @@ -230,9 +227,9 @@ function decisionMaker(state::T1, context, text2textInstructLLM::Function, ] # put in model format - prompt = GeneralUtils.formatLLMtext(_prompt, "granite3") + prompt = GeneralUtils.formatLLMtext(_prompt, llmFormatName) response = text2textInstructLLM(prompt; llmkwargs=llmkwargs) - response = GeneralUtils.deFormatLLMtext(response, "granite3") + response = GeneralUtils.deFormatLLMtext(response, llmFormatName) # LLM tends to generate observation given that it is in the input response = @@ -261,7 +258,7 @@ function decisionMaker(state::T1, context, text2textInstructLLM::Function, if occursin("NULL", response) errornote = "\nYour previous attempt was NULL. This is not allowed" - println("\nERROR SQLLLM decisionMaker() $errornote ", @__FILE__, ":", @__LINE__, " $(Dates.now())") + println("\nERROR SQLLLM decisionMaker(). Attempt $attempt/$maxattempt. $errornote ", @__FILE__, ":", @__LINE__, " $(Dates.now())") continue end @@ -284,11 +281,11 @@ function decisionMaker(state::T1, context, text2textInstructLLM::Function, detected_kw = GeneralUtils.detect_keyword(header, response) if 0 ∈ values(detected_kw) errornote = "\nYour previous attempt did not have all points according to the required response format" - println("\nERROR SQLLLM decisionMaker() $errornote \n$response", @__FILE__, ":", @__LINE__, " $(Dates.now())") + println("\nERROR SQLLLM decisionMaker(). Attempt $attempt/$maxattempt. $errornote ", @__FILE__, ":", @__LINE__, " $(Dates.now())") continue elseif sum(values(detected_kw)) > length(header) errornote = "\nYour previous attempt has duplicated points according to the required response format" - println("\nERROR SQLLLM decisionMaker() $errornote \n$response", @__FILE__, ":", @__LINE__, " $(Dates.now())") + println("\nERROR SQLLLM decisionMaker(). Attempt $attempt/$maxattempt. $errornote ", @__FILE__, ":", @__LINE__, " $(Dates.now())") continue end @@ -311,31 +308,29 @@ function decisionMaker(state::T1, context, text2textInstructLLM::Function, toollist = ["TABLEINFO", "RUNSQL"] if responsedict[:action_name] ∉ toollist - errornote = "\nYour previous attempt has action_name that is not in the tool list" - println("\nERROR SQLLLM decisionMaker() $errornote ", @__FILE__, ":", @__LINE__, " $(Dates.now())") + errornote = "Your previous attempt has action_name that is not in the tool list" + println("\nERROR SQLLLM decisionMaker(). Attempt $attempt/$maxattempt. $errornote ", @__FILE__, ":", @__LINE__, " $(Dates.now())") continue end for i in toollist if occursin(i, responsedict[:action_input]) - errornote = "\nYour previous attempt has action_name in action_input which is not allowed" - println("\nERROR SQLLLM decisionMaker() $errornote ", @__FILE__, ":", @__LINE__, " $(Dates.now())") + errornote = "Your previous attempt has action_name in action_input which is not allowed" + println("\nERROR SQLLLM decisionMaker(). Attempt $attempt/$maxattempt. $errornote ", @__FILE__, ":", @__LINE__, " $(Dates.now())") continue end end for i ∈ Symbol.(dictkey) if length(JSON3.write(responsedict[i])) == 0 - errornote = "\nYour previous attempt has empty value for $i" - println("\nERROR SQLLLM decisionMaker() $errornote ", @__FILE__, ":", @__LINE__, " $(Dates.now())") + errornote = "Your previous attempt has empty value for $i" + println("\nERROR SQLLLM decisionMaker(). Attempt $attempt/$maxattempt. $errornote ", @__FILE__, ":", @__LINE__, " $(Dates.now())") continue end end state[:decisionMaker] = responsedict - return responsedict - end error("SQLLLM DecisionMaker() failed to generate a thought \n", response) end @@ -361,7 +356,8 @@ julia> # Signature """ -function evaluator(state::T1, text2textInstructLLM::Function; maxattempt=10 +function evaluator(state::T1, text2textInstructLLM::Function, llmFormatName::String; + maxattempt=10 ) where {T1<:AbstractDict} systemmsg = @@ -446,13 +442,13 @@ function evaluator(state::T1, text2textInstructLLM::Function; maxattempt=10 ] # put in model format - prompt = GeneralUtils.formatLLMtext(_prompt, "granite3") + prompt = GeneralUtils.formatLLMtext(_prompt, llmFormatName) header = ["Trajectory_evaluation:", "Answer_evaluation:", "Accepted_as_answer:", "Score:", "Suggestion:"] dictkey = ["trajectory_evaluation", "answer_evaluation", "accepted_as_answer", "score", "suggestion"] response = text2textInstructLLM(prompt, modelsize="medium") - response = GeneralUtils.deFormatLLMtext(response, "granite3") + response = GeneralUtils.deFormatLLMtext(response, llmFormatName) # sometime LLM output something like **Comprehension**: which is not expected response = replace(response, "**"=>"") @@ -484,7 +480,7 @@ function evaluator(state::T1, text2textInstructLLM::Function; maxattempt=10 accepted_as_answer::AbstractString = responsedict[:accepted_as_answer] - if accepted_as_answer ∉ ["Yes", "No"] # [PENDING] add errornote into the prompt + if accepted_as_answer ∉ ["Yes", "No"] errornote = "Your previous attempt's accepted_as_answer has wrong format" println("\nERROR SQLLLM evaluator() Attempt $attempt/$maxattempt. $errornote ", @__FILE__, ":", @__LINE__, " $(Dates.now())") continue @@ -705,15 +701,17 @@ function transition(state::T, args::NamedTuple decisionMakerF::Function = args[:decisionMaker] evaluatorF::Function = args[:evaluator] - reflector::Function = args[:reflector] + # reflector::Function = args[:reflector] context = args[:context] executeSQL::Function = args[:executeSQL] text2textInstructLLM::Function = args[:text2textInstructLLM] - insertSQLVectorDB::Function = args[:insertSQLVectorDB] + # insertSQLVectorDB::Function = args[:insertSQLVectorDB] querySQLVectorDBF::Function = args[:querySQLVectorDB] + llmFormatName::String = args[:llmFormatName] # getting SQL from vectorDB - thoughtDict = decisionMakerF(state, context, text2textInstructLLM; querySQLVectorDBF) + thoughtDict = decisionMakerF(state, context, text2textInstructLLM, llmFormatName; + querySQLVectorDBF) # map action and input() to llm function response = @@ -727,7 +725,8 @@ function transition(state::T, args::NamedTuple elseif thoughtDict[:action_name] == "RUNSQL" response = SQLexecution(executeSQL, thoughtDict[:action_input]) if response[:success] - extracted = extractContent_dataframe(response[:result], text2textInstructLLM, thoughtDict[:action_input]) + extracted = extractContent_dataframe(response[:result], text2textInstructLLM, + thoughtDict[:action_input], llmFormatName) (rawresponse=response[:result], result=extracted, errormsg=nothing, success=true) else (result=nothing, errormsg=response[:errormsg], success=false) @@ -743,7 +742,7 @@ function transition(state::T, args::NamedTuple reward::Integer = haskey(response, :reward) ? response[:reward] : 0 isterminal::Bool = haskey(response, :isterminal) ? response[:isterminal] : false newNodeKey, newstate = makeNewState(state, thoughtDict, rawresponse, JSON3.write(result), select, reward, isterminal) - progressvalue::Integer = evaluatorF(newstate, text2textInstructLLM) + progressvalue::Integer = evaluatorF(newstate, text2textInstructLLM, llmFormatName) return (newNodeKey=newNodeKey, newstate=newstate, progressvalue=progressvalue) end @@ -835,6 +834,7 @@ julia> println(result) function query(query::T, executeSQL::Function, text2textInstructLLM::Function; insertSQLVectorDB::Union{Function, Nothing}=nothing, similarSQLVectorDB::Union{Function, Nothing}=nothing, + llmFormatName="qwen3" )::NamedTuple{(:text, :rawresponse), Tuple{Any, Any}} where {T<:AbstractString} # use similarSQLVectorDB to find similar SQL for the query @@ -997,6 +997,7 @@ function query(query::T, executeSQL::Function, text2textInstructLLM::Function; text2textInstructLLM=text2textInstructLLM, querySQLVectorDB=similarSQLVectorDB, insertSQLVectorDB=insertSQLVectorDB, + llmFormatName=llmFormatName ) earlystop(state) = state[:reward] >= 8 ? true : false @@ -1010,14 +1011,14 @@ function query(query::T, executeSQL::Function, text2textInstructLLM::Function; explorationweight=1.0, earlystop=earlystop, saveSimulatedNode=true, - multithread=false) + multithread=true) # compare all high value state answer then select the best one if length(highValueState) > 0 # open("/appfolder/app/highValueState.json", "w") do io # JSON3.pretty(io, highValueState) # end - selected = compareState(query, highValueState, text2textInstructLLM) + selected = compareState(query, highValueState, text2textInstructLLM, llmFormatName) resultState = highValueState[selected] #BUG compareState() select 0 end latestKey, latestInd = GeneralUtils.findHighestIndexKey(resultState[:thoughtHistory], "observation") @@ -1090,8 +1091,9 @@ function makeNewState(currentstate::T1, thoughtDict::T4, rawresponse, response:: end -function generatequestion(state::T1, context, text2textInstructLLM::Function; - similarSQL::Union{T2, Nothing}=nothing, maxattempt=10 +function generatequestion(state::T1, context, text2textInstructLLM::Function, + llmFormatName::String; + similarSQL::Union{T2, Nothing}=nothing, maxattempt=10, )::String where {T1<:AbstractDict, T2<:AbstractString} similarSQL = @@ -1144,9 +1146,10 @@ function generatequestion(state::T1, context, text2textInstructLLM::Function; Here are some examples: Q: What information in the hints is not necessary based on the query? A: Country is not specified in the query thus it should not be included in an SQL - Q: How can I modify a SQL example to fit my specific query needs? A: ... + Q: Why the query failed? + A: ... Let's begin! """ @@ -1181,10 +1184,10 @@ function generatequestion(state::T1, context, text2textInstructLLM::Function; ] # put in model format - prompt = GeneralUtils.formatLLMtext(_prompt, "granite3") + prompt = GeneralUtils.formatLLMtext(_prompt, llmFormatName) response = text2textInstructLLM(prompt, modelsize="medium") - response = GeneralUtils.deFormatLLMtext(response, "granite3") + response = GeneralUtils.deFormatLLMtext(response, llmFormatName) # check if response is valid q_number = count("Q", response) diff --git a/src/llmfunction.jl b/src/llmfunction.jl index dac143c..c026b67 100644 --- a/src/llmfunction.jl +++ b/src/llmfunction.jl @@ -347,8 +347,9 @@ end # Signature """ -function getdata_decisionMaker(state::Dict, context::Dict, text2textInstructLLM::Function -)::NamedTuple{(:thought, :code, :success, :errormsg),Tuple{Union{String,Nothing},Union{String,Nothing},Bool,Union{String,Nothing}}} +function getdata_decisionMaker(state::Dict, context::Dict, text2textInstructLLM::Function, + llmFormatName::String + )::NamedTuple{(:thought, :code, :success, :errormsg),Tuple{Union{String,Nothing},Union{String,Nothing},Bool,Union{String,Nothing}}} Hints = "None" @@ -406,10 +407,10 @@ function getdata_decisionMaker(state::Dict, context::Dict, text2textInstructLLM: ] # put in model format - prompt = GeneralUtils.formatLLMtext(_prompt, "granite3") + prompt = GeneralUtils.formatLLMtext(_prompt, llmFormatName) try response = text2textInstructLLM(prompt, modelsize="medium") - response = GeneralUtils.deFormatLLMtext(response, "granite3") + response = GeneralUtils.deFormatLLMtext(response, llmFormatName) header = ["Comprehension:", "Plan:", "Code:"] dictkey = ["comprehension", "plan", "code"] @@ -518,7 +519,7 @@ function SQLexecution(executeSQL::Function, sql::T tablesize = size(df) row, column = tablesize if row == 0 - error("The resulting table has 0 row. Possible causes: 1) Your search criteria might be too specific. Relaxing some conditions could yield better results. Remember, you can always refine your search later. 2) There could be a typo in your search query. 3) You might be searching in the wrong place.") + error("The resulting table has 0 row. Possible causes: 1) 1) Your search criteria might be overly specific. Consider removing or adjusting highly specific conditions (e.g., exact values, exact phrases, narrow ranges). Start with broader terms and refine your search incrementally. This often resolves empty result. 2) There could be a typo in your search query. 3) You might be searching in the wrong place.") elseif column > 30 error("SQL execution failed. An unexpected error occurred. Please try again.") end @@ -561,8 +562,9 @@ end # Signature """ -function extractContent_dataframe(df::DataFrame, text2textInstructLLM::Function, action::String -)::String +function extractContent_dataframe(df::DataFrame, text2textInstructLLM::Function, action::String, + llmFormatName::String + )::String tablesize = size(df) row = tablesize[1] column = tablesize[2] @@ -628,13 +630,13 @@ function extractContent_dataframe(df::DataFrame, text2textInstructLLM::Function, ] # put in model format - prompt = GeneralUtils.formatLLMtext(_prompt, "granite3") + prompt = GeneralUtils.formatLLMtext(_prompt, llmFormatName) header = ["About_resulting_table:", "Search_summary:"] dictkey = ["about_resulting_table", "search_summary"] for i in 1:5 response = text2textInstructLLM(prompt, modelsize="medium") - response = GeneralUtils.deFormatLLMtext(response, "granite3") + response = GeneralUtils.deFormatLLMtext(response, llmFormatName) kw = [] # use for loop and detect_keyword function to get the exact variation of each keyword in the text then push to kw list @@ -736,7 +738,9 @@ julia> result = SQLLLM.getTableNameFromSQL(sql, text2textInstructLLM) # Signature """ -function getTableNameFromSQL(sql::T, text2textInstructLLM::Function)::Vector{String} where {T<:AbstractString} +function getTableNameFromSQL(sql::T, text2textInstructLLM::Function, + llmFormatName::String + )::Vector{String} where {T<:AbstractString} systemmsg = """ Extract table name out of the user query. @@ -764,14 +768,14 @@ function getTableNameFromSQL(sql::T, text2textInstructLLM::Function)::Vector{Str ] # put in model format - prompt = GeneralUtils.formatLLMtext(_prompt, "granite3") + prompt = GeneralUtils.formatLLMtext(_prompt, llmFormatName) header = ["Table_name:"] dictkey = ["table_name"] for attempt in 1:5 try response = text2textInstructLLM(prompt, modelsize="medium") - response = GeneralUtils.deFormatLLMtext(response, "granite3") + response = GeneralUtils.deFormatLLMtext(response, llmFormatName) responsedict = GeneralUtils.textToDict(response, header; dictKey=dictkey, symbolkey=true) response = copy(JSON3.read(responsedict[:table_name])) @@ -820,44 +824,39 @@ julia> - The LLM evaluates attempts based on accuracy and relevance to the original question """ function compareState(question::String, highValueStateList::Vector{T}, - text2textInstructLLM::Function)::Integer where {T<:AbstractDict} + text2textInstructLLM::Function, llmFormatName::String + )::Integer where {T<:AbstractDict} systemmsg = """ - + Your profile: - You are a helpful assistant - - - The user has made multiple attempts to solve the question, resulting in various answers - + Situation: + - The user has made multiple attempts to solve the question, resulting in various answers + Your mission: - Identify and select the most accurate and relevant response from these multiple results for the user - - + At each round of conversation, you will be given the following: Question: the question the user is trying to answer Attempt: the user's attempted actions and their corresponding results - - + You should then respond to the user with the following: Comparison: a comparison of all results from all attempts Rationale: a brief explanation of why the selected response is the most accurate and relevant Selected_response_number: the number the selected response in the list of results (e.g., 1, 2, 3, ...) - - + You should only respond in format as described below: Comparison: ... Rationale: ... Selected_response_number: ... - - - User's question: "How many German wines do you have?" - Attempt 1: - Action: SELECT COUNT(*) FROM wines WHERE country = 'Germany' - Result: 100 wines - Attempt 2: - Action: SELECT COUNT(*) FROM wines WHERE country = 'Germany' AND type = 'Red' - Result: 50 red wines - Comparison: The second attempt counts only German red wines while the first attempt includes all German wines. - Rationale: The user is asking for the number of German wines without specifying a type, so the most accurate response is the first attempt because it includes all German wines. - Selected_response_number:1 - + Here are some examples: + User's question: "How many German wines do you have?" + Attempt 1: + Action: SELECT COUNT(*) FROM wines WHERE country = 'Germany' + Result: 100 wines + Attempt 2: + Action: SELECT COUNT(*) FROM wines WHERE country = 'Germany' AND type = 'Red' + Result: 50 red wines + Comparison: The second attempt counts only German red wines while the first attempt includes all German wines. + Rationale: The user is asking for the number of German wines without specifying a type, so the most accurate response is the first attempt because it includes all German wines. + Selected_response_number:1 Let's begin! """ @@ -918,7 +917,7 @@ function compareState(question::String, highValueStateList::Vector{T}, ] # put in model format - prompt = GeneralUtils.formatLLMtext(_prompt, "granite3") + prompt = GeneralUtils.formatLLMtext(_prompt, llmFormatName) header = ["Comparison:", "Rationale:", "Selected_response_number:"] dictkey = ["comparison", "rationale", "selected_response_number"] @@ -928,7 +927,7 @@ function compareState(question::String, highValueStateList::Vector{T}, # sometime LLM output something like **Comprehension**: which is not expected response = replace(response, "**"=>"") response = replace(response, "***"=>"") - response = GeneralUtils.deFormatLLMtext(response, "granite3") + response = GeneralUtils.deFormatLLMtext(response, llmFormatName) # make sure every header is in the response for i in header