diff --git a/src/interface.jl b/src/interface.jl index 0a8d62a..a6f881c 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -265,7 +265,8 @@ julia> # Signature """ -function evaluator(a::T1, state::T2)::Tuple{String, Integer} where {T1<:agent, T2<:AbstractDict} +function evaluator(config::T1, state::T2 + )::Tuple{String, Integer} where {T1<:AbstractDict, T2<:AbstractDict} systemmsg = """ @@ -342,19 +343,19 @@ function evaluator(a::T1, state::T2)::Tuple{String, Integer} where {T1<:agent, T """ pprint(prompt) - externalService = a.config[:externalservice][:text2textinstruct] + externalService = config[:externalservice][:text2textinstruct] # apply LLM specific instruct format - externalService = a.config[:externalservice][:text2textinstruct] + externalService = config[:externalservice][:text2textinstruct] msgMeta = GeneralUtils.generate_msgMeta( externalService[:mqtttopic], senderName= "evaluator", - senderId= a.id, + senderId= string(uuid4()), receiverName= "text2textinstruct", - mqttBroker= a.config[:mqttServerInfo][:broker], - mqttBrokerPort= a.config[:mqttServerInfo][:port], + mqttBroker= config[:mqttServerInfo][:broker], + mqttBrokerPort= config[:mqttServerInfo][:port], ) outgoingMsg = Dict( @@ -377,7 +378,7 @@ function evaluator(a::T1, state::T2)::Tuple{String, Integer} where {T1<:agent, T Here is an expected JSON format: {"evaluation": "...", "score": "..."} """ - responseJsonStr = jsoncorrection(a, _responseJsonStr, expectedJsonExample) + responseJsonStr = jsoncorrection(config, _responseJsonStr, expectedJsonExample) evaluationDict = copy(JSON3.read(responseJsonStr)) # check if dict has all required value @@ -671,7 +672,7 @@ julia> # Signature """ -function reflector(a::T1, state::T2)::String where {T1<:agent, T2<:AbstractDict} +function reflector(config::T1, state::T2)::String where {T1<:AbstractDict, T2<:AbstractDict} # https://github.com/andyz245/LanguageAgentTreeSearch/blob/main/hotpot/hotpot.py _prompt = @@ -727,7 +728,7 @@ function reflector(a::T1, state::T2)::String where {T1<:agent, T2<:AbstractDict} """ # apply LLM specific instruct format - externalService = a.config[:externalservice][:text2textinstruct] + externalService = config[:externalservice][:text2textinstruct] llminfo = externalService[:llminfo] prompt = if llminfo[:name] == "llama3instruct" @@ -739,10 +740,10 @@ function reflector(a::T1, state::T2)::String where {T1<:agent, T2<:AbstractDict} msgMeta = GeneralUtils.generate_msgMeta( a.config[:externalservice][:text2textinstruct][:mqtttopic], senderName= "reflector", - senderId= a.id, + senderId= string(uuid4()), receiverName= "text2textinstruct", - mqttBroker= a.config[:mqttServerInfo][:broker], - mqttBrokerPort= a.config[:mqttServerInfo][:port], + mqttBroker= config[:mqttServerInfo][:broker], + mqttBrokerPort= config[:mqttServerInfo][:port], ) outgoingMsg = Dict( @@ -765,7 +766,7 @@ function reflector(a::T1, state::T2)::String where {T1<:agent, T2<:AbstractDict} Here is an expected JSON format: {"reflection": "..."} """ - responseJsonStr = jsoncorrection(a, _responseJsonStr, expectedJsonExample) + responseJsonStr = jsoncorrection(config, _responseJsonStr, expectedJsonExample) reflectionDict = copy(JSON3.read(responseJsonStr)) # check if dict has all required value @@ -791,45 +792,100 @@ function transition() end -# """ Determine whether the state is a terminal state -# # Arguments -# - `state::T` -# a game state - -# # Return -# - `(isterminalstate, reward)::Tuple{Bool, <:Number}` -# # Example -# ```jldoctest -# julia> -# ``` +""" Get a new state -# # TODO +# Arguments + - `a::T1` + one of YiemAgent's agent + - `state::T2` + current game state + - `thoughtDict::T3` + contain Thought, Action, Observation + - `isterminal::Function` + a function to determine terminal state -# # Signature -# """ -# function isterminal(state::T)::Tuple{Bool, <:Number} where {T<:AbstractDict} -# latestObservationKey, _ = GeneralUtils.findHighestIndexKey(state[:thoughtHistory], "observation") -# latestObservation = state[:thoughtHistory][latestObservationKey] +# Return + - `(newNodeKey, newstate, isterminalstate, reward)::Tuple{String, Dict{Symbol, <:Any}, Bool, <:Number}` -# if latestObservation !== nothing +# Example +```jldoctest +julia> state = Dict{Symbol, Dict{Symbol, Any}}( + :thoughtHistory => Dict(:question => "Hello, I want to buy a bottle of wine."), + :storeinfo => Dict(), + :customerinfo => Dict() + ) +julia> thoughtDict = Dict( + :question=> "I want to buy a bottle of wine.", + :thought_1=> "The customer wants to buy a bottle of wine.", + :action_1=> Dict{Symbol, Any}( + :name=>"Chatbox", + :input=>"What occasion are you buying the wine for?", + ), + :observation_1 => "" + ) +``` -# # terminal condition is when the user select wine by putting <> in latest observation -# if occursin("<<", latestObservation) && occursin(">>", latestObservation) -# isterminalstate = true -# reward = 1 -# else -# isterminalstate = false -# reward = 0 -# end -# else -# isterminalstate = false -# reward = 0 -# end +# TODO + - [] add other actions + - [WORKING] add embedding of newstate and store in newstate[:embedding] + +# Signature +""" +function transition(config::T1, state::T2, decisionMaker::Function, evaluator::Function, + reflector::Function + )::Tuple{String, Dict{Symbol, <:Any}, Integer} where {T1<:AbstractDict, T2<:AbstractDict} + + thoughtDict = decisionMaker(config, state) + + actionname = thoughtDict[:action][:name] + actioninput = thoughtDict[:action][:input] + + # map action and input() to llm function + response, select, reward, isterminal = + if actionname == "chatbox" + # deepcopy(state[:virtualCustomerChatHistory]) because I want to keep it clean + # so that other simulation start from this same node is not contaminated with actioninput + virtualWineUserChatbox(config, actioninput, deepcopy(state[:virtualCustomerChatHistory])) # virtual customer + elseif actionname == "winestock" + winestock(config, actioninput) + elseif actionname == "recommendbox" + virtualWineUserRecommendbox(config, actioninput) + else + error("undefined LLM function. Requesting $actionname") + end + + newNodeKey, newstate = LLMMCTS.makeNewState(state, thoughtDict, response, select, reward, + isterminal) + if actionname == "chatbox" + push!(newstate[:virtualCustomerChatHistory], Dict(:name=>"assistant", :text=> actioninput) ) + push!(newstate[:virtualCustomerChatHistory], Dict(:name=>"user", :text=> response)) + end + + stateevaluation, progressvalue = evaluator(config, newstate) + + if newstate[:reward] < 0 + pprint(newstate[:thoughtHistory]) + newstate[:evaluation] = stateevaluation + newstate[:lesson] = reflector(config, newstate) + + # store new lesson for later use + lessonDict = copy(JSON3.read("lesson.json")) + latestLessonKey, latestLessonIndice = + GeneralUtils.findHighestIndexKey(lessonDict, "lesson") + nextIndice = latestLessonKey == :NA ? 1 : latestLessonIndice + 1 + newLessonKey = Symbol("lesson_$(nextIndice)") + lessonDict[newLessonKey] = newstate + open("lesson.json", "w") do io + JSON3.pretty(io, lessonDict) + end + print("---> reflector()") + end + + return (newNodeKey, newstate, progressvalue) +end -# return (isterminalstate, reward) -# end """ Chat with llm. @@ -960,58 +1016,6 @@ end -# function conversation(a::T, userinput::Dict) where {T<:agent} - -# # get new user msg from a.receiveUserMsgChannel - - -# # "newtopic" command to delete chat history -# if userinput[:text] == "newtopic" -# clearhistory(a) - -# return "Okay. What shall we talk about?" - -# else -# # add usermsg to a.chathistory -# addNewMessage(a, "user", userinput[:text]) - -# currentstate = -# if isempty(a.plan[:currenttrajectory]) -# # set up initial state -# Dict{Symbol, Any}( -# # deepcopy the info to prevent modifying the info unintentionally during MCTS planning -# :customerinfo=> deepcopy(a.keywordinfo[:customerinfo]), -# :storeinfo=> deepcopy(a.keywordinfo[:storeinfo]), -# :userselect=> nothing, -# :reward=> 0, -# :isterminal=> false, -# :evaluation=> nothing, -# :lesson=> nothing, -# :thoughtDict=> nothing, -# :totalTrajectoryReward=> nothing, -# :thoughtHistory=> OrderedDict{Symbol, Any}( # contain question, thought_1, action_1, observation_1, thought_2, ... -# # :recap=>, -# :question=> userinput[:text], -# ) -# ) -# else -# a.plan[:currenttrajectory] -# end - -# bestNextState, besttrajectory = runMCTS(a, currentstate, decisionMaker, evaluator, reflector, -# totalsample=3, maxDepth=2, maxiterations=1, explorationweight=1.0) - -# # transition -# newstate = transition(a, bestNextState) -# a.plan[:currenttrajectory] = newstate - -# end -# end - - - - - diff --git a/src/llmfunction.jl b/src/llmfunction.jl index c169929..612a3e8 100644 --- a/src/llmfunction.jl +++ b/src/llmfunction.jl @@ -3,7 +3,7 @@ module llmfunction export virtualWineUserChatbox, jsoncorrection, winestock, virtualWineUserRecommendbox, userChatbox, userRecommendbox -using HTTP, JSON3, URIs, Random, PrettyPrinting +using HTTP, JSON3, URIs, Random, PrettyPrinting, UUIDs using GeneralUtils using ..type, ..util @@ -168,8 +168,8 @@ julia> # Signature """ -function virtualWineUserChatbox(a::T1, input::T2, virtualCustomerChatHistory - )::Union{Tuple{String, Number, Number, Bool}, Tuple{String, Nothing, Number, Bool}} where {T1<:agent, T2<:AbstractString} +function virtualWineUserChatbox(config::T1, input::T2, virtualCustomerChatHistory + )::Union{Tuple{String, Number, Number, Bool}, Tuple{String, Nothing, Number, Bool}} where {T1<:AbstractDict, T2<:AbstractString} previouswines = """ @@ -275,16 +275,16 @@ function virtualWineUserChatbox(a::T1, input::T2, virtualCustomerChatHistory """ pprint(prompt) - externalService = a.config[:externalservice][:text2textinstruct] + externalService = config[:externalservice][:text2textinstruct] # send formatted input to user using GeneralUtils.sendReceiveMqttMsg msgMeta = GeneralUtils.generate_msgMeta( externalService[:mqtttopic], senderName= "virtualWineUserChatbox", - senderId= a.id, + senderId= string(uuid4()), receiverName= "text2textinstruct", - mqttBroker= a.config[:mqttServerInfo][:broker], - mqttBrokerPort= a.config[:mqttServerInfo][:port], + mqttBroker= config[:mqttServerInfo][:broker], + mqttBrokerPort= config[:mqttServerInfo][:port], msgId = "dummyid" #CHANGE remove after testing finished ) @@ -310,7 +310,7 @@ function virtualWineUserChatbox(a::T1, input::T2, virtualCustomerChatHistory "isterminal": "..." } """ - responseJsonStr = jsoncorrection(a, _responseJsonStr, expectedJsonExample) + responseJsonStr = jsoncorrection(config, _responseJsonStr, expectedJsonExample) responseDict = copy(JSON3.read(responseJsonStr)) text::AbstractString = responseDict[:text] @@ -362,8 +362,8 @@ julia> result = winestock(agent, input) # Signature """ -function winestock(a::T1, input::T2 - )::Union{Tuple{String, Number, Number, Bool}, Tuple{String, Nothing, Number, Bool}} where {T1<:agent, T2<:AbstractString} +function winestock(config::T1, input::T2 + )::Union{Tuple{String, Number, Number, Bool}, Tuple{String, Nothing, Number, Bool}} where {T1<:AbstractDict, T2<:AbstractString} # SELECT * # FROM food @@ -469,16 +469,16 @@ function winestock(a::T1, input::T2 """ pprint(prompt) - externalService = a.config[:externalservice][:text2textinstruct] + externalService = config[:externalservice][:text2textinstruct] # send formatted input to user using GeneralUtils.sendReceiveMqttMsg msgMeta = GeneralUtils.generate_msgMeta( externalService[:mqtttopic], senderName= "virtualWineUserChatbox", - senderId= a.id, + senderId= string(uuid4()), receiverName= "text2textinstruct", - mqttBroker= a.config[:mqttServerInfo][:broker], - mqttBrokerPort= a.config[:mqttServerInfo][:port], + mqttBroker= config[:mqttServerInfo][:broker], + mqttBrokerPort= config[:mqttServerInfo][:port], msgId = "dummyid" #CHANGE remove after testing finished ) @@ -505,7 +505,7 @@ function winestock(a::T1, input::T2 } } """ - responseJsonStr = jsoncorrection(a, _responseJsonStr, expectedJsonExample) + responseJsonStr = jsoncorrection(config, _responseJsonStr, expectedJsonExample) _responseDict = copy(JSON3.read(responseJsonStr)) responseDict = _responseDict[:attributes] @@ -559,8 +559,8 @@ function winestock(a::T1, input::T2 # return result, nothing, 0, false end -function wineattributes_wordToNumber(a::T1, input::T2 - )::Dict where {T1<:agent, T2<:AbstractString} +function wineattributes_wordToNumber(config::T1, input::T2 + )::Dict where {T1<:AbstractDict, T2<:AbstractString} systemmsg = """ @@ -656,16 +656,16 @@ function wineattributes_wordToNumber(a::T1, input::T2 """ pprint(prompt) - externalService = a.config[:externalservice][:text2textinstruct] + externalService = config[:externalservice][:text2textinstruct] # send formatted input to user using GeneralUtils.sendReceiveMqttMsg msgMeta = GeneralUtils.generate_msgMeta( externalService[:mqtttopic], senderName= "wineattributes_wordToNumber", - senderId= a.id, + senderId= string(uuid4()), receiverName= "text2textinstruct", - mqttBroker= a.config[:mqttServerInfo][:broker], - mqttBrokerPort= a.config[:mqttServerInfo][:port], + mqttBroker= config[:mqttServerInfo][:broker], + mqttBrokerPort= config[:mqttServerInfo][:port], ) outgoingMsg = Dict( @@ -691,7 +691,7 @@ function wineattributes_wordToNumber(a::T1, input::T2 } } """ - responseJsonStr = jsoncorrection(a, _responseJsonStr, expectedJsonExample) + responseJsonStr = jsoncorrection(config, _responseJsonStr, expectedJsonExample) _responseDict = copy(JSON3.read(responseJsonStr)) responseDict = _responseDict[:attributes]