update
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
module interface
|
||||
|
||||
export addNewMessage, conversation, decisionMaker, isterminal
|
||||
export addNewMessage, conversation, decisionMaker, progressValueEstimator, isterminal
|
||||
|
||||
using JSON3, DataStructures, Dates, UUIDs, HTTP, Random, MQTTClient
|
||||
using GeneralUtils
|
||||
@@ -73,6 +73,7 @@ julia> output_thoughtDict = Dict(
|
||||
[] implement RAG to pull similar experience
|
||||
[] use customerinfo
|
||||
[] user storeinfo
|
||||
[] add reflect
|
||||
|
||||
# Signature
|
||||
"""
|
||||
@@ -97,18 +98,6 @@ function decisionMaker(a::T1, state::T2)::Dict{Symbol, Any} where {T1<:agent, T2
|
||||
# (trajectories)
|
||||
# """
|
||||
|
||||
|
||||
"""
|
||||
{
|
||||
"Question": "I would like to buy a sedan.",
|
||||
"Thought_1": "I have many cars in my inventory suitable for several usage scenarios.",
|
||||
"Thought_2": "It would be better if I knew what the user intends to do with his car.",
|
||||
"Thought_3": "I will ask the user what is the intended usecase",
|
||||
"Action_1": {"name": "chatbox", "input": "What will you use it for?"}
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
_prompt =
|
||||
"""
|
||||
You are a helpful sommelier working for a wine store.
|
||||
@@ -180,14 +169,16 @@ function decisionMaker(a::T1, state::T2)::Dict{Symbol, Any} where {T1<:agent, T2
|
||||
)
|
||||
)
|
||||
|
||||
result = GeneralUtils.sendReceiveMqttMsg(outgoingMsg)
|
||||
thoughtJsonStr = result[:response][:text]
|
||||
_response = GeneralUtils.sendReceiveMqttMsg(outgoingMsg)
|
||||
thoughtJsonStr = _response[:response][:text]
|
||||
thoughtDict = copy(JSON3.read(thoughtJsonStr))
|
||||
return thoughtDict
|
||||
end
|
||||
|
||||
|
||||
"""
|
||||
""" Assigns a scalar value to each new child node to be used for selec-
|
||||
tion and backpropagation. This value effectively quantifies the agent’s progress in task completion,
|
||||
serving as a heuristic to steer the search algorithm towards the most promising regions of the tree.
|
||||
|
||||
# Arguments
|
||||
|
||||
@@ -200,12 +191,76 @@ julia>
|
||||
|
||||
# TODO
|
||||
- [] update docstring
|
||||
- [] implement the function
|
||||
- [x] implement the function
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function stateValueEstimator()
|
||||
function progressValueEstimator(a::T1, state::T2)::Tuple{String, Integer} where {T1<:agent, T2<:AbstractDict}
|
||||
_prompt =
|
||||
"""
|
||||
Analyze the trajectories of a solution to a question answering task. The trajectories are
|
||||
labeled by environmental observations about the situation, thoughts that can reason about
|
||||
the current situation and actions that can be three types:
|
||||
1) winestock[query], which you can use to find wine in your inventory.
|
||||
2) chatbox[text], which you can use to interact with the user.
|
||||
3) finish[answer], which returns your wine reccommendation to the user.
|
||||
|
||||
Given a question and a trajectory, evaluate its correctness and provide your reasoning and
|
||||
analysis in detail. Focus on the latest thought, action, and observation. Incomplete trajectories
|
||||
can be correct if the thoughts and actions so far are correct, even if the answer is not found
|
||||
yet. Do not generate additional thoughts or actions. Then ending with the correctness score s
|
||||
where s is an integer from 1 to 10.
|
||||
|
||||
You should only respond in JSON format as describe below:
|
||||
{
|
||||
"Thought_1": "reasoning 1",
|
||||
"Action_1": {"name": "action to take", "input": "Action input"},
|
||||
"Observation_1": "result of the action",
|
||||
"Evaluation_1": {"evaluation": "your evaluation", "score": your evaluation score}
|
||||
}
|
||||
|
||||
Here are some examples:
|
||||
{
|
||||
"Question": "I'm looking for a sedan with an automatic driving feature.",
|
||||
"Thought_1": "I have many types of sedans in my inventory, each with diverse features.",
|
||||
"Thought_2": "But there is only 1 model that has the feature customer wanted.",
|
||||
"Thought_3": "I should check our inventory first to see if we have it.",
|
||||
"Action_1": {"name": "inventory", "input": "Yiem model A"},
|
||||
"Observation_1": "Yiem model A is in stock.",
|
||||
"Evaluation_1": {"evaluation": "This trajectory is correct as it is reasonable to check an inventory for info provided in the question.
|
||||
It is also better to have simple searches corresponding to a single entity, making this the best action.",
|
||||
"score": 10}
|
||||
}
|
||||
|
||||
$(JSON3.write(state[:thoughtHistory]))
|
||||
"""
|
||||
|
||||
prompt = formatLLMtext_llama3instruct("system", _prompt)
|
||||
|
||||
msgMeta = GeneralUtils.generate_msgMeta(
|
||||
a.config[:externalservice][:text2textinstruct][:mqtttopic],
|
||||
senderName= "progressValueEstimator",
|
||||
senderId= a.id,
|
||||
receiverName= "text2textinstruct",
|
||||
mqttBroker= a.config[:mqttServerInfo][:broker],
|
||||
mqttBrokerPort= a.config[:mqttServerInfo][:port],
|
||||
)
|
||||
|
||||
outgoingMsg = Dict(
|
||||
:msgMeta=> msgMeta,
|
||||
:payload=> Dict(
|
||||
:text=> prompt,
|
||||
)
|
||||
)
|
||||
|
||||
_response = GeneralUtils.sendReceiveMqttMsg(outgoingMsg)
|
||||
thoughtJsonStr = _response[:response][:text]
|
||||
thoughtDict = copy(JSON3.read(thoughtJsonStr))
|
||||
latestEvaluationKey, _ =
|
||||
GeneralUtils.findHighestIndexKey(thoughtDict, "Evaluation")
|
||||
evaluation = thoughtDict[latestEvaluationKey]
|
||||
|
||||
return evaluation[:evaluation], evaluation[:score]
|
||||
end
|
||||
|
||||
|
||||
@@ -335,7 +390,7 @@ function conversation(a::T, userinput::Dict) where {T<:agent}
|
||||
:Question=> userinput[:text],
|
||||
)
|
||||
)
|
||||
bestplan = runMCTS(a, initialState, decisionMaker, stateValueEstimator, reflector,
|
||||
bestplan = runMCTS(a, initialState, decisionMaker, progressValueEstimator, reflector,
|
||||
isterminal, 2, 10, 1000, 1.0)
|
||||
error("---> bestplan")
|
||||
|
||||
|
||||
62
src/mcts.jl
62
src/mcts.jl
@@ -42,12 +42,16 @@ julia> state = Dict(
|
||||
)
|
||||
```
|
||||
|
||||
# TODO
|
||||
[] update docstring
|
||||
|
||||
# Signature
|
||||
"""
|
||||
struct MCTSNode{T<:AbstractDict}
|
||||
statekey::String
|
||||
state::T
|
||||
visits::Integer
|
||||
stateValue::AbstractFloat
|
||||
progressValue::Number
|
||||
children::Dict{String, MCTSNode}
|
||||
end
|
||||
|
||||
@@ -90,12 +94,16 @@ end
|
||||
""" Expand selected node
|
||||
|
||||
# Arguments
|
||||
- `a::T1`
|
||||
One of YiemAgent's agent
|
||||
- `node::MCTSNode`
|
||||
MCTS node
|
||||
- `state::T`
|
||||
- `state::T2`
|
||||
a state of a game. Can be a Dict or something else.
|
||||
- `decisionMaker::Function`
|
||||
|
||||
a function that output Thought and Action
|
||||
- `progressValueEstimator::Function`
|
||||
a function that output trajectory progress score
|
||||
|
||||
# Return
|
||||
|
||||
@@ -104,14 +112,10 @@ end
|
||||
julia>
|
||||
```
|
||||
|
||||
# TODO
|
||||
- [] update docstring
|
||||
- [WORKING] implement the function
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function expand(a::T1, node::MCTSNode, state::T2, decisionMaker::Function,
|
||||
stateValueEstimator::Function; n::Integer=3) where {T1<:agent, T2<:AbstractDict}
|
||||
progressValueEstimator::Function; n::Integer=3) where {T1<:agent, T2<:AbstractDict}
|
||||
|
||||
# sampling action from decisionMaker
|
||||
for sample in 1:n
|
||||
@@ -120,15 +124,12 @@ function expand(a::T1, node::MCTSNode, state::T2, decisionMaker::Function,
|
||||
@show thoughtDict
|
||||
newStatekey, newstate = MCTStransition(a, node.state, thoughtDict) #[] Implement your transition function
|
||||
|
||||
if newStatekey ∉ keys(node.children)# BUG should be "key of the newstate" here not newstate itself
|
||||
node.children[newStatekey] = MCTSNode(newstate, 0, 0.0, Dict{String, MCTSNode}())
|
||||
# add progressValueEstimator
|
||||
_, progressValue = progressValueEstimator(a, newstate)
|
||||
|
||||
if newStatekey ∉ keys(node.children)
|
||||
node.children[newStatekey] = MCTSNode(newStatekey, newstate, 0, progressValue, Dict{String, MCTSNode}())
|
||||
end
|
||||
|
||||
# add stateValueEstimator
|
||||
|
||||
|
||||
|
||||
|
||||
end
|
||||
end
|
||||
|
||||
@@ -145,23 +146,24 @@ julia>
|
||||
|
||||
# TODO
|
||||
- [] update docstring
|
||||
- [] implement the function
|
||||
- [WORKING] implement the function
|
||||
- [] reward only comes at terminal state
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function simulate(state::T, max_depth::Int) where {T<:AbstractDict}
|
||||
total_reward = 0.0
|
||||
for _ in 1:max_depth
|
||||
#[] Implement your action selection function based on highest stateValue
|
||||
action = select_action(state) # current state
|
||||
state, reward = transition(state, action) # Implement transition function to a new state
|
||||
error("--> simulate")
|
||||
total_reward = 0.0
|
||||
for _ in 1:max_depth
|
||||
#[] Implement your action selection function based on highest stateValue
|
||||
action = select_action(state) # current state
|
||||
state, reward = transition(state, action) # Implement transition function to a new state
|
||||
|
||||
#[] check for the terminal state
|
||||
#[] check for the terminal state
|
||||
|
||||
total_reward += reward
|
||||
end
|
||||
return total_reward
|
||||
total_reward += reward
|
||||
end
|
||||
return total_reward
|
||||
end
|
||||
|
||||
"""
|
||||
@@ -332,7 +334,7 @@ end
|
||||
initial state
|
||||
- `decisionMaker::Function`
|
||||
decide what action to take
|
||||
- `stateValueEstimator::Function`
|
||||
- `progressValueEstimator::Function`
|
||||
assess the value of the state
|
||||
- `reflector::Function`
|
||||
generate lesson from trajectory and reward
|
||||
@@ -361,7 +363,7 @@ function runMCTS(
|
||||
a::T1,
|
||||
initialState,
|
||||
decisionMaker::Function,
|
||||
stateValueEstimator::Function,
|
||||
progressValueEstimator::Function,
|
||||
reflector::Function,
|
||||
isterminal::Function,
|
||||
n::Integer,
|
||||
@@ -369,7 +371,7 @@ function runMCTS(
|
||||
maxIterations::Integer,
|
||||
w::Float64) where {T1<:agent}
|
||||
|
||||
root = MCTSNode(initialState, 0, 0.0, Dict{String, MCTSNode}())
|
||||
root = MCTSNode("root", initialState, 0, 0.0, Dict{String, MCTSNode}())
|
||||
|
||||
for _ in 1:maxIterations
|
||||
node = root
|
||||
@@ -377,7 +379,7 @@ function runMCTS(
|
||||
node = select(node, w)
|
||||
end
|
||||
|
||||
expand(a, node, node.state, decisionMaker, stateValueEstimator, n=n)
|
||||
expand(a, node, node.state, decisionMaker, progressValueEstimator, n=n)
|
||||
|
||||
# from paper, just start simulation at this node. Not the node that newly expanded
|
||||
leaf_node = node
|
||||
|
||||
@@ -74,9 +74,13 @@ abstract type agent end
|
||||
)
|
||||
```
|
||||
|
||||
# TODO
|
||||
- [] update docstring
|
||||
- [x] implement the function
|
||||
|
||||
Signature\n
|
||||
-----
|
||||
""" #[] update docstring
|
||||
"""
|
||||
@kwdef mutable struct sommelier <: agent
|
||||
name::String
|
||||
id::String
|
||||
|
||||
Reference in New Issue
Block a user