update
This commit is contained in:
198
src/interface.jl
198
src/interface.jl
@@ -265,7 +265,8 @@ julia>
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function evaluator(a::T1, state::T2)::Tuple{String, Integer} where {T1<:agent, T2<:AbstractDict}
|
||||
function evaluator(config::T1, state::T2
|
||||
)::Tuple{String, Integer} where {T1<:AbstractDict, T2<:AbstractDict}
|
||||
|
||||
systemmsg =
|
||||
"""
|
||||
@@ -342,19 +343,19 @@ function evaluator(a::T1, state::T2)::Tuple{String, Integer} where {T1<:agent, T
|
||||
"""
|
||||
|
||||
pprint(prompt)
|
||||
externalService = a.config[:externalservice][:text2textinstruct]
|
||||
externalService = config[:externalservice][:text2textinstruct]
|
||||
|
||||
|
||||
# apply LLM specific instruct format
|
||||
externalService = a.config[:externalservice][:text2textinstruct]
|
||||
externalService = config[:externalservice][:text2textinstruct]
|
||||
|
||||
msgMeta = GeneralUtils.generate_msgMeta(
|
||||
externalService[:mqtttopic],
|
||||
senderName= "evaluator",
|
||||
senderId= a.id,
|
||||
senderId= string(uuid4()),
|
||||
receiverName= "text2textinstruct",
|
||||
mqttBroker= a.config[:mqttServerInfo][:broker],
|
||||
mqttBrokerPort= a.config[:mqttServerInfo][:port],
|
||||
mqttBroker= config[:mqttServerInfo][:broker],
|
||||
mqttBrokerPort= config[:mqttServerInfo][:port],
|
||||
)
|
||||
|
||||
outgoingMsg = Dict(
|
||||
@@ -377,7 +378,7 @@ function evaluator(a::T1, state::T2)::Tuple{String, Integer} where {T1<:agent, T
|
||||
Here is an expected JSON format:
|
||||
{"evaluation": "...", "score": "..."}
|
||||
"""
|
||||
responseJsonStr = jsoncorrection(a, _responseJsonStr, expectedJsonExample)
|
||||
responseJsonStr = jsoncorrection(config, _responseJsonStr, expectedJsonExample)
|
||||
evaluationDict = copy(JSON3.read(responseJsonStr))
|
||||
|
||||
# check if dict has all required value
|
||||
@@ -671,7 +672,7 @@ julia>
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function reflector(a::T1, state::T2)::String where {T1<:agent, T2<:AbstractDict}
|
||||
function reflector(config::T1, state::T2)::String where {T1<:AbstractDict, T2<:AbstractDict}
|
||||
# https://github.com/andyz245/LanguageAgentTreeSearch/blob/main/hotpot/hotpot.py
|
||||
|
||||
_prompt =
|
||||
@@ -727,7 +728,7 @@ function reflector(a::T1, state::T2)::String where {T1<:agent, T2<:AbstractDict}
|
||||
"""
|
||||
|
||||
# apply LLM specific instruct format
|
||||
externalService = a.config[:externalservice][:text2textinstruct]
|
||||
externalService = config[:externalservice][:text2textinstruct]
|
||||
llminfo = externalService[:llminfo]
|
||||
prompt =
|
||||
if llminfo[:name] == "llama3instruct"
|
||||
@@ -739,10 +740,10 @@ function reflector(a::T1, state::T2)::String where {T1<:agent, T2<:AbstractDict}
|
||||
msgMeta = GeneralUtils.generate_msgMeta(
|
||||
a.config[:externalservice][:text2textinstruct][:mqtttopic],
|
||||
senderName= "reflector",
|
||||
senderId= a.id,
|
||||
senderId= string(uuid4()),
|
||||
receiverName= "text2textinstruct",
|
||||
mqttBroker= a.config[:mqttServerInfo][:broker],
|
||||
mqttBrokerPort= a.config[:mqttServerInfo][:port],
|
||||
mqttBroker= config[:mqttServerInfo][:broker],
|
||||
mqttBrokerPort= config[:mqttServerInfo][:port],
|
||||
)
|
||||
|
||||
outgoingMsg = Dict(
|
||||
@@ -765,7 +766,7 @@ function reflector(a::T1, state::T2)::String where {T1<:agent, T2<:AbstractDict}
|
||||
Here is an expected JSON format:
|
||||
{"reflection": "..."}
|
||||
"""
|
||||
responseJsonStr = jsoncorrection(a, _responseJsonStr, expectedJsonExample)
|
||||
responseJsonStr = jsoncorrection(config, _responseJsonStr, expectedJsonExample)
|
||||
reflectionDict = copy(JSON3.read(responseJsonStr))
|
||||
|
||||
# check if dict has all required value
|
||||
@@ -791,45 +792,100 @@ function transition()
|
||||
end
|
||||
|
||||
|
||||
# """ Determine whether the state is a terminal state
|
||||
|
||||
# # Arguments
|
||||
# - `state::T`
|
||||
# a game state
|
||||
|
||||
# # Return
|
||||
# - `(isterminalstate, reward)::Tuple{Bool, <:Number}`
|
||||
|
||||
# # Example
|
||||
# ```jldoctest
|
||||
# julia>
|
||||
# ```
|
||||
""" Get a new state
|
||||
|
||||
# # TODO
|
||||
# Arguments
|
||||
- `a::T1`
|
||||
one of YiemAgent's agent
|
||||
- `state::T2`
|
||||
current game state
|
||||
- `thoughtDict::T3`
|
||||
contain Thought, Action, Observation
|
||||
- `isterminal::Function`
|
||||
a function to determine terminal state
|
||||
|
||||
# # Signature
|
||||
# """
|
||||
# function isterminal(state::T)::Tuple{Bool, <:Number} where {T<:AbstractDict}
|
||||
# latestObservationKey, _ = GeneralUtils.findHighestIndexKey(state[:thoughtHistory], "observation")
|
||||
# latestObservation = state[:thoughtHistory][latestObservationKey]
|
||||
# Return
|
||||
- `(newNodeKey, newstate, isterminalstate, reward)::Tuple{String, Dict{Symbol, <:Any}, Bool, <:Number}`
|
||||
|
||||
# if latestObservation !== nothing
|
||||
# Example
|
||||
```jldoctest
|
||||
julia> state = Dict{Symbol, Dict{Symbol, Any}}(
|
||||
:thoughtHistory => Dict(:question => "Hello, I want to buy a bottle of wine."),
|
||||
:storeinfo => Dict(),
|
||||
:customerinfo => Dict()
|
||||
)
|
||||
julia> thoughtDict = Dict(
|
||||
:question=> "I want to buy a bottle of wine.",
|
||||
:thought_1=> "The customer wants to buy a bottle of wine.",
|
||||
:action_1=> Dict{Symbol, Any}(
|
||||
:name=>"Chatbox",
|
||||
:input=>"What occasion are you buying the wine for?",
|
||||
),
|
||||
:observation_1 => ""
|
||||
)
|
||||
```
|
||||
|
||||
# # terminal condition is when the user select wine by putting <<winename>> in latest observation
|
||||
# if occursin("<<", latestObservation) && occursin(">>", latestObservation)
|
||||
# isterminalstate = true
|
||||
# reward = 1
|
||||
# else
|
||||
# isterminalstate = false
|
||||
# reward = 0
|
||||
# end
|
||||
# else
|
||||
# isterminalstate = false
|
||||
# reward = 0
|
||||
# end
|
||||
# TODO
|
||||
- [] add other actions
|
||||
- [WORKING] add embedding of newstate and store in newstate[:embedding]
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function transition(config::T1, state::T2, decisionMaker::Function, evaluator::Function,
|
||||
reflector::Function
|
||||
)::Tuple{String, Dict{Symbol, <:Any}, Integer} where {T1<:AbstractDict, T2<:AbstractDict}
|
||||
|
||||
thoughtDict = decisionMaker(config, state)
|
||||
|
||||
actionname = thoughtDict[:action][:name]
|
||||
actioninput = thoughtDict[:action][:input]
|
||||
|
||||
# map action and input() to llm function
|
||||
response, select, reward, isterminal =
|
||||
if actionname == "chatbox"
|
||||
# deepcopy(state[:virtualCustomerChatHistory]) because I want to keep it clean
|
||||
# so that other simulation start from this same node is not contaminated with actioninput
|
||||
virtualWineUserChatbox(config, actioninput, deepcopy(state[:virtualCustomerChatHistory])) # virtual customer
|
||||
elseif actionname == "winestock"
|
||||
winestock(config, actioninput)
|
||||
elseif actionname == "recommendbox"
|
||||
virtualWineUserRecommendbox(config, actioninput)
|
||||
else
|
||||
error("undefined LLM function. Requesting $actionname")
|
||||
end
|
||||
|
||||
newNodeKey, newstate = LLMMCTS.makeNewState(state, thoughtDict, response, select, reward,
|
||||
isterminal)
|
||||
if actionname == "chatbox"
|
||||
push!(newstate[:virtualCustomerChatHistory], Dict(:name=>"assistant", :text=> actioninput) )
|
||||
push!(newstate[:virtualCustomerChatHistory], Dict(:name=>"user", :text=> response))
|
||||
end
|
||||
|
||||
stateevaluation, progressvalue = evaluator(config, newstate)
|
||||
|
||||
if newstate[:reward] < 0
|
||||
pprint(newstate[:thoughtHistory])
|
||||
newstate[:evaluation] = stateevaluation
|
||||
newstate[:lesson] = reflector(config, newstate)
|
||||
|
||||
# store new lesson for later use
|
||||
lessonDict = copy(JSON3.read("lesson.json"))
|
||||
latestLessonKey, latestLessonIndice =
|
||||
GeneralUtils.findHighestIndexKey(lessonDict, "lesson")
|
||||
nextIndice = latestLessonKey == :NA ? 1 : latestLessonIndice + 1
|
||||
newLessonKey = Symbol("lesson_$(nextIndice)")
|
||||
lessonDict[newLessonKey] = newstate
|
||||
open("lesson.json", "w") do io
|
||||
JSON3.pretty(io, lessonDict)
|
||||
end
|
||||
print("---> reflector()")
|
||||
end
|
||||
|
||||
return (newNodeKey, newstate, progressvalue)
|
||||
end
|
||||
|
||||
# return (isterminalstate, reward)
|
||||
# end
|
||||
|
||||
|
||||
""" Chat with llm.
|
||||
@@ -960,58 +1016,6 @@ end
|
||||
|
||||
|
||||
|
||||
# function conversation(a::T, userinput::Dict) where {T<:agent}
|
||||
|
||||
# # get new user msg from a.receiveUserMsgChannel
|
||||
|
||||
|
||||
# # "newtopic" command to delete chat history
|
||||
# if userinput[:text] == "newtopic"
|
||||
# clearhistory(a)
|
||||
|
||||
# return "Okay. What shall we talk about?"
|
||||
|
||||
# else
|
||||
# # add usermsg to a.chathistory
|
||||
# addNewMessage(a, "user", userinput[:text])
|
||||
|
||||
# currentstate =
|
||||
# if isempty(a.plan[:currenttrajectory])
|
||||
# # set up initial state
|
||||
# Dict{Symbol, Any}(
|
||||
# # deepcopy the info to prevent modifying the info unintentionally during MCTS planning
|
||||
# :customerinfo=> deepcopy(a.keywordinfo[:customerinfo]),
|
||||
# :storeinfo=> deepcopy(a.keywordinfo[:storeinfo]),
|
||||
# :userselect=> nothing,
|
||||
# :reward=> 0,
|
||||
# :isterminal=> false,
|
||||
# :evaluation=> nothing,
|
||||
# :lesson=> nothing,
|
||||
# :thoughtDict=> nothing,
|
||||
# :totalTrajectoryReward=> nothing,
|
||||
# :thoughtHistory=> OrderedDict{Symbol, Any}( # contain question, thought_1, action_1, observation_1, thought_2, ...
|
||||
# # :recap=>,
|
||||
# :question=> userinput[:text],
|
||||
# )
|
||||
# )
|
||||
# else
|
||||
# a.plan[:currenttrajectory]
|
||||
# end
|
||||
|
||||
# bestNextState, besttrajectory = runMCTS(a, currentstate, decisionMaker, evaluator, reflector,
|
||||
# totalsample=3, maxDepth=2, maxiterations=1, explorationweight=1.0)
|
||||
|
||||
# # transition
|
||||
# newstate = transition(a, bestNextState)
|
||||
# a.plan[:currenttrajectory] = newstate
|
||||
|
||||
# end
|
||||
# end
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user