diff --git a/src/interface.jl b/src/interface.jl index fb5b4c9..d5ff814 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -4,7 +4,7 @@ export addNewMessage, conversation using JSON3, DataStructures, Dates, UUIDs, HTTP, Random, MQTTClient using GeneralUtils -using ..type, ..util, ..llmfunction +using ..type, ..util, ..llmfunction, ..mcts # ------------------------------------------------------------------------------------------------ # # pythoncall setting # @@ -85,102 +85,105 @@ using ..type, ..util, ..llmfunction Signature\n ----- """ -function conversation(a::T) where {T<:agent} +function conversation(a::T, userinput::Dict) where {T<:agent} """ [] update document - [] MCTS() for planning + [x] MCTS() for planning """ - while true - # check for incoming user message - if isready(a.receiveUserMsgChannel) - incomingMsg = take!(a.receiveUserMsgChannel) - incomingPayload = incomingMsg[:payload] + # "newtopic" command to delete chat history + if userinput[:text] == "newtopic" + clearhistory(a) - # "newtopic" command to delete chat history - if incomingPayload[:text] == "newtopic" - clearhistory(a) - msgMeta = deepcopy(a.msgMeta) - msgMeta[:sendTopic] = incomingMsg[:msgMeta][:replyTopic] - msgMeta[:senderName] = "agent-backend" - msgMeta[:senderId] = a.id - msgMeta[:receiverName] = "agent-frontend" - msgMeta[:receiverId] = incomingMsg[:msgMeta][:senderId] - msgMeta[:replyTopic] = a.config[:receivemsg][:prompt] - msgMeta[:msgId] = string(uuid4()) - msgMeta[:replyToMsgId] = incomingMsg[:msgMeta][:msgId] + return "Okay. What shall we talk about?" - outgoingMsg = Dict( - :msgMeta=> msgMeta, - :payload=> Dict( - :name=> a.name, # will be shown in frontend as agent name - :text => "Okay. What shall we talk about?", - ) - ) - _ = GeneralUtils.sendMqttMsg(outgoingMsg) + else + # add usermsg to a.chathistory + addNewMessage(a, "user", userinput[:text]) - else # a new thinking - # add usermsg to a.chathistory - addNewMessage(a, "user", usermsg) + #[] if the last used tool is a chatbox, put usermsg -> observation and continue actor loop as planned + if !isempty(a.plan[:currenttrajectory]) && + a.plan[:currenttrajectory][end][:action] == "chatbox" - #[WORKING] if the last used tool is a chatbox - if a.plan[:currenttrajectory][end][:action] == "chatbox" - #usermsg -> observation and continue actor loop as planned + + + else #[WORKING] new thinking - else - #planning with MCTS() -> best plan - #actor loop(best plan) + initialState = 0 + bestplan = runMCTS(initialState, decisionMaker, stateValueEstimator, reflector, + 3, 10, 1000, 1.0) + error("---> bestplan") + # actor loop(bestplan) - end - end - end - sleep(1) + end end - - - - - - - - - - - - # workstate = nothing - # response = nothing - - # _ = addNewMessage(a, "user", usermsg) - # isuseplan = isUsePlans(a) - # # newinfo = extractinfo(a, usermsg) - # # a.env = newinfo !== nothing ? updateEnvState(a, newinfo) : a.env - # @show isuseplan - - # if isuseplan # use plan before responding - # if haskey(a.memory[:shortterm], "User:") == false #[] should change role if user want to buy wine. - # a.memory[:shortterm]["User:"] = usermsg - # end - # workstate, response = work(a) - # end - - # # if LLM using askbox, use returning msg form askbox as conversation response - # if workstate == "askbox" || workstate == "formulatedUserResponse" - # #[] paraphrase msg so that it is human friendlier word. - # else - # response = chat_mistral_openorca(a) - # response = split(response, "\n\n")[1] - # response = split(response, "\n\n")[1] - # end - - # response = removeTrailingCharacters(response) - # _ = addNewMessage(a, "assistant", response) - - end +# function conversation(a::T) where {T<:agent} +# """ +# [] update document +# [x] MCTS() for planning +# """ +# while true +# # check for incoming user message +# if isready(a.receiveUserMsgChannel) +# incomingMsg = take!(a.receiveUserMsgChannel) +# incomingPayload = incomingMsg[:payload] +# @show incomingMsg + +# # "newtopic" command to delete chat history +# if incomingPayload[:text] == "newtopic" +# clearhistory(a) +# msgMeta = deepcopy(a.msgMeta) +# msgMeta[:sendTopic] = incomingMsg[:msgMeta][:replyTopic] +# msgMeta[:senderName] = "agent-backend" +# msgMeta[:senderId] = a.id +# msgMeta[:receiverName] = "agent-frontend" +# msgMeta[:receiverId] = incomingMsg[:msgMeta][:senderId] +# msgMeta[:replyTopic] = a.config[:receivemsg][:prompt] +# msgMeta[:msgId] = string(uuid4()) +# msgMeta[:replyToMsgId] = incomingMsg[:msgMeta][:msgId] + +# outgoingMsg = Dict( +# :msgMeta=> msgMeta, +# :payload=> Dict( +# :name=> a.name, # will be shown in frontend as agent name +# :text => "Okay. What shall we talk about?", +# ) +# ) +# # _ = GeneralUtils.sendMqttMsg(outgoingMsg) + +# else +# @show a = 55555 +# # add usermsg to a.chathistory +# addNewMessage(a, "user", usermsg) + +# #[] if the last used tool is a chatbox +# if a.plan[:currenttrajectory][end][:action] == "chatbox" +# #usermsg -> observation and continue actor loop as planned + + + +# else #[WORKING] new thinking + + +# initialState = 0 +# bestplan = runMCTS(initialState, decisionMaker, stateValueEstimator, reflector, +# 3, 10, 1000, 1.0) + +# # actor loop(best plan) + +# end +# end +# end +# sleep(1) +# end +# end + + diff --git a/src/mcts copy 2.jl b/src/mcts copy 2.jl new file mode 100644 index 0000000..bc5b513 --- /dev/null +++ b/src/mcts copy 2.jl @@ -0,0 +1,287 @@ +""" To implement a Monte Carlo Tree Search (MCTS) algorithm in Julia with the UCT (Upper Confidence + Bound for Trees) selection function, you can follow the steps below: Define the necessary types + and functions for the MCTS algorithm: +""" + +module MCTS + +# export + +using Dates, UUIDs, DataStructures, JSON3, Random +using GeneralUtils + +# ---------------------------------------------- 100 --------------------------------------------- # + +""" + + Arguments\n + ----- + + Return\n + ----- + + Example\n + ----- + ```jldoctest + julia> + ``` + + TODO\n + ----- + [] update docstring + [] implement the function + + Signature\n + ----- +""" +struct MCTSNode{T} + state::T + visits::Int + total_reward::Float64 + children::Dict{T, MCTSNode} +end + +""" + + Arguments\n + ----- + + Return\n + ----- + + Example\n + ----- + ```jldoctest + julia> + ``` + + TODO\n + ----- + [] update docstring + [WORKING] check child_node.total_reward w/ LATS paper. Which value total_reward representing + + Signature\n + ----- +""" +function select(node::MCTSNode, c::Float64) + max_uct = -Inf + selected_node = nothing + + for (child_state, child_node) in node.children + uct_value = child_node.total_reward / child_node.visits + + c * sqrt(log(node.visits) / child_node.visits) + if uct_value > max_uct + max_uct = uct_value + selected_node = child_node + end + end + + return selected_node +end + +""" + + Arguments\n + ----- + + Return\n + ----- + + Example\n + ----- + ```jldoctest + julia> + ``` + + TODO\n + ----- + [] update docstring + [] implement the function + + Signature\n + ----- +""" +function expand(node::MCTSNode, state::T, actions::Vector{T}) + for action in actions + new_state = transition(node.state, action) # Implement your transition function + if new_state ∉ keys(node.children) + node.children[new_state] = MCTSNode(new_state, 0, 0.0, Dict{T, MCTSNode}()) + end + end +end + +""" + + Arguments\n + ----- + + Return\n + ----- + + Example\n + ----- + ```jldoctest + julia> + ``` + + TODO\n + ----- + [] update docstring + [] implement the function + + Signature\n + ----- +""" +function simulate(state::T, max_depth::Int) + total_reward = 0.0 + for _ in 1:max_depth + action = select_action(state) # Implement your action selection function + state, reward = transition(state, action) # Implement your transition function + total_reward += reward + end + return total_reward +end + +""" + + Arguments\n + ----- + + Return\n + ----- + + Example\n + ----- + ```jldoctest + julia> + ``` + + TODO\n + ----- + [] update docstring + [] implement the function + + Signature\n + ----- +""" +function backpropagate(node::MCTSNode, reward::Float64) + node.visits += 1 + node.total_reward += reward + if !isempty(node.children) + best_child = argmax([child.total_reward / child.visits for child in values(node.children)]) + backpropagate(node.children[best_child], -reward) + end +end + +""" + + Arguments\n + ----- + + Return\n + ----- + + Example\n + ----- + ```jldoctest + julia> + ``` + + TODO\n + ----- + [] update docstring + [] implement the function + + Signature\n + ----- +""" +function transition(state, action) + +end + +""" Check whether a node is a leaf node + + Arguments\n + ----- + + Return\n + ----- + a task represent an agent + + Example\n + ----- + ```jldoctest + julia> + ``` + + TODO\n + ----- + [] update docstring + [DONE] implement isLeaf() + + Signature\n + ----- +""" +isLeaf(node::MCTSNode)::Bool = isempty(node.children) + +# ------------------------------------------------------------------------------------------------ # +# Create a complete example using the defined MCTS functions # +# ------------------------------------------------------------------------------------------------ # +""" + + Arguments\n + ----- + + Return\n + ----- + + Example\n + ----- + ```jldoctest + julia> + ``` + + TODO\n + ----- + [] update docstring + + Signature\n + ----- +""" +function run_mcts(initial_state, actions, max_iterations::Int, max_depth::Int, w::Float64) + root = MCTSNode(initial_state, 0, 0.0, Dict()) + + for _ in 1:max_iterations + node = root + while !isLeaf(node) + node = select(node, w) + end + + expand(node, node.state, actions) + + leaf_node = node.children[node.state] + reward = simulate(leaf_node.state, max_depth) + backpropagate(leaf_node, reward) + end + + best_child_state = argmax([child.total_reward / child.visits for child in values(root.children)]) + return best_child_state +end + +# Define your transition function and action selection function here + +# Example usage +initial_state = 0 +actions = [-1, 0, 1] +best_action = run_mcts(initial_state, actions, 1000, 10, 1.0) +println("Best action to take: ", best_action) + + + + + + + + + +end \ No newline at end of file diff --git a/src/mcts.jl b/src/mcts.jl index 76e2674..eb882fe 100644 --- a/src/mcts.jl +++ b/src/mcts.jl @@ -3,9 +3,9 @@ and functions for the MCTS algorithm: """ -module MCTS +module mcts -# export +export runMCTS using Dates, UUIDs, DataStructures, JSON3, Random using GeneralUtils @@ -37,7 +37,7 @@ using GeneralUtils struct MCTSNode{T} state::T visits::Int - total_reward::Float64 + stateValue::Float64 children::Dict{T, MCTSNode} end @@ -45,7 +45,10 @@ end Arguments\n ----- - + node::MCTSNode + mcts node + w::Float64 + exploration weight Return\n ----- @@ -58,25 +61,70 @@ end TODO\n ----- [] update docstring - [] implement the function + [DONE] check childNode.total_reward w/ LATS paper. Which value total_reward representing Signature\n ----- """ -function select(node::MCTSNode, c::Float64) +function select(node::MCTSNode, w::Float64) max_uct = -Inf - selected_node = nothing + selectedNode = nothing - for (child_state, child_node) in node.children - uct_value = child_node.total_reward / child_node.visits + - c * sqrt(log(node.visits) / child_node.visits) - if uct_value > max_uct - max_uct = uct_value - selected_node = child_node + for (childState, childNode) in node.children + uctValue = childNode.stateValue + + w * sqrt(log(node.visits) / childNode.visits) + if uctValue > max_uct + max_uct = uctValue + selectedNode = childNode end end - return selected_node + return selectedNode +end + +""" + + Arguments\n + ----- + + Return\n + ----- + + Example\n + ----- + ```jldoctest + julia> + ``` + + TODO\n + ----- + [] update docstring + [WORKING] implement the function + + Signature\n + ----- +""" +function expand(node::MCTSNode, state::T, decisionMaker::Function, stateValueEstimator::Function; + n::Integer=3) where {T<:Any} + + actions = [] + + # sampling action from decisionMaker + # for nth in 1:n + + + # end + + + + + + for action in actions + newState = transition(node.state, action) # Implement your transition function + if newState ∉ keys(node.children) + node.children[newState] = MCTSNode(newState, 0, 0.0, Dict{T, MCTSNode}()) + end + end end """ @@ -101,38 +149,7 @@ end Signature\n ----- """ -function expand(node::MCTSNode, state::T, actions::Vector{T}) - for action in actions - new_state = transition(node.state, action) # Implement your transition function - if new_state ∉ keys(node.children) - node.children[new_state] = MCTSNode(new_state, 0, 0.0, Dict{T, MCTSNode}()) - end - end -end - -""" - - Arguments\n - ----- - - Return\n - ----- - - Example\n - ----- - ```jldoctest - julia> - ``` - - TODO\n - ----- - [] update docstring - [] implement the function - - Signature\n - ----- -""" -function simulate(state::T, max_depth::Int) +function simulate(state::T, max_depth::Int) where {T<:Any} total_reward = 0.0 for _ in 1:max_depth action = select_action(state) # Implement your action selection function @@ -224,9 +241,6 @@ end """ isLeaf(node::MCTSNode)::Bool = isempty(node.children) -# ------------------------------------------------------------------------------------------------ # -# Create a complete example using the defined MCTS functions # -# ------------------------------------------------------------------------------------------------ # """ Arguments\n @@ -244,37 +258,128 @@ isLeaf(node::MCTSNode)::Bool = isempty(node.children) TODO\n ----- [] update docstring + [] implement the function + [] implement RAG to pull similar experience Signature\n ----- """ -function run_mcts(initial_state, actions, max_iterations::Int, max_depth::Int, w::Float64) - root = MCTSNode(initial_state, 0, 0.0, Dict()) +function decisionMaker() - for _ in 1:max_iterations - node = root - while !isLeaf(node) - node = select(node, w) - end +end - expand(node, node.state, actions) +""" - leaf_node = node.children[node.state] - reward = simulate(leaf_node.state, max_depth) - backpropagate(leaf_node, reward) + Arguments\n + ----- + + Return\n + ----- + + Example\n + ----- + ```jldoctest + julia> + ``` + + TODO\n + ----- + [] update docstring + [] implement the function + + Signature\n + ----- +""" +function stateValueEstimator() + +end + +""" + + Arguments\n + ----- + + Return\n + ----- + + Example\n + ----- + ```jldoctest + julia> + ``` + + TODO\n + ----- + [] update docstring + [] implement the function + + Signature\n + ----- +""" +function reflector() + +end + +# ------------------------------------------------------------------------------------------------ # +# Create a complete example using the defined MCTS functions # +# ------------------------------------------------------------------------------------------------ # +""" Search for best action + + Arguments\n + ----- + initial state + initial state + decisionMaker::Function + decide what action to take + stateValueEstimator::Function + assess the value of the state + reflector::Function + generate lesson from trajectory and reward + n::Integer + how many times action will be sampled from decisionMaker + w::Float64 + exploration weight + + Return\n + ----- + + Example\n + ----- + ```jldoctest + julia> + ``` + + TODO\n + ----- + [] update docstring + + Signature\n + ----- +""" +function runMCTS(initialState, decisionMaker::Function, stateValueEstimator::Function, + reflector::Function, totalActionSampled::Integer, maxDepth::Integer, + maxIterations::Integer, w::Float64) + root = MCTSNode(initialState, 0, 0.0, Dict()) + + for _ in 1:maxIterations + node = root + while !isLeaf(node) + node = select(node, w) + end + + expand(node, node.state, decisionMaker, stateValueEstimator, + n=n) + + leaf_node = node.children[node.state] + reward = simulate(leaf_node.state, maxDepth) + backpropagate(leaf_node, reward) end best_child_state = argmax([child.total_reward / child.visits for child in values(root.children)]) return best_child_state end -# Define your transition function and action selection function here -# Example usage -initial_state = 0 -actions = [-1, 0, 1] -best_action = run_mcts(initial_state, actions, 1000, 10, 1.0) -println("Best action to take: ", best_action) diff --git a/test/runtest.jl b/test/runtest.jl new file mode 100644 index 0000000..7faa257 --- /dev/null +++ b/test/runtest.jl @@ -0,0 +1,109 @@ +using Revise # remove when this package is completed +using YiemAgent, GeneralUtils, JSON3, MQTTClient, Dates, UUIDs +using Base.Threads + +# ---------------------------------------------- 100 --------------------------------------------- # + +config = copy(JSON3.read("config.json")) + +instanceInternalTopic = config[:serviceInternalTopic][:value] * "/1" + +client, connection = MakeConnection(config[:mqttServerInfo][:value][:broker], + config[:mqttServerInfo][:value][:port]) + +receiveUserMsgChannel = Channel{Dict}(4) +receiveInternalMsgChannel = Channel{Dict}(4) + +msgMeta = GeneralUtils.generate_msgMeta( + "N/A", + replyTopic = config[:servicetopic][:value] # ask frontend reply to this instance_chat_topic + ) + +agentConfig = Dict( + :receiveprompt=>Dict( + :mqtttopic=> config[:servicetopic][:value], # topic to receive prompt i.e. frontend send msg to this topic + ), + :receiveinternal=>Dict( + :mqtttopic=> instanceInternalTopic, # receive topic for model's internal + ), + :text2text=>Dict( + :mqtttopic=> config[:text2text][:value], + ), + ) + +# Instantiate an agent +tools=Dict( # update input format + "askbox"=> Dict( + :description => "Useful for when you need to ask the user for more context. Do not ask the user their own question.", + :input => """Input is a text in JSON format.{\"Q1\": \"How are you doing?\", \"Q2\": \"How may I help you?\"}""", + :output => "" , + :func => nothing, + ), + # "winestock"=> Dict( + # :description => "A handy tool for searching wine in your inventory that match the user preferences.", + # :input => """Input is a JSON-formatted string that contains a detailed and precise search query.{\"wine type\": \"rose\", \"price\": \"max 35\", \"sweetness level\": \"sweet\", \"intensity level\": \"light bodied\", \"Tannin level\": \"low\", \"Acidity level\": \"low\"}""", + # :output => """Output are wines that match the search query in JSON format.""", + # :func => ChatAgent.winestock, + # ), + "finalanswer"=> Dict( + :description => "Useful for when you are ready to recommend wines to the user.", + :input => """{\"finalanswer\": \"some text\"}.{\"finalanswer\": \"I recommend Zena Crown Vista\"}""", + :output => "" , + :func => nothing, + ), + ) + + a = YiemAgent.sommelier( + receiveUserMsgChannel, + receiveInternalMsgChannel, + msgMeta, + agentConfig, + name= "assistant", + id= "randomSessionID", # agent instance id + tools=tools, + ) + +response = YiemAgent.conversation(a, Dict(:text=> "newtopic", ) ) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +