This commit is contained in:
narawat lamaiin
2025-05-01 07:59:18 +07:00
parent bf223b64b2
commit 5112701dc2
2 changed files with 78 additions and 76 deletions

View File

@@ -102,8 +102,8 @@ Dict(
# Signature
"""
function decisionMaker(state::T1, context, text2textInstructLLM::Function,
; querySQLVectorDBF::Union{T2, Nothing}=nothing
function decisionMaker(state::T1, context, text2textInstructLLM::Function, llmFormatName::String
; querySQLVectorDBF::Union{T2, Nothing}=nothing, maxattempt=10
)::Dict{Symbol, Any} where {T1<:AbstractDict, T2<:Function}
# lessonDict =
@@ -203,13 +203,10 @@ function decisionMaker(state::T1, context, text2textInstructLLM::Function,
:temperature => 0.1,
)
for attempt in 1:10
if attempt > 1
println("\nERROR SQLLLM decisionMaker() attempt $attempt/10 ", @__FILE__, ":", @__LINE__, " $(Dates.now())")
llmkwargs[:temperature] = 0.1 * attempt
end
for attempt in 1:maxattempt
attempt > 1 ? llmkwargs[:temperature] += 0.1 : nothing
QandA = generatequestion(state, context, text2textInstructLLM; similarSQL=similarSQL_)
QandA = generatequestion(state, context, text2textInstructLLM, llmFormatName; similarSQL=similarSQL_)
usermsg =
"""
@@ -230,9 +227,9 @@ function decisionMaker(state::T1, context, text2textInstructLLM::Function,
]
# put in model format
prompt = GeneralUtils.formatLLMtext(_prompt, "granite3")
prompt = GeneralUtils.formatLLMtext(_prompt, llmFormatName)
response = text2textInstructLLM(prompt; llmkwargs=llmkwargs)
response = GeneralUtils.deFormatLLMtext(response, "granite3")
response = GeneralUtils.deFormatLLMtext(response, llmFormatName)
# LLM tends to generate observation given that it is in the input
response =
@@ -261,7 +258,7 @@ function decisionMaker(state::T1, context, text2textInstructLLM::Function,
if occursin("NULL", response)
errornote = "\nYour previous attempt was NULL. This is not allowed"
println("\nERROR SQLLLM decisionMaker() $errornote ", @__FILE__, ":", @__LINE__, " $(Dates.now())")
println("\nERROR SQLLLM decisionMaker(). Attempt $attempt/$maxattempt. $errornote ", @__FILE__, ":", @__LINE__, " $(Dates.now())")
continue
end
@@ -284,11 +281,11 @@ function decisionMaker(state::T1, context, text2textInstructLLM::Function,
detected_kw = GeneralUtils.detect_keyword(header, response)
if 0 values(detected_kw)
errornote = "\nYour previous attempt did not have all points according to the required response format"
println("\nERROR SQLLLM decisionMaker() $errornote \n$response", @__FILE__, ":", @__LINE__, " $(Dates.now())")
println("\nERROR SQLLLM decisionMaker(). Attempt $attempt/$maxattempt. $errornote ", @__FILE__, ":", @__LINE__, " $(Dates.now())")
continue
elseif sum(values(detected_kw)) > length(header)
errornote = "\nYour previous attempt has duplicated points according to the required response format"
println("\nERROR SQLLLM decisionMaker() $errornote \n$response", @__FILE__, ":", @__LINE__, " $(Dates.now())")
println("\nERROR SQLLLM decisionMaker(). Attempt $attempt/$maxattempt. $errornote ", @__FILE__, ":", @__LINE__, " $(Dates.now())")
continue
end
@@ -311,31 +308,29 @@ function decisionMaker(state::T1, context, text2textInstructLLM::Function,
toollist = ["TABLEINFO", "RUNSQL"]
if responsedict[:action_name] toollist
errornote = "\nYour previous attempt has action_name that is not in the tool list"
println("\nERROR SQLLLM decisionMaker() $errornote ", @__FILE__, ":", @__LINE__, " $(Dates.now())")
errornote = "Your previous attempt has action_name that is not in the tool list"
println("\nERROR SQLLLM decisionMaker(). Attempt $attempt/$maxattempt. $errornote ", @__FILE__, ":", @__LINE__, " $(Dates.now())")
continue
end
for i in toollist
if occursin(i, responsedict[:action_input])
errornote = "\nYour previous attempt has action_name in action_input which is not allowed"
println("\nERROR SQLLLM decisionMaker() $errornote ", @__FILE__, ":", @__LINE__, " $(Dates.now())")
errornote = "Your previous attempt has action_name in action_input which is not allowed"
println("\nERROR SQLLLM decisionMaker(). Attempt $attempt/$maxattempt. $errornote ", @__FILE__, ":", @__LINE__, " $(Dates.now())")
continue
end
end
for i Symbol.(dictkey)
if length(JSON3.write(responsedict[i])) == 0
errornote = "\nYour previous attempt has empty value for $i"
println("\nERROR SQLLLM decisionMaker() $errornote ", @__FILE__, ":", @__LINE__, " $(Dates.now())")
errornote = "Your previous attempt has empty value for $i"
println("\nERROR SQLLLM decisionMaker(). Attempt $attempt/$maxattempt. $errornote ", @__FILE__, ":", @__LINE__, " $(Dates.now())")
continue
end
end
state[:decisionMaker] = responsedict
return responsedict
end
error("SQLLLM DecisionMaker() failed to generate a thought \n", response)
end
@@ -361,7 +356,8 @@ julia>
# Signature
"""
function evaluator(state::T1, text2textInstructLLM::Function; maxattempt=10
function evaluator(state::T1, text2textInstructLLM::Function, llmFormatName::String;
maxattempt=10
) where {T1<:AbstractDict}
systemmsg =
@@ -446,13 +442,13 @@ function evaluator(state::T1, text2textInstructLLM::Function; maxattempt=10
]
# put in model format
prompt = GeneralUtils.formatLLMtext(_prompt, "granite3")
prompt = GeneralUtils.formatLLMtext(_prompt, llmFormatName)
header = ["Trajectory_evaluation:", "Answer_evaluation:", "Accepted_as_answer:", "Score:", "Suggestion:"]
dictkey = ["trajectory_evaluation", "answer_evaluation", "accepted_as_answer", "score", "suggestion"]
response = text2textInstructLLM(prompt, modelsize="medium")
response = GeneralUtils.deFormatLLMtext(response, "granite3")
response = GeneralUtils.deFormatLLMtext(response, llmFormatName)
# sometime LLM output something like **Comprehension**: which is not expected
response = replace(response, "**"=>"")
@@ -484,7 +480,7 @@ function evaluator(state::T1, text2textInstructLLM::Function; maxattempt=10
accepted_as_answer::AbstractString = responsedict[:accepted_as_answer]
if accepted_as_answer ["Yes", "No"] # [PENDING] add errornote into the prompt
if accepted_as_answer ["Yes", "No"]
errornote = "Your previous attempt's accepted_as_answer has wrong format"
println("\nERROR SQLLLM evaluator() Attempt $attempt/$maxattempt. $errornote ", @__FILE__, ":", @__LINE__, " $(Dates.now())")
continue
@@ -705,15 +701,17 @@ function transition(state::T, args::NamedTuple
decisionMakerF::Function = args[:decisionMaker]
evaluatorF::Function = args[:evaluator]
reflector::Function = args[:reflector]
# reflector::Function = args[:reflector]
context = args[:context]
executeSQL::Function = args[:executeSQL]
text2textInstructLLM::Function = args[:text2textInstructLLM]
insertSQLVectorDB::Function = args[:insertSQLVectorDB]
# insertSQLVectorDB::Function = args[:insertSQLVectorDB]
querySQLVectorDBF::Function = args[:querySQLVectorDB]
llmFormatName::String = args[:llmFormatName]
# getting SQL from vectorDB
thoughtDict = decisionMakerF(state, context, text2textInstructLLM; querySQLVectorDBF)
thoughtDict = decisionMakerF(state, context, text2textInstructLLM, llmFormatName;
querySQLVectorDBF)
# map action and input() to llm function
response =
@@ -727,7 +725,8 @@ function transition(state::T, args::NamedTuple
elseif thoughtDict[:action_name] == "RUNSQL"
response = SQLexecution(executeSQL, thoughtDict[:action_input])
if response[:success]
extracted = extractContent_dataframe(response[:result], text2textInstructLLM, thoughtDict[:action_input])
extracted = extractContent_dataframe(response[:result], text2textInstructLLM,
thoughtDict[:action_input], llmFormatName)
(rawresponse=response[:result], result=extracted, errormsg=nothing, success=true)
else
(result=nothing, errormsg=response[:errormsg], success=false)
@@ -743,7 +742,7 @@ function transition(state::T, args::NamedTuple
reward::Integer = haskey(response, :reward) ? response[:reward] : 0
isterminal::Bool = haskey(response, :isterminal) ? response[:isterminal] : false
newNodeKey, newstate = makeNewState(state, thoughtDict, rawresponse, JSON3.write(result), select, reward, isterminal)
progressvalue::Integer = evaluatorF(newstate, text2textInstructLLM)
progressvalue::Integer = evaluatorF(newstate, text2textInstructLLM, llmFormatName)
return (newNodeKey=newNodeKey, newstate=newstate, progressvalue=progressvalue)
end
@@ -835,6 +834,7 @@ julia> println(result)
function query(query::T, executeSQL::Function, text2textInstructLLM::Function;
insertSQLVectorDB::Union{Function, Nothing}=nothing,
similarSQLVectorDB::Union{Function, Nothing}=nothing,
llmFormatName="qwen3"
)::NamedTuple{(:text, :rawresponse), Tuple{Any, Any}} where {T<:AbstractString}
# use similarSQLVectorDB to find similar SQL for the query
@@ -997,6 +997,7 @@ function query(query::T, executeSQL::Function, text2textInstructLLM::Function;
text2textInstructLLM=text2textInstructLLM,
querySQLVectorDB=similarSQLVectorDB,
insertSQLVectorDB=insertSQLVectorDB,
llmFormatName=llmFormatName
)
earlystop(state) = state[:reward] >= 8 ? true : false
@@ -1010,14 +1011,14 @@ function query(query::T, executeSQL::Function, text2textInstructLLM::Function;
explorationweight=1.0,
earlystop=earlystop,
saveSimulatedNode=true,
multithread=false)
multithread=true)
# compare all high value state answer then select the best one
if length(highValueState) > 0
# open("/appfolder/app/highValueState.json", "w") do io
# JSON3.pretty(io, highValueState)
# end
selected = compareState(query, highValueState, text2textInstructLLM)
selected = compareState(query, highValueState, text2textInstructLLM, llmFormatName)
resultState = highValueState[selected] #BUG compareState() select 0
end
latestKey, latestInd = GeneralUtils.findHighestIndexKey(resultState[:thoughtHistory], "observation")
@@ -1090,8 +1091,9 @@ function makeNewState(currentstate::T1, thoughtDict::T4, rawresponse, response::
end
function generatequestion(state::T1, context, text2textInstructLLM::Function;
similarSQL::Union{T2, Nothing}=nothing, maxattempt=10
function generatequestion(state::T1, context, text2textInstructLLM::Function,
llmFormatName::String;
similarSQL::Union{T2, Nothing}=nothing, maxattempt=10,
)::String where {T1<:AbstractDict, T2<:AbstractString}
similarSQL =
@@ -1144,9 +1146,10 @@ function generatequestion(state::T1, context, text2textInstructLLM::Function;
Here are some examples:
Q: What information in the hints is not necessary based on the query?
A: Country is not specified in the query thus it should not be included in an SQL
Q: How can I modify a SQL example to fit my specific query needs?
A: ...
Q: Why the query failed?
A: ...
Let's begin!
"""
@@ -1181,10 +1184,10 @@ function generatequestion(state::T1, context, text2textInstructLLM::Function;
]
# put in model format
prompt = GeneralUtils.formatLLMtext(_prompt, "granite3")
prompt = GeneralUtils.formatLLMtext(_prompt, llmFormatName)
response = text2textInstructLLM(prompt, modelsize="medium")
response = GeneralUtils.deFormatLLMtext(response, "granite3")
response = GeneralUtils.deFormatLLMtext(response, llmFormatName)
# check if response is valid
q_number = count("Q", response)

View File

@@ -347,8 +347,9 @@ end
# Signature
"""
function getdata_decisionMaker(state::Dict, context::Dict, text2textInstructLLM::Function
)::NamedTuple{(:thought, :code, :success, :errormsg),Tuple{Union{String,Nothing},Union{String,Nothing},Bool,Union{String,Nothing}}}
function getdata_decisionMaker(state::Dict, context::Dict, text2textInstructLLM::Function,
llmFormatName::String
)::NamedTuple{(:thought, :code, :success, :errormsg),Tuple{Union{String,Nothing},Union{String,Nothing},Bool,Union{String,Nothing}}}
Hints = "None"
@@ -406,10 +407,10 @@ function getdata_decisionMaker(state::Dict, context::Dict, text2textInstructLLM:
]
# put in model format
prompt = GeneralUtils.formatLLMtext(_prompt, "granite3")
prompt = GeneralUtils.formatLLMtext(_prompt, llmFormatName)
try
response = text2textInstructLLM(prompt, modelsize="medium")
response = GeneralUtils.deFormatLLMtext(response, "granite3")
response = GeneralUtils.deFormatLLMtext(response, llmFormatName)
header = ["Comprehension:", "Plan:", "Code:"]
dictkey = ["comprehension", "plan", "code"]
@@ -518,7 +519,7 @@ function SQLexecution(executeSQL::Function, sql::T
tablesize = size(df)
row, column = tablesize
if row == 0
error("The resulting table has 0 row. Possible causes: 1) Your search criteria might be too specific. Relaxing some conditions could yield better results. Remember, you can always refine your search later. 2) There could be a typo in your search query. 3) You might be searching in the wrong place.")
error("The resulting table has 0 row. Possible causes: 1) 1) Your search criteria might be overly specific. Consider removing or adjusting highly specific conditions (e.g., exact values, exact phrases, narrow ranges). Start with broader terms and refine your search incrementally. This often resolves empty result. 2) There could be a typo in your search query. 3) You might be searching in the wrong place.")
elseif column > 30
error("SQL execution failed. An unexpected error occurred. Please try again.")
end
@@ -561,8 +562,9 @@ end
# Signature
"""
function extractContent_dataframe(df::DataFrame, text2textInstructLLM::Function, action::String
)::String
function extractContent_dataframe(df::DataFrame, text2textInstructLLM::Function, action::String,
llmFormatName::String
)::String
tablesize = size(df)
row = tablesize[1]
column = tablesize[2]
@@ -628,13 +630,13 @@ function extractContent_dataframe(df::DataFrame, text2textInstructLLM::Function,
]
# put in model format
prompt = GeneralUtils.formatLLMtext(_prompt, "granite3")
prompt = GeneralUtils.formatLLMtext(_prompt, llmFormatName)
header = ["About_resulting_table:", "Search_summary:"]
dictkey = ["about_resulting_table", "search_summary"]
for i in 1:5
response = text2textInstructLLM(prompt, modelsize="medium")
response = GeneralUtils.deFormatLLMtext(response, "granite3")
response = GeneralUtils.deFormatLLMtext(response, llmFormatName)
kw = []
# use for loop and detect_keyword function to get the exact variation of each keyword in the text then push to kw list
@@ -736,7 +738,9 @@ julia> result = SQLLLM.getTableNameFromSQL(sql, text2textInstructLLM)
# Signature
"""
function getTableNameFromSQL(sql::T, text2textInstructLLM::Function)::Vector{String} where {T<:AbstractString}
function getTableNameFromSQL(sql::T, text2textInstructLLM::Function,
llmFormatName::String
)::Vector{String} where {T<:AbstractString}
systemmsg = """
Extract table name out of the user query.
@@ -764,14 +768,14 @@ function getTableNameFromSQL(sql::T, text2textInstructLLM::Function)::Vector{Str
]
# put in model format
prompt = GeneralUtils.formatLLMtext(_prompt, "granite3")
prompt = GeneralUtils.formatLLMtext(_prompt, llmFormatName)
header = ["Table_name:"]
dictkey = ["table_name"]
for attempt in 1:5
try
response = text2textInstructLLM(prompt, modelsize="medium")
response = GeneralUtils.deFormatLLMtext(response, "granite3")
response = GeneralUtils.deFormatLLMtext(response, llmFormatName)
responsedict = GeneralUtils.textToDict(response, header;
dictKey=dictkey, symbolkey=true)
response = copy(JSON3.read(responsedict[:table_name]))
@@ -820,33 +824,29 @@ julia>
- The LLM evaluates attempts based on accuracy and relevance to the original question
"""
function compareState(question::String, highValueStateList::Vector{T},
text2textInstructLLM::Function)::Integer where {T<:AbstractDict}
text2textInstructLLM::Function, llmFormatName::String
)::Integer where {T<:AbstractDict}
systemmsg =
"""
<Your profile>
Your profile:
- You are a helpful assistant
</Your profile>
<Situation>
The user has made multiple attempts to solve the question, resulting in various answers
<Your mission>
Situation:
- The user has made multiple attempts to solve the question, resulting in various answers
Your mission:
- Identify and select the most accurate and relevant response from these multiple results for the user
</Your mission>
<At each round of conversation, you will be given the following>
At each round of conversation, you will be given the following:
Question: the question the user is trying to answer
Attempt: the user's attempted actions and their corresponding results
</At each round of conversation, you will be given the following>
<You should then respond to the user with the following>
You should then respond to the user with the following:
Comparison: a comparison of all results from all attempts
Rationale: a brief explanation of why the selected response is the most accurate and relevant
Selected_response_number: the number the selected response in the list of results (e.g., 1, 2, 3, ...)
</You should then respond to the user with the following>
<You should only respond in format as described below>
You should only respond in format as described below:
Comparison: ...
Rationale: ...
Selected_response_number: ...
</You should only respond in format as described below>
<Here are some examples>
Here are some examples:
User's question: "How many German wines do you have?"
Attempt 1:
Action: SELECT COUNT(*) FROM wines WHERE country = 'Germany'
@@ -857,7 +857,6 @@ function compareState(question::String, highValueStateList::Vector{T},
Comparison: The second attempt counts only German red wines while the first attempt includes all German wines.
Rationale: The user is asking for the number of German wines without specifying a type, so the most accurate response is the first attempt because it includes all German wines.
Selected_response_number:1
</Here are some examples>
Let's begin!
"""
@@ -918,7 +917,7 @@ function compareState(question::String, highValueStateList::Vector{T},
]
# put in model format
prompt = GeneralUtils.formatLLMtext(_prompt, "granite3")
prompt = GeneralUtils.formatLLMtext(_prompt, llmFormatName)
header = ["Comparison:", "Rationale:", "Selected_response_number:"]
dictkey = ["comparison", "rationale", "selected_response_number"]
@@ -928,7 +927,7 @@ function compareState(question::String, highValueStateList::Vector{T},
# sometime LLM output something like **Comprehension**: which is not expected
response = replace(response, "**"=>"")
response = replace(response, "***"=>"")
response = GeneralUtils.deFormatLLMtext(response, "granite3")
response = GeneralUtils.deFormatLLMtext(response, llmFormatName)
# make sure every header is in the response
for i in header