diff --git a/src/interface.jl b/src/interface.jl
index fb5b4c9..d5ff814 100644
--- a/src/interface.jl
+++ b/src/interface.jl
@@ -4,7 +4,7 @@ export addNewMessage, conversation
using JSON3, DataStructures, Dates, UUIDs, HTTP, Random, MQTTClient
using GeneralUtils
-using ..type, ..util, ..llmfunction
+using ..type, ..util, ..llmfunction, ..mcts
# ------------------------------------------------------------------------------------------------ #
# pythoncall setting #
@@ -85,102 +85,105 @@ using ..type, ..util, ..llmfunction
Signature\n
-----
"""
-function conversation(a::T) where {T<:agent}
+function conversation(a::T, userinput::Dict) where {T<:agent}
"""
[] update document
- [] MCTS() for planning
+ [x] MCTS() for planning
"""
- while true
- # check for incoming user message
- if isready(a.receiveUserMsgChannel)
- incomingMsg = take!(a.receiveUserMsgChannel)
- incomingPayload = incomingMsg[:payload]
+ # "newtopic" command to delete chat history
+ if userinput[:text] == "newtopic"
+ clearhistory(a)
- # "newtopic" command to delete chat history
- if incomingPayload[:text] == "newtopic"
- clearhistory(a)
- msgMeta = deepcopy(a.msgMeta)
- msgMeta[:sendTopic] = incomingMsg[:msgMeta][:replyTopic]
- msgMeta[:senderName] = "agent-backend"
- msgMeta[:senderId] = a.id
- msgMeta[:receiverName] = "agent-frontend"
- msgMeta[:receiverId] = incomingMsg[:msgMeta][:senderId]
- msgMeta[:replyTopic] = a.config[:receivemsg][:prompt]
- msgMeta[:msgId] = string(uuid4())
- msgMeta[:replyToMsgId] = incomingMsg[:msgMeta][:msgId]
+ return "Okay. What shall we talk about?"
- outgoingMsg = Dict(
- :msgMeta=> msgMeta,
- :payload=> Dict(
- :name=> a.name, # will be shown in frontend as agent name
- :text => "Okay. What shall we talk about?",
- )
- )
- _ = GeneralUtils.sendMqttMsg(outgoingMsg)
+ else
+ # add usermsg to a.chathistory
+ addNewMessage(a, "user", userinput[:text])
- else # a new thinking
- # add usermsg to a.chathistory
- addNewMessage(a, "user", usermsg)
+ #[] if the last used tool is a chatbox, put usermsg -> observation and continue actor loop as planned
+ if !isempty(a.plan[:currenttrajectory]) &&
+ a.plan[:currenttrajectory][end][:action] == "chatbox"
- #[WORKING] if the last used tool is a chatbox
- if a.plan[:currenttrajectory][end][:action] == "chatbox"
- #usermsg -> observation and continue actor loop as planned
+
+
+ else #[WORKING] new thinking
- else
- #planning with MCTS() -> best plan
- #actor loop(best plan)
+ initialState = 0
+ bestplan = runMCTS(initialState, decisionMaker, stateValueEstimator, reflector,
+ 3, 10, 1000, 1.0)
+ error("---> bestplan")
+ # actor loop(bestplan)
- end
- end
- end
- sleep(1)
+ end
end
-
-
-
-
-
-
-
-
-
-
-
- # workstate = nothing
- # response = nothing
-
- # _ = addNewMessage(a, "user", usermsg)
- # isuseplan = isUsePlans(a)
- # # newinfo = extractinfo(a, usermsg)
- # # a.env = newinfo !== nothing ? updateEnvState(a, newinfo) : a.env
- # @show isuseplan
-
- # if isuseplan # use plan before responding
- # if haskey(a.memory[:shortterm], "User:") == false #[] should change role if user want to buy wine.
- # a.memory[:shortterm]["User:"] = usermsg
- # end
- # workstate, response = work(a)
- # end
-
- # # if LLM using askbox, use returning msg form askbox as conversation response
- # if workstate == "askbox" || workstate == "formulatedUserResponse"
- # #[] paraphrase msg so that it is human friendlier word.
- # else
- # response = chat_mistral_openorca(a)
- # response = split(response, "\n\n")[1]
- # response = split(response, "\n\n")[1]
- # end
-
- # response = removeTrailingCharacters(response)
- # _ = addNewMessage(a, "assistant", response)
-
-
end
+# function conversation(a::T) where {T<:agent}
+# """
+# [] update document
+# [x] MCTS() for planning
+# """
+# while true
+# # check for incoming user message
+# if isready(a.receiveUserMsgChannel)
+# incomingMsg = take!(a.receiveUserMsgChannel)
+# incomingPayload = incomingMsg[:payload]
+# @show incomingMsg
+
+# # "newtopic" command to delete chat history
+# if incomingPayload[:text] == "newtopic"
+# clearhistory(a)
+# msgMeta = deepcopy(a.msgMeta)
+# msgMeta[:sendTopic] = incomingMsg[:msgMeta][:replyTopic]
+# msgMeta[:senderName] = "agent-backend"
+# msgMeta[:senderId] = a.id
+# msgMeta[:receiverName] = "agent-frontend"
+# msgMeta[:receiverId] = incomingMsg[:msgMeta][:senderId]
+# msgMeta[:replyTopic] = a.config[:receivemsg][:prompt]
+# msgMeta[:msgId] = string(uuid4())
+# msgMeta[:replyToMsgId] = incomingMsg[:msgMeta][:msgId]
+
+# outgoingMsg = Dict(
+# :msgMeta=> msgMeta,
+# :payload=> Dict(
+# :name=> a.name, # will be shown in frontend as agent name
+# :text => "Okay. What shall we talk about?",
+# )
+# )
+# # _ = GeneralUtils.sendMqttMsg(outgoingMsg)
+
+# else
+# @show a = 55555
+# # add usermsg to a.chathistory
+# addNewMessage(a, "user", usermsg)
+
+# #[] if the last used tool is a chatbox
+# if a.plan[:currenttrajectory][end][:action] == "chatbox"
+# #usermsg -> observation and continue actor loop as planned
+
+
+
+# else #[WORKING] new thinking
+
+
+# initialState = 0
+# bestplan = runMCTS(initialState, decisionMaker, stateValueEstimator, reflector,
+# 3, 10, 1000, 1.0)
+
+# # actor loop(best plan)
+
+# end
+# end
+# end
+# sleep(1)
+# end
+# end
+
+
diff --git a/src/mcts copy 2.jl b/src/mcts copy 2.jl
new file mode 100644
index 0000000..bc5b513
--- /dev/null
+++ b/src/mcts copy 2.jl
@@ -0,0 +1,287 @@
+""" To implement a Monte Carlo Tree Search (MCTS) algorithm in Julia with the UCT (Upper Confidence
+ Bound for Trees) selection function, you can follow the steps below: Define the necessary types
+ and functions for the MCTS algorithm:
+"""
+
+module MCTS
+
+# export
+
+using Dates, UUIDs, DataStructures, JSON3, Random
+using GeneralUtils
+
+# ---------------------------------------------- 100 --------------------------------------------- #
+
+"""
+
+ Arguments\n
+ -----
+
+ Return\n
+ -----
+
+ Example\n
+ -----
+ ```jldoctest
+ julia>
+ ```
+
+ TODO\n
+ -----
+ [] update docstring
+ [] implement the function
+
+ Signature\n
+ -----
+"""
+struct MCTSNode{T}
+ state::T
+ visits::Int
+ total_reward::Float64
+ children::Dict{T, MCTSNode}
+end
+
+"""
+
+ Arguments\n
+ -----
+
+ Return\n
+ -----
+
+ Example\n
+ -----
+ ```jldoctest
+ julia>
+ ```
+
+ TODO\n
+ -----
+ [] update docstring
+ [WORKING] check child_node.total_reward w/ LATS paper. Which value total_reward representing
+
+ Signature\n
+ -----
+"""
+function select(node::MCTSNode, c::Float64)
+ max_uct = -Inf
+ selected_node = nothing
+
+ for (child_state, child_node) in node.children
+ uct_value = child_node.total_reward / child_node.visits +
+ c * sqrt(log(node.visits) / child_node.visits)
+ if uct_value > max_uct
+ max_uct = uct_value
+ selected_node = child_node
+ end
+ end
+
+ return selected_node
+end
+
+"""
+
+ Arguments\n
+ -----
+
+ Return\n
+ -----
+
+ Example\n
+ -----
+ ```jldoctest
+ julia>
+ ```
+
+ TODO\n
+ -----
+ [] update docstring
+ [] implement the function
+
+ Signature\n
+ -----
+"""
+function expand(node::MCTSNode, state::T, actions::Vector{T})
+ for action in actions
+ new_state = transition(node.state, action) # Implement your transition function
+ if new_state ∉ keys(node.children)
+ node.children[new_state] = MCTSNode(new_state, 0, 0.0, Dict{T, MCTSNode}())
+ end
+ end
+end
+
+"""
+
+ Arguments\n
+ -----
+
+ Return\n
+ -----
+
+ Example\n
+ -----
+ ```jldoctest
+ julia>
+ ```
+
+ TODO\n
+ -----
+ [] update docstring
+ [] implement the function
+
+ Signature\n
+ -----
+"""
+function simulate(state::T, max_depth::Int)
+ total_reward = 0.0
+ for _ in 1:max_depth
+ action = select_action(state) # Implement your action selection function
+ state, reward = transition(state, action) # Implement your transition function
+ total_reward += reward
+ end
+ return total_reward
+end
+
+"""
+
+ Arguments\n
+ -----
+
+ Return\n
+ -----
+
+ Example\n
+ -----
+ ```jldoctest
+ julia>
+ ```
+
+ TODO\n
+ -----
+ [] update docstring
+ [] implement the function
+
+ Signature\n
+ -----
+"""
+function backpropagate(node::MCTSNode, reward::Float64)
+ node.visits += 1
+ node.total_reward += reward
+ if !isempty(node.children)
+ best_child = argmax([child.total_reward / child.visits for child in values(node.children)])
+ backpropagate(node.children[best_child], -reward)
+ end
+end
+
+"""
+
+ Arguments\n
+ -----
+
+ Return\n
+ -----
+
+ Example\n
+ -----
+ ```jldoctest
+ julia>
+ ```
+
+ TODO\n
+ -----
+ [] update docstring
+ [] implement the function
+
+ Signature\n
+ -----
+"""
+function transition(state, action)
+
+end
+
+""" Check whether a node is a leaf node
+
+ Arguments\n
+ -----
+
+ Return\n
+ -----
+ a task represent an agent
+
+ Example\n
+ -----
+ ```jldoctest
+ julia>
+ ```
+
+ TODO\n
+ -----
+ [] update docstring
+ [DONE] implement isLeaf()
+
+ Signature\n
+ -----
+"""
+isLeaf(node::MCTSNode)::Bool = isempty(node.children)
+
+# ------------------------------------------------------------------------------------------------ #
+# Create a complete example using the defined MCTS functions #
+# ------------------------------------------------------------------------------------------------ #
+"""
+
+ Arguments\n
+ -----
+
+ Return\n
+ -----
+
+ Example\n
+ -----
+ ```jldoctest
+ julia>
+ ```
+
+ TODO\n
+ -----
+ [] update docstring
+
+ Signature\n
+ -----
+"""
+function run_mcts(initial_state, actions, max_iterations::Int, max_depth::Int, w::Float64)
+ root = MCTSNode(initial_state, 0, 0.0, Dict())
+
+ for _ in 1:max_iterations
+ node = root
+ while !isLeaf(node)
+ node = select(node, w)
+ end
+
+ expand(node, node.state, actions)
+
+ leaf_node = node.children[node.state]
+ reward = simulate(leaf_node.state, max_depth)
+ backpropagate(leaf_node, reward)
+ end
+
+ best_child_state = argmax([child.total_reward / child.visits for child in values(root.children)])
+ return best_child_state
+end
+
+# Define your transition function and action selection function here
+
+# Example usage
+initial_state = 0
+actions = [-1, 0, 1]
+best_action = run_mcts(initial_state, actions, 1000, 10, 1.0)
+println("Best action to take: ", best_action)
+
+
+
+
+
+
+
+
+
+end
\ No newline at end of file
diff --git a/src/mcts.jl b/src/mcts.jl
index 76e2674..eb882fe 100644
--- a/src/mcts.jl
+++ b/src/mcts.jl
@@ -3,9 +3,9 @@
and functions for the MCTS algorithm:
"""
-module MCTS
+module mcts
-# export
+export runMCTS
using Dates, UUIDs, DataStructures, JSON3, Random
using GeneralUtils
@@ -37,7 +37,7 @@ using GeneralUtils
struct MCTSNode{T}
state::T
visits::Int
- total_reward::Float64
+ stateValue::Float64
children::Dict{T, MCTSNode}
end
@@ -45,7 +45,10 @@ end
Arguments\n
-----
-
+ node::MCTSNode
+ mcts node
+ w::Float64
+ exploration weight
Return\n
-----
@@ -58,25 +61,70 @@ end
TODO\n
-----
[] update docstring
- [] implement the function
+ [DONE] check childNode.total_reward w/ LATS paper. Which value total_reward representing
Signature\n
-----
"""
-function select(node::MCTSNode, c::Float64)
+function select(node::MCTSNode, w::Float64)
max_uct = -Inf
- selected_node = nothing
+ selectedNode = nothing
- for (child_state, child_node) in node.children
- uct_value = child_node.total_reward / child_node.visits +
- c * sqrt(log(node.visits) / child_node.visits)
- if uct_value > max_uct
- max_uct = uct_value
- selected_node = child_node
+ for (childState, childNode) in node.children
+ uctValue = childNode.stateValue +
+ w * sqrt(log(node.visits) / childNode.visits)
+ if uctValue > max_uct
+ max_uct = uctValue
+ selectedNode = childNode
end
end
- return selected_node
+ return selectedNode
+end
+
+"""
+
+ Arguments\n
+ -----
+
+ Return\n
+ -----
+
+ Example\n
+ -----
+ ```jldoctest
+ julia>
+ ```
+
+ TODO\n
+ -----
+ [] update docstring
+ [WORKING] implement the function
+
+ Signature\n
+ -----
+"""
+function expand(node::MCTSNode, state::T, decisionMaker::Function, stateValueEstimator::Function;
+ n::Integer=3) where {T<:Any}
+
+ actions = []
+
+ # sampling action from decisionMaker
+ # for nth in 1:n
+
+
+ # end
+
+
+
+
+
+ for action in actions
+ newState = transition(node.state, action) # Implement your transition function
+ if newState ∉ keys(node.children)
+ node.children[newState] = MCTSNode(newState, 0, 0.0, Dict{T, MCTSNode}())
+ end
+ end
end
"""
@@ -101,38 +149,7 @@ end
Signature\n
-----
"""
-function expand(node::MCTSNode, state::T, actions::Vector{T})
- for action in actions
- new_state = transition(node.state, action) # Implement your transition function
- if new_state ∉ keys(node.children)
- node.children[new_state] = MCTSNode(new_state, 0, 0.0, Dict{T, MCTSNode}())
- end
- end
-end
-
-"""
-
- Arguments\n
- -----
-
- Return\n
- -----
-
- Example\n
- -----
- ```jldoctest
- julia>
- ```
-
- TODO\n
- -----
- [] update docstring
- [] implement the function
-
- Signature\n
- -----
-"""
-function simulate(state::T, max_depth::Int)
+function simulate(state::T, max_depth::Int) where {T<:Any}
total_reward = 0.0
for _ in 1:max_depth
action = select_action(state) # Implement your action selection function
@@ -224,9 +241,6 @@ end
"""
isLeaf(node::MCTSNode)::Bool = isempty(node.children)
-# ------------------------------------------------------------------------------------------------ #
-# Create a complete example using the defined MCTS functions #
-# ------------------------------------------------------------------------------------------------ #
"""
Arguments\n
@@ -244,37 +258,128 @@ isLeaf(node::MCTSNode)::Bool = isempty(node.children)
TODO\n
-----
[] update docstring
+ [] implement the function
+ [] implement RAG to pull similar experience
Signature\n
-----
"""
-function run_mcts(initial_state, actions, max_iterations::Int, max_depth::Int, w::Float64)
- root = MCTSNode(initial_state, 0, 0.0, Dict())
+function decisionMaker()
- for _ in 1:max_iterations
- node = root
- while !isLeaf(node)
- node = select(node, w)
- end
+end
- expand(node, node.state, actions)
+"""
- leaf_node = node.children[node.state]
- reward = simulate(leaf_node.state, max_depth)
- backpropagate(leaf_node, reward)
+ Arguments\n
+ -----
+
+ Return\n
+ -----
+
+ Example\n
+ -----
+ ```jldoctest
+ julia>
+ ```
+
+ TODO\n
+ -----
+ [] update docstring
+ [] implement the function
+
+ Signature\n
+ -----
+"""
+function stateValueEstimator()
+
+end
+
+"""
+
+ Arguments\n
+ -----
+
+ Return\n
+ -----
+
+ Example\n
+ -----
+ ```jldoctest
+ julia>
+ ```
+
+ TODO\n
+ -----
+ [] update docstring
+ [] implement the function
+
+ Signature\n
+ -----
+"""
+function reflector()
+
+end
+
+# ------------------------------------------------------------------------------------------------ #
+# Create a complete example using the defined MCTS functions #
+# ------------------------------------------------------------------------------------------------ #
+""" Search for best action
+
+ Arguments\n
+ -----
+ initial state
+ initial state
+ decisionMaker::Function
+ decide what action to take
+ stateValueEstimator::Function
+ assess the value of the state
+ reflector::Function
+ generate lesson from trajectory and reward
+ n::Integer
+ how many times action will be sampled from decisionMaker
+ w::Float64
+ exploration weight
+
+ Return\n
+ -----
+
+ Example\n
+ -----
+ ```jldoctest
+ julia>
+ ```
+
+ TODO\n
+ -----
+ [] update docstring
+
+ Signature\n
+ -----
+"""
+function runMCTS(initialState, decisionMaker::Function, stateValueEstimator::Function,
+ reflector::Function, totalActionSampled::Integer, maxDepth::Integer,
+ maxIterations::Integer, w::Float64)
+ root = MCTSNode(initialState, 0, 0.0, Dict())
+
+ for _ in 1:maxIterations
+ node = root
+ while !isLeaf(node)
+ node = select(node, w)
+ end
+
+ expand(node, node.state, decisionMaker, stateValueEstimator,
+ n=n)
+
+ leaf_node = node.children[node.state]
+ reward = simulate(leaf_node.state, maxDepth)
+ backpropagate(leaf_node, reward)
end
best_child_state = argmax([child.total_reward / child.visits for child in values(root.children)])
return best_child_state
end
-# Define your transition function and action selection function here
-# Example usage
-initial_state = 0
-actions = [-1, 0, 1]
-best_action = run_mcts(initial_state, actions, 1000, 10, 1.0)
-println("Best action to take: ", best_action)
diff --git a/test/runtest.jl b/test/runtest.jl
new file mode 100644
index 0000000..7faa257
--- /dev/null
+++ b/test/runtest.jl
@@ -0,0 +1,109 @@
+using Revise # remove when this package is completed
+using YiemAgent, GeneralUtils, JSON3, MQTTClient, Dates, UUIDs
+using Base.Threads
+
+# ---------------------------------------------- 100 --------------------------------------------- #
+
+config = copy(JSON3.read("config.json"))
+
+instanceInternalTopic = config[:serviceInternalTopic][:value] * "/1"
+
+client, connection = MakeConnection(config[:mqttServerInfo][:value][:broker],
+ config[:mqttServerInfo][:value][:port])
+
+receiveUserMsgChannel = Channel{Dict}(4)
+receiveInternalMsgChannel = Channel{Dict}(4)
+
+msgMeta = GeneralUtils.generate_msgMeta(
+ "N/A",
+ replyTopic = config[:servicetopic][:value] # ask frontend reply to this instance_chat_topic
+ )
+
+agentConfig = Dict(
+ :receiveprompt=>Dict(
+ :mqtttopic=> config[:servicetopic][:value], # topic to receive prompt i.e. frontend send msg to this topic
+ ),
+ :receiveinternal=>Dict(
+ :mqtttopic=> instanceInternalTopic, # receive topic for model's internal
+ ),
+ :text2text=>Dict(
+ :mqtttopic=> config[:text2text][:value],
+ ),
+ )
+
+# Instantiate an agent
+tools=Dict( # update input format
+ "askbox"=> Dict(
+ :description => "Useful for when you need to ask the user for more context. Do not ask the user their own question.",
+ :input => """Input is a text in JSON format.{\"Q1\": \"How are you doing?\", \"Q2\": \"How may I help you?\"}""",
+ :output => "" ,
+ :func => nothing,
+ ),
+ # "winestock"=> Dict(
+ # :description => "A handy tool for searching wine in your inventory that match the user preferences.",
+ # :input => """Input is a JSON-formatted string that contains a detailed and precise search query.{\"wine type\": \"rose\", \"price\": \"max 35\", \"sweetness level\": \"sweet\", \"intensity level\": \"light bodied\", \"Tannin level\": \"low\", \"Acidity level\": \"low\"}""",
+ # :output => """