From e9c91fdb4de0b3cca68368d16b3f22c9924eb238 Mon Sep 17 00:00:00 2001 From: narawat lamaiin Date: Wed, 15 May 2024 13:35:26 +0700 Subject: [PATCH] update --- src/interface.jl | 131 ++++++++++++++++++++++++++++++++++----------- src/llmfunction.jl | 4 +- src/mcts.jl | 100 ++++++++++++++++++++-------------- src/type.jl | 4 +- test/runtest.jl | 15 +++++- test/test_1.jl | 2 + 6 files changed, 179 insertions(+), 77 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 8eba29b..08f44b5 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -39,7 +39,8 @@ macro executeStringFunction(functionStr, args...) func_expr = Meta.parse(functionStr) # Create a new function with the parsed expression - function_to_call = eval(Expr(:function, Expr(:call, func_expr, args...), func_expr.args[2:end]...)) + function_to_call = eval(Expr(:function, Expr(:call, func_expr, args...), + func_expr.args[2:end]...)) # Call the newly created function with the provided arguments function_to_call(args...) @@ -744,47 +745,61 @@ julia> response = ChatAgent.conversation(newAgent, "Hi! how are you?") # Signature """ function conversation(a::T, userinput::Dict) where {T<:agent} - - # "newtopic" command to delete chat history if userinput[:text] == "newtopic" clearhistory(a) - return "Okay. What shall we talk about?" - else # add usermsg to a.chathistory addNewMessage(a, "user", userinput[:text]) - currentstate = if isempty(a.plan[:currenttrajectory]) - # set up initial state - Dict{Symbol, Any}( - # deepcopy the info to prevent modifying the info unintentionally during MCTS planning - :customerinfo=> deepcopy(a.keywordinfo[:customerinfo]), - :storeinfo=> deepcopy(a.keywordinfo[:storeinfo]), - :userselect=> nothing, - :reward=> 0, - :isterminal=> false, - :evaluation=> nothing, - :lesson=> nothing, - :thoughtDict=> nothing, - :totalTrajectoryReward=> nothing, - :thoughtHistory=> OrderedDict{Symbol, Any}( # contain question, thought_1, action_1, observation_1, thought_2, ... - # :recap=>, - :question=> userinput[:text], - ) - ) + a.plan[:currenttrajectory] = Dict{Symbol, Any}( + # deepcopy the info to prevent modifying the info unintentionally during MCTS planning + :customerinfo=> deepcopy(a.keywordinfo[:customerinfo]), + :storeinfo=> deepcopy(a.keywordinfo[:storeinfo]), + :userselect=> nothing, + :reward=> 0, + :isterminal=> false, + :evaluation=> nothing, + :lesson=> nothing, + + :totalTrajectoryReward=> nothing, + + # contain question, thought_1, action_1, observation_1, thought_2, ... + :thoughtHistory=> OrderedDict{Symbol, Any}( + #[] :recap=>, + :question=> userinput[:text], + ) + ) else - a.plan[:currenttrajectory] + _, a.plan[:currenttrajectory] = makeNewState(a.plan[:currenttrajectory], + a.plan[:activeplan][:thoughtHistory], userinput[:text], userinput[:select], + userinput[:reward], userinput[:isterminal]) end + end + + while true + bestNextState, besttrajectory = runMCTS(a, a.plan[:currenttrajectory], decisionMaker, + evaluator, reflector, totalsample=2, maxDepth=2, maxiterations=1, explorationweight=1.0) + a.plan[:activeplan] = bestNextState + + latestActionKey, latestActionIndice = + GeneralUtils.findHighestIndexKey(bestNextState[:thoughtHistory], "action") + actionname = bestNextState[:thoughtHistory][latestActionKey][:name] + actioninput = bestNextState[:thoughtHistory][latestActionKey][:input] - bestNextState, besttrajectory = runMCTS(a, currentstate, decisionMaker, evaluator, reflector, - totalsample=3, maxDepth=2, maxiterations=1, explorationweight=1.0) - # transition - newstate = transition(a, bestNextState) - a.plan[:currenttrajectory] = newstate - + if actionname == "chatbox" + # add usermsg to a.chathistory + addNewMessage(a, "assistant", actioninput) + return actioninput + elseif actionname == "recommendbox" + # add usermsg to a.chathistory + addNewMessage(a, "assistant", actioninput) + return actioninput + else + _, a.plan[:currenttrajectory] = transition(a, a.plan[:activeplan]) + end end end @@ -797,6 +812,62 @@ end +# function conversation(a::T, userinput::Dict) where {T<:agent} + +# # get new user msg from a.receiveUserMsgChannel + + +# # "newtopic" command to delete chat history +# if userinput[:text] == "newtopic" +# clearhistory(a) + +# return "Okay. What shall we talk about?" + +# else +# # add usermsg to a.chathistory +# addNewMessage(a, "user", userinput[:text]) + +# currentstate = +# if isempty(a.plan[:currenttrajectory]) +# # set up initial state +# Dict{Symbol, Any}( +# # deepcopy the info to prevent modifying the info unintentionally during MCTS planning +# :customerinfo=> deepcopy(a.keywordinfo[:customerinfo]), +# :storeinfo=> deepcopy(a.keywordinfo[:storeinfo]), +# :userselect=> nothing, +# :reward=> 0, +# :isterminal=> false, +# :evaluation=> nothing, +# :lesson=> nothing, +# :thoughtDict=> nothing, +# :totalTrajectoryReward=> nothing, +# :thoughtHistory=> OrderedDict{Symbol, Any}( # contain question, thought_1, action_1, observation_1, thought_2, ... +# # :recap=>, +# :question=> userinput[:text], +# ) +# ) +# else +# a.plan[:currenttrajectory] +# end + +# bestNextState, besttrajectory = runMCTS(a, currentstate, decisionMaker, evaluator, reflector, +# totalsample=3, maxDepth=2, maxiterations=1, explorationweight=1.0) + +# # transition +# newstate = transition(a, bestNextState) +# a.plan[:currenttrajectory] = newstate + +# end +# end + + + + + + + + + diff --git a/src/llmfunction.jl b/src/llmfunction.jl index 523244a..98df02f 100644 --- a/src/llmfunction.jl +++ b/src/llmfunction.jl @@ -22,7 +22,7 @@ julia> # TODO - [] update docstring - - [PENDING] implement the function + - [WORKING] implement the function # Signature """ @@ -293,7 +293,7 @@ function jsoncorrection(a::T1, input::T2, correctjson = incorrectjson return correctjson catch e - @warn "Attempting correct JSON string. Attempt $attempt" + @warn "Attempting to correct JSON string. Attempt $attempt" e = """$e""" if occursin("EOF", e) e = split(e, "EOF")[1] * "EOF" diff --git a/src/mcts.jl b/src/mcts.jl index c9e2df1..5632351 100644 --- a/src/mcts.jl +++ b/src/mcts.jl @@ -6,7 +6,7 @@ module mcts export MCTSNode, runMCTS, isleaf, selectBestNextState, selectBestTrajectory, transition, - userChatbox + userChatbox, makeNewState using Dates, UUIDs, DataStructures, JSON3, Random, PrettyPrinting using GeneralUtils @@ -144,10 +144,9 @@ function expand(a::T1, node::MCTSNode, decisionMaker::Function, println("---> expand() sample $nthSample") pprintln(node.state[:thoughtHistory]) pprintln(thoughtDict) - node.state[:thoughtDict] = thoughtDict - newNodeKey, newstate = MCTStransition(a, node.state) + newNodeKey, newstate = MCTStransition(a, node.state, thoughtDict) - # add evaluator + stateevaluation, progressvalue = evaluator(a, newstate) if newstate[:reward] < 0 @@ -294,10 +293,9 @@ julia> thoughtDict = Dict( # Signature """ -function MCTStransition(a::T1, state::T2 +function MCTStransition(a::T1, state::T2, thoughtDict::T2 )::Tuple{String, Dict{Symbol, <:Any}} where {T1<:agent, T2<:AbstractDict} - thoughtDict = state[:thoughtDict] actionname = thoughtDict[:action][:name] actioninput = thoughtDict[:action][:input] @@ -313,25 +311,7 @@ function MCTStransition(a::T1, state::T2 error("undefined LLM function. Requesting $actionname") end - latestThoughtKey, latestThoughtIndice = GeneralUtils.findHighestIndexKey(state[:thoughtHistory], - "thought") - nextIndice = latestThoughtKey == :NA ? 1 : latestThoughtIndice + 1 - latestThoughtKey = Symbol("thought_$nextIndice") - latestActionKey = Symbol("action_$nextIndice") - - # add Thought, action, observation to thoughtHistory - newstate = deepcopy(state) - newstate[:thoughtHistory][latestThoughtKey] = thoughtDict[:thought] - newstate[:thoughtHistory][latestActionKey] = thoughtDict[:action] - newObservationKey = Symbol("observation_$(nextIndice)") - newstate[:thoughtHistory][newObservationKey] = response - newstate[:reward] = reward - newstate[:select] = select - newstate[:isterminal] = isterminal - - newNodeKey = GeneralUtils.uuid4snakecase() - - return (newNodeKey, newstate) + return makeNewState(state, thoughtDict, response, select, reward, isterminal) end @@ -374,7 +354,7 @@ julia> thoughtDict = Dict( # Signature """ -function transition(a::T1, state::T2 +function transition(a::T1, state::T2, thoughtDict::T2 )::Dict{Symbol, <:Any} where {T1<:agent, T2<:AbstractDict} thoughtDict = state[:thoughtDict] @@ -383,36 +363,74 @@ function transition(a::T1, state::T2 # map action and input() to llm function response, select, reward, isterminal = - if actionname == "chatbox" - userChatbox(a, actioninput) # virtual customer - elseif actionname == "winestock" + if actionname == "winestock" winestock(a, actioninput) - elseif actionname == "recommendbox" - userRecommendbox(a, actioninput) else error("undefined LLM function. Requesting $actionname") end - latestThoughtKey, latestThoughtIndice = - GeneralUtils.findHighestIndexKey(state[:thoughtHistory], "thought") - nextIndice = latestThoughtKey == :NA ? 1 : latestThoughtIndice + 1 - latestThoughtKey = Symbol("thought_$nextIndice") - latestActionKey = Symbol("action_$nextIndice") + return makeNewState(state, thoughtDict, response, select, reward, isterminal) +end + + +""" + +# Arguments + +# Return + +# Example +```jldoctest +julia> +``` + +# TODO + - [] update docstring + - [TESTING] implement the function + +# Signature +""" +function makeNewState(currentstate::T1, thoughtDict::T4, response::T2, select::Union{T3, Nothing}, + reward::T3, isterminal::Bool + )::Tuple{String, Dict{Symbol, <:Any}} where {T1<:AbstractDict, T2<:AbstractString, T3<:Number, T4<:AbstractDict} + + currentstate_latestThoughtKey, currentstate_latestThoughtIndice = + GeneralUtils.findHighestIndexKey(currentstate[:thoughtHistory], "thought") + currentstate_nextIndice = currentstate_latestThoughtKey == :NA ? 1 : currentstate_latestThoughtIndice + 1 + currentstate_latestThoughtKey = Symbol("thought_$currentstate_nextIndice") + latestActionKey = Symbol("action_$currentstate_nextIndice") + + _, thoughtDict_latestThoughtIndice = + GeneralUtils.findHighestIndexKey(thoughtDict, "thought") + + thoughtDict_latestThoughtKey, thoughtDict_latestActionKey = + if thoughtDict_latestThoughtIndice == -1 + (:thought, :action) + else + ( + Symbol("thought_$thoughtDict_latestThoughtIndice"), + Symbol("action_$thoughtDict_latestThoughtIndice"), + ) + end # add Thought, action, observation to thoughtHistory - newstate = deepcopy(state) - newstate[:thoughtHistory][latestThoughtKey] = thoughtDict[:thought] - newstate[:thoughtHistory][latestActionKey] = thoughtDict[:action] - newObservationKey = Symbol("observation_$(nextIndice)") + newstate = deepcopy(currentstate) + newstate[:thoughtHistory][currentstate_latestThoughtKey] = + thoughtDict[thoughtDict_latestThoughtKey] + newstate[:thoughtHistory][latestActionKey] = thoughtDict[thoughtDict_latestActionKey] + newObservationKey = Symbol("observation_$(currentstate_nextIndice)") newstate[:thoughtHistory][newObservationKey] = response newstate[:reward] = reward newstate[:select] = select newstate[:isterminal] = isterminal - return newstate + newNodeKey = GeneralUtils.uuid4snakecase() + + return (newNodeKey, newstate) end + """ Determine whether a node is a leaf node of a search tree. # Arguments diff --git a/src/type.jl b/src/type.jl index 3d4f317..c70253a 100644 --- a/src/type.jl +++ b/src/type.jl @@ -111,8 +111,8 @@ julia> agent = YiemAgent.bsommelier( # each plan is in [historyPoint_1, historyPoint_2, ...] format :existingplan => Vector(), - :activeplan => Vector{Dict{Symbol, Any}}(), # current using plan - :currenttrajectory=> Vector{Dict{Symbol, Any}}(), # store + :activeplan => Dict{Symbol, Any}(), # current using plan + :currenttrajectory=> Dict{Symbol, Any}(), # store question, thought, action, observation, ... ) # put incoming message here. waiting for further processing diff --git a/test/runtest.jl b/test/runtest.jl index a07212f..dd617e9 100644 --- a/test/runtest.jl +++ b/test/runtest.jl @@ -59,10 +59,21 @@ tools=Dict( # update input format tools=tools, ) -response = YiemAgent.conversation(a, Dict(:text=> "Hello, I would like a get a bottle of wine",) ) - +# response = YiemAgent.conversation(a, Dict(:text=> "newtopic",) ) +response = YiemAgent.conversation(a, Dict(:text=> "Hello, I would like a get a bottle of wine", + :select=> nothing, + :reward=> 0, + :isterminal=> false, + ) ) +println("---> YiemAgent: ", response) +response = YiemAgent.conversation(a, Dict(:text=> "I'm having a graduation party this evening", + :select=> nothing, + :reward=> 0, + :isterminal=> false, + ) ) +println("---> YiemAgent: ", response) diff --git a/test/test_1.jl b/test/test_1.jl index 3e02670..2f05c65 100644 --- a/test/test_1.jl +++ b/test/test_1.jl @@ -1,6 +1,8 @@ using Revise using YiemAgent, GeneralUtils, JSON3, DataStructures +# ---------------------------------------------- 100 --------------------------------------------- # + msgMeta = Dict(:requestResponse => nothing, :msgPurpose => nothing, :receiverId => nothing,