This commit is contained in:
narawat lamaiin
2024-05-06 13:50:08 +07:00
parent a7fdbcede9
commit cfca2b1839
5 changed files with 188 additions and 113 deletions

View File

@@ -58,12 +58,12 @@ end
# Example # Example
```jldoctest ```jldoctest
julia> output_thoughtDict = Dict( julia> output_thoughtDict = Dict(
:Thought_1 => "The customer wants to buy a bottle of wine. This is a good start!", :thought_1 => "The customer wants to buy a bottle of wine. This is a good start!",
:Action_1 => Dict{Symbol, Any}( :action_1 => Dict{Symbol, Any}(
:action=>"Chatbox", :action=>"Chatbox",
:input=>"What occasion are you buying the wine for?" :input=>"What occasion are you buying the wine for?"
), ),
:Observation_1 => "" :observation_1 => ""
) )
``` ```
@@ -98,16 +98,6 @@ function decisionMaker(a::T1, state::T2)::Dict{Symbol, Any} where {T1<:agent, T2
# (trajectories) # (trajectories)
# """ # """
responseformat =
"""
You should only respond in JSON format as describe below:
{
"Thought": "your reasoning",
"Action": {"name": "action to take", "input": "Action input"},
"Observation": "result of the action"
}
"""
_prompt = _prompt =
""" """
You are a helpful sommelier working for a wine store. You are a helpful sommelier working for a wine store.
@@ -127,31 +117,32 @@ function decisionMaker(a::T1, state::T2)::Dict{Symbol, Any} where {T1<:agent, T2
Thought can reason about the current situation, and Action can be three types: Thought can reason about the current situation, and Action can be three types:
1) winestock[query], which you can use to find wine in your inventory. The more input data the better. 1) winestock[query], which you can use to find wine in your inventory. The more input data the better.
2) chatbox[text], which you can use to interact with the user. 2) chatbox[text], which you can use to interact with the user.
3) reccommendbox[answer], which returns your wine reccommendation to the user. 3) recommendbox[answer], which returns your wine reccommendation to the user.
$responseformat You should only respond in JSON format as describe below:
{
"thought": "your reasoning",
"action": {"name": "action to take", "input": "Action input"},
"observation": "result of the action"
}
Here are some examples: Here are some examples:
{ {
"Question": "I would like to buy a sedan with 8 seats.", "question": "I would like to buy a sedan with 8 seats.",
"Thought_1": "Our showroom carries various vehicle model. But I'm not sure whether we have a models that fits the user demand, I need to check our inventory.", "thought_1": "Our showroom carries various vehicle model. But I'm not sure whether we have a models that fits the user demand, I need to check our inventory.",
"Action_1": {"name": "inventory", "input": "sedan with 8 seats."}, "action_1": {"name": "inventory", "input": "sedan with 8 seats."},
"Observation_1": "Several model has 8 seats. Available color are black, red green" "observation_1": "Several model has 8 seats. Available color are black, red green"
} }
{ {
"Thought_2": "I have to ask the user what color he likes.", "thought": "I have a few color for the user to choose from. I will ask him what color he likes.",
"Action_2": {"name": "chatbox", "input": "Which color do you like?"} "action": {"name": "chatbox", "input": "Which color do you like?"}
"Observation_2": "I'll take black." "observation": "I'll take black."
}
{
"Thought_3": "There is only one model that fits the user preference. It's Yiem model A",
"Action_3": {"name": "recommendation", "input": "I recommend a Yiem model A"}
} }
Let's begin! Let's begin!
$(JSON3.write(state[:thoughtHistory])) $(JSON3.write(state[:thoughtHistory]))
{Thought {thought
""" """
# apply LLM specific instruct format # apply LLM specific instruct format
@@ -190,9 +181,9 @@ function decisionMaker(a::T1, state::T2)::Dict{Symbol, Any} where {T1<:agent, T2
""" """
Here is an expected JSON format: Here is an expected JSON format:
{ {
"Thought": "...", "thought": "...",
"Action": {"name": "...", "input": "..."}, "action": {"name": "...", "input": "..."},
"Observation": "..." "observation": "..."
} }
""" """
thoughtJsonStr = jsoncorrection(a, _thoughtJsonStr, expectedJsonExample) thoughtJsonStr = jsoncorrection(a, _thoughtJsonStr, expectedJsonExample)
@@ -224,13 +215,6 @@ julia>
# Signature # Signature
""" """
function progressValueEstimator(a::T1, state::T2)::Tuple{String, Integer} where {T1<:agent, T2<:AbstractDict} function progressValueEstimator(a::T1, state::T2)::Tuple{String, Integer} where {T1<:agent, T2<:AbstractDict}
responseformat =
"""
You should only respond in JSON format as describe below:
{
"Evaluation": {"evaluation": "your evaluation", "score": "your evaluation score"}
}
"""
_prompt = _prompt =
""" """
@@ -239,7 +223,7 @@ function progressValueEstimator(a::T1, state::T2)::Tuple{String, Integer} where
the current situation and actions that can be three types: the current situation and actions that can be three types:
1) winestock[query], which you can use to find wine in your inventory. 1) winestock[query], which you can use to find wine in your inventory.
2) chatbox[text], which you can use to interact with the user. 2) chatbox[text], which you can use to interact with the user.
3) reccommendbox[answer], which returns your wine reccommendation to the user. 3) recommendbox[answer], which returns your wine reccommendation to the user.
Given a question and a trajectory, evaluate its correctness and provide your reasoning and Given a question and a trajectory, evaluate its correctness and provide your reasoning and
analysis in detail. Focus on the latest thought, action, and observation. Incomplete trajectories analysis in detail. Focus on the latest thought, action, and observation. Incomplete trajectories
@@ -247,26 +231,25 @@ function progressValueEstimator(a::T1, state::T2)::Tuple{String, Integer} where
yet. Do not generate additional thoughts or actions. Then ending with the correctness score s yet. Do not generate additional thoughts or actions. Then ending with the correctness score s
where s is an integer from 1 to 10. where s is an integer from 1 to 10.
$responseformat You should only respond in JSON format as describe below:
{"evaluation": "your evaluation", "score": "your evaluation score"}
Here are some examples: Here are some examples:
{ {
"Question": "I'm looking for a sedan with an automatic driving feature.", "question": "I'm looking for a sedan with an automatic driving feature.",
"Thought_1": "I have many types of sedans in my inventory, each with diverse features.", "thought_1": "I have many types of sedans in my inventory, each with diverse features.",
"Thought_2": "But there is only 1 model that has the feature customer wanted.", "thought_2": "But there is only 1 model that has the feature customer wanted.",
"Thought_3": "I should check our inventory first to see if we have it.", "thought_3": "I should check our inventory first to see if we have it.",
"Action_1": {"name": "inventory", "input": "Yiem model A"}, "action_1": {"name": "inventory", "input": "Yiem model A"},
"Observation_1": "Yiem model A is in stock." "observation_1": "Yiem model A is in stock."
} }
{ {"evaluation": "This trajectory is correct as it is reasonable to check an inventory for info provided in the question.
"Evaluation": {"evaluation": "This trajectory is correct as it is reasonable to check an inventory for info provided in the question.
It is also better to have simple searches corresponding to a single entity, making this the best action.", It is also better to have simple searches corresponding to a single entity, making this the best action.",
"score": 10} "score": 10}
}
Let's begin!: Let's begin!:
$(JSON3.write(state[:thoughtHistory])) $(JSON3.write(state[:thoughtHistory]))
{Evaluation {evaluation
""" """
# apply LLM specific instruct format # apply LLM specific instruct format
@@ -304,15 +287,12 @@ function progressValueEstimator(a::T1, state::T2)::Tuple{String, Integer} where
expectedJsonExample = expectedJsonExample =
""" """
Here is an expected JSON format: Here is an expected JSON format:
{ {"evaluation": "...", "score": "..."}
"Evaluation": {"evaluation": "...", "score": "..."}
}
""" """
thoughtJsonStr = jsoncorrection(a, _thoughtJsonStr, expectedJsonExample) resultJsonStr = jsoncorrection(a, _thoughtJsonStr, expectedJsonExample)
thoughtDict = copy(JSON3.read(thoughtJsonStr)) resultDict = copy(JSON3.read(resultJsonStr))
evaluation = thoughtDict[:Evaluation]
return evaluation[:evaluation], evaluation[:score] return resultDict[:evaluation], resultDict[:score]
end end
@@ -355,7 +335,7 @@ julia>
# Signature # Signature
""" """
function isterminal(state::T)::Tuple{Bool, <:Number} where {T<:AbstractDict} function isterminal(state::T)::Tuple{Bool, <:Number} where {T<:AbstractDict}
latestObservationKey, _ = GeneralUtils.findHighestIndexKey(state[:thoughtHistory], "Observation") latestObservationKey, _ = GeneralUtils.findHighestIndexKey(state[:thoughtHistory], "observation")
latestObservation = state[:thoughtHistory][latestObservationKey] latestObservation = state[:thoughtHistory][latestObservationKey]
if latestObservation !== nothing if latestObservation !== nothing
@@ -455,7 +435,7 @@ function conversation(a::T, userinput::Dict) where {T<:agent}
:customerinfo=> deepcopy(a.keywordinfo[:customerinfo]), :customerinfo=> deepcopy(a.keywordinfo[:customerinfo]),
:storeinfo=> deepcopy(a.keywordinfo[:storeinfo]), :storeinfo=> deepcopy(a.keywordinfo[:storeinfo]),
:thoughtHistory=> OrderedDict{Symbol, Any}( # contain question, thought_1, action_1, observation_1, thought_2, ... :thoughtHistory=> OrderedDict{Symbol, Any}( # contain question, thought_1, action_1, observation_1, thought_2, ...
:Question=> userinput[:text], :question=> userinput[:text],
) )
) )
bestplan = runMCTS(a, initialState, decisionMaker, progressValueEstimator, reflector, bestplan = runMCTS(a, initialState, decisionMaker, progressValueEstimator, reflector,

View File

@@ -184,7 +184,7 @@ julia> result = winestock(agent, input)
# TODO # TODO
[] update docs [] update docs
[] implement the function [PENDING] implement the function
# Signature # Signature
""" """

View File

@@ -47,10 +47,11 @@ julia> state = Dict(
# Signature # Signature
""" """
mutable struct MCTSNode{T<:AbstractDict} mutable struct MCTSNode{T1<:AbstractDict, T2<:AbstractString}
nodekey::String nodekey::T2
state::T state::T1
visits::Integer visits::Integer
stateevaluation::T2
statevalue::Number statevalue::Number
reward::Number reward::Number
isterminal::Bool isterminal::Bool
@@ -74,7 +75,7 @@ julia>
# TODO # TODO
[] update docstring [] update docstring
[TESTING] check childNode.total_reward w/ LATS paper. Which value total_reward representing [x] check childNode.total_reward w/ LATS paper. Which value total_reward representing
# Signature # Signature
""" """
@@ -83,8 +84,14 @@ function UCTselect(node::MCTSNode, w::Float64)
selectedNode = nothing selectedNode = nothing
for (childState, childNode) in node.children for (childState, childNode) in node.children
uctValue = childNode.statevalue + weightedterm =
if node.visits == 0 || childNode.visits == 0
0
else
w * sqrt(log(node.visits) / childNode.visits) w * sqrt(log(node.visits) / childNode.visits)
end
uctValue = childNode.statevalue + weightedterm
if uctValue > max_uct if uctValue > max_uct
max_uct = uctValue max_uct = uctValue
selectedNode = childNode selectedNode = childNode
@@ -132,11 +139,10 @@ function expand(a::T1, node::MCTSNode, decisionMaker::Function,
isterminal) isterminal)
# add progressValueEstimator # add progressValueEstimator
progressRationale, statevalue = progressValueEstimator(a, newstate) stateevaluation, statevalue = progressValueEstimator(a, newstate)
statevalue += reward
if newNodeKey keys(node.children) if newNodeKey keys(node.children)
node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, statevalue, node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, stateevaluation, statevalue,
reward, isterminalstate, node, Dict{String, MCTSNode}()) reward, isterminalstate, node, Dict{String, MCTSNode}())
end end
end end
@@ -163,18 +169,18 @@ julia>
# Signature # Signature
""" """
function simulate(a, node::MCTSNode, decisionMaker::Function, progressValueEstimator::Function, function simulate(a, node::MCTSNode, decisionMaker::Function, progressValueEstimator::Function,
isterminal::Function, max_depth::Int; n=3)::Number isterminal::Function, maxDepth::Int; n=3)::Number
simTrajectoryReward = 0.0 simTrajectoryReward = 0.0
for _ in 1:max_depth for depth in 1:maxDepth
if node.isterminal if node.isterminal
break break
else else
expand(a, node, decisionMaker, progressValueEstimator, isterminal, n=n)
end
node = selectChildNode(node)
simTrajectoryReward += node.reward simTrajectoryReward += node.reward
expand(a, node, decisionMaker, progressValueEstimator, isterminal, n=n)
node = selectChildNode(node)
end
end end
return simTrajectoryReward return simTrajectoryReward
@@ -216,26 +222,14 @@ julia>
# Signature # Signature
""" """
function backpropagate(node, simTrajectoryReward; discountRewardCoeff=0.9) function backpropagate(node, simTrajectoryReward; discountRewardCoeff=0.9)
while !isroot(node)
# Update the statistics of the current node based on the result of the playout # Update the statistics of the current node based on the result of the playout
node.visits += 1 node.visits += 1
node.statevalue += ((node.statevalue * (node.visits-1)) + simTrajectoryReward) / node.visits node.statevalue += ((node.statevalue * (node.visits-1)) + simTrajectoryReward) / node.visits
simTrajectoryReward *= discountRewardCoeff # discount because future reward is uncertain
# Backpropagate the result to the parent node recursively node = node.parent
if !isroot(node)
simTrajectoryReward *= discountRewardCoeff
backpropagate(node.parent, simTrajectoryReward)
end end
end end
# function backpropagate(node::MCTSNode, reward::Float64)
# node.visits += 1
# # [] there is no total_reward in the paper, buy they use stateValue
# node.total_reward += reward
# if !isempty(node.children)
# best_child = argmax([child.total_reward / child.visits for child in values(node.children)])
# backpropagate(node.children[best_child], -reward)
# end
# end
""" Get a new state """ Get a new state
@@ -256,18 +250,18 @@ end
# Example # Example
```jldoctest ```jldoctest
julia> state = Dict{Symbol, Dict{Symbol, Any}}( julia> state = Dict{Symbol, Dict{Symbol, Any}}(
:thoughtHistory => Dict(:Question => "Hello, I want to buy a bottle of wine."), :thoughtHistory => Dict(:question => "Hello, I want to buy a bottle of wine."),
:storeinfo => Dict(), :storeinfo => Dict(),
:customerinfo => Dict() :customerinfo => Dict()
) )
julia> thoughtDict = Dict( julia> thoughtDict = Dict(
:Question=> "I want to buy a bottle of wine.", :question=> "I want to buy a bottle of wine.",
:Thought_1=> "The customer wants to buy a bottle of wine.", :thought_1=> "The customer wants to buy a bottle of wine.",
:Action_1=> Dict{Symbol, Any}( :action_1=> Dict{Symbol, Any}(
:name=>"Chatbox", :name=>"Chatbox",
:input=>"What occasion are you buying the wine for?", :input=>"What occasion are you buying the wine for?",
), ),
:Observation_1 => "" :observation_1 => ""
) )
``` ```
@@ -280,8 +274,8 @@ julia> thoughtDict = Dict(
function MCTStransition(a::T1, state::T2, thoughtDict::T3, isterminal::Function function MCTStransition(a::T1, state::T2, thoughtDict::T3, isterminal::Function
)::Tuple{String, Dict{Symbol, <:Any}, Bool, <:Number} where {T1<:agent, T2<:AbstractDict, T3<:AbstractDict} )::Tuple{String, Dict{Symbol, <:Any}, Bool, <:Number} where {T1<:agent, T2<:AbstractDict, T3<:AbstractDict}
actionname = thoughtDict[:Action][:name] actionname = thoughtDict[:action][:name]
actioninput = thoughtDict[:Action][:input] actioninput = thoughtDict[:action][:input]
# map action and input() to llm function # map action and input() to llm function
response = response =
@@ -289,23 +283,23 @@ function MCTStransition(a::T1, state::T2, thoughtDict::T3, isterminal::Function
virtualWineCustomerChatbox(a, actioninput) # virtual customer virtualWineCustomerChatbox(a, actioninput) # virtual customer
elseif actionname == "winestock" elseif actionname == "winestock"
winestock(a, actioninput) winestock(a, actioninput)
elseif actionname == "reccommendbox" elseif actionname == "recommendbox"
virtualWineCustomerReccommendbox(a, actioninput) virtualWineCustomerReccommendbox(a, actioninput)
else else
error("undefined LLM function. Requesting $actionname") error("undefined LLM function. Requesting $actionname")
end end
latestThoughtKey, latestThoughtIndice = GeneralUtils.findHighestIndexKey(state[:thoughtHistory], latestThoughtKey, latestThoughtIndice = GeneralUtils.findHighestIndexKey(state[:thoughtHistory],
"Thought") "thought")
nextIndice = latestThoughtKey == :NA ? 1 : latestThoughtIndice + 1 nextIndice = latestThoughtKey == :NA ? 1 : latestThoughtIndice + 1
latestThoughtKey = Symbol("Thought_$nextIndice") latestThoughtKey = Symbol("thought_$nextIndice")
latestActionKey = Symbol("Action_$nextIndice") latestActionKey = Symbol("action_$nextIndice")
# add Thought, action, observation to thoughtHistory # add Thought, action, observation to thoughtHistory
newstate = deepcopy(state) newstate = deepcopy(state)
newstate[:thoughtHistory][latestThoughtKey] = thoughtDict[:Thought] newstate[:thoughtHistory][latestThoughtKey] = thoughtDict[:thought]
newstate[:thoughtHistory][latestActionKey] = thoughtDict[:Action] newstate[:thoughtHistory][latestActionKey] = thoughtDict[:action]
latestObservationKey = Symbol("Observation_$(nextIndice)") latestObservationKey = Symbol("observation_$(nextIndice)")
newstate[:thoughtHistory][latestObservationKey] = response newstate[:thoughtHistory][latestObservationKey] = response
newNodeKey = GeneralUtils.uuid4snakecase() newNodeKey = GeneralUtils.uuid4snakecase()
@@ -332,7 +326,7 @@ julia> initialState = Dict{Symbol, Any}(
:storeinfo=> Dict{Symbol, Any}(), :storeinfo=> Dict{Symbol, Any}(),
:thoughtHistory=> OrderedDict{Symbol, Any}( :thoughtHistory=> OrderedDict{Symbol, Any}(
:Question=> "How are you?", :question=> "How are you?",
) )
) )
julia> statetype = typeof(initialState) julia> statetype = typeof(initialState)
@@ -341,6 +335,9 @@ julia> YiemAgent.isleaf(root)
true true
``` ```
# TODO
[] update docs
# Signature # Signature
""" """
isleaf(node::MCTSNode)::Bool = isempty(node.children) isleaf(node::MCTSNode)::Bool = isempty(node.children)
@@ -451,9 +448,9 @@ function runMCTS(
maxIterations::Integer, maxIterations::Integer,
w::Float64) where {T1<:agent} w::Float64) where {T1<:agent}
root = MCTSNode("root", initialState, 0, 0, 0, false, nothing, Dict{String, MCTSNode}()) root = MCTSNode("root", initialState, 0, "N/A", 0, 0, false, nothing, Dict{String, MCTSNode}())
for _ in 1:maxIterations for nth in 1:maxIterations
node = root node = root
while !isleaf(node) while !isleaf(node)
node = UCTselect(node, w) node = UCTselect(node, w)
@@ -462,6 +459,7 @@ function runMCTS(
expand(a, node, decisionMaker, progressValueEstimator, isterminal, n=n) expand(a, node, decisionMaker, progressValueEstimator, isterminal, n=n)
leaf_node = selectChildNode(node) leaf_node = selectChildNode(node)
# BUG i didn't assign parent node for this leaf node yet
simTrajectoryReward = simulate(a, leaf_node, decisionMaker, progressValueEstimator, simTrajectoryReward = simulate(a, leaf_node, decisionMaker, progressValueEstimator,
isterminal, maxDepth, n=n) isterminal, maxDepth, n=n)
backpropagate(leaf_node, simTrajectoryReward) backpropagate(leaf_node, simTrajectoryReward)

View File

@@ -68,7 +68,8 @@ response = YiemAgent.conversation(a, Dict(:text=> "Hello, I would like a get a b
"It will be Thai dishes."
"I like medium-bodied with low tannin."

96
test/test_1.jl Normal file
View File

@@ -0,0 +1,96 @@
using Revise
using YiemAgent, GeneralUtils, JSON3, DataStructures
msgMeta = Dict(:requestResponse => nothing,
:msgPurpose => nothing,
:receiverId => nothing,
:getPost => nothing,
:msgId => "4c7111e0-c30e-44c3-8f85-1c8b3f03a8be",
:acknowledgestatus => nothing,
:replyToMsgId => "dummyid",
:msgFormatVersion => nothing,
:mqttServerInfo => Dict(:port => 1883, :broker => "mqtt.yiem.cc"),
:sendTopic => "/testingSessionID",
:receiverName => "wineassistant",
:replyTopic => nothing,
:senderName => "test_1",
:senderSelfnote => nothing,
:senderId => nothing,
:timeStamp => "2024-05-04T08:06:23.561"
)
outgoingMsg = Dict(
:msgMeta=> msgMeta,
:payload=> Dict(
:text=> "We are holding a wedding party",
)
)
result = GeneralUtils.sendMqttMsg(outgoingMsg)
outgoingMsg = Dict(
:msgMeta=> msgMeta,
:payload=> Dict(
:text=> "It will be Thai dishes.",
)
)
result = GeneralUtils.sendMqttMsg(outgoingMsg)
outgoingMsg = Dict(
:msgMeta=> msgMeta,
:payload=> Dict(
:text=> "50 bucks.",
)
)
result = GeneralUtils.sendMqttMsg(outgoingMsg)
outgoingMsg = Dict(
:msgMeta=> msgMeta,
:payload=> Dict(
:text=> "I like full-bodied Red wine with low tannin.",
)
)
result = GeneralUtils.sendMqttMsg(outgoingMsg)
outgoingMsg = Dict(
:msgMeta=> msgMeta,
:payload=> Dict(
:text=> "What do you have?",
)
)
result = GeneralUtils.sendMqttMsg(outgoingMsg)
outgoingMsg = Dict(
:msgMeta=> msgMeta,
:payload=> Dict(
:text=> "<<ok>>",
)
)
result = GeneralUtils.sendMqttMsg(outgoingMsg)