diff --git a/src/interface.jl b/src/interface.jl index f38f852..56baccc 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -823,7 +823,7 @@ function transition(state::T, args::NamedTuple # so that other simulation start from this same node is not contaminated with actioninput listAllTable_json(executeSQL) elseif thoughtDict[:action_name] == "TABLEINFO" - input = thoughtDict[:action_input] # BUG thoughtDict[:action_input] = "\"wine\"" + input = thoughtDict[:action_input] tableinfo(executeSQL, input) elseif thoughtDict[:action_name] == "GETDATA" response = SQLexecution(executeSQL, thoughtDict[:action_input]) @@ -978,11 +978,13 @@ function query(query::T, executeSQL::Function, text2textInstructLLM::Function; earlystop(state) = state[:reward] >= 8 ? true : false - _, resultState = LLMMCTS.runMCTS(initialstate, transition, transitionargs; + _, _, resultState = LLMMCTS.runMCTS(initialstate, transition, transitionargs; horizontalSampleExpansionPhase=2, horizontalSampleSimulationPhase=1, - maxdepth=3, maxiterations=2, explorationweight=1.0, - earlystop=earlystop) + maxSimulationDepth=3, maxiterations=2, + explorationweight=1.0, + earlystop=earlystop, + saveSimulatedNode=true) latestKey, latestInd = GeneralUtils.findHighestIndexKey(resultState[:thoughtHistory], "observation") action_input = Symbol("action_input_$latestInd") # latest sql sql = resultState[:thoughtHistory][action_input] diff --git a/test/runtest.jl b/test/runtest.jl index af20ced..4b1c6f3 100644 --- a/test/runtest.jl +++ b/test/runtest.jl @@ -3,7 +3,7 @@ using LibPQ, Dates, JSON3, PrettyPrinting, UUIDs, DataFrames, DataStructures, Ba using GeneralUtils, SQLLLM -config = copy(JSON3.read("/appfolder/mountvolume/config.json")) +config = copy(JSON3.read("/appfolder/mountvolume/appdata/config.json")) function executeSQL(sql::T) where {T<:AbstractString} host = config[:externalservice][:wineDB][:host] @@ -148,9 +148,9 @@ sessionId = "555" # query = Dict(:text=> "How many wines from France do you have that can be paired with lamb?") # query = "How many wines are from United States?" # query = "retailer: Yiem, wine_type: red, sweetness: 1-2, intensity: 4-5, wine price: 20-40" -query = "wine_type: white, country: United States, sweetness: 1-2, tannin: 3, food to be served with wine: pizza" +# query = "wine_type: white, country: United States, sweetness: 1-2, tannin: 3, food to be served with wine: pizza" # query = "wine_type: white, country: Austria, food to be served with wine: pork" -# query = "wine price: less than 25, wine_type: rose, country: France, sweetness: 2, tannin: 3, food to be served with wine: pizza" +query = "wine price: less than 25, wine_type: rose, country: France, sweetness: 2, tannin: 3, food to be served with wine: pizza" # query = Dict(:text=> "wine_type: white, country: France, sweetness: 1") result = SQLLLM.query(query, executeSQL, text2textInstructLLM; insertSQLVectorDB=insertSQLVectorDB,