update
This commit is contained in:
198
src/interface.jl
198
src/interface.jl
@@ -265,7 +265,8 @@ julia>
|
|||||||
|
|
||||||
# Signature
|
# Signature
|
||||||
"""
|
"""
|
||||||
function evaluator(a::T1, state::T2)::Tuple{String, Integer} where {T1<:agent, T2<:AbstractDict}
|
function evaluator(config::T1, state::T2
|
||||||
|
)::Tuple{String, Integer} where {T1<:AbstractDict, T2<:AbstractDict}
|
||||||
|
|
||||||
systemmsg =
|
systemmsg =
|
||||||
"""
|
"""
|
||||||
@@ -342,19 +343,19 @@ function evaluator(a::T1, state::T2)::Tuple{String, Integer} where {T1<:agent, T
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
pprint(prompt)
|
pprint(prompt)
|
||||||
externalService = a.config[:externalservice][:text2textinstruct]
|
externalService = config[:externalservice][:text2textinstruct]
|
||||||
|
|
||||||
|
|
||||||
# apply LLM specific instruct format
|
# apply LLM specific instruct format
|
||||||
externalService = a.config[:externalservice][:text2textinstruct]
|
externalService = config[:externalservice][:text2textinstruct]
|
||||||
|
|
||||||
msgMeta = GeneralUtils.generate_msgMeta(
|
msgMeta = GeneralUtils.generate_msgMeta(
|
||||||
externalService[:mqtttopic],
|
externalService[:mqtttopic],
|
||||||
senderName= "evaluator",
|
senderName= "evaluator",
|
||||||
senderId= a.id,
|
senderId= string(uuid4()),
|
||||||
receiverName= "text2textinstruct",
|
receiverName= "text2textinstruct",
|
||||||
mqttBroker= a.config[:mqttServerInfo][:broker],
|
mqttBroker= config[:mqttServerInfo][:broker],
|
||||||
mqttBrokerPort= a.config[:mqttServerInfo][:port],
|
mqttBrokerPort= config[:mqttServerInfo][:port],
|
||||||
)
|
)
|
||||||
|
|
||||||
outgoingMsg = Dict(
|
outgoingMsg = Dict(
|
||||||
@@ -377,7 +378,7 @@ function evaluator(a::T1, state::T2)::Tuple{String, Integer} where {T1<:agent, T
|
|||||||
Here is an expected JSON format:
|
Here is an expected JSON format:
|
||||||
{"evaluation": "...", "score": "..."}
|
{"evaluation": "...", "score": "..."}
|
||||||
"""
|
"""
|
||||||
responseJsonStr = jsoncorrection(a, _responseJsonStr, expectedJsonExample)
|
responseJsonStr = jsoncorrection(config, _responseJsonStr, expectedJsonExample)
|
||||||
evaluationDict = copy(JSON3.read(responseJsonStr))
|
evaluationDict = copy(JSON3.read(responseJsonStr))
|
||||||
|
|
||||||
# check if dict has all required value
|
# check if dict has all required value
|
||||||
@@ -671,7 +672,7 @@ julia>
|
|||||||
|
|
||||||
# Signature
|
# Signature
|
||||||
"""
|
"""
|
||||||
function reflector(a::T1, state::T2)::String where {T1<:agent, T2<:AbstractDict}
|
function reflector(config::T1, state::T2)::String where {T1<:AbstractDict, T2<:AbstractDict}
|
||||||
# https://github.com/andyz245/LanguageAgentTreeSearch/blob/main/hotpot/hotpot.py
|
# https://github.com/andyz245/LanguageAgentTreeSearch/blob/main/hotpot/hotpot.py
|
||||||
|
|
||||||
_prompt =
|
_prompt =
|
||||||
@@ -727,7 +728,7 @@ function reflector(a::T1, state::T2)::String where {T1<:agent, T2<:AbstractDict}
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# apply LLM specific instruct format
|
# apply LLM specific instruct format
|
||||||
externalService = a.config[:externalservice][:text2textinstruct]
|
externalService = config[:externalservice][:text2textinstruct]
|
||||||
llminfo = externalService[:llminfo]
|
llminfo = externalService[:llminfo]
|
||||||
prompt =
|
prompt =
|
||||||
if llminfo[:name] == "llama3instruct"
|
if llminfo[:name] == "llama3instruct"
|
||||||
@@ -739,10 +740,10 @@ function reflector(a::T1, state::T2)::String where {T1<:agent, T2<:AbstractDict}
|
|||||||
msgMeta = GeneralUtils.generate_msgMeta(
|
msgMeta = GeneralUtils.generate_msgMeta(
|
||||||
a.config[:externalservice][:text2textinstruct][:mqtttopic],
|
a.config[:externalservice][:text2textinstruct][:mqtttopic],
|
||||||
senderName= "reflector",
|
senderName= "reflector",
|
||||||
senderId= a.id,
|
senderId= string(uuid4()),
|
||||||
receiverName= "text2textinstruct",
|
receiverName= "text2textinstruct",
|
||||||
mqttBroker= a.config[:mqttServerInfo][:broker],
|
mqttBroker= config[:mqttServerInfo][:broker],
|
||||||
mqttBrokerPort= a.config[:mqttServerInfo][:port],
|
mqttBrokerPort= config[:mqttServerInfo][:port],
|
||||||
)
|
)
|
||||||
|
|
||||||
outgoingMsg = Dict(
|
outgoingMsg = Dict(
|
||||||
@@ -765,7 +766,7 @@ function reflector(a::T1, state::T2)::String where {T1<:agent, T2<:AbstractDict}
|
|||||||
Here is an expected JSON format:
|
Here is an expected JSON format:
|
||||||
{"reflection": "..."}
|
{"reflection": "..."}
|
||||||
"""
|
"""
|
||||||
responseJsonStr = jsoncorrection(a, _responseJsonStr, expectedJsonExample)
|
responseJsonStr = jsoncorrection(config, _responseJsonStr, expectedJsonExample)
|
||||||
reflectionDict = copy(JSON3.read(responseJsonStr))
|
reflectionDict = copy(JSON3.read(responseJsonStr))
|
||||||
|
|
||||||
# check if dict has all required value
|
# check if dict has all required value
|
||||||
@@ -791,45 +792,100 @@ function transition()
|
|||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
# """ Determine whether the state is a terminal state
|
|
||||||
|
|
||||||
# # Arguments
|
|
||||||
# - `state::T`
|
|
||||||
# a game state
|
|
||||||
|
|
||||||
# # Return
|
|
||||||
# - `(isterminalstate, reward)::Tuple{Bool, <:Number}`
|
|
||||||
|
|
||||||
# # Example
|
""" Get a new state
|
||||||
# ```jldoctest
|
|
||||||
# julia>
|
|
||||||
# ```
|
|
||||||
|
|
||||||
# # TODO
|
# Arguments
|
||||||
|
- `a::T1`
|
||||||
|
one of YiemAgent's agent
|
||||||
|
- `state::T2`
|
||||||
|
current game state
|
||||||
|
- `thoughtDict::T3`
|
||||||
|
contain Thought, Action, Observation
|
||||||
|
- `isterminal::Function`
|
||||||
|
a function to determine terminal state
|
||||||
|
|
||||||
# # Signature
|
# Return
|
||||||
# """
|
- `(newNodeKey, newstate, isterminalstate, reward)::Tuple{String, Dict{Symbol, <:Any}, Bool, <:Number}`
|
||||||
# function isterminal(state::T)::Tuple{Bool, <:Number} where {T<:AbstractDict}
|
|
||||||
# latestObservationKey, _ = GeneralUtils.findHighestIndexKey(state[:thoughtHistory], "observation")
|
|
||||||
# latestObservation = state[:thoughtHistory][latestObservationKey]
|
|
||||||
|
|
||||||
# if latestObservation !== nothing
|
# Example
|
||||||
|
```jldoctest
|
||||||
|
julia> state = Dict{Symbol, Dict{Symbol, Any}}(
|
||||||
|
:thoughtHistory => Dict(:question => "Hello, I want to buy a bottle of wine."),
|
||||||
|
:storeinfo => Dict(),
|
||||||
|
:customerinfo => Dict()
|
||||||
|
)
|
||||||
|
julia> thoughtDict = Dict(
|
||||||
|
:question=> "I want to buy a bottle of wine.",
|
||||||
|
:thought_1=> "The customer wants to buy a bottle of wine.",
|
||||||
|
:action_1=> Dict{Symbol, Any}(
|
||||||
|
:name=>"Chatbox",
|
||||||
|
:input=>"What occasion are you buying the wine for?",
|
||||||
|
),
|
||||||
|
:observation_1 => ""
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
# # terminal condition is when the user select wine by putting <<winename>> in latest observation
|
# TODO
|
||||||
# if occursin("<<", latestObservation) && occursin(">>", latestObservation)
|
- [] add other actions
|
||||||
# isterminalstate = true
|
- [WORKING] add embedding of newstate and store in newstate[:embedding]
|
||||||
# reward = 1
|
|
||||||
# else
|
# Signature
|
||||||
# isterminalstate = false
|
"""
|
||||||
# reward = 0
|
function transition(config::T1, state::T2, decisionMaker::Function, evaluator::Function,
|
||||||
# end
|
reflector::Function
|
||||||
# else
|
)::Tuple{String, Dict{Symbol, <:Any}, Integer} where {T1<:AbstractDict, T2<:AbstractDict}
|
||||||
# isterminalstate = false
|
|
||||||
# reward = 0
|
thoughtDict = decisionMaker(config, state)
|
||||||
# end
|
|
||||||
|
actionname = thoughtDict[:action][:name]
|
||||||
|
actioninput = thoughtDict[:action][:input]
|
||||||
|
|
||||||
|
# map action and input() to llm function
|
||||||
|
response, select, reward, isterminal =
|
||||||
|
if actionname == "chatbox"
|
||||||
|
# deepcopy(state[:virtualCustomerChatHistory]) because I want to keep it clean
|
||||||
|
# so that other simulation start from this same node is not contaminated with actioninput
|
||||||
|
virtualWineUserChatbox(config, actioninput, deepcopy(state[:virtualCustomerChatHistory])) # virtual customer
|
||||||
|
elseif actionname == "winestock"
|
||||||
|
winestock(config, actioninput)
|
||||||
|
elseif actionname == "recommendbox"
|
||||||
|
virtualWineUserRecommendbox(config, actioninput)
|
||||||
|
else
|
||||||
|
error("undefined LLM function. Requesting $actionname")
|
||||||
|
end
|
||||||
|
|
||||||
|
newNodeKey, newstate = LLMMCTS.makeNewState(state, thoughtDict, response, select, reward,
|
||||||
|
isterminal)
|
||||||
|
if actionname == "chatbox"
|
||||||
|
push!(newstate[:virtualCustomerChatHistory], Dict(:name=>"assistant", :text=> actioninput) )
|
||||||
|
push!(newstate[:virtualCustomerChatHistory], Dict(:name=>"user", :text=> response))
|
||||||
|
end
|
||||||
|
|
||||||
|
stateevaluation, progressvalue = evaluator(config, newstate)
|
||||||
|
|
||||||
|
if newstate[:reward] < 0
|
||||||
|
pprint(newstate[:thoughtHistory])
|
||||||
|
newstate[:evaluation] = stateevaluation
|
||||||
|
newstate[:lesson] = reflector(config, newstate)
|
||||||
|
|
||||||
|
# store new lesson for later use
|
||||||
|
lessonDict = copy(JSON3.read("lesson.json"))
|
||||||
|
latestLessonKey, latestLessonIndice =
|
||||||
|
GeneralUtils.findHighestIndexKey(lessonDict, "lesson")
|
||||||
|
nextIndice = latestLessonKey == :NA ? 1 : latestLessonIndice + 1
|
||||||
|
newLessonKey = Symbol("lesson_$(nextIndice)")
|
||||||
|
lessonDict[newLessonKey] = newstate
|
||||||
|
open("lesson.json", "w") do io
|
||||||
|
JSON3.pretty(io, lessonDict)
|
||||||
|
end
|
||||||
|
print("---> reflector()")
|
||||||
|
end
|
||||||
|
|
||||||
|
return (newNodeKey, newstate, progressvalue)
|
||||||
|
end
|
||||||
|
|
||||||
# return (isterminalstate, reward)
|
|
||||||
# end
|
|
||||||
|
|
||||||
|
|
||||||
""" Chat with llm.
|
""" Chat with llm.
|
||||||
@@ -960,58 +1016,6 @@ end
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
# function conversation(a::T, userinput::Dict) where {T<:agent}
|
|
||||||
|
|
||||||
# # get new user msg from a.receiveUserMsgChannel
|
|
||||||
|
|
||||||
|
|
||||||
# # "newtopic" command to delete chat history
|
|
||||||
# if userinput[:text] == "newtopic"
|
|
||||||
# clearhistory(a)
|
|
||||||
|
|
||||||
# return "Okay. What shall we talk about?"
|
|
||||||
|
|
||||||
# else
|
|
||||||
# # add usermsg to a.chathistory
|
|
||||||
# addNewMessage(a, "user", userinput[:text])
|
|
||||||
|
|
||||||
# currentstate =
|
|
||||||
# if isempty(a.plan[:currenttrajectory])
|
|
||||||
# # set up initial state
|
|
||||||
# Dict{Symbol, Any}(
|
|
||||||
# # deepcopy the info to prevent modifying the info unintentionally during MCTS planning
|
|
||||||
# :customerinfo=> deepcopy(a.keywordinfo[:customerinfo]),
|
|
||||||
# :storeinfo=> deepcopy(a.keywordinfo[:storeinfo]),
|
|
||||||
# :userselect=> nothing,
|
|
||||||
# :reward=> 0,
|
|
||||||
# :isterminal=> false,
|
|
||||||
# :evaluation=> nothing,
|
|
||||||
# :lesson=> nothing,
|
|
||||||
# :thoughtDict=> nothing,
|
|
||||||
# :totalTrajectoryReward=> nothing,
|
|
||||||
# :thoughtHistory=> OrderedDict{Symbol, Any}( # contain question, thought_1, action_1, observation_1, thought_2, ...
|
|
||||||
# # :recap=>,
|
|
||||||
# :question=> userinput[:text],
|
|
||||||
# )
|
|
||||||
# )
|
|
||||||
# else
|
|
||||||
# a.plan[:currenttrajectory]
|
|
||||||
# end
|
|
||||||
|
|
||||||
# bestNextState, besttrajectory = runMCTS(a, currentstate, decisionMaker, evaluator, reflector,
|
|
||||||
# totalsample=3, maxDepth=2, maxiterations=1, explorationweight=1.0)
|
|
||||||
|
|
||||||
# # transition
|
|
||||||
# newstate = transition(a, bestNextState)
|
|
||||||
# a.plan[:currenttrajectory] = newstate
|
|
||||||
|
|
||||||
# end
|
|
||||||
# end
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ module llmfunction
|
|||||||
export virtualWineUserChatbox, jsoncorrection, winestock,
|
export virtualWineUserChatbox, jsoncorrection, winestock,
|
||||||
virtualWineUserRecommendbox, userChatbox, userRecommendbox
|
virtualWineUserRecommendbox, userChatbox, userRecommendbox
|
||||||
|
|
||||||
using HTTP, JSON3, URIs, Random, PrettyPrinting
|
using HTTP, JSON3, URIs, Random, PrettyPrinting, UUIDs
|
||||||
using GeneralUtils
|
using GeneralUtils
|
||||||
using ..type, ..util
|
using ..type, ..util
|
||||||
|
|
||||||
@@ -168,8 +168,8 @@ julia>
|
|||||||
|
|
||||||
# Signature
|
# Signature
|
||||||
"""
|
"""
|
||||||
function virtualWineUserChatbox(a::T1, input::T2, virtualCustomerChatHistory
|
function virtualWineUserChatbox(config::T1, input::T2, virtualCustomerChatHistory
|
||||||
)::Union{Tuple{String, Number, Number, Bool}, Tuple{String, Nothing, Number, Bool}} where {T1<:agent, T2<:AbstractString}
|
)::Union{Tuple{String, Number, Number, Bool}, Tuple{String, Nothing, Number, Bool}} where {T1<:AbstractDict, T2<:AbstractString}
|
||||||
|
|
||||||
previouswines =
|
previouswines =
|
||||||
"""
|
"""
|
||||||
@@ -275,16 +275,16 @@ function virtualWineUserChatbox(a::T1, input::T2, virtualCustomerChatHistory
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
pprint(prompt)
|
pprint(prompt)
|
||||||
externalService = a.config[:externalservice][:text2textinstruct]
|
externalService = config[:externalservice][:text2textinstruct]
|
||||||
|
|
||||||
# send formatted input to user using GeneralUtils.sendReceiveMqttMsg
|
# send formatted input to user using GeneralUtils.sendReceiveMqttMsg
|
||||||
msgMeta = GeneralUtils.generate_msgMeta(
|
msgMeta = GeneralUtils.generate_msgMeta(
|
||||||
externalService[:mqtttopic],
|
externalService[:mqtttopic],
|
||||||
senderName= "virtualWineUserChatbox",
|
senderName= "virtualWineUserChatbox",
|
||||||
senderId= a.id,
|
senderId= string(uuid4()),
|
||||||
receiverName= "text2textinstruct",
|
receiverName= "text2textinstruct",
|
||||||
mqttBroker= a.config[:mqttServerInfo][:broker],
|
mqttBroker= config[:mqttServerInfo][:broker],
|
||||||
mqttBrokerPort= a.config[:mqttServerInfo][:port],
|
mqttBrokerPort= config[:mqttServerInfo][:port],
|
||||||
msgId = "dummyid" #CHANGE remove after testing finished
|
msgId = "dummyid" #CHANGE remove after testing finished
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -310,7 +310,7 @@ function virtualWineUserChatbox(a::T1, input::T2, virtualCustomerChatHistory
|
|||||||
"isterminal": "..."
|
"isterminal": "..."
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
responseJsonStr = jsoncorrection(a, _responseJsonStr, expectedJsonExample)
|
responseJsonStr = jsoncorrection(config, _responseJsonStr, expectedJsonExample)
|
||||||
responseDict = copy(JSON3.read(responseJsonStr))
|
responseDict = copy(JSON3.read(responseJsonStr))
|
||||||
|
|
||||||
text::AbstractString = responseDict[:text]
|
text::AbstractString = responseDict[:text]
|
||||||
@@ -362,8 +362,8 @@ julia> result = winestock(agent, input)
|
|||||||
|
|
||||||
# Signature
|
# Signature
|
||||||
"""
|
"""
|
||||||
function winestock(a::T1, input::T2
|
function winestock(config::T1, input::T2
|
||||||
)::Union{Tuple{String, Number, Number, Bool}, Tuple{String, Nothing, Number, Bool}} where {T1<:agent, T2<:AbstractString}
|
)::Union{Tuple{String, Number, Number, Bool}, Tuple{String, Nothing, Number, Bool}} where {T1<:AbstractDict, T2<:AbstractString}
|
||||||
|
|
||||||
# SELECT *
|
# SELECT *
|
||||||
# FROM food
|
# FROM food
|
||||||
@@ -469,16 +469,16 @@ function winestock(a::T1, input::T2
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
pprint(prompt)
|
pprint(prompt)
|
||||||
externalService = a.config[:externalservice][:text2textinstruct]
|
externalService = config[:externalservice][:text2textinstruct]
|
||||||
|
|
||||||
# send formatted input to user using GeneralUtils.sendReceiveMqttMsg
|
# send formatted input to user using GeneralUtils.sendReceiveMqttMsg
|
||||||
msgMeta = GeneralUtils.generate_msgMeta(
|
msgMeta = GeneralUtils.generate_msgMeta(
|
||||||
externalService[:mqtttopic],
|
externalService[:mqtttopic],
|
||||||
senderName= "virtualWineUserChatbox",
|
senderName= "virtualWineUserChatbox",
|
||||||
senderId= a.id,
|
senderId= string(uuid4()),
|
||||||
receiverName= "text2textinstruct",
|
receiverName= "text2textinstruct",
|
||||||
mqttBroker= a.config[:mqttServerInfo][:broker],
|
mqttBroker= config[:mqttServerInfo][:broker],
|
||||||
mqttBrokerPort= a.config[:mqttServerInfo][:port],
|
mqttBrokerPort= config[:mqttServerInfo][:port],
|
||||||
msgId = "dummyid" #CHANGE remove after testing finished
|
msgId = "dummyid" #CHANGE remove after testing finished
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -505,7 +505,7 @@ function winestock(a::T1, input::T2
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
responseJsonStr = jsoncorrection(a, _responseJsonStr, expectedJsonExample)
|
responseJsonStr = jsoncorrection(config, _responseJsonStr, expectedJsonExample)
|
||||||
_responseDict = copy(JSON3.read(responseJsonStr))
|
_responseDict = copy(JSON3.read(responseJsonStr))
|
||||||
responseDict = _responseDict[:attributes]
|
responseDict = _responseDict[:attributes]
|
||||||
|
|
||||||
@@ -559,8 +559,8 @@ function winestock(a::T1, input::T2
|
|||||||
# return result, nothing, 0, false
|
# return result, nothing, 0, false
|
||||||
end
|
end
|
||||||
|
|
||||||
function wineattributes_wordToNumber(a::T1, input::T2
|
function wineattributes_wordToNumber(config::T1, input::T2
|
||||||
)::Dict where {T1<:agent, T2<:AbstractString}
|
)::Dict where {T1<:AbstractDict, T2<:AbstractString}
|
||||||
|
|
||||||
systemmsg =
|
systemmsg =
|
||||||
"""
|
"""
|
||||||
@@ -656,16 +656,16 @@ function wineattributes_wordToNumber(a::T1, input::T2
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
pprint(prompt)
|
pprint(prompt)
|
||||||
externalService = a.config[:externalservice][:text2textinstruct]
|
externalService = config[:externalservice][:text2textinstruct]
|
||||||
|
|
||||||
# send formatted input to user using GeneralUtils.sendReceiveMqttMsg
|
# send formatted input to user using GeneralUtils.sendReceiveMqttMsg
|
||||||
msgMeta = GeneralUtils.generate_msgMeta(
|
msgMeta = GeneralUtils.generate_msgMeta(
|
||||||
externalService[:mqtttopic],
|
externalService[:mqtttopic],
|
||||||
senderName= "wineattributes_wordToNumber",
|
senderName= "wineattributes_wordToNumber",
|
||||||
senderId= a.id,
|
senderId= string(uuid4()),
|
||||||
receiverName= "text2textinstruct",
|
receiverName= "text2textinstruct",
|
||||||
mqttBroker= a.config[:mqttServerInfo][:broker],
|
mqttBroker= config[:mqttServerInfo][:broker],
|
||||||
mqttBrokerPort= a.config[:mqttServerInfo][:port],
|
mqttBrokerPort= config[:mqttServerInfo][:port],
|
||||||
)
|
)
|
||||||
|
|
||||||
outgoingMsg = Dict(
|
outgoingMsg = Dict(
|
||||||
@@ -691,7 +691,7 @@ function wineattributes_wordToNumber(a::T1, input::T2
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
responseJsonStr = jsoncorrection(a, _responseJsonStr, expectedJsonExample)
|
responseJsonStr = jsoncorrection(config, _responseJsonStr, expectedJsonExample)
|
||||||
_responseDict = copy(JSON3.read(responseJsonStr))
|
_responseDict = copy(JSON3.read(responseJsonStr))
|
||||||
responseDict = _responseDict[:attributes]
|
responseDict = _responseDict[:attributes]
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user