update
This commit is contained in:
@@ -266,15 +266,15 @@ version = "1.0.3"
|
|||||||
|
|
||||||
[[deps.LoweredCodeUtils]]
|
[[deps.LoweredCodeUtils]]
|
||||||
deps = ["JuliaInterpreter"]
|
deps = ["JuliaInterpreter"]
|
||||||
git-tree-sha1 = "31e27f0b0bf0df3e3e951bfcc43fe8c730a219f6"
|
git-tree-sha1 = "c6a36b22d2cca0e1a903f00f600991f97bf5f426"
|
||||||
uuid = "6f1432cf-f94c-5a45-995e-cdbf5db27b0b"
|
uuid = "6f1432cf-f94c-5a45-995e-cdbf5db27b0b"
|
||||||
version = "2.4.5"
|
version = "2.4.6"
|
||||||
|
|
||||||
[[deps.MQTTClient]]
|
[[deps.MQTTClient]]
|
||||||
deps = ["Distributed", "Random", "Sockets"]
|
deps = ["Distributed", "Random", "Sockets"]
|
||||||
git-tree-sha1 = "7d6a1042b8c330d20e4dfbd941f510f92b457624"
|
git-tree-sha1 = "c58ba9d6ae121f58494fa1e5164213f5b4e3e2c7"
|
||||||
uuid = "985f35cc-2c3d-4943-b8c1-f0931d5f0959"
|
uuid = "985f35cc-2c3d-4943-b8c1-f0931d5f0959"
|
||||||
version = "0.2.1"
|
version = "0.3.0"
|
||||||
weakdeps = ["PrecompileTools"]
|
weakdeps = ["PrecompileTools"]
|
||||||
|
|
||||||
[deps.MQTTClient.extensions]
|
[deps.MQTTClient.extensions]
|
||||||
@@ -408,9 +408,9 @@ deps = ["Unicode"]
|
|||||||
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
|
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
|
||||||
|
|
||||||
[[deps.PtrArrays]]
|
[[deps.PtrArrays]]
|
||||||
git-tree-sha1 = "077664975d750757f30e739c870fbbdc01db7913"
|
git-tree-sha1 = "f011fbb92c4d401059b2212c05c0601b70f8b759"
|
||||||
uuid = "43287f4e-b6f4-7ad1-bb20-aadabca52c3d"
|
uuid = "43287f4e-b6f4-7ad1-bb20-aadabca52c3d"
|
||||||
version = "1.1.0"
|
version = "1.2.0"
|
||||||
|
|
||||||
[[deps.PythonCall]]
|
[[deps.PythonCall]]
|
||||||
deps = ["CondaPkg", "Dates", "Libdl", "MacroTools", "Markdown", "Pkg", "REPL", "Requires", "Serialization", "Tables", "UnsafePointers"]
|
deps = ["CondaPkg", "Dates", "Libdl", "MacroTools", "Markdown", "Pkg", "REPL", "Requires", "Serialization", "Tables", "UnsafePointers"]
|
||||||
@@ -456,10 +456,10 @@ uuid = "79098fc4-a85e-5d69-aa6a-4863f24498fa"
|
|||||||
version = "0.7.1"
|
version = "0.7.1"
|
||||||
|
|
||||||
[[deps.Rmath_jll]]
|
[[deps.Rmath_jll]]
|
||||||
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
|
deps = ["Artifacts", "JLLWrappers", "Libdl"]
|
||||||
git-tree-sha1 = "6ed52fdd3382cf21947b15e8870ac0ddbff736da"
|
git-tree-sha1 = "d483cd324ce5cf5d61b77930f0bbd6cb61927d21"
|
||||||
uuid = "f50d1b31-88e8-58de-be2c-1cc44531875f"
|
uuid = "f50d1b31-88e8-58de-be2c-1cc44531875f"
|
||||||
version = "0.4.0+0"
|
version = "0.4.2+0"
|
||||||
|
|
||||||
[[deps.SHA]]
|
[[deps.SHA]]
|
||||||
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
|
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
|
||||||
|
|||||||
@@ -194,9 +194,9 @@ function decisionMaker(a::T1, state::T2)::Dict{Symbol, Any} where {T1<:agent, T2
|
|||||||
:kwargs=> Dict(
|
:kwargs=> Dict(
|
||||||
:max_tokens=> 512,
|
:max_tokens=> 512,
|
||||||
:stop=> ["<|eot_id|>"],
|
:stop=> ["<|eot_id|>"],
|
||||||
)
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
)
|
||||||
@show outgoingMsg
|
@show outgoingMsg
|
||||||
|
|
||||||
for attempt in 1:5
|
for attempt in 1:5
|
||||||
|
|||||||
@@ -223,12 +223,18 @@ function virtualWineUserChatbox(a::T1, input::T2, virtualCustomerChatHistory
|
|||||||
}
|
}
|
||||||
|
|
||||||
Here are some examples:
|
Here are some examples:
|
||||||
|
|
||||||
|
sommelier: "What's your budget?
|
||||||
|
you:
|
||||||
{
|
{
|
||||||
"text": "My budget is 30 USD.",
|
"text": "My budget is 30 USD.",
|
||||||
"select": null,
|
"select": null,
|
||||||
"reward": 0,
|
"reward": 0,
|
||||||
"isterminal": false
|
"isterminal": false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sommelier: "The first option is Zena Crown and the second one is Buano Red."
|
||||||
|
you:
|
||||||
{
|
{
|
||||||
"text": "I like the 2nd option.",
|
"text": "I like the 2nd option.",
|
||||||
"select": 2,
|
"select": 2,
|
||||||
@@ -307,12 +313,12 @@ function virtualWineUserChatbox(a::T1, input::T2, virtualCustomerChatHistory
|
|||||||
responseJsonStr = jsoncorrection(a, _responseJsonStr, expectedJsonExample)
|
responseJsonStr = jsoncorrection(a, _responseJsonStr, expectedJsonExample)
|
||||||
responseDict = copy(JSON3.read(responseJsonStr))
|
responseDict = copy(JSON3.read(responseJsonStr))
|
||||||
|
|
||||||
text = responseDict[:text]
|
text::AbstractString = responseDict[:text]
|
||||||
select = responseDict[:select] == "null" ? nothing : responseDict[:select]
|
select::Union{Nothing, Number} = responseDict[:select] == "null" ? nothing : responseDict[:select]
|
||||||
reward = responseDict[:reward]
|
reward::Number = responseDict[:reward]
|
||||||
isterminal = responseDict[:isterminal]
|
isterminal::Bool = responseDict[:isterminal]
|
||||||
|
|
||||||
if text != "" && select != "" && reward != "" && isterminal != ""
|
if text != ""
|
||||||
# pass test
|
# pass test
|
||||||
else
|
else
|
||||||
error("virtual customer not answer correctly")
|
error("virtual customer not answer correctly")
|
||||||
@@ -332,58 +338,6 @@ function virtualWineUserChatbox(a::T1, input::T2, virtualCustomerChatHistory
|
|||||||
error("virtualWineUserChatbox failed to get a response")
|
error("virtualWineUserChatbox failed to get a response")
|
||||||
end
|
end
|
||||||
|
|
||||||
# function virtualWineUserChatbox(a::T1, input::T2
|
|
||||||
# )::Union{Tuple{String, Number, Number, Bool}, Tuple{String, Nothing, Number, Bool}} where {T1<:agent, T2<:AbstractString}
|
|
||||||
|
|
||||||
# # put in model format
|
|
||||||
# virtualWineCustomer = a.config[:externalservice][:virtualWineCustomer_1]
|
|
||||||
# llminfo = virtualWineCustomer[:llminfo]
|
|
||||||
# prompt =
|
|
||||||
# if llminfo[:name] == "llama3instruct"
|
|
||||||
# formatLLMtext_llama3instruct("assistant", input)
|
|
||||||
# else
|
|
||||||
# error("llm model name is not defied yet $(@__LINE__)")
|
|
||||||
# end
|
|
||||||
|
|
||||||
# # send formatted input to user using GeneralUtils.sendReceiveMqttMsg
|
|
||||||
# msgMeta = GeneralUtils.generate_msgMeta(
|
|
||||||
# virtualWineCustomer[:mqtttopic],
|
|
||||||
# senderName= "virtualWineUserChatbox",
|
|
||||||
# senderId= a.id,
|
|
||||||
# receiverName= "virtualWineCustomer",
|
|
||||||
# mqttBroker= a.config[:mqttServerInfo][:broker],
|
|
||||||
# mqttBrokerPort= a.config[:mqttServerInfo][:port],
|
|
||||||
# msgId = "dummyid" #CHANGE remove after testing finished
|
|
||||||
# )
|
|
||||||
|
|
||||||
# outgoingMsg = Dict(
|
|
||||||
# :msgMeta=> msgMeta,
|
|
||||||
# :payload=> Dict(
|
|
||||||
# :text=> prompt,
|
|
||||||
# )
|
|
||||||
# )
|
|
||||||
|
|
||||||
# attempt = 0
|
|
||||||
# for attempt in 1:5
|
|
||||||
# try
|
|
||||||
# result = GeneralUtils.sendReceiveMqttMsg(outgoingMsg; timeout=120)
|
|
||||||
# response = result[:response]
|
|
||||||
|
|
||||||
# return (response[:text], response[:select], response[:reward], response[:isterminal])
|
|
||||||
# catch e
|
|
||||||
# io = IOBuffer()
|
|
||||||
# showerror(io, e)
|
|
||||||
# errorMsg = String(take!(io))
|
|
||||||
# st = sprint((io, v) -> show(io, "text/plain", v), stacktrace(catch_backtrace()))
|
|
||||||
# println("")
|
|
||||||
# @warn "Error occurred: $errorMsg\n$st"
|
|
||||||
# println("")
|
|
||||||
# end
|
|
||||||
# end
|
|
||||||
# error("virtualWineUserChatbox failed to get a response")
|
|
||||||
# end
|
|
||||||
|
|
||||||
|
|
||||||
""" Search wine in stock.
|
""" Search wine in stock.
|
||||||
|
|
||||||
# Arguments
|
# Arguments
|
||||||
@@ -411,6 +365,150 @@ julia> result = winestock(agent, input)
|
|||||||
function winestock(a::T1, input::T2
|
function winestock(a::T1, input::T2
|
||||||
)::Union{Tuple{String, Number, Number, Bool}, Tuple{String, Nothing, Number, Bool}} where {T1<:agent, T2<:AbstractString}
|
)::Union{Tuple{String, Number, Number, Bool}, Tuple{String, Nothing, Number, Bool}} where {T1<:agent, T2<:AbstractString}
|
||||||
|
|
||||||
|
systemmsg =
|
||||||
|
"""
|
||||||
|
As an attentive sommelier, your mission is to determine the user's preferred levels of sweetness, intensity, tannin, and acidity for a wine based on their input.
|
||||||
|
You'll achieve this by referring to the provided conversion table.
|
||||||
|
|
||||||
|
Conversion Table:
|
||||||
|
Intensity level:
|
||||||
|
Level 1: May correspond to "light-bodied" or a similar description.
|
||||||
|
Level 2: May correspond to "med-light" or a similar description.
|
||||||
|
Level 3: May correspond to "medium" or a similar description.
|
||||||
|
Level 4: May correspond to "med-full" or a similar description.
|
||||||
|
Level 5: May correspond to "full" or a similar description.
|
||||||
|
Sweetness level:
|
||||||
|
Level 1: May correspond to "dry", "no-sweet" or a similar description.
|
||||||
|
Level 2: May correspond to "off-dry", "less-sweet" or a similar description.
|
||||||
|
Level 3: May correspond to "semi-sweet" or a similar description.
|
||||||
|
Level 4: May correspond to "sweet" or a similar description.
|
||||||
|
Level 5: May correspond to "very sweet" or a similar description.
|
||||||
|
Tannin level:
|
||||||
|
Level 1: May correspond to "low tannin" or a similar description.
|
||||||
|
Level 2: May correspond to "semi-low tannin" or a similar description.
|
||||||
|
Level 3: May correspond to "medium tannin" or a similar description.
|
||||||
|
Level 4: May correspond to "semi-high tannin" or a similar description.
|
||||||
|
Level 5: May correspond to "high tannin" or a similar description.
|
||||||
|
Acidity level:
|
||||||
|
Level 1: May correspond to "low acidity" or a similar description.
|
||||||
|
Level 2: May correspond to "semi-low acidity" or a similar description.
|
||||||
|
Level 3: May correspond to "medium acidity" or a similar description.
|
||||||
|
Level 4: May correspond to "semi-high acidity" or a similar description.
|
||||||
|
Level 5: May correspond to "high acidity" or a similar description.
|
||||||
|
|
||||||
|
You should only respond in JSON format as describe below:
|
||||||
|
{
|
||||||
|
"sweetness": "sweetness level",
|
||||||
|
"acidity": "acidity level",
|
||||||
|
"tannin": "tannin level",
|
||||||
|
"intensity": "intensity level"
|
||||||
|
}
|
||||||
|
|
||||||
|
Here are some examples:
|
||||||
|
|
||||||
|
user: red wines, price < 50, body=full-bodied, tannins=1, off dry, acidity=medium, intensity=intense, Thai dishes
|
||||||
|
assistant:
|
||||||
|
{
|
||||||
|
"wine_attributes":
|
||||||
|
{
|
||||||
|
"sweetness": 2,
|
||||||
|
"acidity": 3,
|
||||||
|
"tannin": 1,
|
||||||
|
"intensity": 5
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Let's begin!
|
||||||
|
"""
|
||||||
|
|
||||||
|
usermsg =
|
||||||
|
"""
|
||||||
|
$input
|
||||||
|
"""
|
||||||
|
|
||||||
|
chathistory =
|
||||||
|
[
|
||||||
|
Dict(:name=> "system", :text=> systemmsg),
|
||||||
|
Dict(:name=> "user", :text=> usermsg)
|
||||||
|
]
|
||||||
|
|
||||||
|
# put in model format
|
||||||
|
prompt = formatLLMtext(chathistory, "llama3instruct")
|
||||||
|
prompt *=
|
||||||
|
"""
|
||||||
|
<|start_header_id|>assistant<|end_header_id|>
|
||||||
|
{
|
||||||
|
"""
|
||||||
|
|
||||||
|
pprint(prompt)
|
||||||
|
externalService = a.config[:externalservice][:text2textinstruct]
|
||||||
|
|
||||||
|
# send formatted input to user using GeneralUtils.sendReceiveMqttMsg
|
||||||
|
msgMeta = GeneralUtils.generate_msgMeta(
|
||||||
|
externalService[:mqtttopic],
|
||||||
|
senderName= "virtualWineUserChatbox",
|
||||||
|
senderId= a.id,
|
||||||
|
receiverName= "text2textinstruct",
|
||||||
|
mqttBroker= a.config[:mqttServerInfo][:broker],
|
||||||
|
mqttBrokerPort= a.config[:mqttServerInfo][:port],
|
||||||
|
msgId = "dummyid" #CHANGE remove after testing finished
|
||||||
|
)
|
||||||
|
|
||||||
|
outgoingMsg = Dict(
|
||||||
|
:msgMeta=> msgMeta,
|
||||||
|
:payload=> Dict(
|
||||||
|
:text=> prompt,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
attempt = 0
|
||||||
|
for attempt in 1:5
|
||||||
|
try
|
||||||
|
response = GeneralUtils.sendReceiveMqttMsg(outgoingMsg; timeout=120)
|
||||||
|
_responseJsonStr = response[:response][:text]
|
||||||
|
expectedJsonExample =
|
||||||
|
"""
|
||||||
|
Here is an expected JSON format:
|
||||||
|
{
|
||||||
|
"wine_attributes":
|
||||||
|
{
|
||||||
|
"...": "...",
|
||||||
|
"...": "...",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
responseJsonStr = jsoncorrection(a, _responseJsonStr, expectedJsonExample)
|
||||||
|
responseDict = copy(JSON3.read(responseJsonStr))
|
||||||
|
|
||||||
|
return (text, select, reward, isterminal)
|
||||||
|
catch e
|
||||||
|
io = IOBuffer()
|
||||||
|
showerror(io, e)
|
||||||
|
errorMsg = String(take!(io))
|
||||||
|
st = sprint((io, v) -> show(io, "text/plain", v), stacktrace(catch_backtrace()))
|
||||||
|
println("")
|
||||||
|
@warn "Error occurred: $errorMsg\n$st"
|
||||||
|
println("")
|
||||||
|
end
|
||||||
|
end
|
||||||
|
error("virtualWineUserChatbox failed to get a response")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
winesStr =
|
winesStr =
|
||||||
"""
|
"""
|
||||||
1: El Enemigo Cabernet Franc 2019
|
1: El Enemigo Cabernet Franc 2019
|
||||||
@@ -425,6 +523,23 @@ function winestock(a::T1, input::T2
|
|||||||
"""
|
"""
|
||||||
return result, nothing, 0, false
|
return result, nothing, 0, false
|
||||||
end
|
end
|
||||||
|
# function winestock(a::T1, input::T2
|
||||||
|
# )::Union{Tuple{String, Number, Number, Bool}, Tuple{String, Nothing, Number, Bool}} where {T1<:agent, T2<:AbstractString}
|
||||||
|
|
||||||
|
# winesStr =
|
||||||
|
# """
|
||||||
|
# 1: El Enemigo Cabernet Franc 2019
|
||||||
|
# 2: Tantara Chardonnay 2017
|
||||||
|
# """
|
||||||
|
# result =
|
||||||
|
# """
|
||||||
|
# I found the following wines in our stock:
|
||||||
|
# {
|
||||||
|
# $winesStr
|
||||||
|
# }
|
||||||
|
# """
|
||||||
|
# return result, nothing, 0, false
|
||||||
|
# end
|
||||||
|
|
||||||
|
|
||||||
""" Attemp to correct LLM response's incorrect JSON response.
|
""" Attemp to correct LLM response's incorrect JSON response.
|
||||||
@@ -446,13 +561,14 @@ julia>
|
|||||||
|
|
||||||
# Signature
|
# Signature
|
||||||
"""
|
"""
|
||||||
function jsoncorrection(a::T1, input::T2,
|
function jsoncorrection(a::T1, input::T2, correctJsonExample::T3;
|
||||||
correctJsonExample::T3) where {T1<:agent, T2<:AbstractString, T3<:AbstractString}
|
maxattempt::Integer=3
|
||||||
|
) where {T1<:agent, T2<:AbstractString, T3<:AbstractString}
|
||||||
|
|
||||||
incorrectjson = deepcopy(input)
|
incorrectjson = deepcopy(input)
|
||||||
correctjson = nothing
|
correctjson = nothing
|
||||||
|
|
||||||
for attempt in 1:5
|
for attempt in 1:maxattempt
|
||||||
try
|
try
|
||||||
d = copy(JSON3.read(incorrectjson))
|
d = copy(JSON3.read(incorrectjson))
|
||||||
correctjson = incorrectjson
|
correctjson = incorrectjson
|
||||||
|
|||||||
138
src/mcts copy.jl
138
src/mcts copy.jl
@@ -1,138 +0,0 @@
|
|||||||
""" To implement a Monte Carlo Tree Search (MCTS) algorithm in Julia with the UCT (Upper Confidence
|
|
||||||
Bound for Trees) selection function, you can follow the steps below: Define the necessary types
|
|
||||||
and functions for the MCTS algorithm:
|
|
||||||
"""
|
|
||||||
|
|
||||||
module MCTS
|
|
||||||
|
|
||||||
# export
|
|
||||||
|
|
||||||
using Dates, UUIDs, DataStructures, JSON3, Random
|
|
||||||
using GeneralUtils
|
|
||||||
|
|
||||||
# ---------------------------------------------- 100 --------------------------------------------- #
|
|
||||||
|
|
||||||
"""
|
|
||||||
TODO\n
|
|
||||||
[] update docstring
|
|
||||||
"""
|
|
||||||
struct MCTSNode{T}
|
|
||||||
state::T
|
|
||||||
visits::Int
|
|
||||||
total_reward::Float64
|
|
||||||
children::Dict{T, MCTSNode}
|
|
||||||
end
|
|
||||||
|
|
||||||
"""
|
|
||||||
TODO\n
|
|
||||||
[] update docstring
|
|
||||||
"""
|
|
||||||
function select(node::MCTSNode, c::Float64)
|
|
||||||
max_uct = -Inf
|
|
||||||
selected_node = nothing
|
|
||||||
|
|
||||||
for (child_state, child_node) in node.children
|
|
||||||
uct_value = child_node.total_reward / child_node.visits +
|
|
||||||
c * sqrt(log(node.visits) / child_node.visits)
|
|
||||||
if uct_value > max_uct
|
|
||||||
max_uct = uct_value
|
|
||||||
selected_node = child_node
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
return selected_node
|
|
||||||
end
|
|
||||||
|
|
||||||
"""
|
|
||||||
TODO\n
|
|
||||||
[] update docstring
|
|
||||||
"""
|
|
||||||
function expand(node::MCTSNode, state::T, actions::Vector{T})
|
|
||||||
for action in actions
|
|
||||||
new_state = transition(node.state, action) # Implement your transition function
|
|
||||||
if new_state ∉ keys(node.children)
|
|
||||||
node.children[new_state] = MCTSNode(new_state, 0, 0.0, Dict{T, MCTSNode}())
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
"""
|
|
||||||
TODO\n
|
|
||||||
[] update docstring
|
|
||||||
"""
|
|
||||||
function simulate(state::T, max_depth::Int)
|
|
||||||
total_reward = 0.0
|
|
||||||
for _ in 1:max_depth
|
|
||||||
action = select_action(state) # Implement your action selection function
|
|
||||||
state, reward = transition(state, action) # Implement your transition function
|
|
||||||
total_reward += reward
|
|
||||||
end
|
|
||||||
return total_reward
|
|
||||||
end
|
|
||||||
|
|
||||||
"""
|
|
||||||
TODO\n
|
|
||||||
[] update docstring
|
|
||||||
"""
|
|
||||||
function backpropagate(node::MCTSNode, reward::Float64)
|
|
||||||
node.visits += 1
|
|
||||||
node.total_reward += reward
|
|
||||||
if !isempty(node.children)
|
|
||||||
best_child = argmax([child.total_reward / child.visits for child in values(node.children)])
|
|
||||||
backpropagate(node.children[best_child], -reward)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
"""
|
|
||||||
TODO\n
|
|
||||||
[] update docstring
|
|
||||||
[] implement transition()
|
|
||||||
"""
|
|
||||||
function transition(state, action)
|
|
||||||
|
|
||||||
end
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------------------------------------ #
|
|
||||||
# Create a complete example using the defined MCTS functions #
|
|
||||||
# ------------------------------------------------------------------------------------------------ #
|
|
||||||
"""
|
|
||||||
TODO\n
|
|
||||||
[] update docstring
|
|
||||||
"""
|
|
||||||
function run_mcts(initial_state, actions, max_iterations::Int, max_depth::Int, w::Float64)
|
|
||||||
root = MCTSNode(initial_state, 0, 0.0, Dict())
|
|
||||||
|
|
||||||
for _ in 1:max_iterations
|
|
||||||
node = root
|
|
||||||
while !is_leaf(node)
|
|
||||||
node = select(node, w)
|
|
||||||
end
|
|
||||||
|
|
||||||
expand(node, node.state, actions)
|
|
||||||
|
|
||||||
leaf_node = node.children[node.state]
|
|
||||||
reward = simulate(leaf_node.state, max_depth)
|
|
||||||
backpropagate(leaf_node, reward)
|
|
||||||
end
|
|
||||||
|
|
||||||
best_child_state = argmax([child.total_reward / child.visits for child in values(root.children)])
|
|
||||||
return best_child_state
|
|
||||||
end
|
|
||||||
|
|
||||||
# Define your transition function and action selection function here
|
|
||||||
|
|
||||||
# Example usage
|
|
||||||
initial_state = 0
|
|
||||||
actions = [-1, 0, 1]
|
|
||||||
best_action = run_mcts(initial_state, actions, 1000, 10, 1.0)
|
|
||||||
println("Best action to take: ", best_action)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
end
|
|
||||||
@@ -1,6 +1,4 @@
|
|||||||
""" To implement a Monte Carlo Tree Search (MCTS) algorithm in Julia with the UCT (Upper Confidence
|
""" https://www.harrycodes.com/blog/monte-carlo-tree-search
|
||||||
Bound for Trees) selection function, you can follow the steps below: Define the necessary types
|
|
||||||
and functions for the MCTS algorithm:
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
module mcts
|
module mcts
|
||||||
|
|||||||
226
src/type copy.jl
226
src/type copy.jl
@@ -1,226 +0,0 @@
|
|||||||
module type
|
|
||||||
|
|
||||||
export agent, sommelier
|
|
||||||
|
|
||||||
using Dates, UUIDs, DataStructures, JSON3
|
|
||||||
using GeneralUtils
|
|
||||||
|
|
||||||
# ---------------------------------------------- 100 --------------------------------------------- #
|
|
||||||
|
|
||||||
abstract type agent end
|
|
||||||
|
|
||||||
|
|
||||||
""" A sommelier agent.
|
|
||||||
|
|
||||||
# Arguments
|
|
||||||
- `mqttClient::Client`
|
|
||||||
MQTTClient's client
|
|
||||||
- `msgMeta::Dict{Symbol, Any}`
|
|
||||||
A dict contain info about a message.
|
|
||||||
- `config::Dict{Symbol, Any}`
|
|
||||||
Config info for an agent. Contain mqtt topic for internal use and other info.
|
|
||||||
|
|
||||||
# Keyword Arguments
|
|
||||||
- `name::String`
|
|
||||||
Agent's name
|
|
||||||
- `id::String`
|
|
||||||
Agent's ID
|
|
||||||
- `tools::Dict{Symbol, Any}`
|
|
||||||
Agent's tools
|
|
||||||
- `maxHistoryMsg::Integer`
|
|
||||||
max history message
|
|
||||||
|
|
||||||
# Return
|
|
||||||
- `nothing`
|
|
||||||
|
|
||||||
# Example
|
|
||||||
```jldoctest
|
|
||||||
julia> using YiemAgent, MQTTClient, GeneralUtils
|
|
||||||
julia> msgMeta = GeneralUtils.generate_msgMeta(
|
|
||||||
"N/A",
|
|
||||||
replyTopic = "/testtopic/prompt"
|
|
||||||
)
|
|
||||||
julia> tools= Dict(
|
|
||||||
:chatbox=>Dict(
|
|
||||||
:name => "chatbox",
|
|
||||||
:description => "Useful only for when you need to ask the user for more info or context. Do not ask the user their own question.",
|
|
||||||
:input => "Input should be a text.",
|
|
||||||
:output => "" ,
|
|
||||||
:func => nothing,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
julia> agentConfig = Dict(
|
|
||||||
:receiveprompt=>Dict(
|
|
||||||
:mqtttopic=> "/testtopic/prompt", # topic to receive prompt i.e. frontend send msg to this topic
|
|
||||||
),
|
|
||||||
:receiveinternal=>Dict(
|
|
||||||
:mqtttopic=> "/testtopic/internal", # receive topic for model's internal
|
|
||||||
),
|
|
||||||
:text2text=>Dict(
|
|
||||||
:mqtttopic=> "/text2text/receive",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
julia> client, connection = MakeConnection("test.mosquitto.org", 1883)
|
|
||||||
julia> agent = YiemAgent.bsommelier(
|
|
||||||
client,
|
|
||||||
msgMeta,
|
|
||||||
agentConfig,
|
|
||||||
name= "assistant",
|
|
||||||
id= "555", # agent instance id
|
|
||||||
tools=tools,
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
# TODO
|
|
||||||
- [] update docstring
|
|
||||||
- [x] implement the function
|
|
||||||
|
|
||||||
# Signature
|
|
||||||
"""
|
|
||||||
@kwdef mutable struct sommelier <: agent
|
|
||||||
name::String # agent name
|
|
||||||
id::String # agent id
|
|
||||||
config::Dict # agent config
|
|
||||||
tools::Dict
|
|
||||||
thinkinglimit::Integer # thinking round limit
|
|
||||||
thinkingcount::Integer # used to count attempted round of a task
|
|
||||||
|
|
||||||
""" Memory
|
|
||||||
Ref: Chat prompt format https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGML/discussions/3
|
|
||||||
NO "system" message in chathistory because I want to add it at the inference time
|
|
||||||
chathistory= [
|
|
||||||
Dict(:name=>"user", :text=> "Wassup!", :timestamp=> Dates.now()),
|
|
||||||
Dict(:name=>"assistant", :text=> "Hi I'm your assistant.", :timestamp=> Dates.now()),
|
|
||||||
]
|
|
||||||
|
|
||||||
"""
|
|
||||||
chathistory::Vector{Dict{Symbol, Any}} = Vector{Dict{Symbol, Any}}()
|
|
||||||
|
|
||||||
maxHistoryMsg::Integer # 21th and earlier messages will get summarized
|
|
||||||
keywordinfo::Dict{Symbol, Any} = Dict{Symbol, Any}(
|
|
||||||
:customerinfo => Dict{Symbol, Any}(),
|
|
||||||
:storeinfo => Dict{Symbol, Any}(),
|
|
||||||
)
|
|
||||||
mctstree::Dict{Symbol, Any} = Dict{Symbol, Any}()
|
|
||||||
|
|
||||||
# 1-historyPoint is in Dict{Symbol, Any} and compose of:
|
|
||||||
# state, statevalue, thought, action, observation
|
|
||||||
plan::Dict{Symbol, Any} = Dict{Symbol, Any}(
|
|
||||||
|
|
||||||
# store 3 to 5 best plan AI frequently used to avoid having to search MCTS all the time
|
|
||||||
# each plan is in [historyPoint_1, historyPoint_2, ...] format
|
|
||||||
:existingplan => Vector(),
|
|
||||||
|
|
||||||
:activeplan => Dict{Symbol, Any}(), # current using plan
|
|
||||||
:currenttrajectory=> Dict{Symbol, Any}(), # store question, thought, action, observation, ...
|
|
||||||
)
|
|
||||||
|
|
||||||
end
|
|
||||||
|
|
||||||
function sommelier(
|
|
||||||
config::Dict = Dict(
|
|
||||||
:mqttServerInfo=> Dict(
|
|
||||||
:broker=> nothing,
|
|
||||||
:port=> nothing,
|
|
||||||
),
|
|
||||||
:receivemsg=> Dict(
|
|
||||||
:prompt=> nothing, # topic to receive prompt i.e. frontend send msg to this topic
|
|
||||||
:internal=> nothing,
|
|
||||||
),
|
|
||||||
:thirdPartyService=> Dict(
|
|
||||||
:text2textinstruct=> nothing,
|
|
||||||
:text2textchat=> nothing,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
;
|
|
||||||
name::String= "Assistant",
|
|
||||||
id::String= string(uuid4()),
|
|
||||||
tools::Dict= Dict(
|
|
||||||
:chatbox=> Dict(
|
|
||||||
:name => "chatbox",
|
|
||||||
:description => "Useful for when you need to communicate with the user.",
|
|
||||||
:input => "Input should be a conversation to the user.",
|
|
||||||
:output => "" ,
|
|
||||||
:func => nothing,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
maxHistoryMsg::Integer= 20,
|
|
||||||
thinkinglimit::Integer= 5,
|
|
||||||
thinkingcount::Integer= 0,
|
|
||||||
)
|
|
||||||
|
|
||||||
#[NEXTVERSION] publish to a.config[:configtopic] to get a config.
|
|
||||||
#[NEXTVERSION] get a config message in a.mqttMsg_internal
|
|
||||||
#[NEXTVERSION] set agent according to config
|
|
||||||
|
|
||||||
newAgent = sommelier(
|
|
||||||
name= name,
|
|
||||||
id= id,
|
|
||||||
config= config,
|
|
||||||
maxHistoryMsg= maxHistoryMsg,
|
|
||||||
tools= tools,
|
|
||||||
thinkinglimit= thinkinglimit,
|
|
||||||
thinkingcount= thinkingcount,
|
|
||||||
)
|
|
||||||
|
|
||||||
return newAgent
|
|
||||||
end
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
end # module type
|
|
||||||
@@ -59,26 +59,26 @@ tools=Dict( # update input format
|
|||||||
|
|
||||||
# response = YiemAgent.conversation(a, Dict(:text=> "newtopic",) )
|
# response = YiemAgent.conversation(a, Dict(:text=> "newtopic",) )
|
||||||
|
|
||||||
response = YiemAgent.conversation(a, Dict(:text=> "Hello, I would like a get a bottle of wine",
|
# response = YiemAgent.conversation(a, Dict(:text=> "Hello, I would like a get a bottle of wine",
|
||||||
:select=> nothing,
|
# :select=> nothing,
|
||||||
:reward=> 0,
|
# :reward=> 0,
|
||||||
:isterminal=> false,
|
# :isterminal=> false,
|
||||||
) )
|
# ) )
|
||||||
println("---> YiemAgent: ", response)
|
# println("---> YiemAgent: ", response)
|
||||||
|
|
||||||
response = YiemAgent.conversation(a, Dict(:text=> "I'm having a graduation party this evening. I'll pay at most 30 bucks.",
|
# #BUG mcts do not start at current chat history
|
||||||
:select=> nothing,
|
# response = YiemAgent.conversation(a, Dict(:text=> "I'm having a graduation party this evening. I'll pay at most 30 bucks.",
|
||||||
:reward=> 0,
|
# :select=> nothing,
|
||||||
:isterminal=> false,
|
# :reward=> 0,
|
||||||
) )
|
# :isterminal=> false,
|
||||||
println("---> YiemAgent: ", response)
|
# ) )
|
||||||
|
# println("---> YiemAgent: ", response)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
dummyinput = "price < 50, full-bodied red wine with sweetness level 2, low tannin level and medium acidity level, Thai dishes"
|
||||||
"It will be Thai dishes."
|
response = YiemAgent.winestock(a, dummyinput)
|
||||||
"I like medium-bodied with low tannin."
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user