This commit is contained in:
narawat lamaiin
2024-04-21 16:19:32 +07:00
parent b8d036e800
commit ee1446b1e2
4 changed files with 651 additions and 147 deletions

View File

@@ -4,7 +4,7 @@ export addNewMessage, conversation
using JSON3, DataStructures, Dates, UUIDs, HTTP, Random, MQTTClient using JSON3, DataStructures, Dates, UUIDs, HTTP, Random, MQTTClient
using GeneralUtils using GeneralUtils
using ..type, ..util, ..llmfunction using ..type, ..util, ..llmfunction, ..mcts
# ------------------------------------------------------------------------------------------------ # # ------------------------------------------------------------------------------------------------ #
# pythoncall setting # # pythoncall setting #
@@ -85,102 +85,105 @@ using ..type, ..util, ..llmfunction
Signature\n Signature\n
----- -----
""" """
function conversation(a::T) where {T<:agent} function conversation(a::T, userinput::Dict) where {T<:agent}
""" """
[] update document [] update document
[] MCTS() for planning [x] MCTS() for planning
""" """
while true # "newtopic" command to delete chat history
# check for incoming user message if userinput[:text] == "newtopic"
if isready(a.receiveUserMsgChannel) clearhistory(a)
incomingMsg = take!(a.receiveUserMsgChannel)
incomingPayload = incomingMsg[:payload]
# "newtopic" command to delete chat history return "Okay. What shall we talk about?"
if incomingPayload[:text] == "newtopic"
clearhistory(a)
msgMeta = deepcopy(a.msgMeta)
msgMeta[:sendTopic] = incomingMsg[:msgMeta][:replyTopic]
msgMeta[:senderName] = "agent-backend"
msgMeta[:senderId] = a.id
msgMeta[:receiverName] = "agent-frontend"
msgMeta[:receiverId] = incomingMsg[:msgMeta][:senderId]
msgMeta[:replyTopic] = a.config[:receivemsg][:prompt]
msgMeta[:msgId] = string(uuid4())
msgMeta[:replyToMsgId] = incomingMsg[:msgMeta][:msgId]
outgoingMsg = Dict( else
:msgMeta=> msgMeta, # add usermsg to a.chathistory
:payload=> Dict( addNewMessage(a, "user", userinput[:text])
:name=> a.name, # will be shown in frontend as agent name
:text => "Okay. What shall we talk about?",
)
)
_ = GeneralUtils.sendMqttMsg(outgoingMsg)
else # a new thinking #[] if the last used tool is a chatbox, put usermsg -> observation and continue actor loop as planned
# add usermsg to a.chathistory if !isempty(a.plan[:currenttrajectory]) &&
addNewMessage(a, "user", usermsg) a.plan[:currenttrajectory][end][:action] == "chatbox"
#[WORKING] if the last used tool is a chatbox
if a.plan[:currenttrajectory][end][:action] == "chatbox"
#usermsg -> observation and continue actor loop as planned
else #[WORKING] new thinking
else
#planning with MCTS() -> best plan
#actor loop(best plan) initialState = 0
bestplan = runMCTS(initialState, decisionMaker, stateValueEstimator, reflector,
3, 10, 1000, 1.0)
error("---> bestplan")
# actor loop(bestplan)
end end
end
end
sleep(1)
end end
# workstate = nothing
# response = nothing
# _ = addNewMessage(a, "user", usermsg)
# isuseplan = isUsePlans(a)
# # newinfo = extractinfo(a, usermsg)
# # a.env = newinfo !== nothing ? updateEnvState(a, newinfo) : a.env
# @show isuseplan
# if isuseplan # use plan before responding
# if haskey(a.memory[:shortterm], "User:") == false #[] should change role if user want to buy wine.
# a.memory[:shortterm]["User:"] = usermsg
# end
# workstate, response = work(a)
# end
# # if LLM using askbox, use returning msg form askbox as conversation response
# if workstate == "askbox" || workstate == "formulatedUserResponse"
# #[] paraphrase msg so that it is human friendlier word.
# else
# response = chat_mistral_openorca(a)
# response = split(response, "\n\n")[1]
# response = split(response, "\n\n")[1]
# end
# response = removeTrailingCharacters(response)
# _ = addNewMessage(a, "assistant", response)
end end
# function conversation(a::T) where {T<:agent}
# """
# [] update document
# [x] MCTS() for planning
# """
# while true
# # check for incoming user message
# if isready(a.receiveUserMsgChannel)
# incomingMsg = take!(a.receiveUserMsgChannel)
# incomingPayload = incomingMsg[:payload]
# @show incomingMsg
# # "newtopic" command to delete chat history
# if incomingPayload[:text] == "newtopic"
# clearhistory(a)
# msgMeta = deepcopy(a.msgMeta)
# msgMeta[:sendTopic] = incomingMsg[:msgMeta][:replyTopic]
# msgMeta[:senderName] = "agent-backend"
# msgMeta[:senderId] = a.id
# msgMeta[:receiverName] = "agent-frontend"
# msgMeta[:receiverId] = incomingMsg[:msgMeta][:senderId]
# msgMeta[:replyTopic] = a.config[:receivemsg][:prompt]
# msgMeta[:msgId] = string(uuid4())
# msgMeta[:replyToMsgId] = incomingMsg[:msgMeta][:msgId]
# outgoingMsg = Dict(
# :msgMeta=> msgMeta,
# :payload=> Dict(
# :name=> a.name, # will be shown in frontend as agent name
# :text => "Okay. What shall we talk about?",
# )
# )
# # _ = GeneralUtils.sendMqttMsg(outgoingMsg)
# else
# @show a = 55555
# # add usermsg to a.chathistory
# addNewMessage(a, "user", usermsg)
# #[] if the last used tool is a chatbox
# if a.plan[:currenttrajectory][end][:action] == "chatbox"
# #usermsg -> observation and continue actor loop as planned
# else #[WORKING] new thinking
# initialState = 0
# bestplan = runMCTS(initialState, decisionMaker, stateValueEstimator, reflector,
# 3, 10, 1000, 1.0)
# # actor loop(best plan)
# end
# end
# end
# sleep(1)
# end
# end

287
src/mcts copy 2.jl Normal file
View File

@@ -0,0 +1,287 @@
""" To implement a Monte Carlo Tree Search (MCTS) algorithm in Julia with the UCT (Upper Confidence
Bound for Trees) selection function, you can follow the steps below: Define the necessary types
and functions for the MCTS algorithm:
"""
module MCTS
# export
using Dates, UUIDs, DataStructures, JSON3, Random
using GeneralUtils
# ---------------------------------------------- 100 --------------------------------------------- #
"""
Arguments\n
-----
Return\n
-----
Example\n
-----
```jldoctest
julia>
```
TODO\n
-----
[] update docstring
[] implement the function
Signature\n
-----
"""
struct MCTSNode{T}
state::T
visits::Int
total_reward::Float64
children::Dict{T, MCTSNode}
end
"""
Arguments\n
-----
Return\n
-----
Example\n
-----
```jldoctest
julia>
```
TODO\n
-----
[] update docstring
[WORKING] check child_node.total_reward w/ LATS paper. Which value total_reward representing
Signature\n
-----
"""
function select(node::MCTSNode, c::Float64)
max_uct = -Inf
selected_node = nothing
for (child_state, child_node) in node.children
uct_value = child_node.total_reward / child_node.visits +
c * sqrt(log(node.visits) / child_node.visits)
if uct_value > max_uct
max_uct = uct_value
selected_node = child_node
end
end
return selected_node
end
"""
Arguments\n
-----
Return\n
-----
Example\n
-----
```jldoctest
julia>
```
TODO\n
-----
[] update docstring
[] implement the function
Signature\n
-----
"""
function expand(node::MCTSNode, state::T, actions::Vector{T})
for action in actions
new_state = transition(node.state, action) # Implement your transition function
if new_state keys(node.children)
node.children[new_state] = MCTSNode(new_state, 0, 0.0, Dict{T, MCTSNode}())
end
end
end
"""
Arguments\n
-----
Return\n
-----
Example\n
-----
```jldoctest
julia>
```
TODO\n
-----
[] update docstring
[] implement the function
Signature\n
-----
"""
function simulate(state::T, max_depth::Int)
total_reward = 0.0
for _ in 1:max_depth
action = select_action(state) # Implement your action selection function
state, reward = transition(state, action) # Implement your transition function
total_reward += reward
end
return total_reward
end
"""
Arguments\n
-----
Return\n
-----
Example\n
-----
```jldoctest
julia>
```
TODO\n
-----
[] update docstring
[] implement the function
Signature\n
-----
"""
function backpropagate(node::MCTSNode, reward::Float64)
node.visits += 1
node.total_reward += reward
if !isempty(node.children)
best_child = argmax([child.total_reward / child.visits for child in values(node.children)])
backpropagate(node.children[best_child], -reward)
end
end
"""
Arguments\n
-----
Return\n
-----
Example\n
-----
```jldoctest
julia>
```
TODO\n
-----
[] update docstring
[] implement the function
Signature\n
-----
"""
function transition(state, action)
end
""" Check whether a node is a leaf node
Arguments\n
-----
Return\n
-----
a task represent an agent
Example\n
-----
```jldoctest
julia>
```
TODO\n
-----
[] update docstring
[DONE] implement isLeaf()
Signature\n
-----
"""
isLeaf(node::MCTSNode)::Bool = isempty(node.children)
# ------------------------------------------------------------------------------------------------ #
# Create a complete example using the defined MCTS functions #
# ------------------------------------------------------------------------------------------------ #
"""
Arguments\n
-----
Return\n
-----
Example\n
-----
```jldoctest
julia>
```
TODO\n
-----
[] update docstring
Signature\n
-----
"""
function run_mcts(initial_state, actions, max_iterations::Int, max_depth::Int, w::Float64)
root = MCTSNode(initial_state, 0, 0.0, Dict())
for _ in 1:max_iterations
node = root
while !isLeaf(node)
node = select(node, w)
end
expand(node, node.state, actions)
leaf_node = node.children[node.state]
reward = simulate(leaf_node.state, max_depth)
backpropagate(leaf_node, reward)
end
best_child_state = argmax([child.total_reward / child.visits for child in values(root.children)])
return best_child_state
end
# Define your transition function and action selection function here
# Example usage
initial_state = 0
actions = [-1, 0, 1]
best_action = run_mcts(initial_state, actions, 1000, 10, 1.0)
println("Best action to take: ", best_action)
end

View File

@@ -3,9 +3,9 @@
and functions for the MCTS algorithm: and functions for the MCTS algorithm:
""" """
module MCTS module mcts
# export export runMCTS
using Dates, UUIDs, DataStructures, JSON3, Random using Dates, UUIDs, DataStructures, JSON3, Random
using GeneralUtils using GeneralUtils
@@ -37,7 +37,7 @@ using GeneralUtils
struct MCTSNode{T} struct MCTSNode{T}
state::T state::T
visits::Int visits::Int
total_reward::Float64 stateValue::Float64
children::Dict{T, MCTSNode} children::Dict{T, MCTSNode}
end end
@@ -45,7 +45,10 @@ end
Arguments\n Arguments\n
----- -----
node::MCTSNode
mcts node
w::Float64
exploration weight
Return\n Return\n
----- -----
@@ -58,25 +61,70 @@ end
TODO\n TODO\n
----- -----
[] update docstring [] update docstring
[] implement the function [DONE] check childNode.total_reward w/ LATS paper. Which value total_reward representing
Signature\n Signature\n
----- -----
""" """
function select(node::MCTSNode, c::Float64) function select(node::MCTSNode, w::Float64)
max_uct = -Inf max_uct = -Inf
selected_node = nothing selectedNode = nothing
for (child_state, child_node) in node.children for (childState, childNode) in node.children
uct_value = child_node.total_reward / child_node.visits + uctValue = childNode.stateValue +
c * sqrt(log(node.visits) / child_node.visits) w * sqrt(log(node.visits) / childNode.visits)
if uct_value > max_uct if uctValue > max_uct
max_uct = uct_value max_uct = uctValue
selected_node = child_node selectedNode = childNode
end end
end end
return selected_node return selectedNode
end
"""
Arguments\n
-----
Return\n
-----
Example\n
-----
```jldoctest
julia>
```
TODO\n
-----
[] update docstring
[WORKING] implement the function
Signature\n
-----
"""
function expand(node::MCTSNode, state::T, decisionMaker::Function, stateValueEstimator::Function;
n::Integer=3) where {T<:Any}
actions = []
# sampling action from decisionMaker
# for nth in 1:n
# end
for action in actions
newState = transition(node.state, action) # Implement your transition function
if newState keys(node.children)
node.children[newState] = MCTSNode(newState, 0, 0.0, Dict{T, MCTSNode}())
end
end
end end
""" """
@@ -101,38 +149,7 @@ end
Signature\n Signature\n
----- -----
""" """
function expand(node::MCTSNode, state::T, actions::Vector{T}) function simulate(state::T, max_depth::Int) where {T<:Any}
for action in actions
new_state = transition(node.state, action) # Implement your transition function
if new_state keys(node.children)
node.children[new_state] = MCTSNode(new_state, 0, 0.0, Dict{T, MCTSNode}())
end
end
end
"""
Arguments\n
-----
Return\n
-----
Example\n
-----
```jldoctest
julia>
```
TODO\n
-----
[] update docstring
[] implement the function
Signature\n
-----
"""
function simulate(state::T, max_depth::Int)
total_reward = 0.0 total_reward = 0.0
for _ in 1:max_depth for _ in 1:max_depth
action = select_action(state) # Implement your action selection function action = select_action(state) # Implement your action selection function
@@ -224,9 +241,6 @@ end
""" """
isLeaf(node::MCTSNode)::Bool = isempty(node.children) isLeaf(node::MCTSNode)::Bool = isempty(node.children)
# ------------------------------------------------------------------------------------------------ #
# Create a complete example using the defined MCTS functions #
# ------------------------------------------------------------------------------------------------ #
""" """
Arguments\n Arguments\n
@@ -244,37 +258,128 @@ isLeaf(node::MCTSNode)::Bool = isempty(node.children)
TODO\n TODO\n
----- -----
[] update docstring [] update docstring
[] implement the function
[] implement RAG to pull similar experience
Signature\n Signature\n
----- -----
""" """
function run_mcts(initial_state, actions, max_iterations::Int, max_depth::Int, w::Float64) function decisionMaker()
root = MCTSNode(initial_state, 0, 0.0, Dict())
for _ in 1:max_iterations end
node = root
while !isLeaf(node)
node = select(node, w)
end
expand(node, node.state, actions) """
leaf_node = node.children[node.state] Arguments\n
reward = simulate(leaf_node.state, max_depth) -----
backpropagate(leaf_node, reward)
Return\n
-----
Example\n
-----
```jldoctest
julia>
```
TODO\n
-----
[] update docstring
[] implement the function
Signature\n
-----
"""
function stateValueEstimator()
end
"""
Arguments\n
-----
Return\n
-----
Example\n
-----
```jldoctest
julia>
```
TODO\n
-----
[] update docstring
[] implement the function
Signature\n
-----
"""
function reflector()
end
# ------------------------------------------------------------------------------------------------ #
# Create a complete example using the defined MCTS functions #
# ------------------------------------------------------------------------------------------------ #
""" Search for best action
Arguments\n
-----
initial state
initial state
decisionMaker::Function
decide what action to take
stateValueEstimator::Function
assess the value of the state
reflector::Function
generate lesson from trajectory and reward
n::Integer
how many times action will be sampled from decisionMaker
w::Float64
exploration weight
Return\n
-----
Example\n
-----
```jldoctest
julia>
```
TODO\n
-----
[] update docstring
Signature\n
-----
"""
function runMCTS(initialState, decisionMaker::Function, stateValueEstimator::Function,
reflector::Function, totalActionSampled::Integer, maxDepth::Integer,
maxIterations::Integer, w::Float64)
root = MCTSNode(initialState, 0, 0.0, Dict())
for _ in 1:maxIterations
node = root
while !isLeaf(node)
node = select(node, w)
end
expand(node, node.state, decisionMaker, stateValueEstimator,
n=n)
leaf_node = node.children[node.state]
reward = simulate(leaf_node.state, maxDepth)
backpropagate(leaf_node, reward)
end end
best_child_state = argmax([child.total_reward / child.visits for child in values(root.children)]) best_child_state = argmax([child.total_reward / child.visits for child in values(root.children)])
return best_child_state return best_child_state
end end
# Define your transition function and action selection function here
# Example usage
initial_state = 0
actions = [-1, 0, 1]
best_action = run_mcts(initial_state, actions, 1000, 10, 1.0)
println("Best action to take: ", best_action)

109
test/runtest.jl Normal file
View File

@@ -0,0 +1,109 @@
using Revise # remove when this package is completed
using YiemAgent, GeneralUtils, JSON3, MQTTClient, Dates, UUIDs
using Base.Threads
# ---------------------------------------------- 100 --------------------------------------------- #
config = copy(JSON3.read("config.json"))
instanceInternalTopic = config[:serviceInternalTopic][:value] * "/1"
client, connection = MakeConnection(config[:mqttServerInfo][:value][:broker],
config[:mqttServerInfo][:value][:port])
receiveUserMsgChannel = Channel{Dict}(4)
receiveInternalMsgChannel = Channel{Dict}(4)
msgMeta = GeneralUtils.generate_msgMeta(
"N/A",
replyTopic = config[:servicetopic][:value] # ask frontend reply to this instance_chat_topic
)
agentConfig = Dict(
:receiveprompt=>Dict(
:mqtttopic=> config[:servicetopic][:value], # topic to receive prompt i.e. frontend send msg to this topic
),
:receiveinternal=>Dict(
:mqtttopic=> instanceInternalTopic, # receive topic for model's internal
),
:text2text=>Dict(
:mqtttopic=> config[:text2text][:value],
),
)
# Instantiate an agent
tools=Dict( # update input format
"askbox"=> Dict(
:description => "<askbox tool description>Useful for when you need to ask the user for more context. Do not ask the user their own question.</askbox tool description>",
:input => """<input>Input is a text in JSON format.</input><input example>{\"Q1\": \"How are you doing?\", \"Q2\": \"How may I help you?\"}</input example>""",
:output => "" ,
:func => nothing,
),
# "winestock"=> Dict(
# :description => "<winestock tool description>A handy tool for searching wine in your inventory that match the user preferences.</winestock tool description>",
# :input => """<input>Input is a JSON-formatted string that contains a detailed and precise search query.</input><input example>{\"wine type\": \"rose\", \"price\": \"max 35\", \"sweetness level\": \"sweet\", \"intensity level\": \"light bodied\", \"Tannin level\": \"low\", \"Acidity level\": \"low\"}</input example>""",
# :output => """<output>Output are wines that match the search query in JSON format.""",
# :func => ChatAgent.winestock,
# ),
"finalanswer"=> Dict(
:description => "<tool description>Useful for when you are ready to recommend wines to the user.</tool description>",
:input => """<input format>{\"finalanswer\": \"some text\"}.</input format><input example>{\"finalanswer\": \"I recommend Zena Crown Vista\"}</input example>""",
:output => "" ,
:func => nothing,
),
)
a = YiemAgent.sommelier(
receiveUserMsgChannel,
receiveInternalMsgChannel,
msgMeta,
agentConfig,
name= "assistant",
id= "randomSessionID", # agent instance id
tools=tools,
)
response = YiemAgent.conversation(a, Dict(:text=> "newtopic", ) )