This commit is contained in:
narawat lamaiin
2024-05-28 23:48:50 +07:00
parent fcf8d855b8
commit 3f38fdbb70
7 changed files with 202 additions and 452 deletions

View File

@@ -266,15 +266,15 @@ version = "1.0.3"
[[deps.LoweredCodeUtils]] [[deps.LoweredCodeUtils]]
deps = ["JuliaInterpreter"] deps = ["JuliaInterpreter"]
git-tree-sha1 = "31e27f0b0bf0df3e3e951bfcc43fe8c730a219f6" git-tree-sha1 = "c6a36b22d2cca0e1a903f00f600991f97bf5f426"
uuid = "6f1432cf-f94c-5a45-995e-cdbf5db27b0b" uuid = "6f1432cf-f94c-5a45-995e-cdbf5db27b0b"
version = "2.4.5" version = "2.4.6"
[[deps.MQTTClient]] [[deps.MQTTClient]]
deps = ["Distributed", "Random", "Sockets"] deps = ["Distributed", "Random", "Sockets"]
git-tree-sha1 = "7d6a1042b8c330d20e4dfbd941f510f92b457624" git-tree-sha1 = "c58ba9d6ae121f58494fa1e5164213f5b4e3e2c7"
uuid = "985f35cc-2c3d-4943-b8c1-f0931d5f0959" uuid = "985f35cc-2c3d-4943-b8c1-f0931d5f0959"
version = "0.2.1" version = "0.3.0"
weakdeps = ["PrecompileTools"] weakdeps = ["PrecompileTools"]
[deps.MQTTClient.extensions] [deps.MQTTClient.extensions]
@@ -408,9 +408,9 @@ deps = ["Unicode"]
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
[[deps.PtrArrays]] [[deps.PtrArrays]]
git-tree-sha1 = "077664975d750757f30e739c870fbbdc01db7913" git-tree-sha1 = "f011fbb92c4d401059b2212c05c0601b70f8b759"
uuid = "43287f4e-b6f4-7ad1-bb20-aadabca52c3d" uuid = "43287f4e-b6f4-7ad1-bb20-aadabca52c3d"
version = "1.1.0" version = "1.2.0"
[[deps.PythonCall]] [[deps.PythonCall]]
deps = ["CondaPkg", "Dates", "Libdl", "MacroTools", "Markdown", "Pkg", "REPL", "Requires", "Serialization", "Tables", "UnsafePointers"] deps = ["CondaPkg", "Dates", "Libdl", "MacroTools", "Markdown", "Pkg", "REPL", "Requires", "Serialization", "Tables", "UnsafePointers"]
@@ -456,10 +456,10 @@ uuid = "79098fc4-a85e-5d69-aa6a-4863f24498fa"
version = "0.7.1" version = "0.7.1"
[[deps.Rmath_jll]] [[deps.Rmath_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] deps = ["Artifacts", "JLLWrappers", "Libdl"]
git-tree-sha1 = "6ed52fdd3382cf21947b15e8870ac0ddbff736da" git-tree-sha1 = "d483cd324ce5cf5d61b77930f0bbd6cb61927d21"
uuid = "f50d1b31-88e8-58de-be2c-1cc44531875f" uuid = "f50d1b31-88e8-58de-be2c-1cc44531875f"
version = "0.4.0+0" version = "0.4.2+0"
[[deps.SHA]] [[deps.SHA]]
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"

View File

@@ -223,12 +223,18 @@ function virtualWineUserChatbox(a::T1, input::T2, virtualCustomerChatHistory
} }
Here are some examples: Here are some examples:
sommelier: "What's your budget?
you:
{ {
"text": "My budget is 30 USD.", "text": "My budget is 30 USD.",
"select": null, "select": null,
"reward": 0, "reward": 0,
"isterminal": false "isterminal": false
} }
sommelier: "The first option is Zena Crown and the second one is Buano Red."
you:
{ {
"text": "I like the 2nd option.", "text": "I like the 2nd option.",
"select": 2, "select": 2,
@@ -307,12 +313,12 @@ function virtualWineUserChatbox(a::T1, input::T2, virtualCustomerChatHistory
responseJsonStr = jsoncorrection(a, _responseJsonStr, expectedJsonExample) responseJsonStr = jsoncorrection(a, _responseJsonStr, expectedJsonExample)
responseDict = copy(JSON3.read(responseJsonStr)) responseDict = copy(JSON3.read(responseJsonStr))
text = responseDict[:text] text::AbstractString = responseDict[:text]
select = responseDict[:select] == "null" ? nothing : responseDict[:select] select::Union{Nothing, Number} = responseDict[:select] == "null" ? nothing : responseDict[:select]
reward = responseDict[:reward] reward::Number = responseDict[:reward]
isterminal = responseDict[:isterminal] isterminal::Bool = responseDict[:isterminal]
if text != "" && select != "" && reward != "" && isterminal != "" if text != ""
# pass test # pass test
else else
error("virtual customer not answer correctly") error("virtual customer not answer correctly")
@@ -332,58 +338,6 @@ function virtualWineUserChatbox(a::T1, input::T2, virtualCustomerChatHistory
error("virtualWineUserChatbox failed to get a response") error("virtualWineUserChatbox failed to get a response")
end end
# function virtualWineUserChatbox(a::T1, input::T2
# )::Union{Tuple{String, Number, Number, Bool}, Tuple{String, Nothing, Number, Bool}} where {T1<:agent, T2<:AbstractString}
# # put in model format
# virtualWineCustomer = a.config[:externalservice][:virtualWineCustomer_1]
# llminfo = virtualWineCustomer[:llminfo]
# prompt =
# if llminfo[:name] == "llama3instruct"
# formatLLMtext_llama3instruct("assistant", input)
# else
# error("llm model name is not defied yet $(@__LINE__)")
# end
# # send formatted input to user using GeneralUtils.sendReceiveMqttMsg
# msgMeta = GeneralUtils.generate_msgMeta(
# virtualWineCustomer[:mqtttopic],
# senderName= "virtualWineUserChatbox",
# senderId= a.id,
# receiverName= "virtualWineCustomer",
# mqttBroker= a.config[:mqttServerInfo][:broker],
# mqttBrokerPort= a.config[:mqttServerInfo][:port],
# msgId = "dummyid" #CHANGE remove after testing finished
# )
# outgoingMsg = Dict(
# :msgMeta=> msgMeta,
# :payload=> Dict(
# :text=> prompt,
# )
# )
# attempt = 0
# for attempt in 1:5
# try
# result = GeneralUtils.sendReceiveMqttMsg(outgoingMsg; timeout=120)
# response = result[:response]
# return (response[:text], response[:select], response[:reward], response[:isterminal])
# catch e
# io = IOBuffer()
# showerror(io, e)
# errorMsg = String(take!(io))
# st = sprint((io, v) -> show(io, "text/plain", v), stacktrace(catch_backtrace()))
# println("")
# @warn "Error occurred: $errorMsg\n$st"
# println("")
# end
# end
# error("virtualWineUserChatbox failed to get a response")
# end
""" Search wine in stock. """ Search wine in stock.
# Arguments # Arguments
@@ -411,6 +365,150 @@ julia> result = winestock(agent, input)
function winestock(a::T1, input::T2 function winestock(a::T1, input::T2
)::Union{Tuple{String, Number, Number, Bool}, Tuple{String, Nothing, Number, Bool}} where {T1<:agent, T2<:AbstractString} )::Union{Tuple{String, Number, Number, Bool}, Tuple{String, Nothing, Number, Bool}} where {T1<:agent, T2<:AbstractString}
systemmsg =
"""
As an attentive sommelier, your mission is to determine the user's preferred levels of sweetness, intensity, tannin, and acidity for a wine based on their input.
You'll achieve this by referring to the provided conversion table.
Conversion Table:
Intensity level:
Level 1: May correspond to "light-bodied" or a similar description.
Level 2: May correspond to "med-light" or a similar description.
Level 3: May correspond to "medium" or a similar description.
Level 4: May correspond to "med-full" or a similar description.
Level 5: May correspond to "full" or a similar description.
Sweetness level:
Level 1: May correspond to "dry", "no-sweet" or a similar description.
Level 2: May correspond to "off-dry", "less-sweet" or a similar description.
Level 3: May correspond to "semi-sweet" or a similar description.
Level 4: May correspond to "sweet" or a similar description.
Level 5: May correspond to "very sweet" or a similar description.
Tannin level:
Level 1: May correspond to "low tannin" or a similar description.
Level 2: May correspond to "semi-low tannin" or a similar description.
Level 3: May correspond to "medium tannin" or a similar description.
Level 4: May correspond to "semi-high tannin" or a similar description.
Level 5: May correspond to "high tannin" or a similar description.
Acidity level:
Level 1: May correspond to "low acidity" or a similar description.
Level 2: May correspond to "semi-low acidity" or a similar description.
Level 3: May correspond to "medium acidity" or a similar description.
Level 4: May correspond to "semi-high acidity" or a similar description.
Level 5: May correspond to "high acidity" or a similar description.
You should only respond in JSON format as describe below:
{
"sweetness": "sweetness level",
"acidity": "acidity level",
"tannin": "tannin level",
"intensity": "intensity level"
}
Here are some examples:
user: red wines, price < 50, body=full-bodied, tannins=1, off dry, acidity=medium, intensity=intense, Thai dishes
assistant:
{
"wine_attributes":
{
"sweetness": 2,
"acidity": 3,
"tannin": 1,
"intensity": 5
}
}
Let's begin!
"""
usermsg =
"""
$input
"""
chathistory =
[
Dict(:name=> "system", :text=> systemmsg),
Dict(:name=> "user", :text=> usermsg)
]
# put in model format
prompt = formatLLMtext(chathistory, "llama3instruct")
prompt *=
"""
<|start_header_id|>assistant<|end_header_id|>
{
"""
pprint(prompt)
externalService = a.config[:externalservice][:text2textinstruct]
# send formatted input to user using GeneralUtils.sendReceiveMqttMsg
msgMeta = GeneralUtils.generate_msgMeta(
externalService[:mqtttopic],
senderName= "virtualWineUserChatbox",
senderId= a.id,
receiverName= "text2textinstruct",
mqttBroker= a.config[:mqttServerInfo][:broker],
mqttBrokerPort= a.config[:mqttServerInfo][:port],
msgId = "dummyid" #CHANGE remove after testing finished
)
outgoingMsg = Dict(
:msgMeta=> msgMeta,
:payload=> Dict(
:text=> prompt,
)
)
attempt = 0
for attempt in 1:5
try
response = GeneralUtils.sendReceiveMqttMsg(outgoingMsg; timeout=120)
_responseJsonStr = response[:response][:text]
expectedJsonExample =
"""
Here is an expected JSON format:
{
"wine_attributes":
{
"...": "...",
"...": "...",
}
}
"""
responseJsonStr = jsoncorrection(a, _responseJsonStr, expectedJsonExample)
responseDict = copy(JSON3.read(responseJsonStr))
return (text, select, reward, isterminal)
catch e
io = IOBuffer()
showerror(io, e)
errorMsg = String(take!(io))
st = sprint((io, v) -> show(io, "text/plain", v), stacktrace(catch_backtrace()))
println("")
@warn "Error occurred: $errorMsg\n$st"
println("")
end
end
error("virtualWineUserChatbox failed to get a response")
winesStr = winesStr =
""" """
1: El Enemigo Cabernet Franc 2019 1: El Enemigo Cabernet Franc 2019
@@ -425,6 +523,23 @@ function winestock(a::T1, input::T2
""" """
return result, nothing, 0, false return result, nothing, 0, false
end end
# function winestock(a::T1, input::T2
# )::Union{Tuple{String, Number, Number, Bool}, Tuple{String, Nothing, Number, Bool}} where {T1<:agent, T2<:AbstractString}
# winesStr =
# """
# 1: El Enemigo Cabernet Franc 2019
# 2: Tantara Chardonnay 2017
# """
# result =
# """
# I found the following wines in our stock:
# {
# $winesStr
# }
# """
# return result, nothing, 0, false
# end
""" Attemp to correct LLM response's incorrect JSON response. """ Attemp to correct LLM response's incorrect JSON response.
@@ -446,13 +561,14 @@ julia>
# Signature # Signature
""" """
function jsoncorrection(a::T1, input::T2, function jsoncorrection(a::T1, input::T2, correctJsonExample::T3;
correctJsonExample::T3) where {T1<:agent, T2<:AbstractString, T3<:AbstractString} maxattempt::Integer=3
) where {T1<:agent, T2<:AbstractString, T3<:AbstractString}
incorrectjson = deepcopy(input) incorrectjson = deepcopy(input)
correctjson = nothing correctjson = nothing
for attempt in 1:5 for attempt in 1:maxattempt
try try
d = copy(JSON3.read(incorrectjson)) d = copy(JSON3.read(incorrectjson))
correctjson = incorrectjson correctjson = incorrectjson

View File

@@ -1,138 +0,0 @@
""" To implement a Monte Carlo Tree Search (MCTS) algorithm in Julia with the UCT (Upper Confidence
Bound for Trees) selection function, you can follow the steps below: Define the necessary types
and functions for the MCTS algorithm:
"""
module MCTS
# export
using Dates, UUIDs, DataStructures, JSON3, Random
using GeneralUtils
# ---------------------------------------------- 100 --------------------------------------------- #
"""
TODO\n
[] update docstring
"""
struct MCTSNode{T}
state::T
visits::Int
total_reward::Float64
children::Dict{T, MCTSNode}
end
"""
TODO\n
[] update docstring
"""
function select(node::MCTSNode, c::Float64)
max_uct = -Inf
selected_node = nothing
for (child_state, child_node) in node.children
uct_value = child_node.total_reward / child_node.visits +
c * sqrt(log(node.visits) / child_node.visits)
if uct_value > max_uct
max_uct = uct_value
selected_node = child_node
end
end
return selected_node
end
"""
TODO\n
[] update docstring
"""
function expand(node::MCTSNode, state::T, actions::Vector{T})
for action in actions
new_state = transition(node.state, action) # Implement your transition function
if new_state keys(node.children)
node.children[new_state] = MCTSNode(new_state, 0, 0.0, Dict{T, MCTSNode}())
end
end
end
"""
TODO\n
[] update docstring
"""
function simulate(state::T, max_depth::Int)
total_reward = 0.0
for _ in 1:max_depth
action = select_action(state) # Implement your action selection function
state, reward = transition(state, action) # Implement your transition function
total_reward += reward
end
return total_reward
end
"""
TODO\n
[] update docstring
"""
function backpropagate(node::MCTSNode, reward::Float64)
node.visits += 1
node.total_reward += reward
if !isempty(node.children)
best_child = argmax([child.total_reward / child.visits for child in values(node.children)])
backpropagate(node.children[best_child], -reward)
end
end
"""
TODO\n
[] update docstring
[] implement transition()
"""
function transition(state, action)
end
# ------------------------------------------------------------------------------------------------ #
# Create a complete example using the defined MCTS functions #
# ------------------------------------------------------------------------------------------------ #
"""
TODO\n
[] update docstring
"""
function run_mcts(initial_state, actions, max_iterations::Int, max_depth::Int, w::Float64)
root = MCTSNode(initial_state, 0, 0.0, Dict())
for _ in 1:max_iterations
node = root
while !is_leaf(node)
node = select(node, w)
end
expand(node, node.state, actions)
leaf_node = node.children[node.state]
reward = simulate(leaf_node.state, max_depth)
backpropagate(leaf_node, reward)
end
best_child_state = argmax([child.total_reward / child.visits for child in values(root.children)])
return best_child_state
end
# Define your transition function and action selection function here
# Example usage
initial_state = 0
actions = [-1, 0, 1]
best_action = run_mcts(initial_state, actions, 1000, 10, 1.0)
println("Best action to take: ", best_action)
end

View File

@@ -1,6 +1,4 @@
""" To implement a Monte Carlo Tree Search (MCTS) algorithm in Julia with the UCT (Upper Confidence """ https://www.harrycodes.com/blog/monte-carlo-tree-search
Bound for Trees) selection function, you can follow the steps below: Define the necessary types
and functions for the MCTS algorithm:
""" """
module mcts module mcts

View File

@@ -1,226 +0,0 @@
module type
export agent, sommelier
using Dates, UUIDs, DataStructures, JSON3
using GeneralUtils
# ---------------------------------------------- 100 --------------------------------------------- #
abstract type agent end
""" A sommelier agent.
# Arguments
- `mqttClient::Client`
MQTTClient's client
- `msgMeta::Dict{Symbol, Any}`
A dict contain info about a message.
- `config::Dict{Symbol, Any}`
Config info for an agent. Contain mqtt topic for internal use and other info.
# Keyword Arguments
- `name::String`
Agent's name
- `id::String`
Agent's ID
- `tools::Dict{Symbol, Any}`
Agent's tools
- `maxHistoryMsg::Integer`
max history message
# Return
- `nothing`
# Example
```jldoctest
julia> using YiemAgent, MQTTClient, GeneralUtils
julia> msgMeta = GeneralUtils.generate_msgMeta(
"N/A",
replyTopic = "/testtopic/prompt"
)
julia> tools= Dict(
:chatbox=>Dict(
:name => "chatbox",
:description => "Useful only for when you need to ask the user for more info or context. Do not ask the user their own question.",
:input => "Input should be a text.",
:output => "" ,
:func => nothing,
),
)
julia> agentConfig = Dict(
:receiveprompt=>Dict(
:mqtttopic=> "/testtopic/prompt", # topic to receive prompt i.e. frontend send msg to this topic
),
:receiveinternal=>Dict(
:mqtttopic=> "/testtopic/internal", # receive topic for model's internal
),
:text2text=>Dict(
:mqtttopic=> "/text2text/receive",
),
)
julia> client, connection = MakeConnection("test.mosquitto.org", 1883)
julia> agent = YiemAgent.bsommelier(
client,
msgMeta,
agentConfig,
name= "assistant",
id= "555", # agent instance id
tools=tools,
)
```
# TODO
- [] update docstring
- [x] implement the function
# Signature
"""
@kwdef mutable struct sommelier <: agent
name::String # agent name
id::String # agent id
config::Dict # agent config
tools::Dict
thinkinglimit::Integer # thinking round limit
thinkingcount::Integer # used to count attempted round of a task
""" Memory
Ref: Chat prompt format https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGML/discussions/3
NO "system" message in chathistory because I want to add it at the inference time
chathistory= [
Dict(:name=>"user", :text=> "Wassup!", :timestamp=> Dates.now()),
Dict(:name=>"assistant", :text=> "Hi I'm your assistant.", :timestamp=> Dates.now()),
]
"""
chathistory::Vector{Dict{Symbol, Any}} = Vector{Dict{Symbol, Any}}()
maxHistoryMsg::Integer # 21th and earlier messages will get summarized
keywordinfo::Dict{Symbol, Any} = Dict{Symbol, Any}(
:customerinfo => Dict{Symbol, Any}(),
:storeinfo => Dict{Symbol, Any}(),
)
mctstree::Dict{Symbol, Any} = Dict{Symbol, Any}()
# 1-historyPoint is in Dict{Symbol, Any} and compose of:
# state, statevalue, thought, action, observation
plan::Dict{Symbol, Any} = Dict{Symbol, Any}(
# store 3 to 5 best plan AI frequently used to avoid having to search MCTS all the time
# each plan is in [historyPoint_1, historyPoint_2, ...] format
:existingplan => Vector(),
:activeplan => Dict{Symbol, Any}(), # current using plan
:currenttrajectory=> Dict{Symbol, Any}(), # store question, thought, action, observation, ...
)
end
function sommelier(
config::Dict = Dict(
:mqttServerInfo=> Dict(
:broker=> nothing,
:port=> nothing,
),
:receivemsg=> Dict(
:prompt=> nothing, # topic to receive prompt i.e. frontend send msg to this topic
:internal=> nothing,
),
:thirdPartyService=> Dict(
:text2textinstruct=> nothing,
:text2textchat=> nothing,
),
)
;
name::String= "Assistant",
id::String= string(uuid4()),
tools::Dict= Dict(
:chatbox=> Dict(
:name => "chatbox",
:description => "Useful for when you need to communicate with the user.",
:input => "Input should be a conversation to the user.",
:output => "" ,
:func => nothing,
),
),
maxHistoryMsg::Integer= 20,
thinkinglimit::Integer= 5,
thinkingcount::Integer= 0,
)
#[NEXTVERSION] publish to a.config[:configtopic] to get a config.
#[NEXTVERSION] get a config message in a.mqttMsg_internal
#[NEXTVERSION] set agent according to config
newAgent = sommelier(
name= name,
id= id,
config= config,
maxHistoryMsg= maxHistoryMsg,
tools= tools,
thinkinglimit= thinkinglimit,
thinkingcount= thinkingcount,
)
return newAgent
end
end # module type

View File

@@ -59,26 +59,26 @@ tools=Dict( # update input format
# response = YiemAgent.conversation(a, Dict(:text=> "newtopic",) ) # response = YiemAgent.conversation(a, Dict(:text=> "newtopic",) )
response = YiemAgent.conversation(a, Dict(:text=> "Hello, I would like a get a bottle of wine", # response = YiemAgent.conversation(a, Dict(:text=> "Hello, I would like a get a bottle of wine",
:select=> nothing, # :select=> nothing,
:reward=> 0, # :reward=> 0,
:isterminal=> false, # :isterminal=> false,
) ) # ) )
println("---> YiemAgent: ", response) # println("---> YiemAgent: ", response)
response = YiemAgent.conversation(a, Dict(:text=> "I'm having a graduation party this evening. I'll pay at most 30 bucks.", # #BUG mcts do not start at current chat history
:select=> nothing, # response = YiemAgent.conversation(a, Dict(:text=> "I'm having a graduation party this evening. I'll pay at most 30 bucks.",
:reward=> 0, # :select=> nothing,
:isterminal=> false, # :reward=> 0,
) ) # :isterminal=> false,
println("---> YiemAgent: ", response) # ) )
# println("---> YiemAgent: ", response)
dummyinput = "price < 50, full-bodied red wine with sweetness level 2, low tannin level and medium acidity level, Thai dishes"
"It will be Thai dishes." response = YiemAgent.winestock(a, dummyinput)
"I like medium-bodied with low tannin."