This commit is contained in:
2024-12-27 20:33:24 +07:00
parent 1e36ea96e9
commit 38108d7e8d
2 changed files with 8 additions and 8 deletions

View File

@@ -1,7 +1,7 @@
name = "SQLLLM" name = "SQLLLM"
uuid = "2ebc79c7-cc10-4a3a-9665-d2e1d61e63d3" uuid = "2ebc79c7-cc10-4a3a-9665-d2e1d61e63d3"
authors = ["narawat lamaiin <narawat@outlook.com>"] authors = ["narawat lamaiin <narawat@outlook.com>"]
version = "0.2.0" version = "0.2.1"
[deps] [deps]
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"

View File

@@ -318,7 +318,7 @@ julia>
# Signature # Signature
""" """
function evaluator(state::T1, text2textInstructLLM::Function; function evaluator(state::T1, text2textInstructLLM::Function;
addSQLVectorDB::Union{Function, Nothing}=nothing insertSQLVectorDB::Union{Function, Nothing}=nothing
) where {T1<:AbstractDict} ) where {T1<:AbstractDict}
# systemmsg = # systemmsg =
@@ -784,7 +784,7 @@ function transition(state::T, args::NamedTuple
context = args[:context] context = args[:context]
executeSQL::Function = args[:executeSQL] executeSQL::Function = args[:executeSQL]
text2textInstructLLM::Function = args[:text2textInstructLLM] text2textInstructLLM::Function = args[:text2textInstructLLM]
addSQLVectorDBF::Function = args[:addSQLVectorDB] insertSQLVectorDB::Function = args[:insertSQLVectorDB]
querySQLVectorDBF::Function = args[:querySQLVectorDB] querySQLVectorDBF::Function = args[:querySQLVectorDB]
# getting SQL from vectorDB # getting SQL from vectorDB
@@ -820,7 +820,7 @@ function transition(state::T, args::NamedTuple
isterminal::Bool = haskey(response, :isterminal) ? response[:isterminal] : false isterminal::Bool = haskey(response, :isterminal) ? response[:isterminal] : false
newNodeKey, newstate = makeNewState(state, thoughtDict, rawresponse, JSON3.write(result), select, reward, isterminal) newNodeKey, newstate = makeNewState(state, thoughtDict, rawresponse, JSON3.write(result), select, reward, isterminal)
progressvalue::Integer = evaluatorF(newstate, text2textInstructLLM; progressvalue::Integer = evaluatorF(newstate, text2textInstructLLM;
addSQLVectorDB=addSQLVectorDBF) insertSQLVectorDB=insertSQLVectorDB)
return (newNodeKey=newNodeKey, newstate=newstate, progressvalue=progressvalue) return (newNodeKey=newNodeKey, newstate=newstate, progressvalue=progressvalue)
end end
@@ -903,7 +903,7 @@ julia> println(result)
# Signature # Signature
""" """
function query(query::T, executeSQL::Function, text2textInstructLLM::Function; function query(query::T, executeSQL::Function, text2textInstructLLM::Function;
addSQLVectorDB::Union{Function, Nothing}=nothing, insertSQLVectorDB::Union{Function, Nothing}=nothing,
similarSQLVectorDB::Union{Function, Nothing}=nothing, similarSQLVectorDB::Union{Function, Nothing}=nothing,
) where {T<:AbstractString} ) where {T<:AbstractString}
@@ -947,7 +947,7 @@ function query(query::T, executeSQL::Function, text2textInstructLLM::Function;
executeSQL=executeSQL, executeSQL=executeSQL,
text2textInstructLLM=text2textInstructLLM, text2textInstructLLM=text2textInstructLLM,
querySQLVectorDB=similarSQLVectorDB, querySQLVectorDB=similarSQLVectorDB,
addSQLVectorDB=addSQLVectorDB, insertSQLVectorDB=insertSQLVectorDB,
) )
earlystop(state) = state[:reward] >= 8 ? true : false earlystop(state) = state[:reward] >= 8 ? true : false
@@ -961,10 +961,10 @@ function query(query::T, executeSQL::Function, text2textInstructLLM::Function;
extracted = resultState[:thoughtHistory][latestKey] extracted = resultState[:thoughtHistory][latestKey]
# add to vectorDB only if the answer is achieved and the state is terminal # add to vectorDB only if the answer is achieved and the state is terminal
if addSQLVectorDB !== nothing && resultState[:isterminal] == true && if insertSQLVectorDB !== nothing && resultState[:isterminal] == true &&
resultState[:rawresponse] !== nothing resultState[:rawresponse] !== nothing
addSQLVectorDB(resultState[:thoughtHistory][:question], sql) insertSQLVectorDB(resultState[:thoughtHistory][:question], sql)
end end
return (text=extracted, rawresponse=resultState[:rawresponse]) return (text=extracted, rawresponse=resultState[:rawresponse])