This commit is contained in:
narawat lamaiin
2024-10-14 09:12:06 +07:00
parent 7b2d85da48
commit 4ef968b86e
8 changed files with 473 additions and 432 deletions

View File

@@ -250,8 +250,7 @@ function decisionMaker(state::T1, context, text2textInstructLLM::Function,
error("DecisionMaker has more than one key per categories")
end
end
println("--> SQLLLM decisionMaker() ", @__FILE__, " ", @__LINE__)
pprintln(Dict(responsedict))
return responsedict
catch e
io = IOBuffer()
@@ -499,7 +498,7 @@ julia>
function evaluator(state::T1, text2textInstructLLM::Function;
addSQLVectorDB::Union{Function, Nothing}=nothing
) where {T1<:AbstractDict}
println("Evaluating state", @__FILE__, " ", @__LINE__)
# systemmsg =
# """
# You are a helpful assistant that analyzes agent's trajectories to find solutions and observations (i.e., the results of actions) to answer the user's questions.
@@ -737,14 +736,16 @@ function evaluator(state::T1, text2textInstructLLM::Function;
# mark as terminal state when the answer is achieved
if accepted_as_answer == "Yes"
state[:isterminal] = true
state[:reward] = 1
# user score as reward because different answers hold different value for the user.
state[:reward] = responsedict[:score]
#add to vectorDB
if addSQLVectorDB !== nothing
addSQLVectorDB(state)
end
end
println("--> 5 Evaluator ", @__FILE__, " ", @__LINE__)
println("~~~ 5 Evaluator() ", @__FILE__, " ", @__LINE__)
pprintln(Dict(responsedict))
return responsedict[:score]
@@ -953,7 +954,7 @@ julia> state = Dict(
# TODO
- [] add embedding of newstate and store in newstate[:embedding]
- [WORKING] should getdata() return isterminal?
# Signature
"""
function transition(state::T, args::NamedTuple
@@ -992,17 +993,13 @@ function transition(state::T, args::NamedTuple
else
error("undefined LLM function. Requesting $actionname")
end
# this section allow LLM functions above to have different return values.
result = haskey(response, :result) ? response[:result] : nothing
success::Bool = haskey(response, :success) ? response[:success] : false
result = success ? response[:result] : response[:errormsg]
select = haskey(response, :select) ? response[:select] : nothing
reward::Integer = haskey(response, :reward) ? response[:reward] : 0
isterminal::Bool = haskey(response, :isterminal) ? response[:isterminal] : false
errormsg::Union{AbstractString, Nothing} = haskey(response, :errormsg) ? response[:errormsg] : nothing
success::Bool = haskey(response, :success) ? response[:success] : false
newNodeKey, newstate = makeNewState(state, thoughtDict, JSON3.write(result), select, reward, isterminal)
println("SQLLLM transition() 1 ", @__FILE__, " ", @__LINE__)
progressvalue::Integer = evaluatorF(newstate, text2textInstructLLM;
addSQLVectorDB=addSQLVectorDBF)
@@ -1090,14 +1087,14 @@ function query(query::T, executeSQL::Function, text2textInstructLLM::Function;
addSQLVectorDB::Union{Function, Nothing}=nothing,
querySQLVectorDB::Union{Function, Nothing}=nothing
)::String where {T<:AbstractString}
#[WORKING] add extra context for Evaluator so that it knows the observation is from seaching a database
# add extra context for Evaluator so that it knows the observation is from seaching a database
query = "Search the database for {$query}"
initialstate = Dict{Symbol, Any}(
:reward=> 0,
:isterminal=> false,
:evaluation=> "None",
:suggestion=> "None",
:evaluationscore=> 0,
:suggestion=> "None",
:accepted_as_answer=> "No",
:lesson=> nothing,
@@ -1121,10 +1118,13 @@ function query(query::T, executeSQL::Function, text2textInstructLLM::Function;
addSQLVectorDB=addSQLVectorDB,
)
_, result = LLMMCTS.runMCTS(initialstate, transition, transitionargs;
totalsample=1, maxdepth=3, maxiterations=1, explorationweight=1.0)
latestKey, _ = GeneralUtils.findHighestIndexKey(result[:thoughtHistory], "observation")
resulttext = result[:thoughtHistory][latestKey]
earlystop(state) = state[:reward] >= 8 ? true : false
_, resultState = LLMMCTS.runMCTS(initialstate, transition, transitionargs;
totalsample=1, maxdepth=3, maxiterations=3, explorationweight=1.0,
earlystop=earlystop)
latestKey, _ = GeneralUtils.findHighestIndexKey(resultState[:thoughtHistory], "observation")
resulttext = resultState[:thoughtHistory][latestKey]
return resulttext
end