This commit is contained in:
narawat lamaiin
2024-04-21 16:19:32 +07:00
parent b8d036e800
commit ee1446b1e2
4 changed files with 651 additions and 147 deletions

View File

@@ -4,7 +4,7 @@ export addNewMessage, conversation
using JSON3, DataStructures, Dates, UUIDs, HTTP, Random, MQTTClient
using GeneralUtils
using ..type, ..util, ..llmfunction
using ..type, ..util, ..llmfunction, ..mcts
# ------------------------------------------------------------------------------------------------ #
# pythoncall setting #
@@ -85,102 +85,105 @@ using ..type, ..util, ..llmfunction
Signature\n
-----
"""
function conversation(a::T) where {T<:agent}
function conversation(a::T, userinput::Dict) where {T<:agent}
"""
[] update document
[] MCTS() for planning
[x] MCTS() for planning
"""
while true
# check for incoming user message
if isready(a.receiveUserMsgChannel)
incomingMsg = take!(a.receiveUserMsgChannel)
incomingPayload = incomingMsg[:payload]
# "newtopic" command to delete chat history
if userinput[:text] == "newtopic"
clearhistory(a)
# "newtopic" command to delete chat history
if incomingPayload[:text] == "newtopic"
clearhistory(a)
msgMeta = deepcopy(a.msgMeta)
msgMeta[:sendTopic] = incomingMsg[:msgMeta][:replyTopic]
msgMeta[:senderName] = "agent-backend"
msgMeta[:senderId] = a.id
msgMeta[:receiverName] = "agent-frontend"
msgMeta[:receiverId] = incomingMsg[:msgMeta][:senderId]
msgMeta[:replyTopic] = a.config[:receivemsg][:prompt]
msgMeta[:msgId] = string(uuid4())
msgMeta[:replyToMsgId] = incomingMsg[:msgMeta][:msgId]
return "Okay. What shall we talk about?"
outgoingMsg = Dict(
:msgMeta=> msgMeta,
:payload=> Dict(
:name=> a.name, # will be shown in frontend as agent name
:text => "Okay. What shall we talk about?",
)
)
_ = GeneralUtils.sendMqttMsg(outgoingMsg)
else
# add usermsg to a.chathistory
addNewMessage(a, "user", userinput[:text])
else # a new thinking
# add usermsg to a.chathistory
addNewMessage(a, "user", usermsg)
#[] if the last used tool is a chatbox, put usermsg -> observation and continue actor loop as planned
if !isempty(a.plan[:currenttrajectory]) &&
a.plan[:currenttrajectory][end][:action] == "chatbox"
#[WORKING] if the last used tool is a chatbox
if a.plan[:currenttrajectory][end][:action] == "chatbox"
#usermsg -> observation and continue actor loop as planned
else #[WORKING] new thinking
else
#planning with MCTS() -> best plan
#actor loop(best plan)
initialState = 0
bestplan = runMCTS(initialState, decisionMaker, stateValueEstimator, reflector,
3, 10, 1000, 1.0)
error("---> bestplan")
# actor loop(bestplan)
end
end
end
sleep(1)
end
end
# workstate = nothing
# response = nothing
# _ = addNewMessage(a, "user", usermsg)
# isuseplan = isUsePlans(a)
# # newinfo = extractinfo(a, usermsg)
# # a.env = newinfo !== nothing ? updateEnvState(a, newinfo) : a.env
# @show isuseplan
# if isuseplan # use plan before responding
# if haskey(a.memory[:shortterm], "User:") == false #[] should change role if user want to buy wine.
# a.memory[:shortterm]["User:"] = usermsg
# end
# workstate, response = work(a)
# end
# # if LLM using askbox, use returning msg form askbox as conversation response
# if workstate == "askbox" || workstate == "formulatedUserResponse"
# #[] paraphrase msg so that it is human friendlier word.
# else
# response = chat_mistral_openorca(a)
# response = split(response, "\n\n")[1]
# response = split(response, "\n\n")[1]
# end
# response = removeTrailingCharacters(response)
# _ = addNewMessage(a, "assistant", response)
end
# function conversation(a::T) where {T<:agent}
# """
# [] update document
# [x] MCTS() for planning
# """
# while true
# # check for incoming user message
# if isready(a.receiveUserMsgChannel)
# incomingMsg = take!(a.receiveUserMsgChannel)
# incomingPayload = incomingMsg[:payload]
# @show incomingMsg
# # "newtopic" command to delete chat history
# if incomingPayload[:text] == "newtopic"
# clearhistory(a)
# msgMeta = deepcopy(a.msgMeta)
# msgMeta[:sendTopic] = incomingMsg[:msgMeta][:replyTopic]
# msgMeta[:senderName] = "agent-backend"
# msgMeta[:senderId] = a.id
# msgMeta[:receiverName] = "agent-frontend"
# msgMeta[:receiverId] = incomingMsg[:msgMeta][:senderId]
# msgMeta[:replyTopic] = a.config[:receivemsg][:prompt]
# msgMeta[:msgId] = string(uuid4())
# msgMeta[:replyToMsgId] = incomingMsg[:msgMeta][:msgId]
# outgoingMsg = Dict(
# :msgMeta=> msgMeta,
# :payload=> Dict(
# :name=> a.name, # will be shown in frontend as agent name
# :text => "Okay. What shall we talk about?",
# )
# )
# # _ = GeneralUtils.sendMqttMsg(outgoingMsg)
# else
# @show a = 55555
# # add usermsg to a.chathistory
# addNewMessage(a, "user", usermsg)
# #[] if the last used tool is a chatbox
# if a.plan[:currenttrajectory][end][:action] == "chatbox"
# #usermsg -> observation and continue actor loop as planned
# else #[WORKING] new thinking
# initialState = 0
# bestplan = runMCTS(initialState, decisionMaker, stateValueEstimator, reflector,
# 3, 10, 1000, 1.0)
# # actor loop(best plan)
# end
# end
# end
# sleep(1)
# end
# end

287
src/mcts copy 2.jl Normal file
View File

@@ -0,0 +1,287 @@
""" To implement a Monte Carlo Tree Search (MCTS) algorithm in Julia with the UCT (Upper Confidence
Bound for Trees) selection function, you can follow the steps below: Define the necessary types
and functions for the MCTS algorithm:
"""
module MCTS
# export
using Dates, UUIDs, DataStructures, JSON3, Random
using GeneralUtils
# ---------------------------------------------- 100 --------------------------------------------- #
"""
Arguments\n
-----
Return\n
-----
Example\n
-----
```jldoctest
julia>
```
TODO\n
-----
[] update docstring
[] implement the function
Signature\n
-----
"""
struct MCTSNode{T}
state::T
visits::Int
total_reward::Float64
children::Dict{T, MCTSNode}
end
"""
Arguments\n
-----
Return\n
-----
Example\n
-----
```jldoctest
julia>
```
TODO\n
-----
[] update docstring
[WORKING] check child_node.total_reward w/ LATS paper. Which value total_reward representing
Signature\n
-----
"""
function select(node::MCTSNode, c::Float64)
max_uct = -Inf
selected_node = nothing
for (child_state, child_node) in node.children
uct_value = child_node.total_reward / child_node.visits +
c * sqrt(log(node.visits) / child_node.visits)
if uct_value > max_uct
max_uct = uct_value
selected_node = child_node
end
end
return selected_node
end
"""
Arguments\n
-----
Return\n
-----
Example\n
-----
```jldoctest
julia>
```
TODO\n
-----
[] update docstring
[] implement the function
Signature\n
-----
"""
function expand(node::MCTSNode, state::T, actions::Vector{T})
for action in actions
new_state = transition(node.state, action) # Implement your transition function
if new_state keys(node.children)
node.children[new_state] = MCTSNode(new_state, 0, 0.0, Dict{T, MCTSNode}())
end
end
end
"""
Arguments\n
-----
Return\n
-----
Example\n
-----
```jldoctest
julia>
```
TODO\n
-----
[] update docstring
[] implement the function
Signature\n
-----
"""
function simulate(state::T, max_depth::Int)
total_reward = 0.0
for _ in 1:max_depth
action = select_action(state) # Implement your action selection function
state, reward = transition(state, action) # Implement your transition function
total_reward += reward
end
return total_reward
end
"""
Arguments\n
-----
Return\n
-----
Example\n
-----
```jldoctest
julia>
```
TODO\n
-----
[] update docstring
[] implement the function
Signature\n
-----
"""
function backpropagate(node::MCTSNode, reward::Float64)
node.visits += 1
node.total_reward += reward
if !isempty(node.children)
best_child = argmax([child.total_reward / child.visits for child in values(node.children)])
backpropagate(node.children[best_child], -reward)
end
end
"""
Arguments\n
-----
Return\n
-----
Example\n
-----
```jldoctest
julia>
```
TODO\n
-----
[] update docstring
[] implement the function
Signature\n
-----
"""
function transition(state, action)
end
""" Check whether a node is a leaf node
Arguments\n
-----
Return\n
-----
a task represent an agent
Example\n
-----
```jldoctest
julia>
```
TODO\n
-----
[] update docstring
[DONE] implement isLeaf()
Signature\n
-----
"""
isLeaf(node::MCTSNode)::Bool = isempty(node.children)
# ------------------------------------------------------------------------------------------------ #
# Create a complete example using the defined MCTS functions #
# ------------------------------------------------------------------------------------------------ #
"""
Arguments\n
-----
Return\n
-----
Example\n
-----
```jldoctest
julia>
```
TODO\n
-----
[] update docstring
Signature\n
-----
"""
function run_mcts(initial_state, actions, max_iterations::Int, max_depth::Int, w::Float64)
root = MCTSNode(initial_state, 0, 0.0, Dict())
for _ in 1:max_iterations
node = root
while !isLeaf(node)
node = select(node, w)
end
expand(node, node.state, actions)
leaf_node = node.children[node.state]
reward = simulate(leaf_node.state, max_depth)
backpropagate(leaf_node, reward)
end
best_child_state = argmax([child.total_reward / child.visits for child in values(root.children)])
return best_child_state
end
# Define your transition function and action selection function here
# Example usage
initial_state = 0
actions = [-1, 0, 1]
best_action = run_mcts(initial_state, actions, 1000, 10, 1.0)
println("Best action to take: ", best_action)
end

View File

@@ -3,9 +3,9 @@
and functions for the MCTS algorithm:
"""
module MCTS
module mcts
# export
export runMCTS
using Dates, UUIDs, DataStructures, JSON3, Random
using GeneralUtils
@@ -37,7 +37,7 @@ using GeneralUtils
struct MCTSNode{T}
state::T
visits::Int
total_reward::Float64
stateValue::Float64
children::Dict{T, MCTSNode}
end
@@ -45,7 +45,10 @@ end
Arguments\n
-----
node::MCTSNode
mcts node
w::Float64
exploration weight
Return\n
-----
@@ -58,25 +61,70 @@ end
TODO\n
-----
[] update docstring
[] implement the function
[DONE] check childNode.total_reward w/ LATS paper. Which value total_reward representing
Signature\n
-----
"""
function select(node::MCTSNode, c::Float64)
function select(node::MCTSNode, w::Float64)
max_uct = -Inf
selected_node = nothing
selectedNode = nothing
for (child_state, child_node) in node.children
uct_value = child_node.total_reward / child_node.visits +
c * sqrt(log(node.visits) / child_node.visits)
if uct_value > max_uct
max_uct = uct_value
selected_node = child_node
for (childState, childNode) in node.children
uctValue = childNode.stateValue +
w * sqrt(log(node.visits) / childNode.visits)
if uctValue > max_uct
max_uct = uctValue
selectedNode = childNode
end
end
return selected_node
return selectedNode
end
"""
Arguments\n
-----
Return\n
-----
Example\n
-----
```jldoctest
julia>
```
TODO\n
-----
[] update docstring
[WORKING] implement the function
Signature\n
-----
"""
function expand(node::MCTSNode, state::T, decisionMaker::Function, stateValueEstimator::Function;
n::Integer=3) where {T<:Any}
actions = []
# sampling action from decisionMaker
# for nth in 1:n
# end
for action in actions
newState = transition(node.state, action) # Implement your transition function
if newState keys(node.children)
node.children[newState] = MCTSNode(newState, 0, 0.0, Dict{T, MCTSNode}())
end
end
end
"""
@@ -101,38 +149,7 @@ end
Signature\n
-----
"""
function expand(node::MCTSNode, state::T, actions::Vector{T})
for action in actions
new_state = transition(node.state, action) # Implement your transition function
if new_state keys(node.children)
node.children[new_state] = MCTSNode(new_state, 0, 0.0, Dict{T, MCTSNode}())
end
end
end
"""
Arguments\n
-----
Return\n
-----
Example\n
-----
```jldoctest
julia>
```
TODO\n
-----
[] update docstring
[] implement the function
Signature\n
-----
"""
function simulate(state::T, max_depth::Int)
function simulate(state::T, max_depth::Int) where {T<:Any}
total_reward = 0.0
for _ in 1:max_depth
action = select_action(state) # Implement your action selection function
@@ -224,9 +241,6 @@ end
"""
isLeaf(node::MCTSNode)::Bool = isempty(node.children)
# ------------------------------------------------------------------------------------------------ #
# Create a complete example using the defined MCTS functions #
# ------------------------------------------------------------------------------------------------ #
"""
Arguments\n
@@ -244,37 +258,128 @@ isLeaf(node::MCTSNode)::Bool = isempty(node.children)
TODO\n
-----
[] update docstring
[] implement the function
[] implement RAG to pull similar experience
Signature\n
-----
"""
function run_mcts(initial_state, actions, max_iterations::Int, max_depth::Int, w::Float64)
root = MCTSNode(initial_state, 0, 0.0, Dict())
function decisionMaker()
for _ in 1:max_iterations
node = root
while !isLeaf(node)
node = select(node, w)
end
end
expand(node, node.state, actions)
"""
leaf_node = node.children[node.state]
reward = simulate(leaf_node.state, max_depth)
backpropagate(leaf_node, reward)
Arguments\n
-----
Return\n
-----
Example\n
-----
```jldoctest
julia>
```
TODO\n
-----
[] update docstring
[] implement the function
Signature\n
-----
"""
function stateValueEstimator()
end
"""
Arguments\n
-----
Return\n
-----
Example\n
-----
```jldoctest
julia>
```
TODO\n
-----
[] update docstring
[] implement the function
Signature\n
-----
"""
function reflector()
end
# ------------------------------------------------------------------------------------------------ #
# Create a complete example using the defined MCTS functions #
# ------------------------------------------------------------------------------------------------ #
""" Search for best action
Arguments\n
-----
initial state
initial state
decisionMaker::Function
decide what action to take
stateValueEstimator::Function
assess the value of the state
reflector::Function
generate lesson from trajectory and reward
n::Integer
how many times action will be sampled from decisionMaker
w::Float64
exploration weight
Return\n
-----
Example\n
-----
```jldoctest
julia>
```
TODO\n
-----
[] update docstring
Signature\n
-----
"""
function runMCTS(initialState, decisionMaker::Function, stateValueEstimator::Function,
reflector::Function, totalActionSampled::Integer, maxDepth::Integer,
maxIterations::Integer, w::Float64)
root = MCTSNode(initialState, 0, 0.0, Dict())
for _ in 1:maxIterations
node = root
while !isLeaf(node)
node = select(node, w)
end
expand(node, node.state, decisionMaker, stateValueEstimator,
n=n)
leaf_node = node.children[node.state]
reward = simulate(leaf_node.state, maxDepth)
backpropagate(leaf_node, reward)
end
best_child_state = argmax([child.total_reward / child.visits for child in values(root.children)])
return best_child_state
end
# Define your transition function and action selection function here
# Example usage
initial_state = 0
actions = [-1, 0, 1]
best_action = run_mcts(initial_state, actions, 1000, 10, 1.0)
println("Best action to take: ", best_action)