This commit is contained in:
narawat lamaiin
2024-04-30 10:56:14 +07:00
parent 6c3ef4414b
commit bbdaa2248a
4 changed files with 79 additions and 68 deletions

View File

@@ -114,10 +114,10 @@ function decisionMaker(a::T1, state::T2)::String where {T1<:agent, T2<:AbstractD
You should only respond in JSON format as describe below: You should only respond in JSON format as describe below:
{ {
"Thought_1": "reasoning", "Thought_1": "reasoning 1",
"Thought_2": "reasoning", "Thought_2": "reasoning 2",
... ...
"Thought_n": "reasoning", "Thought_n": "reasoning n",
"Action_1": {"name": "action to take", "input": "Action input"}, "Action_1": {"name": "action to take", "input": "Action input"},
"Observation_1": "result of the action" "Observation_1": "result of the action"
} }
@@ -131,10 +131,10 @@ function decisionMaker(a::T1, state::T2)::String where {T1<:agent, T2<:AbstractD
"Action_1": {"name": "chatbox", "input": "What will you use it for?"} "Action_1": {"name": "chatbox", "input": "What will you use it for?"}
} }
{ {
"Question": "I'm looking for a sedan.", "Question": "I'm looking for a sedan with an automatic driving feature.",
"Thought_1": "I have many types of sedans in my inventory, each with diverse features.", "Thought_1": "I have many types of sedans in my inventory, each with diverse features.",
"Thought_2": "It would be easier to make a recommendation if I knew what feature the user is looking for. I should ask the user.", "Thought_2": "But there is only 1 car that has the feature customer wanted.",
"Action_1": {"name": "chatbox", "input": "Do you have any specific feature in mind?"} "Action_1": {"name": "finish", "input": "I recommend a Tesla model Y. It has your requested feature and much more."}
} }
$reflect $reflect
@@ -304,18 +304,18 @@ function conversation(a::T, userinput::Dict) where {T<:agent}
else #[PENDING] new thinking else #[PENDING] new thinking
initialState = Dict( initialState = Dict{Symbol, Any}(
# deepcopy the info to prevent modifying the info unintentionally during MCTS planning # deepcopy the info to prevent modifying the info unintentionally during MCTS planning
:customerinfo=> deepcopy(a.keywordinfo[:customerinfo]), :customerinfo=> deepcopy(a.keywordinfo[:customerinfo]),
:storeinfo=> deepcopy(a.keywordinfo[:storeinfo]), :storeinfo=> deepcopy(a.keywordinfo[:storeinfo]),
:thoughtHistory=> Dict{Symbol, Any}( # contain question, thought_1, action_1, observation_1, thought_2, ... :thoughtHistory=> OrderedDict{Symbol, Any}( # contain question, thought_1, action_1, observation_1, thought_2, ...
:Question=> userinput[:text], :Question=> userinput[:text],
) )
) )
bestplan = runMCTS(a, initialState, decisionMaker, stateValueEstimator, reflector, bestplan = runMCTS(a, initialState, decisionMaker, stateValueEstimator, reflector,
isterminal, 3, 10, 1000, 1.0) isterminal, 2, 10, 1000, 1.0)
error("---> bestplan") error("---> bestplan")
# actor loop(bestplan) # actor loop(bestplan)

View File

@@ -1,6 +1,6 @@
module llmfunction module llmfunction
# export wikisearch, winestock, askbox export virtualWineCustomerChatbox
using HTTP, JSON3, URIs, Random using HTTP, JSON3, URIs, Random
using GeneralUtils using GeneralUtils
@@ -63,10 +63,6 @@ end
julia> julia>
``` ```
# TODO
- [x] update docstring
- [TESTING] implement the function
# Signature # Signature
""" """
function virtualWineCustomerChatbox(a::T1, input::T2)::String where {T1<:agent, T2<:AbstractString} function virtualWineCustomerChatbox(a::T1, input::T2)::String where {T1<:agent, T2<:AbstractString}
@@ -89,6 +85,7 @@ function virtualWineCustomerChatbox(a::T1, input::T2)::String where {T1<:agent,
receiverName= "virtualWineCustomer", receiverName= "virtualWineCustomer",
mqttBroker= a.config[:mqttServerInfo][:broker], mqttBroker= a.config[:mqttServerInfo][:broker],
mqttBrokerPort= a.config[:mqttServerInfo][:port], mqttBrokerPort= a.config[:mqttServerInfo][:port],
msgId = "dummyid" #CHANGE remove after testing finished
) )
outgoingMsg = Dict( outgoingMsg = Dict(

View File

@@ -5,7 +5,7 @@
module mcts module mcts
export MCTSNode, runMCTS export MCTSNode, runMCTS, isleaf
using Dates, UUIDs, DataStructures, JSON3, Random using Dates, UUIDs, DataStructures, JSON3, Random
using GeneralUtils using GeneralUtils
@@ -18,34 +18,30 @@ using ..type, ..llmfunction
# Arguments # Arguments
- `state::T` - `state::T`
a state of a game. Can be a Dict or something else. a state of a game. Can be a Dict or something else.
For example:
state = Dict(
:info=> Dict(), # keyword info
:thoughtHistory=> Dict(
:question=> _,
:thought_1=> _,
:action_1=> _,
:observation_1=> _,
:thought_2=> _,
...
)
)
- `visits::Integer ` - `visits::Integer `
number of time the game visits this state number of time the game visits this state
- `stateValue::Float64` - `stateValue::Float64`
state value state value
- `children::Dict{T, MCTSNode}`
children node
# Return # Return
- `nothing`
# Example # Example
```jldoctest ```jldoctest
julia> julia> state = Dict(
:info=> Dict(), # keyword info
:thoughtHistory=> Dict(
:question=> _,
:thought_1=> _,
:action_1=> _,
:observation_1=> _,
:thought_2=> _,
...
)
)
``` ```
# TODO
[] update docstring
[x] implement the function
# Signature # Signature
""" """
struct MCTSNode{T<:AbstractDict} struct MCTSNode{T<:AbstractDict}
@@ -131,14 +127,15 @@ function expand(a::T1, node::MCTSNode, state::T2, decisionMaker::Function, state
) )
:Observation_1 => "" :Observation_1 => ""
""" """
@show state
@show thoughtDict
newstate = MCTStransition(a, node.state, thoughtDict) #[] Implement your transition function newstate = MCTStransition(a, node.state, thoughtDict) #[] Implement your transition function
if newstate keys(node.children) if newstate keys(node.children)
node.children[newstate] = MCTSNode(newstate, 0, 0.0, Dict{T, MCTSNode}()) statetype = typeof(state)
node.children[newstate] = MCTSNode(newstate, 0, 0.0, Dict{statetype, MCTSNode}())
end end
end end
error("--> expand")
end end
""" """
@@ -208,7 +205,7 @@ end
one of YiemAgent's agent one of YiemAgent's agent
- `state::T2` - `state::T2`
current game state current game state
- `thoughtDict::T2` - `thoughtDict::T3`
contain Thought, Action, Observation contain Thought, Action, Observation
# Return # Return
@@ -217,26 +214,32 @@ end
# Example # Example
```jldoctest ```jldoctest
julia> thoughtDict = Dict( julia> state = Dict{Symbol, Dict{Symbol, Any}}(
:Question=> "I want to buy a bottle of wine." :thoughtHistory => Dict(:Question => "Hello, I want to buy a bottle of wine."),
:Thought_1=> "The customer wants to buy a bottle of wine. This is a good start!", :storeinfo => Dict(),
:Action_1=> Dict{Symbol, Any}( :customerinfo => Dict()
:name=>"Chatbox", )
:input=>"What occasion are you buying the wine for?" julia> thoughtDict = Dict(
), :Question=> "I want to buy a bottle of wine.",
:Observation_1 => "" :Thought_1=> "The customer wants to buy a bottle of wine.",
) :Action_1=> Dict{Symbol, Any}(
:name=>"Chatbox",
:input=>"What occasion are you buying the wine for?",
),
:Observation_1 => ""
)
``` ```
# TODO # TODO
- [x] update docstring - [] update docstring
- [TESTING] implement the function - [PENDING] add other actions
# Signature # Signature
""" """
function MCTStransition(a::T1, state::T2, thoughtDict::T2)::AbstractDict where {T1<:agent, T2<:AbstractDict} function MCTStransition(a::T1, state::T2,
latestThoughtKey, latestindice = GeneralUtils.findHighestIndexKey(thoughtDict, "Thought") thoughtDict::T3)::AbstractDict where {T1<:agent, T2<:AbstractDict, T3<:AbstractDict}
latestActionKey = GeneralUtils.findHighestIndexKey(thoughtDict, "Action") latestThoughtKey, _ = GeneralUtils.findHighestIndexKey(thoughtDict, "Thought")
latestActionKey, latestActionIndice = GeneralUtils.findHighestIndexKey(thoughtDict, "Action")
_action = thoughtDict[latestActionKey] _action = thoughtDict[latestActionKey]
actionname = _action[:name] actionname = _action[:name]
actioninput = _action[:input] actioninput = _action[:input]
@@ -244,7 +247,7 @@ function MCTStransition(a::T1, state::T2, thoughtDict::T2)::AbstractDict where {
# map action and input() to llm function # map action and input() to llm function
response = response =
if actionname == "chatbox" if actionname == "chatbox"
virtualWineCustomerChatbox(a, actioninput) # user virtu virtualWineCustomerChatbox(a, actioninput) # virtual customer
elseif actionname == "winestock" elseif actionname == "winestock"
elseif actionname == "finish" elseif actionname == "finish"
@@ -257,28 +260,38 @@ function MCTStransition(a::T1, state::T2, thoughtDict::T2)::AbstractDict where {
newstate = deepcopy(state) newstate = deepcopy(state)
newstate[:thoughtHistory][latestThoughtKey] = thoughtDict[latestThoughtKey] newstate[:thoughtHistory][latestThoughtKey] = thoughtDict[latestThoughtKey]
newstate[:thoughtHistory][latestActionKey] = thoughtDict[latestActionKey] newstate[:thoughtHistory][latestActionKey] = thoughtDict[latestActionKey]
latestObservationKey = Symbol("Observation_$(latestindice)") latestObservationKey = Symbol("Observation_$(latestActionIndice)")
newstate[:thoughtHistory][latestObservationKey] = response newstate[:thoughtHistory][latestObservationKey] = response
error("--> transition")
return newstate return newstate
end end
"""
""" Determine whether a node is a leaf node of a search tree.
# Arguments # Arguments
- `node::MCTSNode`
a search tree node
# Return # Return
- `result::Bool`
true if it is a leaf node, false otherwise.
# Example # Example
```jldoctest ```jldoctest
julia> julia> using Revise
``` julia> using YiemAgent, DataStructures
julia> initialState = Dict{Symbol, Any}(
:customerinfo=> Dict{Symbol, Any}(),
:storeinfo=> Dict{Symbol, Any}(),
# TODO :thoughtHistory=> OrderedDict{Symbol, Any}(
- [] update docstring :Question=> "How are you?",
- [x] implement the function )
)
julia> statetype = typeof(initialState)
julia> root = YiemAgent.MCTSNode(initialState, 0, 0.0, Dict{statetype, YiemAgent.MCTSNode}())
julia> YiemAgent.isleaf(root)
true
```
# Signature # Signature
""" """
@@ -366,13 +379,14 @@ function runMCTS(
end end
expand(a, node, node.state, decisionMaker, stateValueEstimator, n=n) expand(a, node, node.state, decisionMaker, stateValueEstimator, n=n)
error("---> runMCTS")
leaf_node = node.children[node.state] # mark leaf node leaf_node = node.children[node.state] # mark leaf node
reward = simulate(leaf_node.state, maxDepth) reward = simulate(leaf_node.state, maxDepth)
backpropagate(leaf_node, reward) backpropagate(leaf_node, reward)
end end
best_child_state = argmax([child.total_reward / child.visits for child in values(root.children)]) best_child_state = argmax([child.total_reward / child.visits for child in values(root.children)])
error("---> runMCTS")
return best_child_state return best_child_state
end end

View File

@@ -55,7 +55,7 @@ tools=Dict( # update input format
receiveInternalMsgChannel, receiveInternalMsgChannel,
agentConfig, agentConfig,
name= "assistant", name= "assistant",
id= "randomSessionID", # agent instance id id= "testingSessionID", # agent instance id
tools=tools, tools=tools,
) )