diff --git a/src/interface.jl b/src/interface.jl index 3e29722..c8f8b9e 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -1,6 +1,6 @@ module interface -export addNewMessage, conversation, decisionMaker, isterminal +export addNewMessage, conversation, decisionMaker, progressValueEstimator, isterminal using JSON3, DataStructures, Dates, UUIDs, HTTP, Random, MQTTClient using GeneralUtils @@ -73,6 +73,7 @@ julia> output_thoughtDict = Dict( [] implement RAG to pull similar experience [] use customerinfo [] user storeinfo + [] add reflect # Signature """ @@ -97,18 +98,6 @@ function decisionMaker(a::T1, state::T2)::Dict{Symbol, Any} where {T1<:agent, T2 # (trajectories) # """ - - """ - { - "Question": "I would like to buy a sedan.", - "Thought_1": "I have many cars in my inventory suitable for several usage scenarios.", - "Thought_2": "It would be better if I knew what the user intends to do with his car.", - "Thought_3": "I will ask the user what is the intended usecase", - "Action_1": {"name": "chatbox", "input": "What will you use it for?"} - } - """ - - _prompt = """ You are a helpful sommelier working for a wine store. @@ -180,14 +169,16 @@ function decisionMaker(a::T1, state::T2)::Dict{Symbol, Any} where {T1<:agent, T2 ) ) - result = GeneralUtils.sendReceiveMqttMsg(outgoingMsg) - thoughtJsonStr = result[:response][:text] + _response = GeneralUtils.sendReceiveMqttMsg(outgoingMsg) + thoughtJsonStr = _response[:response][:text] thoughtDict = copy(JSON3.read(thoughtJsonStr)) return thoughtDict end -""" +""" Assigns a scalar value to each new child node to be used for selec- +tion and backpropagation. This value effectively quantifies the agent’s progress in task completion, +serving as a heuristic to steer the search algorithm towards the most promising regions of the tree. # Arguments @@ -200,12 +191,76 @@ julia> # TODO - [] update docstring - - [] implement the function + - [x] implement the function # Signature """ -function stateValueEstimator() +function progressValueEstimator(a::T1, state::T2)::Tuple{String, Integer} where {T1<:agent, T2<:AbstractDict} + _prompt = + """ + Analyze the trajectories of a solution to a question answering task. The trajectories are + labeled by environmental observations about the situation, thoughts that can reason about + the current situation and actions that can be three types: + 1) winestock[query], which you can use to find wine in your inventory. + 2) chatbox[text], which you can use to interact with the user. + 3) finish[answer], which returns your wine reccommendation to the user. + Given a question and a trajectory, evaluate its correctness and provide your reasoning and + analysis in detail. Focus on the latest thought, action, and observation. Incomplete trajectories + can be correct if the thoughts and actions so far are correct, even if the answer is not found + yet. Do not generate additional thoughts or actions. Then ending with the correctness score s + where s is an integer from 1 to 10. + + You should only respond in JSON format as describe below: + { + "Thought_1": "reasoning 1", + "Action_1": {"name": "action to take", "input": "Action input"}, + "Observation_1": "result of the action", + "Evaluation_1": {"evaluation": "your evaluation", "score": your evaluation score} + } + + Here are some examples: + { + "Question": "I'm looking for a sedan with an automatic driving feature.", + "Thought_1": "I have many types of sedans in my inventory, each with diverse features.", + "Thought_2": "But there is only 1 model that has the feature customer wanted.", + "Thought_3": "I should check our inventory first to see if we have it.", + "Action_1": {"name": "inventory", "input": "Yiem model A"}, + "Observation_1": "Yiem model A is in stock.", + "Evaluation_1": {"evaluation": "This trajectory is correct as it is reasonable to check an inventory for info provided in the question. + It is also better to have simple searches corresponding to a single entity, making this the best action.", + "score": 10} + } + + $(JSON3.write(state[:thoughtHistory])) + """ + + prompt = formatLLMtext_llama3instruct("system", _prompt) + + msgMeta = GeneralUtils.generate_msgMeta( + a.config[:externalservice][:text2textinstruct][:mqtttopic], + senderName= "progressValueEstimator", + senderId= a.id, + receiverName= "text2textinstruct", + mqttBroker= a.config[:mqttServerInfo][:broker], + mqttBrokerPort= a.config[:mqttServerInfo][:port], + ) + + outgoingMsg = Dict( + :msgMeta=> msgMeta, + :payload=> Dict( + :text=> prompt, + ) + ) + + _response = GeneralUtils.sendReceiveMqttMsg(outgoingMsg) + thoughtJsonStr = _response[:response][:text] + thoughtDict = copy(JSON3.read(thoughtJsonStr)) + latestEvaluationKey, _ = + GeneralUtils.findHighestIndexKey(thoughtDict, "Evaluation") + evaluation = thoughtDict[latestEvaluationKey] + + return evaluation[:evaluation], evaluation[:score] end @@ -335,7 +390,7 @@ function conversation(a::T, userinput::Dict) where {T<:agent} :Question=> userinput[:text], ) ) - bestplan = runMCTS(a, initialState, decisionMaker, stateValueEstimator, reflector, + bestplan = runMCTS(a, initialState, decisionMaker, progressValueEstimator, reflector, isterminal, 2, 10, 1000, 1.0) error("---> bestplan") diff --git a/src/mcts.jl b/src/mcts.jl index aa89129..326cdde 100644 --- a/src/mcts.jl +++ b/src/mcts.jl @@ -42,12 +42,16 @@ julia> state = Dict( ) ``` +# TODO + [] update docstring + # Signature """ struct MCTSNode{T<:AbstractDict} + statekey::String state::T visits::Integer - stateValue::AbstractFloat + progressValue::Number children::Dict{String, MCTSNode} end @@ -90,12 +94,16 @@ end """ Expand selected node # Arguments + - `a::T1` + One of YiemAgent's agent - `node::MCTSNode` MCTS node - - `state::T` + - `state::T2` a state of a game. Can be a Dict or something else. - `decisionMaker::Function` - + a function that output Thought and Action + - `progressValueEstimator::Function` + a function that output trajectory progress score # Return @@ -104,14 +112,10 @@ end julia> ``` -# TODO - - [] update docstring - - [WORKING] implement the function - # Signature """ function expand(a::T1, node::MCTSNode, state::T2, decisionMaker::Function, - stateValueEstimator::Function; n::Integer=3) where {T1<:agent, T2<:AbstractDict} + progressValueEstimator::Function; n::Integer=3) where {T1<:agent, T2<:AbstractDict} # sampling action from decisionMaker for sample in 1:n @@ -120,15 +124,12 @@ function expand(a::T1, node::MCTSNode, state::T2, decisionMaker::Function, @show thoughtDict newStatekey, newstate = MCTStransition(a, node.state, thoughtDict) #[] Implement your transition function - if newStatekey ∉ keys(node.children)# BUG should be "key of the newstate" here not newstate itself - node.children[newStatekey] = MCTSNode(newstate, 0, 0.0, Dict{String, MCTSNode}()) + # add progressValueEstimator + _, progressValue = progressValueEstimator(a, newstate) + + if newStatekey ∉ keys(node.children) + node.children[newStatekey] = MCTSNode(newStatekey, newstate, 0, progressValue, Dict{String, MCTSNode}()) end - - # add stateValueEstimator - - - - end end @@ -145,23 +146,24 @@ julia> # TODO - [] update docstring - - [] implement the function + - [WORKING] implement the function - [] reward only comes at terminal state # Signature """ function simulate(state::T, max_depth::Int) where {T<:AbstractDict} - total_reward = 0.0 - for _ in 1:max_depth - #[] Implement your action selection function based on highest stateValue - action = select_action(state) # current state - state, reward = transition(state, action) # Implement transition function to a new state + error("--> simulate") + total_reward = 0.0 + for _ in 1:max_depth + #[] Implement your action selection function based on highest stateValue + action = select_action(state) # current state + state, reward = transition(state, action) # Implement transition function to a new state - #[] check for the terminal state + #[] check for the terminal state - total_reward += reward - end - return total_reward + total_reward += reward + end + return total_reward end """ @@ -332,7 +334,7 @@ end initial state - `decisionMaker::Function` decide what action to take - - `stateValueEstimator::Function` + - `progressValueEstimator::Function` assess the value of the state - `reflector::Function` generate lesson from trajectory and reward @@ -361,7 +363,7 @@ function runMCTS( a::T1, initialState, decisionMaker::Function, - stateValueEstimator::Function, + progressValueEstimator::Function, reflector::Function, isterminal::Function, n::Integer, @@ -369,7 +371,7 @@ function runMCTS( maxIterations::Integer, w::Float64) where {T1<:agent} - root = MCTSNode(initialState, 0, 0.0, Dict{String, MCTSNode}()) + root = MCTSNode("root", initialState, 0, 0.0, Dict{String, MCTSNode}()) for _ in 1:maxIterations node = root @@ -377,7 +379,7 @@ function runMCTS( node = select(node, w) end - expand(a, node, node.state, decisionMaker, stateValueEstimator, n=n) + expand(a, node, node.state, decisionMaker, progressValueEstimator, n=n) # from paper, just start simulation at this node. Not the node that newly expanded leaf_node = node diff --git a/src/type.jl b/src/type.jl index 4a57e40..c87bd48 100644 --- a/src/type.jl +++ b/src/type.jl @@ -74,9 +74,13 @@ abstract type agent end ) ``` + # TODO + - [] update docstring + - [x] implement the function + Signature\n ----- -""" #[] update docstring +""" @kwdef mutable struct sommelier <: agent name::String id::String