This commit is contained in:
narawat lamaiin
2024-06-01 08:18:09 +07:00
parent 97c566a9d5
commit 1c863cd8ca
2 changed files with 123 additions and 119 deletions

View File

@@ -265,7 +265,8 @@ julia>
# Signature
"""
function evaluator(a::T1, state::T2)::Tuple{String, Integer} where {T1<:agent, T2<:AbstractDict}
function evaluator(config::T1, state::T2
)::Tuple{String, Integer} where {T1<:AbstractDict, T2<:AbstractDict}
systemmsg =
"""
@@ -342,19 +343,19 @@ function evaluator(a::T1, state::T2)::Tuple{String, Integer} where {T1<:agent, T
"""
pprint(prompt)
externalService = a.config[:externalservice][:text2textinstruct]
externalService = config[:externalservice][:text2textinstruct]
# apply LLM specific instruct format
externalService = a.config[:externalservice][:text2textinstruct]
externalService = config[:externalservice][:text2textinstruct]
msgMeta = GeneralUtils.generate_msgMeta(
externalService[:mqtttopic],
senderName= "evaluator",
senderId= a.id,
senderId= string(uuid4()),
receiverName= "text2textinstruct",
mqttBroker= a.config[:mqttServerInfo][:broker],
mqttBrokerPort= a.config[:mqttServerInfo][:port],
mqttBroker= config[:mqttServerInfo][:broker],
mqttBrokerPort= config[:mqttServerInfo][:port],
)
outgoingMsg = Dict(
@@ -377,7 +378,7 @@ function evaluator(a::T1, state::T2)::Tuple{String, Integer} where {T1<:agent, T
Here is an expected JSON format:
{"evaluation": "...", "score": "..."}
"""
responseJsonStr = jsoncorrection(a, _responseJsonStr, expectedJsonExample)
responseJsonStr = jsoncorrection(config, _responseJsonStr, expectedJsonExample)
evaluationDict = copy(JSON3.read(responseJsonStr))
# check if dict has all required value
@@ -671,7 +672,7 @@ julia>
# Signature
"""
function reflector(a::T1, state::T2)::String where {T1<:agent, T2<:AbstractDict}
function reflector(config::T1, state::T2)::String where {T1<:AbstractDict, T2<:AbstractDict}
# https://github.com/andyz245/LanguageAgentTreeSearch/blob/main/hotpot/hotpot.py
_prompt =
@@ -727,7 +728,7 @@ function reflector(a::T1, state::T2)::String where {T1<:agent, T2<:AbstractDict}
"""
# apply LLM specific instruct format
externalService = a.config[:externalservice][:text2textinstruct]
externalService = config[:externalservice][:text2textinstruct]
llminfo = externalService[:llminfo]
prompt =
if llminfo[:name] == "llama3instruct"
@@ -739,10 +740,10 @@ function reflector(a::T1, state::T2)::String where {T1<:agent, T2<:AbstractDict}
msgMeta = GeneralUtils.generate_msgMeta(
a.config[:externalservice][:text2textinstruct][:mqtttopic],
senderName= "reflector",
senderId= a.id,
senderId= string(uuid4()),
receiverName= "text2textinstruct",
mqttBroker= a.config[:mqttServerInfo][:broker],
mqttBrokerPort= a.config[:mqttServerInfo][:port],
mqttBroker= config[:mqttServerInfo][:broker],
mqttBrokerPort= config[:mqttServerInfo][:port],
)
outgoingMsg = Dict(
@@ -765,7 +766,7 @@ function reflector(a::T1, state::T2)::String where {T1<:agent, T2<:AbstractDict}
Here is an expected JSON format:
{"reflection": "..."}
"""
responseJsonStr = jsoncorrection(a, _responseJsonStr, expectedJsonExample)
responseJsonStr = jsoncorrection(config, _responseJsonStr, expectedJsonExample)
reflectionDict = copy(JSON3.read(responseJsonStr))
# check if dict has all required value
@@ -791,45 +792,100 @@ function transition()
end
# """ Determine whether the state is a terminal state
# # Arguments
# - `state::T`
# a game state
# # Return
# - `(isterminalstate, reward)::Tuple{Bool, <:Number}`
# # Example
# ```jldoctest
# julia>
# ```
""" Get a new state
# # TODO
# Arguments
- `a::T1`
one of YiemAgent's agent
- `state::T2`
current game state
- `thoughtDict::T3`
contain Thought, Action, Observation
- `isterminal::Function`
a function to determine terminal state
# # Signature
# """
# function isterminal(state::T)::Tuple{Bool, <:Number} where {T<:AbstractDict}
# latestObservationKey, _ = GeneralUtils.findHighestIndexKey(state[:thoughtHistory], "observation")
# latestObservation = state[:thoughtHistory][latestObservationKey]
# Return
- `(newNodeKey, newstate, isterminalstate, reward)::Tuple{String, Dict{Symbol, <:Any}, Bool, <:Number}`
# if latestObservation !== nothing
# Example
```jldoctest
julia> state = Dict{Symbol, Dict{Symbol, Any}}(
:thoughtHistory => Dict(:question => "Hello, I want to buy a bottle of wine."),
:storeinfo => Dict(),
:customerinfo => Dict()
)
julia> thoughtDict = Dict(
:question=> "I want to buy a bottle of wine.",
:thought_1=> "The customer wants to buy a bottle of wine.",
:action_1=> Dict{Symbol, Any}(
:name=>"Chatbox",
:input=>"What occasion are you buying the wine for?",
),
:observation_1 => ""
)
```
# # terminal condition is when the user select wine by putting <<winename>> in latest observation
# if occursin("<<", latestObservation) && occursin(">>", latestObservation)
# isterminalstate = true
# reward = 1
# else
# isterminalstate = false
# reward = 0
# end
# else
# isterminalstate = false
# reward = 0
# end
# TODO
- [] add other actions
- [WORKING] add embedding of newstate and store in newstate[:embedding]
# Signature
"""
function transition(config::T1, state::T2, decisionMaker::Function, evaluator::Function,
reflector::Function
)::Tuple{String, Dict{Symbol, <:Any}, Integer} where {T1<:AbstractDict, T2<:AbstractDict}
thoughtDict = decisionMaker(config, state)
actionname = thoughtDict[:action][:name]
actioninput = thoughtDict[:action][:input]
# map action and input() to llm function
response, select, reward, isterminal =
if actionname == "chatbox"
# deepcopy(state[:virtualCustomerChatHistory]) because I want to keep it clean
# so that other simulation start from this same node is not contaminated with actioninput
virtualWineUserChatbox(config, actioninput, deepcopy(state[:virtualCustomerChatHistory])) # virtual customer
elseif actionname == "winestock"
winestock(config, actioninput)
elseif actionname == "recommendbox"
virtualWineUserRecommendbox(config, actioninput)
else
error("undefined LLM function. Requesting $actionname")
end
newNodeKey, newstate = LLMMCTS.makeNewState(state, thoughtDict, response, select, reward,
isterminal)
if actionname == "chatbox"
push!(newstate[:virtualCustomerChatHistory], Dict(:name=>"assistant", :text=> actioninput) )
push!(newstate[:virtualCustomerChatHistory], Dict(:name=>"user", :text=> response))
end
stateevaluation, progressvalue = evaluator(config, newstate)
if newstate[:reward] < 0
pprint(newstate[:thoughtHistory])
newstate[:evaluation] = stateevaluation
newstate[:lesson] = reflector(config, newstate)
# store new lesson for later use
lessonDict = copy(JSON3.read("lesson.json"))
latestLessonKey, latestLessonIndice =
GeneralUtils.findHighestIndexKey(lessonDict, "lesson")
nextIndice = latestLessonKey == :NA ? 1 : latestLessonIndice + 1
newLessonKey = Symbol("lesson_$(nextIndice)")
lessonDict[newLessonKey] = newstate
open("lesson.json", "w") do io
JSON3.pretty(io, lessonDict)
end
print("---> reflector()")
end
return (newNodeKey, newstate, progressvalue)
end
# return (isterminalstate, reward)
# end
""" Chat with llm.
@@ -960,58 +1016,6 @@ end
# function conversation(a::T, userinput::Dict) where {T<:agent}
# # get new user msg from a.receiveUserMsgChannel
# # "newtopic" command to delete chat history
# if userinput[:text] == "newtopic"
# clearhistory(a)
# return "Okay. What shall we talk about?"
# else
# # add usermsg to a.chathistory
# addNewMessage(a, "user", userinput[:text])
# currentstate =
# if isempty(a.plan[:currenttrajectory])
# # set up initial state
# Dict{Symbol, Any}(
# # deepcopy the info to prevent modifying the info unintentionally during MCTS planning
# :customerinfo=> deepcopy(a.keywordinfo[:customerinfo]),
# :storeinfo=> deepcopy(a.keywordinfo[:storeinfo]),
# :userselect=> nothing,
# :reward=> 0,
# :isterminal=> false,
# :evaluation=> nothing,
# :lesson=> nothing,
# :thoughtDict=> nothing,
# :totalTrajectoryReward=> nothing,
# :thoughtHistory=> OrderedDict{Symbol, Any}( # contain question, thought_1, action_1, observation_1, thought_2, ...
# # :recap=>,
# :question=> userinput[:text],
# )
# )
# else
# a.plan[:currenttrajectory]
# end
# bestNextState, besttrajectory = runMCTS(a, currentstate, decisionMaker, evaluator, reflector,
# totalsample=3, maxDepth=2, maxiterations=1, explorationweight=1.0)
# # transition
# newstate = transition(a, bestNextState)
# a.plan[:currenttrajectory] = newstate
# end
# end

View File

@@ -3,7 +3,7 @@ module llmfunction
export virtualWineUserChatbox, jsoncorrection, winestock,
virtualWineUserRecommendbox, userChatbox, userRecommendbox
using HTTP, JSON3, URIs, Random, PrettyPrinting
using HTTP, JSON3, URIs, Random, PrettyPrinting, UUIDs
using GeneralUtils
using ..type, ..util
@@ -168,8 +168,8 @@ julia>
# Signature
"""
function virtualWineUserChatbox(a::T1, input::T2, virtualCustomerChatHistory
)::Union{Tuple{String, Number, Number, Bool}, Tuple{String, Nothing, Number, Bool}} where {T1<:agent, T2<:AbstractString}
function virtualWineUserChatbox(config::T1, input::T2, virtualCustomerChatHistory
)::Union{Tuple{String, Number, Number, Bool}, Tuple{String, Nothing, Number, Bool}} where {T1<:AbstractDict, T2<:AbstractString}
previouswines =
"""
@@ -275,16 +275,16 @@ function virtualWineUserChatbox(a::T1, input::T2, virtualCustomerChatHistory
"""
pprint(prompt)
externalService = a.config[:externalservice][:text2textinstruct]
externalService = config[:externalservice][:text2textinstruct]
# send formatted input to user using GeneralUtils.sendReceiveMqttMsg
msgMeta = GeneralUtils.generate_msgMeta(
externalService[:mqtttopic],
senderName= "virtualWineUserChatbox",
senderId= a.id,
senderId= string(uuid4()),
receiverName= "text2textinstruct",
mqttBroker= a.config[:mqttServerInfo][:broker],
mqttBrokerPort= a.config[:mqttServerInfo][:port],
mqttBroker= config[:mqttServerInfo][:broker],
mqttBrokerPort= config[:mqttServerInfo][:port],
msgId = "dummyid" #CHANGE remove after testing finished
)
@@ -310,7 +310,7 @@ function virtualWineUserChatbox(a::T1, input::T2, virtualCustomerChatHistory
"isterminal": "..."
}
"""
responseJsonStr = jsoncorrection(a, _responseJsonStr, expectedJsonExample)
responseJsonStr = jsoncorrection(config, _responseJsonStr, expectedJsonExample)
responseDict = copy(JSON3.read(responseJsonStr))
text::AbstractString = responseDict[:text]
@@ -362,8 +362,8 @@ julia> result = winestock(agent, input)
# Signature
"""
function winestock(a::T1, input::T2
)::Union{Tuple{String, Number, Number, Bool}, Tuple{String, Nothing, Number, Bool}} where {T1<:agent, T2<:AbstractString}
function winestock(config::T1, input::T2
)::Union{Tuple{String, Number, Number, Bool}, Tuple{String, Nothing, Number, Bool}} where {T1<:AbstractDict, T2<:AbstractString}
# SELECT *
# FROM food
@@ -469,16 +469,16 @@ function winestock(a::T1, input::T2
"""
pprint(prompt)
externalService = a.config[:externalservice][:text2textinstruct]
externalService = config[:externalservice][:text2textinstruct]
# send formatted input to user using GeneralUtils.sendReceiveMqttMsg
msgMeta = GeneralUtils.generate_msgMeta(
externalService[:mqtttopic],
senderName= "virtualWineUserChatbox",
senderId= a.id,
senderId= string(uuid4()),
receiverName= "text2textinstruct",
mqttBroker= a.config[:mqttServerInfo][:broker],
mqttBrokerPort= a.config[:mqttServerInfo][:port],
mqttBroker= config[:mqttServerInfo][:broker],
mqttBrokerPort= config[:mqttServerInfo][:port],
msgId = "dummyid" #CHANGE remove after testing finished
)
@@ -505,7 +505,7 @@ function winestock(a::T1, input::T2
}
}
"""
responseJsonStr = jsoncorrection(a, _responseJsonStr, expectedJsonExample)
responseJsonStr = jsoncorrection(config, _responseJsonStr, expectedJsonExample)
_responseDict = copy(JSON3.read(responseJsonStr))
responseDict = _responseDict[:attributes]
@@ -559,8 +559,8 @@ function winestock(a::T1, input::T2
# return result, nothing, 0, false
end
function wineattributes_wordToNumber(a::T1, input::T2
)::Dict where {T1<:agent, T2<:AbstractString}
function wineattributes_wordToNumber(config::T1, input::T2
)::Dict where {T1<:AbstractDict, T2<:AbstractString}
systemmsg =
"""
@@ -656,16 +656,16 @@ function wineattributes_wordToNumber(a::T1, input::T2
"""
pprint(prompt)
externalService = a.config[:externalservice][:text2textinstruct]
externalService = config[:externalservice][:text2textinstruct]
# send formatted input to user using GeneralUtils.sendReceiveMqttMsg
msgMeta = GeneralUtils.generate_msgMeta(
externalService[:mqtttopic],
senderName= "wineattributes_wordToNumber",
senderId= a.id,
senderId= string(uuid4()),
receiverName= "text2textinstruct",
mqttBroker= a.config[:mqttServerInfo][:broker],
mqttBrokerPort= a.config[:mqttServerInfo][:port],
mqttBroker= config[:mqttServerInfo][:broker],
mqttBrokerPort= config[:mqttServerInfo][:port],
)
outgoingMsg = Dict(
@@ -691,7 +691,7 @@ function wineattributes_wordToNumber(a::T1, input::T2
}
}
"""
responseJsonStr = jsoncorrection(a, _responseJsonStr, expectedJsonExample)
responseJsonStr = jsoncorrection(config, _responseJsonStr, expectedJsonExample)
_responseDict = copy(JSON3.read(responseJsonStr))
responseDict = _responseDict[:attributes]