This commit is contained in:
narawat lamaiin
2024-05-03 10:32:41 +07:00
parent 8262423317
commit ef940b6ada
3 changed files with 111 additions and 50 deletions

View File

@@ -1,6 +1,6 @@
module interface
export addNewMessage, conversation, decisionMaker, isterminal
export addNewMessage, conversation, decisionMaker, progressValueEstimator, isterminal
using JSON3, DataStructures, Dates, UUIDs, HTTP, Random, MQTTClient
using GeneralUtils
@@ -73,6 +73,7 @@ julia> output_thoughtDict = Dict(
[] implement RAG to pull similar experience
[] use customerinfo
[] user storeinfo
[] add reflect
# Signature
"""
@@ -97,18 +98,6 @@ function decisionMaker(a::T1, state::T2)::Dict{Symbol, Any} where {T1<:agent, T2
# (trajectories)
# """
"""
{
"Question": "I would like to buy a sedan.",
"Thought_1": "I have many cars in my inventory suitable for several usage scenarios.",
"Thought_2": "It would be better if I knew what the user intends to do with his car.",
"Thought_3": "I will ask the user what is the intended usecase",
"Action_1": {"name": "chatbox", "input": "What will you use it for?"}
}
"""
_prompt =
"""
You are a helpful sommelier working for a wine store.
@@ -180,14 +169,16 @@ function decisionMaker(a::T1, state::T2)::Dict{Symbol, Any} where {T1<:agent, T2
)
)
result = GeneralUtils.sendReceiveMqttMsg(outgoingMsg)
thoughtJsonStr = result[:response][:text]
_response = GeneralUtils.sendReceiveMqttMsg(outgoingMsg)
thoughtJsonStr = _response[:response][:text]
thoughtDict = copy(JSON3.read(thoughtJsonStr))
return thoughtDict
end
"""
""" Assigns a scalar value to each new child node to be used for selec-
tion and backpropagation. This value effectively quantifies the agents progress in task completion,
serving as a heuristic to steer the search algorithm towards the most promising regions of the tree.
# Arguments
@@ -200,12 +191,76 @@ julia>
# TODO
- [] update docstring
- [] implement the function
- [x] implement the function
# Signature
"""
function stateValueEstimator()
function progressValueEstimator(a::T1, state::T2)::Tuple{String, Integer} where {T1<:agent, T2<:AbstractDict}
_prompt =
"""
Analyze the trajectories of a solution to a question answering task. The trajectories are
labeled by environmental observations about the situation, thoughts that can reason about
the current situation and actions that can be three types:
1) winestock[query], which you can use to find wine in your inventory.
2) chatbox[text], which you can use to interact with the user.
3) finish[answer], which returns your wine reccommendation to the user.
Given a question and a trajectory, evaluate its correctness and provide your reasoning and
analysis in detail. Focus on the latest thought, action, and observation. Incomplete trajectories
can be correct if the thoughts and actions so far are correct, even if the answer is not found
yet. Do not generate additional thoughts or actions. Then ending with the correctness score s
where s is an integer from 1 to 10.
You should only respond in JSON format as describe below:
{
"Thought_1": "reasoning 1",
"Action_1": {"name": "action to take", "input": "Action input"},
"Observation_1": "result of the action",
"Evaluation_1": {"evaluation": "your evaluation", "score": your evaluation score}
}
Here are some examples:
{
"Question": "I'm looking for a sedan with an automatic driving feature.",
"Thought_1": "I have many types of sedans in my inventory, each with diverse features.",
"Thought_2": "But there is only 1 model that has the feature customer wanted.",
"Thought_3": "I should check our inventory first to see if we have it.",
"Action_1": {"name": "inventory", "input": "Yiem model A"},
"Observation_1": "Yiem model A is in stock.",
"Evaluation_1": {"evaluation": "This trajectory is correct as it is reasonable to check an inventory for info provided in the question.
It is also better to have simple searches corresponding to a single entity, making this the best action.",
"score": 10}
}
$(JSON3.write(state[:thoughtHistory]))
"""
prompt = formatLLMtext_llama3instruct("system", _prompt)
msgMeta = GeneralUtils.generate_msgMeta(
a.config[:externalservice][:text2textinstruct][:mqtttopic],
senderName= "progressValueEstimator",
senderId= a.id,
receiverName= "text2textinstruct",
mqttBroker= a.config[:mqttServerInfo][:broker],
mqttBrokerPort= a.config[:mqttServerInfo][:port],
)
outgoingMsg = Dict(
:msgMeta=> msgMeta,
:payload=> Dict(
:text=> prompt,
)
)
_response = GeneralUtils.sendReceiveMqttMsg(outgoingMsg)
thoughtJsonStr = _response[:response][:text]
thoughtDict = copy(JSON3.read(thoughtJsonStr))
latestEvaluationKey, _ =
GeneralUtils.findHighestIndexKey(thoughtDict, "Evaluation")
evaluation = thoughtDict[latestEvaluationKey]
return evaluation[:evaluation], evaluation[:score]
end
@@ -335,7 +390,7 @@ function conversation(a::T, userinput::Dict) where {T<:agent}
:Question=> userinput[:text],
)
)
bestplan = runMCTS(a, initialState, decisionMaker, stateValueEstimator, reflector,
bestplan = runMCTS(a, initialState, decisionMaker, progressValueEstimator, reflector,
isterminal, 2, 10, 1000, 1.0)
error("---> bestplan")

View File

@@ -42,12 +42,16 @@ julia> state = Dict(
)
```
# TODO
[] update docstring
# Signature
"""
struct MCTSNode{T<:AbstractDict}
statekey::String
state::T
visits::Integer
stateValue::AbstractFloat
progressValue::Number
children::Dict{String, MCTSNode}
end
@@ -90,12 +94,16 @@ end
""" Expand selected node
# Arguments
- `a::T1`
One of YiemAgent's agent
- `node::MCTSNode`
MCTS node
- `state::T`
- `state::T2`
a state of a game. Can be a Dict or something else.
- `decisionMaker::Function`
a function that output Thought and Action
- `progressValueEstimator::Function`
a function that output trajectory progress score
# Return
@@ -104,14 +112,10 @@ end
julia>
```
# TODO
- [] update docstring
- [WORKING] implement the function
# Signature
"""
function expand(a::T1, node::MCTSNode, state::T2, decisionMaker::Function,
stateValueEstimator::Function; n::Integer=3) where {T1<:agent, T2<:AbstractDict}
progressValueEstimator::Function; n::Integer=3) where {T1<:agent, T2<:AbstractDict}
# sampling action from decisionMaker
for sample in 1:n
@@ -120,15 +124,12 @@ function expand(a::T1, node::MCTSNode, state::T2, decisionMaker::Function,
@show thoughtDict
newStatekey, newstate = MCTStransition(a, node.state, thoughtDict) #[] Implement your transition function
if newStatekey keys(node.children)# BUG should be "key of the newstate" here not newstate itself
node.children[newStatekey] = MCTSNode(newstate, 0, 0.0, Dict{String, MCTSNode}())
# add progressValueEstimator
_, progressValue = progressValueEstimator(a, newstate)
if newStatekey keys(node.children)
node.children[newStatekey] = MCTSNode(newStatekey, newstate, 0, progressValue, Dict{String, MCTSNode}())
end
# add stateValueEstimator
end
end
@@ -145,23 +146,24 @@ julia>
# TODO
- [] update docstring
- [] implement the function
- [WORKING] implement the function
- [] reward only comes at terminal state
# Signature
"""
function simulate(state::T, max_depth::Int) where {T<:AbstractDict}
total_reward = 0.0
for _ in 1:max_depth
#[] Implement your action selection function based on highest stateValue
action = select_action(state) # current state
state, reward = transition(state, action) # Implement transition function to a new state
error("--> simulate")
total_reward = 0.0
for _ in 1:max_depth
#[] Implement your action selection function based on highest stateValue
action = select_action(state) # current state
state, reward = transition(state, action) # Implement transition function to a new state
#[] check for the terminal state
#[] check for the terminal state
total_reward += reward
end
return total_reward
total_reward += reward
end
return total_reward
end
"""
@@ -332,7 +334,7 @@ end
initial state
- `decisionMaker::Function`
decide what action to take
- `stateValueEstimator::Function`
- `progressValueEstimator::Function`
assess the value of the state
- `reflector::Function`
generate lesson from trajectory and reward
@@ -361,7 +363,7 @@ function runMCTS(
a::T1,
initialState,
decisionMaker::Function,
stateValueEstimator::Function,
progressValueEstimator::Function,
reflector::Function,
isterminal::Function,
n::Integer,
@@ -369,7 +371,7 @@ function runMCTS(
maxIterations::Integer,
w::Float64) where {T1<:agent}
root = MCTSNode(initialState, 0, 0.0, Dict{String, MCTSNode}())
root = MCTSNode("root", initialState, 0, 0.0, Dict{String, MCTSNode}())
for _ in 1:maxIterations
node = root
@@ -377,7 +379,7 @@ function runMCTS(
node = select(node, w)
end
expand(a, node, node.state, decisionMaker, stateValueEstimator, n=n)
expand(a, node, node.state, decisionMaker, progressValueEstimator, n=n)
# from paper, just start simulation at this node. Not the node that newly expanded
leaf_node = node

View File

@@ -74,9 +74,13 @@ abstract type agent end
)
```
# TODO
- [] update docstring
- [x] implement the function
Signature\n
-----
""" #[] update docstring
"""
@kwdef mutable struct sommelier <: agent
name::String
id::String