update
This commit is contained in:
@@ -1,6 +1,6 @@
|
|||||||
module interface
|
module interface
|
||||||
|
|
||||||
export addNewMessage, conversation, decisionMaker, isterminal
|
export addNewMessage, conversation, decisionMaker, progressValueEstimator, isterminal
|
||||||
|
|
||||||
using JSON3, DataStructures, Dates, UUIDs, HTTP, Random, MQTTClient
|
using JSON3, DataStructures, Dates, UUIDs, HTTP, Random, MQTTClient
|
||||||
using GeneralUtils
|
using GeneralUtils
|
||||||
@@ -73,6 +73,7 @@ julia> output_thoughtDict = Dict(
|
|||||||
[] implement RAG to pull similar experience
|
[] implement RAG to pull similar experience
|
||||||
[] use customerinfo
|
[] use customerinfo
|
||||||
[] user storeinfo
|
[] user storeinfo
|
||||||
|
[] add reflect
|
||||||
|
|
||||||
# Signature
|
# Signature
|
||||||
"""
|
"""
|
||||||
@@ -97,18 +98,6 @@ function decisionMaker(a::T1, state::T2)::Dict{Symbol, Any} where {T1<:agent, T2
|
|||||||
# (trajectories)
|
# (trajectories)
|
||||||
# """
|
# """
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
{
|
|
||||||
"Question": "I would like to buy a sedan.",
|
|
||||||
"Thought_1": "I have many cars in my inventory suitable for several usage scenarios.",
|
|
||||||
"Thought_2": "It would be better if I knew what the user intends to do with his car.",
|
|
||||||
"Thought_3": "I will ask the user what is the intended usecase",
|
|
||||||
"Action_1": {"name": "chatbox", "input": "What will you use it for?"}
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
_prompt =
|
_prompt =
|
||||||
"""
|
"""
|
||||||
You are a helpful sommelier working for a wine store.
|
You are a helpful sommelier working for a wine store.
|
||||||
@@ -180,14 +169,16 @@ function decisionMaker(a::T1, state::T2)::Dict{Symbol, Any} where {T1<:agent, T2
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
result = GeneralUtils.sendReceiveMqttMsg(outgoingMsg)
|
_response = GeneralUtils.sendReceiveMqttMsg(outgoingMsg)
|
||||||
thoughtJsonStr = result[:response][:text]
|
thoughtJsonStr = _response[:response][:text]
|
||||||
thoughtDict = copy(JSON3.read(thoughtJsonStr))
|
thoughtDict = copy(JSON3.read(thoughtJsonStr))
|
||||||
return thoughtDict
|
return thoughtDict
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
"""
|
""" Assigns a scalar value to each new child node to be used for selec-
|
||||||
|
tion and backpropagation. This value effectively quantifies the agent’s progress in task completion,
|
||||||
|
serving as a heuristic to steer the search algorithm towards the most promising regions of the tree.
|
||||||
|
|
||||||
# Arguments
|
# Arguments
|
||||||
|
|
||||||
@@ -200,12 +191,76 @@ julia>
|
|||||||
|
|
||||||
# TODO
|
# TODO
|
||||||
- [] update docstring
|
- [] update docstring
|
||||||
- [] implement the function
|
- [x] implement the function
|
||||||
|
|
||||||
# Signature
|
# Signature
|
||||||
"""
|
"""
|
||||||
function stateValueEstimator()
|
function progressValueEstimator(a::T1, state::T2)::Tuple{String, Integer} where {T1<:agent, T2<:AbstractDict}
|
||||||
|
_prompt =
|
||||||
|
"""
|
||||||
|
Analyze the trajectories of a solution to a question answering task. The trajectories are
|
||||||
|
labeled by environmental observations about the situation, thoughts that can reason about
|
||||||
|
the current situation and actions that can be three types:
|
||||||
|
1) winestock[query], which you can use to find wine in your inventory.
|
||||||
|
2) chatbox[text], which you can use to interact with the user.
|
||||||
|
3) finish[answer], which returns your wine reccommendation to the user.
|
||||||
|
|
||||||
|
Given a question and a trajectory, evaluate its correctness and provide your reasoning and
|
||||||
|
analysis in detail. Focus on the latest thought, action, and observation. Incomplete trajectories
|
||||||
|
can be correct if the thoughts and actions so far are correct, even if the answer is not found
|
||||||
|
yet. Do not generate additional thoughts or actions. Then ending with the correctness score s
|
||||||
|
where s is an integer from 1 to 10.
|
||||||
|
|
||||||
|
You should only respond in JSON format as describe below:
|
||||||
|
{
|
||||||
|
"Thought_1": "reasoning 1",
|
||||||
|
"Action_1": {"name": "action to take", "input": "Action input"},
|
||||||
|
"Observation_1": "result of the action",
|
||||||
|
"Evaluation_1": {"evaluation": "your evaluation", "score": your evaluation score}
|
||||||
|
}
|
||||||
|
|
||||||
|
Here are some examples:
|
||||||
|
{
|
||||||
|
"Question": "I'm looking for a sedan with an automatic driving feature.",
|
||||||
|
"Thought_1": "I have many types of sedans in my inventory, each with diverse features.",
|
||||||
|
"Thought_2": "But there is only 1 model that has the feature customer wanted.",
|
||||||
|
"Thought_3": "I should check our inventory first to see if we have it.",
|
||||||
|
"Action_1": {"name": "inventory", "input": "Yiem model A"},
|
||||||
|
"Observation_1": "Yiem model A is in stock.",
|
||||||
|
"Evaluation_1": {"evaluation": "This trajectory is correct as it is reasonable to check an inventory for info provided in the question.
|
||||||
|
It is also better to have simple searches corresponding to a single entity, making this the best action.",
|
||||||
|
"score": 10}
|
||||||
|
}
|
||||||
|
|
||||||
|
$(JSON3.write(state[:thoughtHistory]))
|
||||||
|
"""
|
||||||
|
|
||||||
|
prompt = formatLLMtext_llama3instruct("system", _prompt)
|
||||||
|
|
||||||
|
msgMeta = GeneralUtils.generate_msgMeta(
|
||||||
|
a.config[:externalservice][:text2textinstruct][:mqtttopic],
|
||||||
|
senderName= "progressValueEstimator",
|
||||||
|
senderId= a.id,
|
||||||
|
receiverName= "text2textinstruct",
|
||||||
|
mqttBroker= a.config[:mqttServerInfo][:broker],
|
||||||
|
mqttBrokerPort= a.config[:mqttServerInfo][:port],
|
||||||
|
)
|
||||||
|
|
||||||
|
outgoingMsg = Dict(
|
||||||
|
:msgMeta=> msgMeta,
|
||||||
|
:payload=> Dict(
|
||||||
|
:text=> prompt,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
_response = GeneralUtils.sendReceiveMqttMsg(outgoingMsg)
|
||||||
|
thoughtJsonStr = _response[:response][:text]
|
||||||
|
thoughtDict = copy(JSON3.read(thoughtJsonStr))
|
||||||
|
latestEvaluationKey, _ =
|
||||||
|
GeneralUtils.findHighestIndexKey(thoughtDict, "Evaluation")
|
||||||
|
evaluation = thoughtDict[latestEvaluationKey]
|
||||||
|
|
||||||
|
return evaluation[:evaluation], evaluation[:score]
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
@@ -335,7 +390,7 @@ function conversation(a::T, userinput::Dict) where {T<:agent}
|
|||||||
:Question=> userinput[:text],
|
:Question=> userinput[:text],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
bestplan = runMCTS(a, initialState, decisionMaker, stateValueEstimator, reflector,
|
bestplan = runMCTS(a, initialState, decisionMaker, progressValueEstimator, reflector,
|
||||||
isterminal, 2, 10, 1000, 1.0)
|
isterminal, 2, 10, 1000, 1.0)
|
||||||
error("---> bestplan")
|
error("---> bestplan")
|
||||||
|
|
||||||
|
|||||||
62
src/mcts.jl
62
src/mcts.jl
@@ -42,12 +42,16 @@ julia> state = Dict(
|
|||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
# TODO
|
||||||
|
[] update docstring
|
||||||
|
|
||||||
# Signature
|
# Signature
|
||||||
"""
|
"""
|
||||||
struct MCTSNode{T<:AbstractDict}
|
struct MCTSNode{T<:AbstractDict}
|
||||||
|
statekey::String
|
||||||
state::T
|
state::T
|
||||||
visits::Integer
|
visits::Integer
|
||||||
stateValue::AbstractFloat
|
progressValue::Number
|
||||||
children::Dict{String, MCTSNode}
|
children::Dict{String, MCTSNode}
|
||||||
end
|
end
|
||||||
|
|
||||||
@@ -90,12 +94,16 @@ end
|
|||||||
""" Expand selected node
|
""" Expand selected node
|
||||||
|
|
||||||
# Arguments
|
# Arguments
|
||||||
|
- `a::T1`
|
||||||
|
One of YiemAgent's agent
|
||||||
- `node::MCTSNode`
|
- `node::MCTSNode`
|
||||||
MCTS node
|
MCTS node
|
||||||
- `state::T`
|
- `state::T2`
|
||||||
a state of a game. Can be a Dict or something else.
|
a state of a game. Can be a Dict or something else.
|
||||||
- `decisionMaker::Function`
|
- `decisionMaker::Function`
|
||||||
|
a function that output Thought and Action
|
||||||
|
- `progressValueEstimator::Function`
|
||||||
|
a function that output trajectory progress score
|
||||||
|
|
||||||
# Return
|
# Return
|
||||||
|
|
||||||
@@ -104,14 +112,10 @@ end
|
|||||||
julia>
|
julia>
|
||||||
```
|
```
|
||||||
|
|
||||||
# TODO
|
|
||||||
- [] update docstring
|
|
||||||
- [WORKING] implement the function
|
|
||||||
|
|
||||||
# Signature
|
# Signature
|
||||||
"""
|
"""
|
||||||
function expand(a::T1, node::MCTSNode, state::T2, decisionMaker::Function,
|
function expand(a::T1, node::MCTSNode, state::T2, decisionMaker::Function,
|
||||||
stateValueEstimator::Function; n::Integer=3) where {T1<:agent, T2<:AbstractDict}
|
progressValueEstimator::Function; n::Integer=3) where {T1<:agent, T2<:AbstractDict}
|
||||||
|
|
||||||
# sampling action from decisionMaker
|
# sampling action from decisionMaker
|
||||||
for sample in 1:n
|
for sample in 1:n
|
||||||
@@ -120,15 +124,12 @@ function expand(a::T1, node::MCTSNode, state::T2, decisionMaker::Function,
|
|||||||
@show thoughtDict
|
@show thoughtDict
|
||||||
newStatekey, newstate = MCTStransition(a, node.state, thoughtDict) #[] Implement your transition function
|
newStatekey, newstate = MCTStransition(a, node.state, thoughtDict) #[] Implement your transition function
|
||||||
|
|
||||||
if newStatekey ∉ keys(node.children)# BUG should be "key of the newstate" here not newstate itself
|
# add progressValueEstimator
|
||||||
node.children[newStatekey] = MCTSNode(newstate, 0, 0.0, Dict{String, MCTSNode}())
|
_, progressValue = progressValueEstimator(a, newstate)
|
||||||
|
|
||||||
|
if newStatekey ∉ keys(node.children)
|
||||||
|
node.children[newStatekey] = MCTSNode(newStatekey, newstate, 0, progressValue, Dict{String, MCTSNode}())
|
||||||
end
|
end
|
||||||
|
|
||||||
# add stateValueEstimator
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
@@ -145,23 +146,24 @@ julia>
|
|||||||
|
|
||||||
# TODO
|
# TODO
|
||||||
- [] update docstring
|
- [] update docstring
|
||||||
- [] implement the function
|
- [WORKING] implement the function
|
||||||
- [] reward only comes at terminal state
|
- [] reward only comes at terminal state
|
||||||
|
|
||||||
# Signature
|
# Signature
|
||||||
"""
|
"""
|
||||||
function simulate(state::T, max_depth::Int) where {T<:AbstractDict}
|
function simulate(state::T, max_depth::Int) where {T<:AbstractDict}
|
||||||
total_reward = 0.0
|
error("--> simulate")
|
||||||
for _ in 1:max_depth
|
total_reward = 0.0
|
||||||
#[] Implement your action selection function based on highest stateValue
|
for _ in 1:max_depth
|
||||||
action = select_action(state) # current state
|
#[] Implement your action selection function based on highest stateValue
|
||||||
state, reward = transition(state, action) # Implement transition function to a new state
|
action = select_action(state) # current state
|
||||||
|
state, reward = transition(state, action) # Implement transition function to a new state
|
||||||
|
|
||||||
#[] check for the terminal state
|
#[] check for the terminal state
|
||||||
|
|
||||||
total_reward += reward
|
total_reward += reward
|
||||||
end
|
end
|
||||||
return total_reward
|
return total_reward
|
||||||
end
|
end
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@@ -332,7 +334,7 @@ end
|
|||||||
initial state
|
initial state
|
||||||
- `decisionMaker::Function`
|
- `decisionMaker::Function`
|
||||||
decide what action to take
|
decide what action to take
|
||||||
- `stateValueEstimator::Function`
|
- `progressValueEstimator::Function`
|
||||||
assess the value of the state
|
assess the value of the state
|
||||||
- `reflector::Function`
|
- `reflector::Function`
|
||||||
generate lesson from trajectory and reward
|
generate lesson from trajectory and reward
|
||||||
@@ -361,7 +363,7 @@ function runMCTS(
|
|||||||
a::T1,
|
a::T1,
|
||||||
initialState,
|
initialState,
|
||||||
decisionMaker::Function,
|
decisionMaker::Function,
|
||||||
stateValueEstimator::Function,
|
progressValueEstimator::Function,
|
||||||
reflector::Function,
|
reflector::Function,
|
||||||
isterminal::Function,
|
isterminal::Function,
|
||||||
n::Integer,
|
n::Integer,
|
||||||
@@ -369,7 +371,7 @@ function runMCTS(
|
|||||||
maxIterations::Integer,
|
maxIterations::Integer,
|
||||||
w::Float64) where {T1<:agent}
|
w::Float64) where {T1<:agent}
|
||||||
|
|
||||||
root = MCTSNode(initialState, 0, 0.0, Dict{String, MCTSNode}())
|
root = MCTSNode("root", initialState, 0, 0.0, Dict{String, MCTSNode}())
|
||||||
|
|
||||||
for _ in 1:maxIterations
|
for _ in 1:maxIterations
|
||||||
node = root
|
node = root
|
||||||
@@ -377,7 +379,7 @@ function runMCTS(
|
|||||||
node = select(node, w)
|
node = select(node, w)
|
||||||
end
|
end
|
||||||
|
|
||||||
expand(a, node, node.state, decisionMaker, stateValueEstimator, n=n)
|
expand(a, node, node.state, decisionMaker, progressValueEstimator, n=n)
|
||||||
|
|
||||||
# from paper, just start simulation at this node. Not the node that newly expanded
|
# from paper, just start simulation at this node. Not the node that newly expanded
|
||||||
leaf_node = node
|
leaf_node = node
|
||||||
|
|||||||
@@ -74,9 +74,13 @@ abstract type agent end
|
|||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
# TODO
|
||||||
|
- [] update docstring
|
||||||
|
- [x] implement the function
|
||||||
|
|
||||||
Signature\n
|
Signature\n
|
||||||
-----
|
-----
|
||||||
""" #[] update docstring
|
"""
|
||||||
@kwdef mutable struct sommelier <: agent
|
@kwdef mutable struct sommelier <: agent
|
||||||
name::String
|
name::String
|
||||||
id::String
|
id::String
|
||||||
|
|||||||
Reference in New Issue
Block a user