update
This commit is contained in:
149
src/interface.jl
149
src/interface.jl
@@ -4,7 +4,7 @@ export addNewMessage, conversation
|
|||||||
|
|
||||||
using JSON3, DataStructures, Dates, UUIDs, HTTP, Random, MQTTClient
|
using JSON3, DataStructures, Dates, UUIDs, HTTP, Random, MQTTClient
|
||||||
using GeneralUtils
|
using GeneralUtils
|
||||||
using ..type, ..util, ..llmfunction
|
using ..type, ..util, ..llmfunction, ..mcts
|
||||||
|
|
||||||
# ------------------------------------------------------------------------------------------------ #
|
# ------------------------------------------------------------------------------------------------ #
|
||||||
# pythoncall setting #
|
# pythoncall setting #
|
||||||
@@ -85,100 +85,103 @@ using ..type, ..util, ..llmfunction
|
|||||||
Signature\n
|
Signature\n
|
||||||
-----
|
-----
|
||||||
"""
|
"""
|
||||||
function conversation(a::T) where {T<:agent}
|
function conversation(a::T, userinput::Dict) where {T<:agent}
|
||||||
"""
|
"""
|
||||||
[] update document
|
[] update document
|
||||||
[] MCTS() for planning
|
[x] MCTS() for planning
|
||||||
"""
|
"""
|
||||||
while true
|
|
||||||
# check for incoming user message
|
|
||||||
if isready(a.receiveUserMsgChannel)
|
|
||||||
incomingMsg = take!(a.receiveUserMsgChannel)
|
|
||||||
incomingPayload = incomingMsg[:payload]
|
|
||||||
|
|
||||||
# "newtopic" command to delete chat history
|
# "newtopic" command to delete chat history
|
||||||
if incomingPayload[:text] == "newtopic"
|
if userinput[:text] == "newtopic"
|
||||||
clearhistory(a)
|
clearhistory(a)
|
||||||
msgMeta = deepcopy(a.msgMeta)
|
|
||||||
msgMeta[:sendTopic] = incomingMsg[:msgMeta][:replyTopic]
|
|
||||||
msgMeta[:senderName] = "agent-backend"
|
|
||||||
msgMeta[:senderId] = a.id
|
|
||||||
msgMeta[:receiverName] = "agent-frontend"
|
|
||||||
msgMeta[:receiverId] = incomingMsg[:msgMeta][:senderId]
|
|
||||||
msgMeta[:replyTopic] = a.config[:receivemsg][:prompt]
|
|
||||||
msgMeta[:msgId] = string(uuid4())
|
|
||||||
msgMeta[:replyToMsgId] = incomingMsg[:msgMeta][:msgId]
|
|
||||||
|
|
||||||
outgoingMsg = Dict(
|
|
||||||
:msgMeta=> msgMeta,
|
|
||||||
:payload=> Dict(
|
|
||||||
:name=> a.name, # will be shown in frontend as agent name
|
|
||||||
:text => "Okay. What shall we talk about?",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
_ = GeneralUtils.sendMqttMsg(outgoingMsg)
|
|
||||||
|
|
||||||
else # a new thinking
|
|
||||||
# add usermsg to a.chathistory
|
|
||||||
addNewMessage(a, "user", usermsg)
|
|
||||||
|
|
||||||
#[WORKING] if the last used tool is a chatbox
|
|
||||||
if a.plan[:currenttrajectory][end][:action] == "chatbox"
|
|
||||||
#usermsg -> observation and continue actor loop as planned
|
|
||||||
|
|
||||||
|
return "Okay. What shall we talk about?"
|
||||||
|
|
||||||
else
|
else
|
||||||
#planning with MCTS() -> best plan
|
# add usermsg to a.chathistory
|
||||||
|
addNewMessage(a, "user", userinput[:text])
|
||||||
|
|
||||||
|
#[] if the last used tool is a chatbox, put usermsg -> observation and continue actor loop as planned
|
||||||
|
if !isempty(a.plan[:currenttrajectory]) &&
|
||||||
|
a.plan[:currenttrajectory][end][:action] == "chatbox"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
else #[WORKING] new thinking
|
||||||
|
|
||||||
|
|
||||||
|
initialState = 0
|
||||||
|
bestplan = runMCTS(initialState, decisionMaker, stateValueEstimator, reflector,
|
||||||
|
3, 10, 1000, 1.0)
|
||||||
|
error("---> bestplan")
|
||||||
# actor loop(bestplan)
|
# actor loop(bestplan)
|
||||||
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
sleep(1)
|
|
||||||
end
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# function conversation(a::T) where {T<:agent}
|
||||||
|
# """
|
||||||
|
# [] update document
|
||||||
|
# [x] MCTS() for planning
|
||||||
|
# """
|
||||||
|
# while true
|
||||||
|
# # check for incoming user message
|
||||||
|
# if isready(a.receiveUserMsgChannel)
|
||||||
|
# incomingMsg = take!(a.receiveUserMsgChannel)
|
||||||
|
# incomingPayload = incomingMsg[:payload]
|
||||||
|
# @show incomingMsg
|
||||||
|
|
||||||
|
# # "newtopic" command to delete chat history
|
||||||
|
# if incomingPayload[:text] == "newtopic"
|
||||||
|
# clearhistory(a)
|
||||||
|
# msgMeta = deepcopy(a.msgMeta)
|
||||||
|
# msgMeta[:sendTopic] = incomingMsg[:msgMeta][:replyTopic]
|
||||||
|
# msgMeta[:senderName] = "agent-backend"
|
||||||
|
# msgMeta[:senderId] = a.id
|
||||||
|
# msgMeta[:receiverName] = "agent-frontend"
|
||||||
|
# msgMeta[:receiverId] = incomingMsg[:msgMeta][:senderId]
|
||||||
|
# msgMeta[:replyTopic] = a.config[:receivemsg][:prompt]
|
||||||
|
# msgMeta[:msgId] = string(uuid4())
|
||||||
|
# msgMeta[:replyToMsgId] = incomingMsg[:msgMeta][:msgId]
|
||||||
|
|
||||||
|
# outgoingMsg = Dict(
|
||||||
|
# :msgMeta=> msgMeta,
|
||||||
|
# :payload=> Dict(
|
||||||
|
# :name=> a.name, # will be shown in frontend as agent name
|
||||||
|
# :text => "Okay. What shall we talk about?",
|
||||||
|
# )
|
||||||
|
# )
|
||||||
|
# # _ = GeneralUtils.sendMqttMsg(outgoingMsg)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# workstate = nothing
|
|
||||||
# response = nothing
|
|
||||||
|
|
||||||
# _ = addNewMessage(a, "user", usermsg)
|
|
||||||
# isuseplan = isUsePlans(a)
|
|
||||||
# # newinfo = extractinfo(a, usermsg)
|
|
||||||
# # a.env = newinfo !== nothing ? updateEnvState(a, newinfo) : a.env
|
|
||||||
# @show isuseplan
|
|
||||||
|
|
||||||
# if isuseplan # use plan before responding
|
|
||||||
# if haskey(a.memory[:shortterm], "User:") == false #[] should change role if user want to buy wine.
|
|
||||||
# a.memory[:shortterm]["User:"] = usermsg
|
|
||||||
# end
|
|
||||||
# workstate, response = work(a)
|
|
||||||
# end
|
|
||||||
|
|
||||||
# # if LLM using askbox, use returning msg form askbox as conversation response
|
|
||||||
# if workstate == "askbox" || workstate == "formulatedUserResponse"
|
|
||||||
# #[] paraphrase msg so that it is human friendlier word.
|
|
||||||
# else
|
# else
|
||||||
# response = chat_mistral_openorca(a)
|
# @show a = 55555
|
||||||
# response = split(response, "\n\n")[1]
|
# # add usermsg to a.chathistory
|
||||||
# response = split(response, "\n\n")[1]
|
# addNewMessage(a, "user", usermsg)
|
||||||
|
|
||||||
|
# #[] if the last used tool is a chatbox
|
||||||
|
# if a.plan[:currenttrajectory][end][:action] == "chatbox"
|
||||||
|
# #usermsg -> observation and continue actor loop as planned
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# else #[WORKING] new thinking
|
||||||
|
|
||||||
|
|
||||||
|
# initialState = 0
|
||||||
|
# bestplan = runMCTS(initialState, decisionMaker, stateValueEstimator, reflector,
|
||||||
|
# 3, 10, 1000, 1.0)
|
||||||
|
|
||||||
|
# # actor loop(best plan)
|
||||||
|
|
||||||
|
# end
|
||||||
|
# end
|
||||||
|
# end
|
||||||
|
# sleep(1)
|
||||||
|
# end
|
||||||
# end
|
# end
|
||||||
|
|
||||||
# response = removeTrailingCharacters(response)
|
|
||||||
# _ = addNewMessage(a, "assistant", response)
|
|
||||||
|
|
||||||
|
|
||||||
end
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
287
src/mcts copy 2.jl
Normal file
287
src/mcts copy 2.jl
Normal file
@@ -0,0 +1,287 @@
|
|||||||
|
""" To implement a Monte Carlo Tree Search (MCTS) algorithm in Julia with the UCT (Upper Confidence
|
||||||
|
Bound for Trees) selection function, you can follow the steps below: Define the necessary types
|
||||||
|
and functions for the MCTS algorithm:
|
||||||
|
"""
|
||||||
|
|
||||||
|
module MCTS
|
||||||
|
|
||||||
|
# export
|
||||||
|
|
||||||
|
using Dates, UUIDs, DataStructures, JSON3, Random
|
||||||
|
using GeneralUtils
|
||||||
|
|
||||||
|
# ---------------------------------------------- 100 --------------------------------------------- #
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
Arguments\n
|
||||||
|
-----
|
||||||
|
|
||||||
|
Return\n
|
||||||
|
-----
|
||||||
|
|
||||||
|
Example\n
|
||||||
|
-----
|
||||||
|
```jldoctest
|
||||||
|
julia>
|
||||||
|
```
|
||||||
|
|
||||||
|
TODO\n
|
||||||
|
-----
|
||||||
|
[] update docstring
|
||||||
|
[] implement the function
|
||||||
|
|
||||||
|
Signature\n
|
||||||
|
-----
|
||||||
|
"""
|
||||||
|
struct MCTSNode{T}
|
||||||
|
state::T
|
||||||
|
visits::Int
|
||||||
|
total_reward::Float64
|
||||||
|
children::Dict{T, MCTSNode}
|
||||||
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
Arguments\n
|
||||||
|
-----
|
||||||
|
|
||||||
|
Return\n
|
||||||
|
-----
|
||||||
|
|
||||||
|
Example\n
|
||||||
|
-----
|
||||||
|
```jldoctest
|
||||||
|
julia>
|
||||||
|
```
|
||||||
|
|
||||||
|
TODO\n
|
||||||
|
-----
|
||||||
|
[] update docstring
|
||||||
|
[WORKING] check child_node.total_reward w/ LATS paper. Which value total_reward representing
|
||||||
|
|
||||||
|
Signature\n
|
||||||
|
-----
|
||||||
|
"""
|
||||||
|
function select(node::MCTSNode, c::Float64)
|
||||||
|
max_uct = -Inf
|
||||||
|
selected_node = nothing
|
||||||
|
|
||||||
|
for (child_state, child_node) in node.children
|
||||||
|
uct_value = child_node.total_reward / child_node.visits +
|
||||||
|
c * sqrt(log(node.visits) / child_node.visits)
|
||||||
|
if uct_value > max_uct
|
||||||
|
max_uct = uct_value
|
||||||
|
selected_node = child_node
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
return selected_node
|
||||||
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
Arguments\n
|
||||||
|
-----
|
||||||
|
|
||||||
|
Return\n
|
||||||
|
-----
|
||||||
|
|
||||||
|
Example\n
|
||||||
|
-----
|
||||||
|
```jldoctest
|
||||||
|
julia>
|
||||||
|
```
|
||||||
|
|
||||||
|
TODO\n
|
||||||
|
-----
|
||||||
|
[] update docstring
|
||||||
|
[] implement the function
|
||||||
|
|
||||||
|
Signature\n
|
||||||
|
-----
|
||||||
|
"""
|
||||||
|
function expand(node::MCTSNode, state::T, actions::Vector{T})
|
||||||
|
for action in actions
|
||||||
|
new_state = transition(node.state, action) # Implement your transition function
|
||||||
|
if new_state ∉ keys(node.children)
|
||||||
|
node.children[new_state] = MCTSNode(new_state, 0, 0.0, Dict{T, MCTSNode}())
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
Arguments\n
|
||||||
|
-----
|
||||||
|
|
||||||
|
Return\n
|
||||||
|
-----
|
||||||
|
|
||||||
|
Example\n
|
||||||
|
-----
|
||||||
|
```jldoctest
|
||||||
|
julia>
|
||||||
|
```
|
||||||
|
|
||||||
|
TODO\n
|
||||||
|
-----
|
||||||
|
[] update docstring
|
||||||
|
[] implement the function
|
||||||
|
|
||||||
|
Signature\n
|
||||||
|
-----
|
||||||
|
"""
|
||||||
|
function simulate(state::T, max_depth::Int)
|
||||||
|
total_reward = 0.0
|
||||||
|
for _ in 1:max_depth
|
||||||
|
action = select_action(state) # Implement your action selection function
|
||||||
|
state, reward = transition(state, action) # Implement your transition function
|
||||||
|
total_reward += reward
|
||||||
|
end
|
||||||
|
return total_reward
|
||||||
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
Arguments\n
|
||||||
|
-----
|
||||||
|
|
||||||
|
Return\n
|
||||||
|
-----
|
||||||
|
|
||||||
|
Example\n
|
||||||
|
-----
|
||||||
|
```jldoctest
|
||||||
|
julia>
|
||||||
|
```
|
||||||
|
|
||||||
|
TODO\n
|
||||||
|
-----
|
||||||
|
[] update docstring
|
||||||
|
[] implement the function
|
||||||
|
|
||||||
|
Signature\n
|
||||||
|
-----
|
||||||
|
"""
|
||||||
|
function backpropagate(node::MCTSNode, reward::Float64)
|
||||||
|
node.visits += 1
|
||||||
|
node.total_reward += reward
|
||||||
|
if !isempty(node.children)
|
||||||
|
best_child = argmax([child.total_reward / child.visits for child in values(node.children)])
|
||||||
|
backpropagate(node.children[best_child], -reward)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
Arguments\n
|
||||||
|
-----
|
||||||
|
|
||||||
|
Return\n
|
||||||
|
-----
|
||||||
|
|
||||||
|
Example\n
|
||||||
|
-----
|
||||||
|
```jldoctest
|
||||||
|
julia>
|
||||||
|
```
|
||||||
|
|
||||||
|
TODO\n
|
||||||
|
-----
|
||||||
|
[] update docstring
|
||||||
|
[] implement the function
|
||||||
|
|
||||||
|
Signature\n
|
||||||
|
-----
|
||||||
|
"""
|
||||||
|
function transition(state, action)
|
||||||
|
|
||||||
|
end
|
||||||
|
|
||||||
|
""" Check whether a node is a leaf node
|
||||||
|
|
||||||
|
Arguments\n
|
||||||
|
-----
|
||||||
|
|
||||||
|
Return\n
|
||||||
|
-----
|
||||||
|
a task represent an agent
|
||||||
|
|
||||||
|
Example\n
|
||||||
|
-----
|
||||||
|
```jldoctest
|
||||||
|
julia>
|
||||||
|
```
|
||||||
|
|
||||||
|
TODO\n
|
||||||
|
-----
|
||||||
|
[] update docstring
|
||||||
|
[DONE] implement isLeaf()
|
||||||
|
|
||||||
|
Signature\n
|
||||||
|
-----
|
||||||
|
"""
|
||||||
|
isLeaf(node::MCTSNode)::Bool = isempty(node.children)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------------------------------------ #
|
||||||
|
# Create a complete example using the defined MCTS functions #
|
||||||
|
# ------------------------------------------------------------------------------------------------ #
|
||||||
|
"""
|
||||||
|
|
||||||
|
Arguments\n
|
||||||
|
-----
|
||||||
|
|
||||||
|
Return\n
|
||||||
|
-----
|
||||||
|
|
||||||
|
Example\n
|
||||||
|
-----
|
||||||
|
```jldoctest
|
||||||
|
julia>
|
||||||
|
```
|
||||||
|
|
||||||
|
TODO\n
|
||||||
|
-----
|
||||||
|
[] update docstring
|
||||||
|
|
||||||
|
Signature\n
|
||||||
|
-----
|
||||||
|
"""
|
||||||
|
function run_mcts(initial_state, actions, max_iterations::Int, max_depth::Int, w::Float64)
|
||||||
|
root = MCTSNode(initial_state, 0, 0.0, Dict())
|
||||||
|
|
||||||
|
for _ in 1:max_iterations
|
||||||
|
node = root
|
||||||
|
while !isLeaf(node)
|
||||||
|
node = select(node, w)
|
||||||
|
end
|
||||||
|
|
||||||
|
expand(node, node.state, actions)
|
||||||
|
|
||||||
|
leaf_node = node.children[node.state]
|
||||||
|
reward = simulate(leaf_node.state, max_depth)
|
||||||
|
backpropagate(leaf_node, reward)
|
||||||
|
end
|
||||||
|
|
||||||
|
best_child_state = argmax([child.total_reward / child.visits for child in values(root.children)])
|
||||||
|
return best_child_state
|
||||||
|
end
|
||||||
|
|
||||||
|
# Define your transition function and action selection function here
|
||||||
|
|
||||||
|
# Example usage
|
||||||
|
initial_state = 0
|
||||||
|
actions = [-1, 0, 1]
|
||||||
|
best_action = run_mcts(initial_state, actions, 1000, 10, 1.0)
|
||||||
|
println("Best action to take: ", best_action)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
end
|
||||||
173
src/mcts.jl
173
src/mcts.jl
@@ -3,9 +3,9 @@
|
|||||||
and functions for the MCTS algorithm:
|
and functions for the MCTS algorithm:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
module MCTS
|
module mcts
|
||||||
|
|
||||||
# export
|
export runMCTS
|
||||||
|
|
||||||
using Dates, UUIDs, DataStructures, JSON3, Random
|
using Dates, UUIDs, DataStructures, JSON3, Random
|
||||||
using GeneralUtils
|
using GeneralUtils
|
||||||
@@ -37,7 +37,7 @@ using GeneralUtils
|
|||||||
struct MCTSNode{T}
|
struct MCTSNode{T}
|
||||||
state::T
|
state::T
|
||||||
visits::Int
|
visits::Int
|
||||||
total_reward::Float64
|
stateValue::Float64
|
||||||
children::Dict{T, MCTSNode}
|
children::Dict{T, MCTSNode}
|
||||||
end
|
end
|
||||||
|
|
||||||
@@ -45,7 +45,10 @@ end
|
|||||||
|
|
||||||
Arguments\n
|
Arguments\n
|
||||||
-----
|
-----
|
||||||
|
node::MCTSNode
|
||||||
|
mcts node
|
||||||
|
w::Float64
|
||||||
|
exploration weight
|
||||||
Return\n
|
Return\n
|
||||||
-----
|
-----
|
||||||
|
|
||||||
@@ -58,25 +61,25 @@ end
|
|||||||
TODO\n
|
TODO\n
|
||||||
-----
|
-----
|
||||||
[] update docstring
|
[] update docstring
|
||||||
[] implement the function
|
[DONE] check childNode.total_reward w/ LATS paper. Which value total_reward representing
|
||||||
|
|
||||||
Signature\n
|
Signature\n
|
||||||
-----
|
-----
|
||||||
"""
|
"""
|
||||||
function select(node::MCTSNode, c::Float64)
|
function select(node::MCTSNode, w::Float64)
|
||||||
max_uct = -Inf
|
max_uct = -Inf
|
||||||
selected_node = nothing
|
selectedNode = nothing
|
||||||
|
|
||||||
for (child_state, child_node) in node.children
|
for (childState, childNode) in node.children
|
||||||
uct_value = child_node.total_reward / child_node.visits +
|
uctValue = childNode.stateValue +
|
||||||
c * sqrt(log(node.visits) / child_node.visits)
|
w * sqrt(log(node.visits) / childNode.visits)
|
||||||
if uct_value > max_uct
|
if uctValue > max_uct
|
||||||
max_uct = uct_value
|
max_uct = uctValue
|
||||||
selected_node = child_node
|
selectedNode = childNode
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
return selected_node
|
return selectedNode
|
||||||
end
|
end
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@@ -96,16 +99,30 @@ end
|
|||||||
TODO\n
|
TODO\n
|
||||||
-----
|
-----
|
||||||
[] update docstring
|
[] update docstring
|
||||||
[] implement the function
|
[WORKING] implement the function
|
||||||
|
|
||||||
Signature\n
|
Signature\n
|
||||||
-----
|
-----
|
||||||
"""
|
"""
|
||||||
function expand(node::MCTSNode, state::T, actions::Vector{T})
|
function expand(node::MCTSNode, state::T, decisionMaker::Function, stateValueEstimator::Function;
|
||||||
|
n::Integer=3) where {T<:Any}
|
||||||
|
|
||||||
|
actions = []
|
||||||
|
|
||||||
|
# sampling action from decisionMaker
|
||||||
|
# for nth in 1:n
|
||||||
|
|
||||||
|
|
||||||
|
# end
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
for action in actions
|
for action in actions
|
||||||
new_state = transition(node.state, action) # Implement your transition function
|
newState = transition(node.state, action) # Implement your transition function
|
||||||
if new_state ∉ keys(node.children)
|
if newState ∉ keys(node.children)
|
||||||
node.children[new_state] = MCTSNode(new_state, 0, 0.0, Dict{T, MCTSNode}())
|
node.children[newState] = MCTSNode(newState, 0, 0.0, Dict{T, MCTSNode}())
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
@@ -132,7 +149,7 @@ end
|
|||||||
Signature\n
|
Signature\n
|
||||||
-----
|
-----
|
||||||
"""
|
"""
|
||||||
function simulate(state::T, max_depth::Int)
|
function simulate(state::T, max_depth::Int) where {T<:Any}
|
||||||
total_reward = 0.0
|
total_reward = 0.0
|
||||||
for _ in 1:max_depth
|
for _ in 1:max_depth
|
||||||
action = select_action(state) # Implement your action selection function
|
action = select_action(state) # Implement your action selection function
|
||||||
@@ -224,9 +241,6 @@ end
|
|||||||
"""
|
"""
|
||||||
isLeaf(node::MCTSNode)::Bool = isempty(node.children)
|
isLeaf(node::MCTSNode)::Bool = isempty(node.children)
|
||||||
|
|
||||||
# ------------------------------------------------------------------------------------------------ #
|
|
||||||
# Create a complete example using the defined MCTS functions #
|
|
||||||
# ------------------------------------------------------------------------------------------------ #
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
Arguments\n
|
Arguments\n
|
||||||
@@ -244,23 +258,120 @@ isLeaf(node::MCTSNode)::Bool = isempty(node.children)
|
|||||||
TODO\n
|
TODO\n
|
||||||
-----
|
-----
|
||||||
[] update docstring
|
[] update docstring
|
||||||
|
[] implement the function
|
||||||
|
[] implement RAG to pull similar experience
|
||||||
|
|
||||||
Signature\n
|
Signature\n
|
||||||
-----
|
-----
|
||||||
"""
|
"""
|
||||||
function run_mcts(initial_state, actions, max_iterations::Int, max_depth::Int, w::Float64)
|
function decisionMaker()
|
||||||
root = MCTSNode(initial_state, 0, 0.0, Dict())
|
|
||||||
|
|
||||||
for _ in 1:max_iterations
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
Arguments\n
|
||||||
|
-----
|
||||||
|
|
||||||
|
Return\n
|
||||||
|
-----
|
||||||
|
|
||||||
|
Example\n
|
||||||
|
-----
|
||||||
|
```jldoctest
|
||||||
|
julia>
|
||||||
|
```
|
||||||
|
|
||||||
|
TODO\n
|
||||||
|
-----
|
||||||
|
[] update docstring
|
||||||
|
[] implement the function
|
||||||
|
|
||||||
|
Signature\n
|
||||||
|
-----
|
||||||
|
"""
|
||||||
|
function stateValueEstimator()
|
||||||
|
|
||||||
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
Arguments\n
|
||||||
|
-----
|
||||||
|
|
||||||
|
Return\n
|
||||||
|
-----
|
||||||
|
|
||||||
|
Example\n
|
||||||
|
-----
|
||||||
|
```jldoctest
|
||||||
|
julia>
|
||||||
|
```
|
||||||
|
|
||||||
|
TODO\n
|
||||||
|
-----
|
||||||
|
[] update docstring
|
||||||
|
[] implement the function
|
||||||
|
|
||||||
|
Signature\n
|
||||||
|
-----
|
||||||
|
"""
|
||||||
|
function reflector()
|
||||||
|
|
||||||
|
end
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------------------------------------ #
|
||||||
|
# Create a complete example using the defined MCTS functions #
|
||||||
|
# ------------------------------------------------------------------------------------------------ #
|
||||||
|
""" Search for best action
|
||||||
|
|
||||||
|
Arguments\n
|
||||||
|
-----
|
||||||
|
initial state
|
||||||
|
initial state
|
||||||
|
decisionMaker::Function
|
||||||
|
decide what action to take
|
||||||
|
stateValueEstimator::Function
|
||||||
|
assess the value of the state
|
||||||
|
reflector::Function
|
||||||
|
generate lesson from trajectory and reward
|
||||||
|
n::Integer
|
||||||
|
how many times action will be sampled from decisionMaker
|
||||||
|
w::Float64
|
||||||
|
exploration weight
|
||||||
|
|
||||||
|
Return\n
|
||||||
|
-----
|
||||||
|
|
||||||
|
Example\n
|
||||||
|
-----
|
||||||
|
```jldoctest
|
||||||
|
julia>
|
||||||
|
```
|
||||||
|
|
||||||
|
TODO\n
|
||||||
|
-----
|
||||||
|
[] update docstring
|
||||||
|
|
||||||
|
Signature\n
|
||||||
|
-----
|
||||||
|
"""
|
||||||
|
function runMCTS(initialState, decisionMaker::Function, stateValueEstimator::Function,
|
||||||
|
reflector::Function, totalActionSampled::Integer, maxDepth::Integer,
|
||||||
|
maxIterations::Integer, w::Float64)
|
||||||
|
root = MCTSNode(initialState, 0, 0.0, Dict())
|
||||||
|
|
||||||
|
for _ in 1:maxIterations
|
||||||
node = root
|
node = root
|
||||||
while !isLeaf(node)
|
while !isLeaf(node)
|
||||||
node = select(node, w)
|
node = select(node, w)
|
||||||
end
|
end
|
||||||
|
|
||||||
expand(node, node.state, actions)
|
expand(node, node.state, decisionMaker, stateValueEstimator,
|
||||||
|
n=n)
|
||||||
|
|
||||||
leaf_node = node.children[node.state]
|
leaf_node = node.children[node.state]
|
||||||
reward = simulate(leaf_node.state, max_depth)
|
reward = simulate(leaf_node.state, maxDepth)
|
||||||
backpropagate(leaf_node, reward)
|
backpropagate(leaf_node, reward)
|
||||||
end
|
end
|
||||||
|
|
||||||
@@ -268,13 +379,7 @@ function run_mcts(initial_state, actions, max_iterations::Int, max_depth::Int, w
|
|||||||
return best_child_state
|
return best_child_state
|
||||||
end
|
end
|
||||||
|
|
||||||
# Define your transition function and action selection function here
|
|
||||||
|
|
||||||
# Example usage
|
|
||||||
initial_state = 0
|
|
||||||
actions = [-1, 0, 1]
|
|
||||||
best_action = run_mcts(initial_state, actions, 1000, 10, 1.0)
|
|
||||||
println("Best action to take: ", best_action)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
109
test/runtest.jl
Normal file
109
test/runtest.jl
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
using Revise # remove when this package is completed
|
||||||
|
using YiemAgent, GeneralUtils, JSON3, MQTTClient, Dates, UUIDs
|
||||||
|
using Base.Threads
|
||||||
|
|
||||||
|
# ---------------------------------------------- 100 --------------------------------------------- #
|
||||||
|
|
||||||
|
config = copy(JSON3.read("config.json"))
|
||||||
|
|
||||||
|
instanceInternalTopic = config[:serviceInternalTopic][:value] * "/1"
|
||||||
|
|
||||||
|
client, connection = MakeConnection(config[:mqttServerInfo][:value][:broker],
|
||||||
|
config[:mqttServerInfo][:value][:port])
|
||||||
|
|
||||||
|
receiveUserMsgChannel = Channel{Dict}(4)
|
||||||
|
receiveInternalMsgChannel = Channel{Dict}(4)
|
||||||
|
|
||||||
|
msgMeta = GeneralUtils.generate_msgMeta(
|
||||||
|
"N/A",
|
||||||
|
replyTopic = config[:servicetopic][:value] # ask frontend reply to this instance_chat_topic
|
||||||
|
)
|
||||||
|
|
||||||
|
agentConfig = Dict(
|
||||||
|
:receiveprompt=>Dict(
|
||||||
|
:mqtttopic=> config[:servicetopic][:value], # topic to receive prompt i.e. frontend send msg to this topic
|
||||||
|
),
|
||||||
|
:receiveinternal=>Dict(
|
||||||
|
:mqtttopic=> instanceInternalTopic, # receive topic for model's internal
|
||||||
|
),
|
||||||
|
:text2text=>Dict(
|
||||||
|
:mqtttopic=> config[:text2text][:value],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Instantiate an agent
|
||||||
|
tools=Dict( # update input format
|
||||||
|
"askbox"=> Dict(
|
||||||
|
:description => "<askbox tool description>Useful for when you need to ask the user for more context. Do not ask the user their own question.</askbox tool description>",
|
||||||
|
:input => """<input>Input is a text in JSON format.</input><input example>{\"Q1\": \"How are you doing?\", \"Q2\": \"How may I help you?\"}</input example>""",
|
||||||
|
:output => "" ,
|
||||||
|
:func => nothing,
|
||||||
|
),
|
||||||
|
# "winestock"=> Dict(
|
||||||
|
# :description => "<winestock tool description>A handy tool for searching wine in your inventory that match the user preferences.</winestock tool description>",
|
||||||
|
# :input => """<input>Input is a JSON-formatted string that contains a detailed and precise search query.</input><input example>{\"wine type\": \"rose\", \"price\": \"max 35\", \"sweetness level\": \"sweet\", \"intensity level\": \"light bodied\", \"Tannin level\": \"low\", \"Acidity level\": \"low\"}</input example>""",
|
||||||
|
# :output => """<output>Output are wines that match the search query in JSON format.""",
|
||||||
|
# :func => ChatAgent.winestock,
|
||||||
|
# ),
|
||||||
|
"finalanswer"=> Dict(
|
||||||
|
:description => "<tool description>Useful for when you are ready to recommend wines to the user.</tool description>",
|
||||||
|
:input => """<input format>{\"finalanswer\": \"some text\"}.</input format><input example>{\"finalanswer\": \"I recommend Zena Crown Vista\"}</input example>""",
|
||||||
|
:output => "" ,
|
||||||
|
:func => nothing,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
a = YiemAgent.sommelier(
|
||||||
|
receiveUserMsgChannel,
|
||||||
|
receiveInternalMsgChannel,
|
||||||
|
msgMeta,
|
||||||
|
agentConfig,
|
||||||
|
name= "assistant",
|
||||||
|
id= "randomSessionID", # agent instance id
|
||||||
|
tools=tools,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = YiemAgent.conversation(a, Dict(:text=> "newtopic", ) )
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Reference in New Issue
Block a user