diff --git a/src/interface.jl b/src/interface.jl index 55be066..7982fd5 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -114,10 +114,10 @@ function decisionMaker(a::T1, state::T2)::String where {T1<:agent, T2<:AbstractD You should only respond in JSON format as describe below: { - "Thought_1": "reasoning", - "Thought_2": "reasoning", + "Thought_1": "reasoning 1", + "Thought_2": "reasoning 2", ... - "Thought_n": "reasoning", + "Thought_n": "reasoning n", "Action_1": {"name": "action to take", "input": "Action input"}, "Observation_1": "result of the action" } @@ -131,10 +131,10 @@ function decisionMaker(a::T1, state::T2)::String where {T1<:agent, T2<:AbstractD "Action_1": {"name": "chatbox", "input": "What will you use it for?"} } { - "Question": "I'm looking for a sedan.", + "Question": "I'm looking for a sedan with an automatic driving feature.", "Thought_1": "I have many types of sedans in my inventory, each with diverse features.", - "Thought_2": "It would be easier to make a recommendation if I knew what feature the user is looking for. I should ask the user.", - "Action_1": {"name": "chatbox", "input": "Do you have any specific feature in mind?"} + "Thought_2": "But there is only 1 car that has the feature customer wanted.", + "Action_1": {"name": "finish", "input": "I recommend a Tesla model Y. It has your requested feature and much more."} } $reflect @@ -304,18 +304,18 @@ function conversation(a::T, userinput::Dict) where {T<:agent} else #[PENDING] new thinking - initialState = Dict( + initialState = Dict{Symbol, Any}( # deepcopy the info to prevent modifying the info unintentionally during MCTS planning :customerinfo=> deepcopy(a.keywordinfo[:customerinfo]), :storeinfo=> deepcopy(a.keywordinfo[:storeinfo]), - :thoughtHistory=> Dict{Symbol, Any}( # contain question, thought_1, action_1, observation_1, thought_2, ... + :thoughtHistory=> OrderedDict{Symbol, Any}( # contain question, thought_1, action_1, observation_1, thought_2, ... :Question=> userinput[:text], - ) - ) + ) + ) bestplan = runMCTS(a, initialState, decisionMaker, stateValueEstimator, reflector, - isterminal, 3, 10, 1000, 1.0) + isterminal, 2, 10, 1000, 1.0) error("---> bestplan") # actor loop(bestplan) diff --git a/src/llmfunction.jl b/src/llmfunction.jl index 2ba2954..a4c705e 100644 --- a/src/llmfunction.jl +++ b/src/llmfunction.jl @@ -1,6 +1,6 @@ module llmfunction -# export wikisearch, winestock, askbox +export virtualWineCustomerChatbox using HTTP, JSON3, URIs, Random using GeneralUtils @@ -63,10 +63,6 @@ end julia> ``` -# TODO - - [x] update docstring - - [TESTING] implement the function - # Signature """ function virtualWineCustomerChatbox(a::T1, input::T2)::String where {T1<:agent, T2<:AbstractString} @@ -89,6 +85,7 @@ function virtualWineCustomerChatbox(a::T1, input::T2)::String where {T1<:agent, receiverName= "virtualWineCustomer", mqttBroker= a.config[:mqttServerInfo][:broker], mqttBrokerPort= a.config[:mqttServerInfo][:port], + msgId = "dummyid" #CHANGE remove after testing finished ) outgoingMsg = Dict( diff --git a/src/mcts.jl b/src/mcts.jl index ae8ff2c..c8d3511 100644 --- a/src/mcts.jl +++ b/src/mcts.jl @@ -5,7 +5,7 @@ module mcts -export MCTSNode, runMCTS +export MCTSNode, runMCTS, isleaf using Dates, UUIDs, DataStructures, JSON3, Random using GeneralUtils @@ -18,34 +18,30 @@ using ..type, ..llmfunction # Arguments - `state::T` a state of a game. Can be a Dict or something else. - For example: - state = Dict( - :info=> Dict(), # keyword info - :thoughtHistory=> Dict( - :question=> _, - :thought_1=> _, - :action_1=> _, - :observation_1=> _, - :thought_2=> _, - ... - ) - ) - `visits::Integer ` number of time the game visits this state - `stateValue::Float64` state value + - `children::Dict{T, MCTSNode}` + children node # Return - + - `nothing` # Example ```jldoctest -julia> +julia> state = Dict( + :info=> Dict(), # keyword info + :thoughtHistory=> Dict( + :question=> _, + :thought_1=> _, + :action_1=> _, + :observation_1=> _, + :thought_2=> _, + ... + ) + ) ``` -# TODO - [] update docstring - [x] implement the function - # Signature """ struct MCTSNode{T<:AbstractDict} @@ -131,14 +127,15 @@ function expand(a::T1, node::MCTSNode, state::T2, decisionMaker::Function, state ) :Observation_1 => "" """ - + @show state + @show thoughtDict newstate = MCTStransition(a, node.state, thoughtDict) #[] Implement your transition function if newstate ∉ keys(node.children) - node.children[newstate] = MCTSNode(newstate, 0, 0.0, Dict{T, MCTSNode}()) + statetype = typeof(state) + node.children[newstate] = MCTSNode(newstate, 0, 0.0, Dict{statetype, MCTSNode}()) end - end - error("--> expand") + end end """ @@ -208,7 +205,7 @@ end one of YiemAgent's agent - `state::T2` current game state - - `thoughtDict::T2` + - `thoughtDict::T3` contain Thought, Action, Observation # Return @@ -217,26 +214,32 @@ end # Example ```jldoctest -julia> thoughtDict = Dict( - :Question=> "I want to buy a bottle of wine." - :Thought_1=> "The customer wants to buy a bottle of wine. This is a good start!", - :Action_1=> Dict{Symbol, Any}( - :name=>"Chatbox", - :input=>"What occasion are you buying the wine for?" - ), - :Observation_1 => "" -) +julia> state = Dict{Symbol, Dict{Symbol, Any}}( + :thoughtHistory => Dict(:Question => "Hello, I want to buy a bottle of wine."), + :storeinfo => Dict(), + :customerinfo => Dict() + ) +julia> thoughtDict = Dict( + :Question=> "I want to buy a bottle of wine.", + :Thought_1=> "The customer wants to buy a bottle of wine.", + :Action_1=> Dict{Symbol, Any}( + :name=>"Chatbox", + :input=>"What occasion are you buying the wine for?", + ), + :Observation_1 => "" + ) ``` # TODO - - [x] update docstring - - [TESTING] implement the function + - [] update docstring + - [PENDING] add other actions # Signature """ -function MCTStransition(a::T1, state::T2, thoughtDict::T2)::AbstractDict where {T1<:agent, T2<:AbstractDict} - latestThoughtKey, latestindice = GeneralUtils.findHighestIndexKey(thoughtDict, "Thought") - latestActionKey = GeneralUtils.findHighestIndexKey(thoughtDict, "Action") +function MCTStransition(a::T1, state::T2, + thoughtDict::T3)::AbstractDict where {T1<:agent, T2<:AbstractDict, T3<:AbstractDict} + latestThoughtKey, _ = GeneralUtils.findHighestIndexKey(thoughtDict, "Thought") + latestActionKey, latestActionIndice = GeneralUtils.findHighestIndexKey(thoughtDict, "Action") _action = thoughtDict[latestActionKey] actionname = _action[:name] actioninput = _action[:input] @@ -244,7 +247,7 @@ function MCTStransition(a::T1, state::T2, thoughtDict::T2)::AbstractDict where { # map action and input() to llm function response = if actionname == "chatbox" - virtualWineCustomerChatbox(a, actioninput) # user virtu + virtualWineCustomerChatbox(a, actioninput) # virtual customer elseif actionname == "winestock" elseif actionname == "finish" @@ -257,28 +260,38 @@ function MCTStransition(a::T1, state::T2, thoughtDict::T2)::AbstractDict where { newstate = deepcopy(state) newstate[:thoughtHistory][latestThoughtKey] = thoughtDict[latestThoughtKey] newstate[:thoughtHistory][latestActionKey] = thoughtDict[latestActionKey] - latestObservationKey = Symbol("Observation_$(latestindice)") + latestObservationKey = Symbol("Observation_$(latestActionIndice)") newstate[:thoughtHistory][latestObservationKey] = response - - error("--> transition") return newstate end -""" + +""" Determine whether a node is a leaf node of a search tree. # Arguments - + - `node::MCTSNode` + a search tree node # Return - + - `result::Bool` + true if it is a leaf node, false otherwise. # Example ```jldoctest -julia> -``` +julia> using Revise +julia> using YiemAgent, DataStructures +julia> initialState = Dict{Symbol, Any}( + :customerinfo=> Dict{Symbol, Any}(), + :storeinfo=> Dict{Symbol, Any}(), -# TODO - - [] update docstring - - [x] implement the function + :thoughtHistory=> OrderedDict{Symbol, Any}( + :Question=> "How are you?", + ) + ) +julia> statetype = typeof(initialState) +julia> root = YiemAgent.MCTSNode(initialState, 0, 0.0, Dict{statetype, YiemAgent.MCTSNode}()) +julia> YiemAgent.isleaf(root) +true +``` # Signature """ @@ -366,13 +379,14 @@ function runMCTS( end expand(a, node, node.state, decisionMaker, stateValueEstimator, n=n) - error("---> runMCTS") + leaf_node = node.children[node.state] # mark leaf node reward = simulate(leaf_node.state, maxDepth) backpropagate(leaf_node, reward) end best_child_state = argmax([child.total_reward / child.visits for child in values(root.children)]) + error("---> runMCTS") return best_child_state end diff --git a/test/runtest.jl b/test/runtest.jl index 7f05ab4..266e882 100644 --- a/test/runtest.jl +++ b/test/runtest.jl @@ -55,7 +55,7 @@ tools=Dict( # update input format receiveInternalMsgChannel, agentConfig, name= "assistant", - id= "randomSessionID", # agent instance id + id= "testingSessionID", # agent instance id tools=tools, )