This commit is contained in:
narawat lamaiin
2024-05-28 23:48:50 +07:00
parent fcf8d855b8
commit 3f38fdbb70
7 changed files with 202 additions and 452 deletions

View File

@@ -194,9 +194,9 @@ function decisionMaker(a::T1, state::T2)::Dict{Symbol, Any} where {T1<:agent, T2
:kwargs=> Dict(
:max_tokens=> 512,
:stop=> ["<|eot_id|>"],
)
)
)
)
@show outgoingMsg
for attempt in 1:5

View File

@@ -223,12 +223,18 @@ function virtualWineUserChatbox(a::T1, input::T2, virtualCustomerChatHistory
}
Here are some examples:
sommelier: "What's your budget?
you:
{
"text": "My budget is 30 USD.",
"select": null,
"reward": 0,
"isterminal": false
}
sommelier: "The first option is Zena Crown and the second one is Buano Red."
you:
{
"text": "I like the 2nd option.",
"select": 2,
@@ -307,12 +313,12 @@ function virtualWineUserChatbox(a::T1, input::T2, virtualCustomerChatHistory
responseJsonStr = jsoncorrection(a, _responseJsonStr, expectedJsonExample)
responseDict = copy(JSON3.read(responseJsonStr))
text = responseDict[:text]
select = responseDict[:select] == "null" ? nothing : responseDict[:select]
reward = responseDict[:reward]
isterminal = responseDict[:isterminal]
text::AbstractString = responseDict[:text]
select::Union{Nothing, Number} = responseDict[:select] == "null" ? nothing : responseDict[:select]
reward::Number = responseDict[:reward]
isterminal::Bool = responseDict[:isterminal]
if text != "" && select != "" && reward != "" && isterminal != ""
if text != ""
# pass test
else
error("virtual customer not answer correctly")
@@ -332,58 +338,6 @@ function virtualWineUserChatbox(a::T1, input::T2, virtualCustomerChatHistory
error("virtualWineUserChatbox failed to get a response")
end
# function virtualWineUserChatbox(a::T1, input::T2
# )::Union{Tuple{String, Number, Number, Bool}, Tuple{String, Nothing, Number, Bool}} where {T1<:agent, T2<:AbstractString}
# # put in model format
# virtualWineCustomer = a.config[:externalservice][:virtualWineCustomer_1]
# llminfo = virtualWineCustomer[:llminfo]
# prompt =
# if llminfo[:name] == "llama3instruct"
# formatLLMtext_llama3instruct("assistant", input)
# else
# error("llm model name is not defied yet $(@__LINE__)")
# end
# # send formatted input to user using GeneralUtils.sendReceiveMqttMsg
# msgMeta = GeneralUtils.generate_msgMeta(
# virtualWineCustomer[:mqtttopic],
# senderName= "virtualWineUserChatbox",
# senderId= a.id,
# receiverName= "virtualWineCustomer",
# mqttBroker= a.config[:mqttServerInfo][:broker],
# mqttBrokerPort= a.config[:mqttServerInfo][:port],
# msgId = "dummyid" #CHANGE remove after testing finished
# )
# outgoingMsg = Dict(
# :msgMeta=> msgMeta,
# :payload=> Dict(
# :text=> prompt,
# )
# )
# attempt = 0
# for attempt in 1:5
# try
# result = GeneralUtils.sendReceiveMqttMsg(outgoingMsg; timeout=120)
# response = result[:response]
# return (response[:text], response[:select], response[:reward], response[:isterminal])
# catch e
# io = IOBuffer()
# showerror(io, e)
# errorMsg = String(take!(io))
# st = sprint((io, v) -> show(io, "text/plain", v), stacktrace(catch_backtrace()))
# println("")
# @warn "Error occurred: $errorMsg\n$st"
# println("")
# end
# end
# error("virtualWineUserChatbox failed to get a response")
# end
""" Search wine in stock.
# Arguments
@@ -411,6 +365,150 @@ julia> result = winestock(agent, input)
function winestock(a::T1, input::T2
)::Union{Tuple{String, Number, Number, Bool}, Tuple{String, Nothing, Number, Bool}} where {T1<:agent, T2<:AbstractString}
systemmsg =
"""
As an attentive sommelier, your mission is to determine the user's preferred levels of sweetness, intensity, tannin, and acidity for a wine based on their input.
You'll achieve this by referring to the provided conversion table.
Conversion Table:
Intensity level:
Level 1: May correspond to "light-bodied" or a similar description.
Level 2: May correspond to "med-light" or a similar description.
Level 3: May correspond to "medium" or a similar description.
Level 4: May correspond to "med-full" or a similar description.
Level 5: May correspond to "full" or a similar description.
Sweetness level:
Level 1: May correspond to "dry", "no-sweet" or a similar description.
Level 2: May correspond to "off-dry", "less-sweet" or a similar description.
Level 3: May correspond to "semi-sweet" or a similar description.
Level 4: May correspond to "sweet" or a similar description.
Level 5: May correspond to "very sweet" or a similar description.
Tannin level:
Level 1: May correspond to "low tannin" or a similar description.
Level 2: May correspond to "semi-low tannin" or a similar description.
Level 3: May correspond to "medium tannin" or a similar description.
Level 4: May correspond to "semi-high tannin" or a similar description.
Level 5: May correspond to "high tannin" or a similar description.
Acidity level:
Level 1: May correspond to "low acidity" or a similar description.
Level 2: May correspond to "semi-low acidity" or a similar description.
Level 3: May correspond to "medium acidity" or a similar description.
Level 4: May correspond to "semi-high acidity" or a similar description.
Level 5: May correspond to "high acidity" or a similar description.
You should only respond in JSON format as describe below:
{
"sweetness": "sweetness level",
"acidity": "acidity level",
"tannin": "tannin level",
"intensity": "intensity level"
}
Here are some examples:
user: red wines, price < 50, body=full-bodied, tannins=1, off dry, acidity=medium, intensity=intense, Thai dishes
assistant:
{
"wine_attributes":
{
"sweetness": 2,
"acidity": 3,
"tannin": 1,
"intensity": 5
}
}
Let's begin!
"""
usermsg =
"""
$input
"""
chathistory =
[
Dict(:name=> "system", :text=> systemmsg),
Dict(:name=> "user", :text=> usermsg)
]
# put in model format
prompt = formatLLMtext(chathistory, "llama3instruct")
prompt *=
"""
<|start_header_id|>assistant<|end_header_id|>
{
"""
pprint(prompt)
externalService = a.config[:externalservice][:text2textinstruct]
# send formatted input to user using GeneralUtils.sendReceiveMqttMsg
msgMeta = GeneralUtils.generate_msgMeta(
externalService[:mqtttopic],
senderName= "virtualWineUserChatbox",
senderId= a.id,
receiverName= "text2textinstruct",
mqttBroker= a.config[:mqttServerInfo][:broker],
mqttBrokerPort= a.config[:mqttServerInfo][:port],
msgId = "dummyid" #CHANGE remove after testing finished
)
outgoingMsg = Dict(
:msgMeta=> msgMeta,
:payload=> Dict(
:text=> prompt,
)
)
attempt = 0
for attempt in 1:5
try
response = GeneralUtils.sendReceiveMqttMsg(outgoingMsg; timeout=120)
_responseJsonStr = response[:response][:text]
expectedJsonExample =
"""
Here is an expected JSON format:
{
"wine_attributes":
{
"...": "...",
"...": "...",
}
}
"""
responseJsonStr = jsoncorrection(a, _responseJsonStr, expectedJsonExample)
responseDict = copy(JSON3.read(responseJsonStr))
return (text, select, reward, isterminal)
catch e
io = IOBuffer()
showerror(io, e)
errorMsg = String(take!(io))
st = sprint((io, v) -> show(io, "text/plain", v), stacktrace(catch_backtrace()))
println("")
@warn "Error occurred: $errorMsg\n$st"
println("")
end
end
error("virtualWineUserChatbox failed to get a response")
winesStr =
"""
1: El Enemigo Cabernet Franc 2019
@@ -425,6 +523,23 @@ function winestock(a::T1, input::T2
"""
return result, nothing, 0, false
end
# function winestock(a::T1, input::T2
# )::Union{Tuple{String, Number, Number, Bool}, Tuple{String, Nothing, Number, Bool}} where {T1<:agent, T2<:AbstractString}
# winesStr =
# """
# 1: El Enemigo Cabernet Franc 2019
# 2: Tantara Chardonnay 2017
# """
# result =
# """
# I found the following wines in our stock:
# {
# $winesStr
# }
# """
# return result, nothing, 0, false
# end
""" Attemp to correct LLM response's incorrect JSON response.
@@ -446,13 +561,14 @@ julia>
# Signature
"""
function jsoncorrection(a::T1, input::T2,
correctJsonExample::T3) where {T1<:agent, T2<:AbstractString, T3<:AbstractString}
function jsoncorrection(a::T1, input::T2, correctJsonExample::T3;
maxattempt::Integer=3
) where {T1<:agent, T2<:AbstractString, T3<:AbstractString}
incorrectjson = deepcopy(input)
correctjson = nothing
for attempt in 1:5
for attempt in 1:maxattempt
try
d = copy(JSON3.read(incorrectjson))
correctjson = incorrectjson

View File

@@ -1,138 +0,0 @@
""" To implement a Monte Carlo Tree Search (MCTS) algorithm in Julia with the UCT (Upper Confidence
Bound for Trees) selection function, you can follow the steps below: Define the necessary types
and functions for the MCTS algorithm:
"""
module MCTS
# export
using Dates, UUIDs, DataStructures, JSON3, Random
using GeneralUtils
# ---------------------------------------------- 100 --------------------------------------------- #
"""
TODO\n
[] update docstring
"""
struct MCTSNode{T}
state::T
visits::Int
total_reward::Float64
children::Dict{T, MCTSNode}
end
"""
TODO\n
[] update docstring
"""
function select(node::MCTSNode, c::Float64)
max_uct = -Inf
selected_node = nothing
for (child_state, child_node) in node.children
uct_value = child_node.total_reward / child_node.visits +
c * sqrt(log(node.visits) / child_node.visits)
if uct_value > max_uct
max_uct = uct_value
selected_node = child_node
end
end
return selected_node
end
"""
TODO\n
[] update docstring
"""
function expand(node::MCTSNode, state::T, actions::Vector{T})
for action in actions
new_state = transition(node.state, action) # Implement your transition function
if new_state keys(node.children)
node.children[new_state] = MCTSNode(new_state, 0, 0.0, Dict{T, MCTSNode}())
end
end
end
"""
TODO\n
[] update docstring
"""
function simulate(state::T, max_depth::Int)
total_reward = 0.0
for _ in 1:max_depth
action = select_action(state) # Implement your action selection function
state, reward = transition(state, action) # Implement your transition function
total_reward += reward
end
return total_reward
end
"""
TODO\n
[] update docstring
"""
function backpropagate(node::MCTSNode, reward::Float64)
node.visits += 1
node.total_reward += reward
if !isempty(node.children)
best_child = argmax([child.total_reward / child.visits for child in values(node.children)])
backpropagate(node.children[best_child], -reward)
end
end
"""
TODO\n
[] update docstring
[] implement transition()
"""
function transition(state, action)
end
# ------------------------------------------------------------------------------------------------ #
# Create a complete example using the defined MCTS functions #
# ------------------------------------------------------------------------------------------------ #
"""
TODO\n
[] update docstring
"""
function run_mcts(initial_state, actions, max_iterations::Int, max_depth::Int, w::Float64)
root = MCTSNode(initial_state, 0, 0.0, Dict())
for _ in 1:max_iterations
node = root
while !is_leaf(node)
node = select(node, w)
end
expand(node, node.state, actions)
leaf_node = node.children[node.state]
reward = simulate(leaf_node.state, max_depth)
backpropagate(leaf_node, reward)
end
best_child_state = argmax([child.total_reward / child.visits for child in values(root.children)])
return best_child_state
end
# Define your transition function and action selection function here
# Example usage
initial_state = 0
actions = [-1, 0, 1]
best_action = run_mcts(initial_state, actions, 1000, 10, 1.0)
println("Best action to take: ", best_action)
end

View File

@@ -1,6 +1,4 @@
""" To implement a Monte Carlo Tree Search (MCTS) algorithm in Julia with the UCT (Upper Confidence
Bound for Trees) selection function, you can follow the steps below: Define the necessary types
and functions for the MCTS algorithm:
""" https://www.harrycodes.com/blog/monte-carlo-tree-search
"""
module mcts

View File

@@ -1,226 +0,0 @@
module type
export agent, sommelier
using Dates, UUIDs, DataStructures, JSON3
using GeneralUtils
# ---------------------------------------------- 100 --------------------------------------------- #
abstract type agent end
""" A sommelier agent.
# Arguments
- `mqttClient::Client`
MQTTClient's client
- `msgMeta::Dict{Symbol, Any}`
A dict contain info about a message.
- `config::Dict{Symbol, Any}`
Config info for an agent. Contain mqtt topic for internal use and other info.
# Keyword Arguments
- `name::String`
Agent's name
- `id::String`
Agent's ID
- `tools::Dict{Symbol, Any}`
Agent's tools
- `maxHistoryMsg::Integer`
max history message
# Return
- `nothing`
# Example
```jldoctest
julia> using YiemAgent, MQTTClient, GeneralUtils
julia> msgMeta = GeneralUtils.generate_msgMeta(
"N/A",
replyTopic = "/testtopic/prompt"
)
julia> tools= Dict(
:chatbox=>Dict(
:name => "chatbox",
:description => "Useful only for when you need to ask the user for more info or context. Do not ask the user their own question.",
:input => "Input should be a text.",
:output => "" ,
:func => nothing,
),
)
julia> agentConfig = Dict(
:receiveprompt=>Dict(
:mqtttopic=> "/testtopic/prompt", # topic to receive prompt i.e. frontend send msg to this topic
),
:receiveinternal=>Dict(
:mqtttopic=> "/testtopic/internal", # receive topic for model's internal
),
:text2text=>Dict(
:mqtttopic=> "/text2text/receive",
),
)
julia> client, connection = MakeConnection("test.mosquitto.org", 1883)
julia> agent = YiemAgent.bsommelier(
client,
msgMeta,
agentConfig,
name= "assistant",
id= "555", # agent instance id
tools=tools,
)
```
# TODO
- [] update docstring
- [x] implement the function
# Signature
"""
@kwdef mutable struct sommelier <: agent
name::String # agent name
id::String # agent id
config::Dict # agent config
tools::Dict
thinkinglimit::Integer # thinking round limit
thinkingcount::Integer # used to count attempted round of a task
""" Memory
Ref: Chat prompt format https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGML/discussions/3
NO "system" message in chathistory because I want to add it at the inference time
chathistory= [
Dict(:name=>"user", :text=> "Wassup!", :timestamp=> Dates.now()),
Dict(:name=>"assistant", :text=> "Hi I'm your assistant.", :timestamp=> Dates.now()),
]
"""
chathistory::Vector{Dict{Symbol, Any}} = Vector{Dict{Symbol, Any}}()
maxHistoryMsg::Integer # 21th and earlier messages will get summarized
keywordinfo::Dict{Symbol, Any} = Dict{Symbol, Any}(
:customerinfo => Dict{Symbol, Any}(),
:storeinfo => Dict{Symbol, Any}(),
)
mctstree::Dict{Symbol, Any} = Dict{Symbol, Any}()
# 1-historyPoint is in Dict{Symbol, Any} and compose of:
# state, statevalue, thought, action, observation
plan::Dict{Symbol, Any} = Dict{Symbol, Any}(
# store 3 to 5 best plan AI frequently used to avoid having to search MCTS all the time
# each plan is in [historyPoint_1, historyPoint_2, ...] format
:existingplan => Vector(),
:activeplan => Dict{Symbol, Any}(), # current using plan
:currenttrajectory=> Dict{Symbol, Any}(), # store question, thought, action, observation, ...
)
end
function sommelier(
config::Dict = Dict(
:mqttServerInfo=> Dict(
:broker=> nothing,
:port=> nothing,
),
:receivemsg=> Dict(
:prompt=> nothing, # topic to receive prompt i.e. frontend send msg to this topic
:internal=> nothing,
),
:thirdPartyService=> Dict(
:text2textinstruct=> nothing,
:text2textchat=> nothing,
),
)
;
name::String= "Assistant",
id::String= string(uuid4()),
tools::Dict= Dict(
:chatbox=> Dict(
:name => "chatbox",
:description => "Useful for when you need to communicate with the user.",
:input => "Input should be a conversation to the user.",
:output => "" ,
:func => nothing,
),
),
maxHistoryMsg::Integer= 20,
thinkinglimit::Integer= 5,
thinkingcount::Integer= 0,
)
#[NEXTVERSION] publish to a.config[:configtopic] to get a config.
#[NEXTVERSION] get a config message in a.mqttMsg_internal
#[NEXTVERSION] set agent according to config
newAgent = sommelier(
name= name,
id= id,
config= config,
maxHistoryMsg= maxHistoryMsg,
tools= tools,
thinkinglimit= thinkinglimit,
thinkingcount= thinkingcount,
)
return newAgent
end
end # module type