This commit is contained in:
narawat lamaiin
2024-05-13 17:37:44 +07:00
parent 8431258f1c
commit 62c6ce90ed
4 changed files with 432 additions and 86 deletions

View File

@@ -295,7 +295,7 @@ function evaluator(a::T1, state::T2)::Tuple{String, Integer} where {T1<:agent, T
}
{"evaluation": "This trajectory is correct as it is reasonable to check an inventory for info provided in the question.
It is also better to have simple searches corresponding to a single entity, making this the best action.",
"score": 7
"score": 10
}
{
@@ -309,7 +309,7 @@ function evaluator(a::T1, state::T2)::Tuple{String, Integer} where {T1<:agent, T
}
{"evaluation": "This trajectory is incorrect as my search term should be related to a 4-colors pen with a pencil in it,
not a pen and a pencil seperately. A better search term should have been a 4-colors pen with a pencil, all-in-one.",
"score": 3
"score": 0
}
Let's begin!:
@@ -378,6 +378,143 @@ function evaluator(a::T1, state::T2)::Tuple{String, Integer} where {T1<:agent, T
end
# """
# # Arguments
# - `a::T1`
# one of Yiem's agent
# - `state::T2`
# a game state
# # Return
# - `evaluation::Tuple{String, Integer}`
# evaluation and score
# # Example
# ```jldoctest
# julia>
# ```
# # TODO
# - [] update docs
# - [] implement the function
# # Signature
# """
# function comparer(a::T1, state::T2)::Tuple{String, Integer} where {T1<:agent, T2<:AbstractDict}
# _prompt =
# """
# Analyze the trajectories of a solution to a question answering task. The trajectories are
# labeled by environmental observations about the situation, thoughts that can reason about
# the current situation and actions that can be three types:
# 1) winestock[query], which you can use to find wine in your inventory.
# 2) chatbox[text], which you can use to interact with the user.
# 3) recommendbox[answer], which returns your wine recommendation to the user.
# Given a question and a trajectory, evaluate its correctness and provide your reasoning and
# analysis in detail. Focus on the latest thought, action, and observation. Incomplete trajectories
# can be correct if the thoughts and actions so far are correct, even if the answer is not found
# yet. Do not generate additional thoughts or actions. Then ending with the correctness score s
# where s is an integer from 0 to 10.
# You should only respond in JSON format as describe below:
# {"evaluation": "your evaluation", "score": "your evaluation score"}
# Here are some examples:
# {
# "question": "I'm looking for a sedan with an automatic driving feature.",
# "thought_1": "I have many types of sedans in my inventory, each with diverse features.",
# "thought_2": "But there is only 1 model that has the feature customer wanted.",
# "thought_3": "I should check our inventory first to see if we have it.",
# "action_1": {"name": "inventory", "input": "Yiem model A"},
# "observation_1": "Yiem model A is in stock."
# }
# {"evaluation": "This trajectory is correct as it is reasonable to check an inventory for info provided in the question.
# It is also better to have simple searches corresponding to a single entity, making this the best action.",
# "score": 10
# }
# {
# "question": "Do you have an all-in-one pen with 4 colors and a pencil for sale?",
# "thought_1": "Let me check our inventory first to see if I have it.",
# "action_1": {"name": "inventory", "input": "pen with 4 color and a pencil."},
# "observation_1": "I found {1: "Pilot Dr. grip 4-in-1 pen", 2: "Rotting pencil"}",
# "thought_2": "Ok, I have what the user is asking. Let's tell the user.",
# "action_2": {"name": "chatbox", "input": "Yes, we do have a Pilot Dr. grip 4-in-1 pen and a Rotting pencil"},
# "observation_1": "This is not what I wanted."
# }
# {"evaluation": "This trajectory is incorrect as my search term should be related to a 4-colors pen with a pencil in it,
# not a pen and a pencil seperately. A better search term should have been a 4-colors pen with a pencil, all-in-one.",
# "score": 0
# }
# Let's begin!:
# $(JSON3.write(state[:thoughtHistory]))
# {"evaluation"
# """
# # apply LLM specific instruct format
# externalService = a.config[:externalservice][:text2textinstruct]
# llminfo = externalService[:llminfo]
# prompt =
# if llminfo[:name] == "llama3instruct"
# formatLLMtext_llama3instruct("system", _prompt)
# else
# error("llm model name is not defied yet $(@__LINE__)")
# end
# msgMeta = GeneralUtils.generate_msgMeta(
# a.config[:externalservice][:text2textinstruct][:mqtttopic],
# senderName= "evaluator",
# senderId= a.id,
# receiverName= "text2textinstruct",
# mqttBroker= a.config[:mqttServerInfo][:broker],
# mqttBrokerPort= a.config[:mqttServerInfo][:port],
# )
# outgoingMsg = Dict(
# :msgMeta=> msgMeta,
# :payload=> Dict(
# :text=> prompt,
# :kwargs=> Dict(
# :max_tokens=> 512,
# :stop=> ["<|eot_id|>"],
# )
# )
# )
# for attempt in 1:5
# try
# response = GeneralUtils.sendReceiveMqttMsg(outgoingMsg)
# _responseJsonStr = response[:response][:text]
# expectedJsonExample =
# """
# Here is an expected JSON format:
# {"evaluation": "...", "score": "..."}
# """
# responseJsonStr = jsoncorrection(a, _responseJsonStr, expectedJsonExample)
# evaluationDict = copy(JSON3.read(responseJsonStr))
# # check if dict has all required value
# dummya::AbstractString = evaluationDict[:evaluation]
# dummyb::Integer = evaluationDict[:score]
# return (evaluationDict[:evaluation], evaluationDict[:score])
# catch e
# io = IOBuffer()
# showerror(io, e)
# errorMsg = String(take!(io))
# st = sprint((io, v) -> show(io, "text/plain", v), stacktrace(catch_backtrace()))
# println("")
# @warn "Attempt $attempt. Error occurred: $errorMsg\n$st"
# println("")
# end
# end
# error("evaluator failed to generate an evaluation")
# end
"""
# Arguments
@@ -600,8 +737,9 @@ julia> response = ChatAgent.conversation(newAgent, "Hi! how are you?")
# TODO
- [] update docstring
- [WORKING] MCTS() for planning
- [x] MCTS() for planning
- [] add recap to initialState for earlier completed question
- [WORKING] conversation loop
# Signature
"""
@@ -617,36 +755,36 @@ function conversation(a::T, userinput::Dict) where {T<:agent}
# add usermsg to a.chathistory
addNewMessage(a, "user", userinput[:text])
#[] if the last used tool is a chatbox, put usermsg -> observation and continue actor loop as planned
if !isempty(a.plan[:currenttrajectory]) &&
a.plan[:currenttrajectory][end][:action] == "chatbox"
else
initialState = Dict{Symbol, Any}(
currentstate =
if isempty(a.plan[:currenttrajectory])
# set up initial state
Dict{Symbol, Any}(
# deepcopy the info to prevent modifying the info unintentionally during MCTS planning
:customerinfo=> deepcopy(a.keywordinfo[:customerinfo]),
:storeinfo=> deepcopy(a.keywordinfo[:storeinfo]),
:select=> nothing,
:userselect=> nothing,
:reward=> 0,
:isterminal=> false,
:evaluation=> nothing,
:lesson=> nothing,
:thoughtDict=> nothing,
:totalTrajectoryReward=> nothing,
:thoughtHistory=> OrderedDict{Symbol, Any}( # contain question, thought_1, action_1, observation_1, thought_2, ...
# :recap=>,
:question=> userinput[:text],
)
)
bestplan = runMCTS(a, initialState, decisionMaker, evaluator, reflector,
2, 3, 4, 1.0)
error("---> bestplan")
# actor loop(bestplan)
else
a.plan[:currenttrajectory]
end
bestNextState, besttrajectory = runMCTS(a, currentstate, decisionMaker, evaluator, reflector,
totalsample=3, maxDepth=2, maxiterations=1, explorationweight=1.0)
# transition
newstate = transition(a, bestNextState)
a.plan[:currenttrajectory] = newstate
end
end

View File

@@ -1,7 +1,7 @@
module llmfunction
export virtualWineCustomerChatbox, jsoncorrection, winestock,
virtualWineCustomerReccommendbox
export virtualWineUserChatbox, jsoncorrection, winestock,
virtualWineUserRecommendbox, userChatbox, userRecommendbox
using HTTP, JSON3, URIs, Random
using GeneralUtils
@@ -26,8 +26,46 @@ julia>
# Signature
"""
function chatbox(a::T1, input::T2) where {T1<:agent, T2<:AbstractString}
error("--> chatbox")
function userChatbox(a::T1, input::T2) where {T1<:agent, T2<:AbstractString}
error("--> userChatbox")
# put in model format
virtualWineCustomer = a.config[:externalservice][:virtualWineCustomer_1]
llminfo = virtualWineCustomer[:llminfo]
formattedinput =
if llminfo[:name] == "llama3instruct"
formatLLMtext_llama3instruct("assistant", input)
else
error("llm model name is not defied yet $(@__LINE__)")
end
# send formatted input to user using GeneralUtils.sendReceiveMqttMsg
# return response
end
"""
# Arguments
# Return
# Example
```jldoctest
julia>
```
# TODO
- [] update docstring
- [PENDING] implement the function
# Signature
"""
function userRecommendbox(a::T1, input::T2) where {T1<:agent, T2<:AbstractString}
error("--> userRecommendbox")
# put in model format
virtualWineCustomer = a.config[:externalservice][:virtualWineCustomer_1]
@@ -69,7 +107,7 @@ julia>
# Signature
"""
function virtualWineCustomerReccommendbox(a::T1, input
function virtualWineUserRecommendbox(a::T1, input
)::Union{Tuple{String, Number, Number, Bool}, Tuple{String, Nothing, Number, Bool}} where {T1<:agent}
# put in model format
@@ -85,7 +123,7 @@ function virtualWineCustomerReccommendbox(a::T1, input
# send formatted input to user using GeneralUtils.sendReceiveMqttMsg
msgMeta = GeneralUtils.generate_msgMeta(
virtualWineCustomer[:mqtttopic],
senderName= "virtualWineCustomerReccommendbox",
senderName= "virtualWineUserRecommendbox",
senderId= a.id,
receiverName= "virtualWineCustomer",
mqttBroker= a.config[:mqttServerInfo][:broker],
@@ -126,11 +164,10 @@ julia>
# TODO
- [] update docs
- [] add to remove <<< user option select >>> and <<| reward |>>
# Signature
"""
function virtualWineCustomerChatbox(a::T1, input::T2
function virtualWineUserChatbox(a::T1, input::T2
)::Union{Tuple{String, Number, Number, Bool}, Tuple{String, Nothing, Number, Bool}} where {T1<:agent, T2<:AbstractString}
# put in model format
@@ -146,7 +183,7 @@ function virtualWineCustomerChatbox(a::T1, input::T2
# send formatted input to user using GeneralUtils.sendReceiveMqttMsg
msgMeta = GeneralUtils.generate_msgMeta(
virtualWineCustomer[:mqtttopic],
senderName= "virtualWineCustomerChatbox",
senderName= "virtualWineUserChatbox",
senderId= a.id,
receiverName= "virtualWineCustomer",
mqttBroker= a.config[:mqttServerInfo][:broker],
@@ -178,7 +215,7 @@ function virtualWineCustomerChatbox(a::T1, input::T2
println("")
end
end
error("virtualWineCustomerChatbox failed to get a response")
error("virtualWineUserChatbox failed to get a response")
end

View File

@@ -5,7 +5,8 @@
module mcts
export MCTSNode, runMCTS, isleaf
export MCTSNode, runMCTS, isleaf, selectBestNextState, selectBestTrajectory, transition,
userChatbox
using Dates, UUIDs, DataStructures, JSON3, Random, PrettyPrinting
using GeneralUtils
@@ -51,9 +52,9 @@ mutable struct MCTSNode{T1<:AbstractDict, T2<:AbstractString}
nodekey::T2
state::T1
visits::Integer
progressvalue::Number
statevalue::Number
reward::Number
progressvalue::Number # estimate value by LLM's reasoning
statevalue::Number # store discounted commulative reward (gather from its child node)
reward::Number # this node's own reward
isterminal::Bool
parent::Union{MCTSNode, Nothing}
children::Dict{String, MCTSNode}
@@ -132,23 +133,24 @@ julia>
# Signature
"""
function expand(a::T1, node::MCTSNode, decisionMaker::Function,
evaluator::Function, reflector::Function; n::Integer=3) where {T1<:agent}
evaluator::Function, reflector::Function; totalsample::Integer=3
) where {T1<:agent}
nthSample = 0
while true
nthSample += 1
if nthSample <= n
if nthSample <= totalsample
thoughtDict = decisionMaker(a, node.state)
println("---> expand() sample $nthSample")
pprintln(node.state[:thoughtHistory])
pprintln(thoughtDict)
newNodeKey, newstate, reward, isterminalstate =
MCTStransition(a, node.state, thoughtDict)
node.state[:thoughtDict] = thoughtDict
newNodeKey, newstate = MCTStransition(a, node.state)
# add evaluator
stateevaluation, progressvalue = evaluator(a, newstate)
if reward < 0
if newstate[:reward] < 0
pprint(newstate[:thoughtHistory])
newstate[:evaluation] = stateevaluation
newstate[:lesson] = reflector(a, newstate)
@@ -167,8 +169,9 @@ function expand(a::T1, node::MCTSNode, decisionMaker::Function,
end
if newNodeKey keys(node.children)
node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, progressvalue, 0,
reward, isterminalstate, node, Dict{String, MCTSNode}())
node.children[newNodeKey] =
MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward],
newstate[:isterminal], node, Dict{String, MCTSNode}())
end
else
break
@@ -196,24 +199,30 @@ end
julia>
```
# TODO
- [] update docs
# Signature
"""
function simulate(a::T, node::MCTSNode, decisionMaker::Function, evaluator::Function,
reflector::Function; maxDepth::Integer=3, n::Integer=3)::Number where {T<:agent}
reflector::Function; maxDepth::Integer=3, totalsample::Integer=3
)::Union{Tuple{Number, Dict{Symbol, <:Any}}, Tuple{Number, Nothing}} where {T<:agent}
simTrajectoryReward = 0.0
terminalstate = nothing
for depth in 1:maxDepth
simTrajectoryReward += node.reward
if node.isterminal
terminalstate = node.state
break
else
expand(a, node, decisionMaker, evaluator, reflector; n=n)
expand(a, node, decisionMaker, evaluator, reflector; totalsample=totalsample)
node = selectChildNode(node)
end
end
return simTrajectoryReward
return (simTrajectoryReward, terminalstate)
end
""" Backpropagate reward along the simulation chain
@@ -285,20 +294,21 @@ julia> thoughtDict = Dict(
# Signature
"""
function MCTStransition(a::T1, state::T2, thoughtDict::T3
)::Tuple{String, Dict{Symbol, <:Any}, <:Number, Bool} where {T1<:agent, T2<:AbstractDict, T3<:AbstractDict}
function MCTStransition(a::T1, state::T2
)::Tuple{String, Dict{Symbol, <:Any}} where {T1<:agent, T2<:AbstractDict}
thoughtDict = state[:thoughtDict]
actionname = thoughtDict[:action][:name]
actioninput = thoughtDict[:action][:input]
# map action and input() to llm function
response, select, reward, isterminal =
if actionname == "chatbox"
virtualWineCustomerChatbox(a, actioninput) # virtual customer
virtualWineUserChatbox(a, actioninput) # virtual customer
elseif actionname == "winestock"
winestock(a, actioninput)
elseif actionname == "recommendbox"
virtualWineCustomerReccommendbox(a, actioninput)
virtualWineUserRecommendbox(a, actioninput)
else
error("undefined LLM function. Requesting $actionname")
end
@@ -321,7 +331,85 @@ function MCTStransition(a::T1, state::T2, thoughtDict::T3
newNodeKey = GeneralUtils.uuid4snakecase()
return (newNodeKey, newstate, reward, isterminal)
return (newNodeKey, newstate)
end
""" Get a new state
# Arguments
- `a::T1`
one of YiemAgent's agent
- `state::T2`
current game state
- `thoughtDict::T3`
contain Thought, Action, Observation
- `isterminal::Function`
a function to determine terminal state
# Return
- `(newNodeKey, newstate, isterminalstate, reward)::Tuple{String, Dict{Symbol, <:Any}, Bool, <:Number}`
# Example
```jldoctest
julia> state = Dict{Symbol, Dict{Symbol, Any}}(
:thoughtHistory => Dict(:question => "Hello, I want to buy a bottle of wine."),
:storeinfo => Dict(),
:customerinfo => Dict()
)
julia> thoughtDict = Dict(
:question=> "I want to buy a bottle of wine.",
:thought_1=> "The customer wants to buy a bottle of wine.",
:action_1=> Dict{Symbol, Any}(
:name=>"Chatbox",
:input=>"What occasion are you buying the wine for?",
),
:observation_1 => ""
)
```
# TODO
- [x] add other actions
- [] add embedding of newstate and store in newstate[:embedding]
# Signature
"""
function transition(a::T1, state::T2
)::Dict{Symbol, <:Any} where {T1<:agent, T2<:AbstractDict}
thoughtDict = state[:thoughtDict]
actionname = thoughtDict[:action][:name]
actioninput = thoughtDict[:action][:input]
# map action and input() to llm function
response, select, reward, isterminal =
if actionname == "chatbox"
userChatbox(a, actioninput) # virtual customer
elseif actionname == "winestock"
winestock(a, actioninput)
elseif actionname == "recommendbox"
userRecommendbox(a, actioninput)
else
error("undefined LLM function. Requesting $actionname")
end
latestThoughtKey, latestThoughtIndice =
GeneralUtils.findHighestIndexKey(state[:thoughtHistory], "thought")
nextIndice = latestThoughtKey == :NA ? 1 : latestThoughtIndice + 1
latestThoughtKey = Symbol("thought_$nextIndice")
latestActionKey = Symbol("action_$nextIndice")
# add Thought, action, observation to thoughtHistory
newstate = deepcopy(state)
newstate[:thoughtHistory][latestThoughtKey] = thoughtDict[:thought]
newstate[:thoughtHistory][latestActionKey] = thoughtDict[:action]
newObservationKey = Symbol("observation_$(nextIndice)")
newstate[:thoughtHistory][newObservationKey] = response
newstate[:reward] = reward
newstate[:select] = select
newstate[:isterminal] = isterminal
return newstate
end
@@ -396,6 +484,90 @@ function selectChildNode(node::MCTSNode)::MCTSNode
end
"""
# Arguments
- `node::MCTSNode`
node of a search tree
# Return
- `childNode::MCTSNode`
the highest value child node
# Example
```jldoctest
julia>
```
# TODO
- [] update docs
- [TESTING] implement the function
# Signature
"""
function selectBestNextState(node::MCTSNode)::MCTSNode
highestProgressValue = 0
nodekey = nothing
# if all childnode has statevalue == 0, use progressvalue + reward to select the best node
stateValueSum = sum([v.statevalue for (k, v) in node.children])
if stateValueSum != 0
for (k, childnode) in node.children
potential = childnode.statevalue / childnode.visits
if potential > highestProgressValue
highestProgressValue = potential
nodekey = childnode.nodekey
end
end
else
for (k, childnode) in node.children
potential = childnode.progressvalue + childnode.reward
if potential > highestProgressValue
highestProgressValue = potential
nodekey = childnode.nodekey
end
end
end
return node.children[nodekey]
end
"""
# Arguments
- `node::MCTSNode`
node of a search tree
# Return
- `childNode::MCTSNode`
the highest value child node
# Example
```jldoctest
julia>
```
# TODO
- [] update docs
- [TESTING] implement the function
# Signature
"""
function selectBestTrajectory(node::MCTSNode)::MCTSNode
while !isleaf(node)
node = selectBestNextState(node)
end
return node
end
""" Determine wheter a given node is a root node
# Arguments
@@ -451,7 +623,7 @@ julia>
# TODO
[] update docstring
[PENDING] return best plan
[x] return best action
# Signature
"""
@@ -460,46 +632,48 @@ function runMCTS(
initialState,
decisionMaker::Function,
evaluator::Function,
reflector::Function,
n::Integer,
maxDepth::Integer,
maxIterations::Integer,
w::Float64
reflector::Function;
totalsample::Integer=3,
maxDepth::Integer=3,
maxiterations::Integer=10,
explorationweight::Number=1.0,
) where {T1<:agent}
root = MCTSNode("root", initialState, 0, 0, 0, 0, false, nothing, Dict{String, MCTSNode}())
for nth in 1:maxIterations
for nth in 1:maxiterations
node = root
node.visits += 1
while !isleaf(node)
node = UCTselect(node, w)
node = UCTselect(node, explorationweight)
end
if node.isterminal
# MCTS arrive at the leaf node that is also a terminal state,
# do nothing then go directly to backpropagation
backpropagate(leafNode, node.reward)
else
expand(a, node, decisionMaker, evaluator, reflector; n=n)
expand(a, node, decisionMaker, evaluator, reflector; totalsample=totalsample)
leafNode = selectChildNode(node)
simTrajectoryReward = simulate(a, leafNode, decisionMaker, evaluator,
reflector; maxDepth=maxDepth, n=n)
simTrajectoryReward, terminalstate = simulate(a, leafNode, decisionMaker, evaluator,
reflector; maxDepth=maxDepth, totalsample=totalsample)
if terminalstate !== nothing
terminalstate[:totalTrajectoryReward] = simTrajectoryReward
end
#[] write best state to file if it has higher simTrajectoryReward. Use to improve evaluation
# open("trajectory.json", "w") do io
# JSON3.pretty(io, terminalstate)
# end
backpropagate(leafNode, simTrajectoryReward)
end
end
avgStateValue = 0
selectedChildKey = nothing
for (k, v) in root.children
k_avgStateValue = v.statevalue / v.visits
if k_avgStateValue > avgStateValue
avgStateValue = k_avgStateValue
selectedChildKey = k
end
end
bestNextState = selectBestNextState(root)
besttrajectory = selectBestTrajectory(root)
return root.children[selectedChildKey]
return (bestNextState.state, besttrajectory.state)
end
@@ -525,9 +699,6 @@ end

View File

@@ -134,7 +134,7 @@ result = GeneralUtils.sendMqttMsg(outgoingMsg)
outgoingMsg = Dict(
:msgMeta=> msgMeta,
:payload=> Dict(
:text=> "I like it dry.",
:text=> "I like dry wine with fruity flavors.",
:select=> nothing,
:reward=> 0,
:isterminal=> false,