This commit is contained in:
narawat lamaiin
2024-06-01 08:18:09 +07:00
parent 97c566a9d5
commit 1c863cd8ca
2 changed files with 123 additions and 119 deletions

View File

@@ -265,7 +265,8 @@ julia>
# Signature # Signature
""" """
function evaluator(a::T1, state::T2)::Tuple{String, Integer} where {T1<:agent, T2<:AbstractDict} function evaluator(config::T1, state::T2
)::Tuple{String, Integer} where {T1<:AbstractDict, T2<:AbstractDict}
systemmsg = systemmsg =
""" """
@@ -342,19 +343,19 @@ function evaluator(a::T1, state::T2)::Tuple{String, Integer} where {T1<:agent, T
""" """
pprint(prompt) pprint(prompt)
externalService = a.config[:externalservice][:text2textinstruct] externalService = config[:externalservice][:text2textinstruct]
# apply LLM specific instruct format # apply LLM specific instruct format
externalService = a.config[:externalservice][:text2textinstruct] externalService = config[:externalservice][:text2textinstruct]
msgMeta = GeneralUtils.generate_msgMeta( msgMeta = GeneralUtils.generate_msgMeta(
externalService[:mqtttopic], externalService[:mqtttopic],
senderName= "evaluator", senderName= "evaluator",
senderId= a.id, senderId= string(uuid4()),
receiverName= "text2textinstruct", receiverName= "text2textinstruct",
mqttBroker= a.config[:mqttServerInfo][:broker], mqttBroker= config[:mqttServerInfo][:broker],
mqttBrokerPort= a.config[:mqttServerInfo][:port], mqttBrokerPort= config[:mqttServerInfo][:port],
) )
outgoingMsg = Dict( outgoingMsg = Dict(
@@ -377,7 +378,7 @@ function evaluator(a::T1, state::T2)::Tuple{String, Integer} where {T1<:agent, T
Here is an expected JSON format: Here is an expected JSON format:
{"evaluation": "...", "score": "..."} {"evaluation": "...", "score": "..."}
""" """
responseJsonStr = jsoncorrection(a, _responseJsonStr, expectedJsonExample) responseJsonStr = jsoncorrection(config, _responseJsonStr, expectedJsonExample)
evaluationDict = copy(JSON3.read(responseJsonStr)) evaluationDict = copy(JSON3.read(responseJsonStr))
# check if dict has all required value # check if dict has all required value
@@ -671,7 +672,7 @@ julia>
# Signature # Signature
""" """
function reflector(a::T1, state::T2)::String where {T1<:agent, T2<:AbstractDict} function reflector(config::T1, state::T2)::String where {T1<:AbstractDict, T2<:AbstractDict}
# https://github.com/andyz245/LanguageAgentTreeSearch/blob/main/hotpot/hotpot.py # https://github.com/andyz245/LanguageAgentTreeSearch/blob/main/hotpot/hotpot.py
_prompt = _prompt =
@@ -727,7 +728,7 @@ function reflector(a::T1, state::T2)::String where {T1<:agent, T2<:AbstractDict}
""" """
# apply LLM specific instruct format # apply LLM specific instruct format
externalService = a.config[:externalservice][:text2textinstruct] externalService = config[:externalservice][:text2textinstruct]
llminfo = externalService[:llminfo] llminfo = externalService[:llminfo]
prompt = prompt =
if llminfo[:name] == "llama3instruct" if llminfo[:name] == "llama3instruct"
@@ -739,10 +740,10 @@ function reflector(a::T1, state::T2)::String where {T1<:agent, T2<:AbstractDict}
msgMeta = GeneralUtils.generate_msgMeta( msgMeta = GeneralUtils.generate_msgMeta(
a.config[:externalservice][:text2textinstruct][:mqtttopic], a.config[:externalservice][:text2textinstruct][:mqtttopic],
senderName= "reflector", senderName= "reflector",
senderId= a.id, senderId= string(uuid4()),
receiverName= "text2textinstruct", receiverName= "text2textinstruct",
mqttBroker= a.config[:mqttServerInfo][:broker], mqttBroker= config[:mqttServerInfo][:broker],
mqttBrokerPort= a.config[:mqttServerInfo][:port], mqttBrokerPort= config[:mqttServerInfo][:port],
) )
outgoingMsg = Dict( outgoingMsg = Dict(
@@ -765,7 +766,7 @@ function reflector(a::T1, state::T2)::String where {T1<:agent, T2<:AbstractDict}
Here is an expected JSON format: Here is an expected JSON format:
{"reflection": "..."} {"reflection": "..."}
""" """
responseJsonStr = jsoncorrection(a, _responseJsonStr, expectedJsonExample) responseJsonStr = jsoncorrection(config, _responseJsonStr, expectedJsonExample)
reflectionDict = copy(JSON3.read(responseJsonStr)) reflectionDict = copy(JSON3.read(responseJsonStr))
# check if dict has all required value # check if dict has all required value
@@ -791,45 +792,100 @@ function transition()
end end
# """ Determine whether the state is a terminal state
# # Arguments
# - `state::T`
# a game state
# # Return
# - `(isterminalstate, reward)::Tuple{Bool, <:Number}`
# # Example """ Get a new state
# ```jldoctest
# julia>
# ```
# # TODO # Arguments
- `a::T1`
one of YiemAgent's agent
- `state::T2`
current game state
- `thoughtDict::T3`
contain Thought, Action, Observation
- `isterminal::Function`
a function to determine terminal state
# # Signature # Return
# """ - `(newNodeKey, newstate, isterminalstate, reward)::Tuple{String, Dict{Symbol, <:Any}, Bool, <:Number}`
# function isterminal(state::T)::Tuple{Bool, <:Number} where {T<:AbstractDict}
# latestObservationKey, _ = GeneralUtils.findHighestIndexKey(state[:thoughtHistory], "observation")
# latestObservation = state[:thoughtHistory][latestObservationKey]
# if latestObservation !== nothing # Example
```jldoctest
julia> state = Dict{Symbol, Dict{Symbol, Any}}(
:thoughtHistory => Dict(:question => "Hello, I want to buy a bottle of wine."),
:storeinfo => Dict(),
:customerinfo => Dict()
)
julia> thoughtDict = Dict(
:question=> "I want to buy a bottle of wine.",
:thought_1=> "The customer wants to buy a bottle of wine.",
:action_1=> Dict{Symbol, Any}(
:name=>"Chatbox",
:input=>"What occasion are you buying the wine for?",
),
:observation_1 => ""
)
```
# # terminal condition is when the user select wine by putting <<winename>> in latest observation # TODO
# if occursin("<<", latestObservation) && occursin(">>", latestObservation) - [] add other actions
# isterminalstate = true - [WORKING] add embedding of newstate and store in newstate[:embedding]
# reward = 1
# else # Signature
# isterminalstate = false """
# reward = 0 function transition(config::T1, state::T2, decisionMaker::Function, evaluator::Function,
# end reflector::Function
# else )::Tuple{String, Dict{Symbol, <:Any}, Integer} where {T1<:AbstractDict, T2<:AbstractDict}
# isterminalstate = false
# reward = 0 thoughtDict = decisionMaker(config, state)
# end
actionname = thoughtDict[:action][:name]
actioninput = thoughtDict[:action][:input]
# map action and input() to llm function
response, select, reward, isterminal =
if actionname == "chatbox"
# deepcopy(state[:virtualCustomerChatHistory]) because I want to keep it clean
# so that other simulation start from this same node is not contaminated with actioninput
virtualWineUserChatbox(config, actioninput, deepcopy(state[:virtualCustomerChatHistory])) # virtual customer
elseif actionname == "winestock"
winestock(config, actioninput)
elseif actionname == "recommendbox"
virtualWineUserRecommendbox(config, actioninput)
else
error("undefined LLM function. Requesting $actionname")
end
newNodeKey, newstate = LLMMCTS.makeNewState(state, thoughtDict, response, select, reward,
isterminal)
if actionname == "chatbox"
push!(newstate[:virtualCustomerChatHistory], Dict(:name=>"assistant", :text=> actioninput) )
push!(newstate[:virtualCustomerChatHistory], Dict(:name=>"user", :text=> response))
end
stateevaluation, progressvalue = evaluator(config, newstate)
if newstate[:reward] < 0
pprint(newstate[:thoughtHistory])
newstate[:evaluation] = stateevaluation
newstate[:lesson] = reflector(config, newstate)
# store new lesson for later use
lessonDict = copy(JSON3.read("lesson.json"))
latestLessonKey, latestLessonIndice =
GeneralUtils.findHighestIndexKey(lessonDict, "lesson")
nextIndice = latestLessonKey == :NA ? 1 : latestLessonIndice + 1
newLessonKey = Symbol("lesson_$(nextIndice)")
lessonDict[newLessonKey] = newstate
open("lesson.json", "w") do io
JSON3.pretty(io, lessonDict)
end
print("---> reflector()")
end
return (newNodeKey, newstate, progressvalue)
end
# return (isterminalstate, reward)
# end
""" Chat with llm. """ Chat with llm.
@@ -960,58 +1016,6 @@ end
# function conversation(a::T, userinput::Dict) where {T<:agent}
# # get new user msg from a.receiveUserMsgChannel
# # "newtopic" command to delete chat history
# if userinput[:text] == "newtopic"
# clearhistory(a)
# return "Okay. What shall we talk about?"
# else
# # add usermsg to a.chathistory
# addNewMessage(a, "user", userinput[:text])
# currentstate =
# if isempty(a.plan[:currenttrajectory])
# # set up initial state
# Dict{Symbol, Any}(
# # deepcopy the info to prevent modifying the info unintentionally during MCTS planning
# :customerinfo=> deepcopy(a.keywordinfo[:customerinfo]),
# :storeinfo=> deepcopy(a.keywordinfo[:storeinfo]),
# :userselect=> nothing,
# :reward=> 0,
# :isterminal=> false,
# :evaluation=> nothing,
# :lesson=> nothing,
# :thoughtDict=> nothing,
# :totalTrajectoryReward=> nothing,
# :thoughtHistory=> OrderedDict{Symbol, Any}( # contain question, thought_1, action_1, observation_1, thought_2, ...
# # :recap=>,
# :question=> userinput[:text],
# )
# )
# else
# a.plan[:currenttrajectory]
# end
# bestNextState, besttrajectory = runMCTS(a, currentstate, decisionMaker, evaluator, reflector,
# totalsample=3, maxDepth=2, maxiterations=1, explorationweight=1.0)
# # transition
# newstate = transition(a, bestNextState)
# a.plan[:currenttrajectory] = newstate
# end
# end

View File

@@ -3,7 +3,7 @@ module llmfunction
export virtualWineUserChatbox, jsoncorrection, winestock, export virtualWineUserChatbox, jsoncorrection, winestock,
virtualWineUserRecommendbox, userChatbox, userRecommendbox virtualWineUserRecommendbox, userChatbox, userRecommendbox
using HTTP, JSON3, URIs, Random, PrettyPrinting using HTTP, JSON3, URIs, Random, PrettyPrinting, UUIDs
using GeneralUtils using GeneralUtils
using ..type, ..util using ..type, ..util
@@ -168,8 +168,8 @@ julia>
# Signature # Signature
""" """
function virtualWineUserChatbox(a::T1, input::T2, virtualCustomerChatHistory function virtualWineUserChatbox(config::T1, input::T2, virtualCustomerChatHistory
)::Union{Tuple{String, Number, Number, Bool}, Tuple{String, Nothing, Number, Bool}} where {T1<:agent, T2<:AbstractString} )::Union{Tuple{String, Number, Number, Bool}, Tuple{String, Nothing, Number, Bool}} where {T1<:AbstractDict, T2<:AbstractString}
previouswines = previouswines =
""" """
@@ -275,16 +275,16 @@ function virtualWineUserChatbox(a::T1, input::T2, virtualCustomerChatHistory
""" """
pprint(prompt) pprint(prompt)
externalService = a.config[:externalservice][:text2textinstruct] externalService = config[:externalservice][:text2textinstruct]
# send formatted input to user using GeneralUtils.sendReceiveMqttMsg # send formatted input to user using GeneralUtils.sendReceiveMqttMsg
msgMeta = GeneralUtils.generate_msgMeta( msgMeta = GeneralUtils.generate_msgMeta(
externalService[:mqtttopic], externalService[:mqtttopic],
senderName= "virtualWineUserChatbox", senderName= "virtualWineUserChatbox",
senderId= a.id, senderId= string(uuid4()),
receiverName= "text2textinstruct", receiverName= "text2textinstruct",
mqttBroker= a.config[:mqttServerInfo][:broker], mqttBroker= config[:mqttServerInfo][:broker],
mqttBrokerPort= a.config[:mqttServerInfo][:port], mqttBrokerPort= config[:mqttServerInfo][:port],
msgId = "dummyid" #CHANGE remove after testing finished msgId = "dummyid" #CHANGE remove after testing finished
) )
@@ -310,7 +310,7 @@ function virtualWineUserChatbox(a::T1, input::T2, virtualCustomerChatHistory
"isterminal": "..." "isterminal": "..."
} }
""" """
responseJsonStr = jsoncorrection(a, _responseJsonStr, expectedJsonExample) responseJsonStr = jsoncorrection(config, _responseJsonStr, expectedJsonExample)
responseDict = copy(JSON3.read(responseJsonStr)) responseDict = copy(JSON3.read(responseJsonStr))
text::AbstractString = responseDict[:text] text::AbstractString = responseDict[:text]
@@ -362,8 +362,8 @@ julia> result = winestock(agent, input)
# Signature # Signature
""" """
function winestock(a::T1, input::T2 function winestock(config::T1, input::T2
)::Union{Tuple{String, Number, Number, Bool}, Tuple{String, Nothing, Number, Bool}} where {T1<:agent, T2<:AbstractString} )::Union{Tuple{String, Number, Number, Bool}, Tuple{String, Nothing, Number, Bool}} where {T1<:AbstractDict, T2<:AbstractString}
# SELECT * # SELECT *
# FROM food # FROM food
@@ -469,16 +469,16 @@ function winestock(a::T1, input::T2
""" """
pprint(prompt) pprint(prompt)
externalService = a.config[:externalservice][:text2textinstruct] externalService = config[:externalservice][:text2textinstruct]
# send formatted input to user using GeneralUtils.sendReceiveMqttMsg # send formatted input to user using GeneralUtils.sendReceiveMqttMsg
msgMeta = GeneralUtils.generate_msgMeta( msgMeta = GeneralUtils.generate_msgMeta(
externalService[:mqtttopic], externalService[:mqtttopic],
senderName= "virtualWineUserChatbox", senderName= "virtualWineUserChatbox",
senderId= a.id, senderId= string(uuid4()),
receiverName= "text2textinstruct", receiverName= "text2textinstruct",
mqttBroker= a.config[:mqttServerInfo][:broker], mqttBroker= config[:mqttServerInfo][:broker],
mqttBrokerPort= a.config[:mqttServerInfo][:port], mqttBrokerPort= config[:mqttServerInfo][:port],
msgId = "dummyid" #CHANGE remove after testing finished msgId = "dummyid" #CHANGE remove after testing finished
) )
@@ -505,7 +505,7 @@ function winestock(a::T1, input::T2
} }
} }
""" """
responseJsonStr = jsoncorrection(a, _responseJsonStr, expectedJsonExample) responseJsonStr = jsoncorrection(config, _responseJsonStr, expectedJsonExample)
_responseDict = copy(JSON3.read(responseJsonStr)) _responseDict = copy(JSON3.read(responseJsonStr))
responseDict = _responseDict[:attributes] responseDict = _responseDict[:attributes]
@@ -559,8 +559,8 @@ function winestock(a::T1, input::T2
# return result, nothing, 0, false # return result, nothing, 0, false
end end
function wineattributes_wordToNumber(a::T1, input::T2 function wineattributes_wordToNumber(config::T1, input::T2
)::Dict where {T1<:agent, T2<:AbstractString} )::Dict where {T1<:AbstractDict, T2<:AbstractString}
systemmsg = systemmsg =
""" """
@@ -656,16 +656,16 @@ function wineattributes_wordToNumber(a::T1, input::T2
""" """
pprint(prompt) pprint(prompt)
externalService = a.config[:externalservice][:text2textinstruct] externalService = config[:externalservice][:text2textinstruct]
# send formatted input to user using GeneralUtils.sendReceiveMqttMsg # send formatted input to user using GeneralUtils.sendReceiveMqttMsg
msgMeta = GeneralUtils.generate_msgMeta( msgMeta = GeneralUtils.generate_msgMeta(
externalService[:mqtttopic], externalService[:mqtttopic],
senderName= "wineattributes_wordToNumber", senderName= "wineattributes_wordToNumber",
senderId= a.id, senderId= string(uuid4()),
receiverName= "text2textinstruct", receiverName= "text2textinstruct",
mqttBroker= a.config[:mqttServerInfo][:broker], mqttBroker= config[:mqttServerInfo][:broker],
mqttBrokerPort= a.config[:mqttServerInfo][:port], mqttBrokerPort= config[:mqttServerInfo][:port],
) )
outgoingMsg = Dict( outgoingMsg = Dict(
@@ -691,7 +691,7 @@ function wineattributes_wordToNumber(a::T1, input::T2
} }
} }
""" """
responseJsonStr = jsoncorrection(a, _responseJsonStr, expectedJsonExample) responseJsonStr = jsoncorrection(config, _responseJsonStr, expectedJsonExample)
_responseDict = copy(JSON3.read(responseJsonStr)) _responseDict = copy(JSON3.read(responseJsonStr))
responseDict = _responseDict[:attributes] responseDict = _responseDict[:attributes]