update
This commit is contained in:
@@ -224,17 +224,23 @@ function decisionMaker(state::T1, context, text2textInstructLLM::Function,
|
|||||||
"""
|
"""
|
||||||
<|start_header_id|>assistant<|end_header_id|>
|
<|start_header_id|>assistant<|end_header_id|>
|
||||||
"""
|
"""
|
||||||
|
|
||||||
try
|
|
||||||
response = text2textInstructLLM(prompt)
|
response = text2textInstructLLM(prompt)
|
||||||
println("\nSQL decisionMaker() rawresponse: ", response)
|
println("\nSQL decisionMaker() rawresponse: ", response)
|
||||||
|
|
||||||
|
if occursin("NULL", response)
|
||||||
|
errornote = "\nSQL decisionMaker() NULL response is not allowed"
|
||||||
|
println("Attempt $attempt $errornote ", @__FILE__, " ", @__LINE__)
|
||||||
|
break
|
||||||
|
end
|
||||||
|
|
||||||
header = ["Understanding", "Reasoning", "Plan", "Action_name", "Action_input", "Observation"]
|
header = ["Understanding", "Reasoning", "Plan", "Action_name", "Action_input", "Observation"]
|
||||||
|
|
||||||
# detect if there are more than 1 key per categories
|
# detect if there are more than 1 key per categories
|
||||||
count = GeneralUtils.countGivenWords(response, header)
|
count = GeneralUtils.countGivenWords(response, header)
|
||||||
if sum(count) > length(header)
|
if sum(count) > length(header)
|
||||||
error("\nSQL decisionMaker() duplicated keywords", @__FILE__, " ", @__LINE__)
|
errornote = "\nSQL decisionMaker() duplicated keywords"
|
||||||
|
println("Attempt $attempt $errornote ", @__FILE__, " ", @__LINE__)
|
||||||
|
break
|
||||||
end
|
end
|
||||||
|
|
||||||
# textToDict() search for action_input
|
# textToDict() search for action_input
|
||||||
@@ -257,18 +263,24 @@ function decisionMaker(state::T1, context, text2textInstructLLM::Function,
|
|||||||
|
|
||||||
toollist = ["TABLEINFO", "GETDATA"]
|
toollist = ["TABLEINFO", "GETDATA"]
|
||||||
if responsedict[:action_name] ∉ toollist
|
if responsedict[:action_name] ∉ toollist
|
||||||
error("SQL decisionMaker() didn't use the given functions ", @__FILE__, " ", @__LINE__)
|
errornote = "\nSQL decisionMaker() didn't use the given functions"
|
||||||
|
println("Attempt $attempt $errornote ", @__FILE__, " ", @__LINE__)
|
||||||
|
break
|
||||||
end
|
end
|
||||||
|
|
||||||
for i in toollist
|
for i in toollist
|
||||||
if occursin(i, responsedict[:action_input])
|
if occursin(i, responsedict[:action_input])
|
||||||
error("Action_name is in action_input which is not allowed.")
|
errornote = "\nSQL decisionMaker() action_name is in action_input which is not allowed."
|
||||||
|
println("Attempt $attempt $errornote ", @__FILE__, " ", @__LINE__)
|
||||||
|
break
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
for i ∈ [:understanding, :reasoning, :plan, :action_name, :action_input]
|
for i ∈ [:understanding, :reasoning, :plan, :action_name, :action_input]
|
||||||
if length(JSON3.write(responsedict[i])) == 0
|
if length(JSON3.write(responsedict[i])) == 0
|
||||||
error("$i is empty ", @__FILE__, " ", @__LINE__)
|
errornote = "\nSQL decisionMaker() $i is empty"
|
||||||
|
println("Attempt $attempt $errornote ", @__FILE__, " ", @__LINE__)
|
||||||
|
break
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
@@ -276,22 +288,15 @@ function decisionMaker(state::T1, context, text2textInstructLLM::Function,
|
|||||||
for i ∈ [:understanding, :reasoning, :plan, :action_name, :action_input]
|
for i ∈ [:understanding, :reasoning, :plan, :action_name, :action_input]
|
||||||
matchkeys = GeneralUtils.findMatchingDictKey(responsedict, i)
|
matchkeys = GeneralUtils.findMatchingDictKey(responsedict, i)
|
||||||
if length(matchkeys) > 1
|
if length(matchkeys) > 1
|
||||||
error("DecisionMaker has more than one key per categories")
|
errornote = "\nSQL decisionMaker() $i has more than one key"
|
||||||
|
println("Attempt $attempt $errornote ", @__FILE__, " ", @__LINE__)
|
||||||
|
break
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
state[:decisionMaker] = responsedict
|
state[:decisionMaker] = responsedict
|
||||||
|
|
||||||
return responsedict
|
return responsedict
|
||||||
catch e
|
|
||||||
io = IOBuffer()
|
|
||||||
showerror(io, e)
|
|
||||||
errorMsg = String(take!(io))
|
|
||||||
st = sprint((io, v) -> show(io, "text/plain", v), stacktrace(catch_backtrace()))
|
|
||||||
println("")
|
|
||||||
println("\n~~~ SQLLLM decisionMaker() Attempt $attempt. Error occurred: $errorMsg\n$st ", @__FILE__, " ", @__LINE__)
|
|
||||||
println("")
|
|
||||||
end
|
|
||||||
|
|
||||||
end
|
end
|
||||||
error("DecisionMaker failed to generate a thought ", response)
|
error("DecisionMaker failed to generate a thought ", response)
|
||||||
@@ -797,7 +802,7 @@ function transition(state::T, args::NamedTuple
|
|||||||
# so that other simulation start from this same node is not contaminated with actioninput
|
# so that other simulation start from this same node is not contaminated with actioninput
|
||||||
listAllTable_json(executeSQL)
|
listAllTable_json(executeSQL)
|
||||||
elseif thoughtDict[:action_name] == "TABLEINFO"
|
elseif thoughtDict[:action_name] == "TABLEINFO"
|
||||||
input = copy(JSON3.read(thoughtDict[:action_input]))
|
input = copy(JSON3.read(thoughtDict[:action_input])) # BUG thoughtDict[:action_input] = "\"wine\""
|
||||||
tableinfo(executeSQL, input)
|
tableinfo(executeSQL, input)
|
||||||
elseif thoughtDict[:action_name] == "GETDATA"
|
elseif thoughtDict[:action_name] == "GETDATA"
|
||||||
response = SQLexecution(executeSQL, thoughtDict[:action_input])
|
response = SQLexecution(executeSQL, thoughtDict[:action_input])
|
||||||
|
|||||||
@@ -520,65 +520,6 @@ julia> response = SQLLLM.SQLexecution(executeSQL, sql)
|
|||||||
|
|
||||||
# Signature
|
# Signature
|
||||||
"""
|
"""
|
||||||
# function SQLexecution(executeSQL::Function, sql::T
|
|
||||||
# )::NamedTuple{(:result, :success, :errormsg, :reward, :isterminal), Tuple{Union{DataFrame, Nothing}, Bool, Union{String, Nothing}, Integer, Bool}} where {T<:AbstractString}
|
|
||||||
# println("\n~~~ 1-01 ", @__FILE__, " ", @__LINE__)
|
|
||||||
# #XXX dummy SQL. use for testing
|
|
||||||
# # sql = "SELECT w.wine_name FROM wine w JOIN wine_food wf ON w.wine_id = wf.wine_id JOIN food f ON wf.food_id = f.food_id WHERE f.\"food_name\" = 'lamb';"
|
|
||||||
# # sql = " SELECT w.wine_name FROM wine w JOIN food f ON f.food_name = 'lamb' JOIN wine_food wf ON w.wine_id = wf.wine_id AND f.food_id = wf.food_id GROUP BY w.wine_name ORDER BY COUNT(DISTINCT w.wine_id) DESC;"
|
|
||||||
# # sql = " SELECT COUNT(DISTINCT wf.wine_id) FROM wine w JOIN wine_food wf ON w.wine_id = wf.wine_id JOIN food f ON wf.food_id = f.food_id WHERE f.food_name ILIKE '%lamb%'"
|
|
||||||
|
|
||||||
# #XXX use for package testing, remove when done
|
|
||||||
# # ans = "1.schilfwein zweigelt 2.cabernet sauvignon reserve limited edition"
|
|
||||||
# # ans = "There are 1500 wines that can be paired with lamb."
|
|
||||||
# # ans = "1500"
|
|
||||||
# # return (response=ans, errormsg=nothing, reward=1, isterminal=true)
|
|
||||||
|
|
||||||
# # add LIMIT to the SQL to prevent loading large data
|
|
||||||
# sql = strip(sql)
|
|
||||||
# println("\n~~~ SQL 1", @__FILE__, " ", @__LINE__)
|
|
||||||
# println(sql)
|
|
||||||
# println("\n~~~ 1-02 ", @__FILE__, " ", @__LINE__)
|
|
||||||
|
|
||||||
# if sql[end] != ';'
|
|
||||||
# errorMsg = "Error, SQL execution failed because it does not ended with ';'"
|
|
||||||
# return (result=nothing, success=false, errormsg=errorMsg, reward=0, isterminal=false)
|
|
||||||
# end
|
|
||||||
# println("\n~~~ 1-03 ", @__FILE__, " ", @__LINE__)
|
|
||||||
# if !occursin("LIMIT", sql)
|
|
||||||
# # sql = sql[1:end-1] * " LIMIT 100;"
|
|
||||||
# sql = sql[1:end-1] * " ORDER BY RANDOM() LIMIT 2;"
|
|
||||||
# end
|
|
||||||
|
|
||||||
# println("\n~~~ SQL 2", @__FILE__, " ", @__LINE__)
|
|
||||||
# println(sql)
|
|
||||||
# println("\n~~~ 1-1 ", @__FILE__, " ", @__LINE__)
|
|
||||||
# result = executeSQL(sql)
|
|
||||||
# println("\n~~~ 1-2 ", @__FILE__, " ", @__LINE__)
|
|
||||||
# df = DataFrame(result)
|
|
||||||
# println("\n~~~ raw df ", df)
|
|
||||||
# tablesize = size(df)
|
|
||||||
# println("\n~~~ df size ", tablesize)
|
|
||||||
# println("\n~~~ 6 ", @__FILE__, " ", @__LINE__)
|
|
||||||
# row = tablesize[1]
|
|
||||||
# println("\n~~~ 7 ", @__FILE__, " ", @__LINE__)
|
|
||||||
# if row == 0 # if 0 row
|
|
||||||
# errorMsg = "The resulting table has 0 row. Possible causes: 1) SQL is incorrect 2) There is no data that match your search criteria."
|
|
||||||
# return (result=nothing, success=false, errormsg=errorMsg, reward=0, isterminal=false)
|
|
||||||
# end
|
|
||||||
# println("\n~~~ 8 ", @__FILE__, " ", @__LINE__)
|
|
||||||
# df1 =
|
|
||||||
# if row > 2
|
|
||||||
# # ramdom row to pick
|
|
||||||
# df[sample(1:nrow(df), 2, replace=false), :] # random select 2 rows from df
|
|
||||||
# else
|
|
||||||
# df
|
|
||||||
# end
|
|
||||||
|
|
||||||
# println("\n~~~ SQLexecution result ", @__FILE__, " ", @__LINE__)
|
|
||||||
# println(df1)
|
|
||||||
# return (result=df1, success=true, errormsg=nothing, reward=1, isterminal=true)
|
|
||||||
# end
|
|
||||||
function SQLexecution(executeSQL::Function, sql::T
|
function SQLexecution(executeSQL::Function, sql::T
|
||||||
) where {T<:AbstractString}
|
) where {T<:AbstractString}
|
||||||
|
|
||||||
@@ -596,9 +537,12 @@ function SQLexecution(executeSQL::Function, sql::T
|
|||||||
|
|
||||||
# add LIMIT to the SQL to prevent loading large data
|
# add LIMIT to the SQL to prevent loading large data
|
||||||
sql = strip(sql)
|
sql = strip(sql)
|
||||||
|
|
||||||
|
# remove DISTINCT keyword because it is incompatible with RANDOM()
|
||||||
|
sql = replace(sql, "DISTINCT" => "")
|
||||||
|
|
||||||
if sql[end] == ';'
|
if sql[end] == ';'
|
||||||
if !occursin("LIMIT", sql)
|
if !occursin("LIMIT", sql)
|
||||||
# sql = sql[1:end-1] * " LIMIT 100;"
|
|
||||||
sql = sql[1:end-1] * " ORDER BY RANDOM() LIMIT 2;"
|
sql = sql[1:end-1] * " ORDER BY RANDOM() LIMIT 2;"
|
||||||
end
|
end
|
||||||
else
|
else
|
||||||
|
|||||||
Reference in New Issue
Block a user