update
This commit is contained in:
178
src/interface.jl
178
src/interface.jl
@@ -295,7 +295,7 @@ function evaluator(a::T1, state::T2)::Tuple{String, Integer} where {T1<:agent, T
|
|||||||
}
|
}
|
||||||
{"evaluation": "This trajectory is correct as it is reasonable to check an inventory for info provided in the question.
|
{"evaluation": "This trajectory is correct as it is reasonable to check an inventory for info provided in the question.
|
||||||
It is also better to have simple searches corresponding to a single entity, making this the best action.",
|
It is also better to have simple searches corresponding to a single entity, making this the best action.",
|
||||||
"score": 7
|
"score": 10
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
@@ -309,7 +309,7 @@ function evaluator(a::T1, state::T2)::Tuple{String, Integer} where {T1<:agent, T
|
|||||||
}
|
}
|
||||||
{"evaluation": "This trajectory is incorrect as my search term should be related to a 4-colors pen with a pencil in it,
|
{"evaluation": "This trajectory is incorrect as my search term should be related to a 4-colors pen with a pencil in it,
|
||||||
not a pen and a pencil seperately. A better search term should have been a 4-colors pen with a pencil, all-in-one.",
|
not a pen and a pencil seperately. A better search term should have been a 4-colors pen with a pencil, all-in-one.",
|
||||||
"score": 3
|
"score": 0
|
||||||
}
|
}
|
||||||
|
|
||||||
Let's begin!:
|
Let's begin!:
|
||||||
@@ -378,6 +378,143 @@ function evaluator(a::T1, state::T2)::Tuple{String, Integer} where {T1<:agent, T
|
|||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
|
# """
|
||||||
|
|
||||||
|
# # Arguments
|
||||||
|
# - `a::T1`
|
||||||
|
# one of Yiem's agent
|
||||||
|
# - `state::T2`
|
||||||
|
# a game state
|
||||||
|
|
||||||
|
# # Return
|
||||||
|
# - `evaluation::Tuple{String, Integer}`
|
||||||
|
# evaluation and score
|
||||||
|
|
||||||
|
# # Example
|
||||||
|
# ```jldoctest
|
||||||
|
# julia>
|
||||||
|
# ```
|
||||||
|
|
||||||
|
# # TODO
|
||||||
|
# - [] update docs
|
||||||
|
# - [] implement the function
|
||||||
|
|
||||||
|
# # Signature
|
||||||
|
# """
|
||||||
|
# function comparer(a::T1, state::T2)::Tuple{String, Integer} where {T1<:agent, T2<:AbstractDict}
|
||||||
|
|
||||||
|
# _prompt =
|
||||||
|
# """
|
||||||
|
# Analyze the trajectories of a solution to a question answering task. The trajectories are
|
||||||
|
# labeled by environmental observations about the situation, thoughts that can reason about
|
||||||
|
# the current situation and actions that can be three types:
|
||||||
|
# 1) winestock[query], which you can use to find wine in your inventory.
|
||||||
|
# 2) chatbox[text], which you can use to interact with the user.
|
||||||
|
# 3) recommendbox[answer], which returns your wine recommendation to the user.
|
||||||
|
|
||||||
|
# Given a question and a trajectory, evaluate its correctness and provide your reasoning and
|
||||||
|
# analysis in detail. Focus on the latest thought, action, and observation. Incomplete trajectories
|
||||||
|
# can be correct if the thoughts and actions so far are correct, even if the answer is not found
|
||||||
|
# yet. Do not generate additional thoughts or actions. Then ending with the correctness score s
|
||||||
|
# where s is an integer from 0 to 10.
|
||||||
|
|
||||||
|
# You should only respond in JSON format as describe below:
|
||||||
|
# {"evaluation": "your evaluation", "score": "your evaluation score"}
|
||||||
|
|
||||||
|
# Here are some examples:
|
||||||
|
# {
|
||||||
|
# "question": "I'm looking for a sedan with an automatic driving feature.",
|
||||||
|
# "thought_1": "I have many types of sedans in my inventory, each with diverse features.",
|
||||||
|
# "thought_2": "But there is only 1 model that has the feature customer wanted.",
|
||||||
|
# "thought_3": "I should check our inventory first to see if we have it.",
|
||||||
|
# "action_1": {"name": "inventory", "input": "Yiem model A"},
|
||||||
|
# "observation_1": "Yiem model A is in stock."
|
||||||
|
# }
|
||||||
|
# {"evaluation": "This trajectory is correct as it is reasonable to check an inventory for info provided in the question.
|
||||||
|
# It is also better to have simple searches corresponding to a single entity, making this the best action.",
|
||||||
|
# "score": 10
|
||||||
|
# }
|
||||||
|
|
||||||
|
# {
|
||||||
|
# "question": "Do you have an all-in-one pen with 4 colors and a pencil for sale?",
|
||||||
|
# "thought_1": "Let me check our inventory first to see if I have it.",
|
||||||
|
# "action_1": {"name": "inventory", "input": "pen with 4 color and a pencil."},
|
||||||
|
# "observation_1": "I found {1: "Pilot Dr. grip 4-in-1 pen", 2: "Rotting pencil"}",
|
||||||
|
# "thought_2": "Ok, I have what the user is asking. Let's tell the user.",
|
||||||
|
# "action_2": {"name": "chatbox", "input": "Yes, we do have a Pilot Dr. grip 4-in-1 pen and a Rotting pencil"},
|
||||||
|
# "observation_1": "This is not what I wanted."
|
||||||
|
# }
|
||||||
|
# {"evaluation": "This trajectory is incorrect as my search term should be related to a 4-colors pen with a pencil in it,
|
||||||
|
# not a pen and a pencil seperately. A better search term should have been a 4-colors pen with a pencil, all-in-one.",
|
||||||
|
# "score": 0
|
||||||
|
# }
|
||||||
|
|
||||||
|
# Let's begin!:
|
||||||
|
# $(JSON3.write(state[:thoughtHistory]))
|
||||||
|
# {"evaluation"
|
||||||
|
# """
|
||||||
|
|
||||||
|
# # apply LLM specific instruct format
|
||||||
|
# externalService = a.config[:externalservice][:text2textinstruct]
|
||||||
|
# llminfo = externalService[:llminfo]
|
||||||
|
# prompt =
|
||||||
|
# if llminfo[:name] == "llama3instruct"
|
||||||
|
# formatLLMtext_llama3instruct("system", _prompt)
|
||||||
|
# else
|
||||||
|
# error("llm model name is not defied yet $(@__LINE__)")
|
||||||
|
# end
|
||||||
|
|
||||||
|
# msgMeta = GeneralUtils.generate_msgMeta(
|
||||||
|
# a.config[:externalservice][:text2textinstruct][:mqtttopic],
|
||||||
|
# senderName= "evaluator",
|
||||||
|
# senderId= a.id,
|
||||||
|
# receiverName= "text2textinstruct",
|
||||||
|
# mqttBroker= a.config[:mqttServerInfo][:broker],
|
||||||
|
# mqttBrokerPort= a.config[:mqttServerInfo][:port],
|
||||||
|
# )
|
||||||
|
|
||||||
|
# outgoingMsg = Dict(
|
||||||
|
# :msgMeta=> msgMeta,
|
||||||
|
# :payload=> Dict(
|
||||||
|
# :text=> prompt,
|
||||||
|
# :kwargs=> Dict(
|
||||||
|
# :max_tokens=> 512,
|
||||||
|
# :stop=> ["<|eot_id|>"],
|
||||||
|
# )
|
||||||
|
# )
|
||||||
|
# )
|
||||||
|
|
||||||
|
# for attempt in 1:5
|
||||||
|
# try
|
||||||
|
# response = GeneralUtils.sendReceiveMqttMsg(outgoingMsg)
|
||||||
|
# _responseJsonStr = response[:response][:text]
|
||||||
|
# expectedJsonExample =
|
||||||
|
# """
|
||||||
|
# Here is an expected JSON format:
|
||||||
|
# {"evaluation": "...", "score": "..."}
|
||||||
|
# """
|
||||||
|
# responseJsonStr = jsoncorrection(a, _responseJsonStr, expectedJsonExample)
|
||||||
|
# evaluationDict = copy(JSON3.read(responseJsonStr))
|
||||||
|
|
||||||
|
# # check if dict has all required value
|
||||||
|
# dummya::AbstractString = evaluationDict[:evaluation]
|
||||||
|
# dummyb::Integer = evaluationDict[:score]
|
||||||
|
|
||||||
|
# return (evaluationDict[:evaluation], evaluationDict[:score])
|
||||||
|
# catch e
|
||||||
|
# io = IOBuffer()
|
||||||
|
# showerror(io, e)
|
||||||
|
# errorMsg = String(take!(io))
|
||||||
|
# st = sprint((io, v) -> show(io, "text/plain", v), stacktrace(catch_backtrace()))
|
||||||
|
# println("")
|
||||||
|
# @warn "Attempt $attempt. Error occurred: $errorMsg\n$st"
|
||||||
|
# println("")
|
||||||
|
# end
|
||||||
|
# end
|
||||||
|
# error("evaluator failed to generate an evaluation")
|
||||||
|
# end
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Arguments
|
# Arguments
|
||||||
@@ -600,8 +737,9 @@ julia> response = ChatAgent.conversation(newAgent, "Hi! how are you?")
|
|||||||
|
|
||||||
# TODO
|
# TODO
|
||||||
- [] update docstring
|
- [] update docstring
|
||||||
- [WORKING] MCTS() for planning
|
- [x] MCTS() for planning
|
||||||
- [] add recap to initialState for earlier completed question
|
- [] add recap to initialState for earlier completed question
|
||||||
|
- [WORKING] conversation loop
|
||||||
|
|
||||||
# Signature
|
# Signature
|
||||||
"""
|
"""
|
||||||
@@ -617,36 +755,36 @@ function conversation(a::T, userinput::Dict) where {T<:agent}
|
|||||||
# add usermsg to a.chathistory
|
# add usermsg to a.chathistory
|
||||||
addNewMessage(a, "user", userinput[:text])
|
addNewMessage(a, "user", userinput[:text])
|
||||||
|
|
||||||
#[] if the last used tool is a chatbox, put usermsg -> observation and continue actor loop as planned
|
currentstate =
|
||||||
if !isempty(a.plan[:currenttrajectory]) &&
|
if isempty(a.plan[:currenttrajectory])
|
||||||
a.plan[:currenttrajectory][end][:action] == "chatbox"
|
# set up initial state
|
||||||
|
Dict{Symbol, Any}(
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
else
|
|
||||||
initialState = Dict{Symbol, Any}(
|
|
||||||
|
|
||||||
# deepcopy the info to prevent modifying the info unintentionally during MCTS planning
|
# deepcopy the info to prevent modifying the info unintentionally during MCTS planning
|
||||||
:customerinfo=> deepcopy(a.keywordinfo[:customerinfo]),
|
:customerinfo=> deepcopy(a.keywordinfo[:customerinfo]),
|
||||||
:storeinfo=> deepcopy(a.keywordinfo[:storeinfo]),
|
:storeinfo=> deepcopy(a.keywordinfo[:storeinfo]),
|
||||||
:select=> nothing,
|
:userselect=> nothing,
|
||||||
:reward=> 0,
|
:reward=> 0,
|
||||||
:isterminal=> false,
|
:isterminal=> false,
|
||||||
:evaluation=> nothing,
|
:evaluation=> nothing,
|
||||||
:lesson=> nothing,
|
:lesson=> nothing,
|
||||||
|
:thoughtDict=> nothing,
|
||||||
|
:totalTrajectoryReward=> nothing,
|
||||||
:thoughtHistory=> OrderedDict{Symbol, Any}( # contain question, thought_1, action_1, observation_1, thought_2, ...
|
:thoughtHistory=> OrderedDict{Symbol, Any}( # contain question, thought_1, action_1, observation_1, thought_2, ...
|
||||||
# :recap=>,
|
# :recap=>,
|
||||||
:question=> userinput[:text],
|
:question=> userinput[:text],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
bestplan = runMCTS(a, initialState, decisionMaker, evaluator, reflector,
|
else
|
||||||
2, 3, 4, 1.0)
|
a.plan[:currenttrajectory]
|
||||||
error("---> bestplan")
|
|
||||||
|
|
||||||
# actor loop(bestplan)
|
|
||||||
|
|
||||||
end
|
end
|
||||||
|
|
||||||
|
bestNextState, besttrajectory = runMCTS(a, currentstate, decisionMaker, evaluator, reflector,
|
||||||
|
totalsample=3, maxDepth=2, maxiterations=1, explorationweight=1.0)
|
||||||
|
|
||||||
|
# transition
|
||||||
|
newstate = transition(a, bestNextState)
|
||||||
|
a.plan[:currenttrajectory] = newstate
|
||||||
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
module llmfunction
|
module llmfunction
|
||||||
|
|
||||||
export virtualWineCustomerChatbox, jsoncorrection, winestock,
|
export virtualWineUserChatbox, jsoncorrection, winestock,
|
||||||
virtualWineCustomerReccommendbox
|
virtualWineUserRecommendbox, userChatbox, userRecommendbox
|
||||||
|
|
||||||
using HTTP, JSON3, URIs, Random
|
using HTTP, JSON3, URIs, Random
|
||||||
using GeneralUtils
|
using GeneralUtils
|
||||||
@@ -26,8 +26,46 @@ julia>
|
|||||||
|
|
||||||
# Signature
|
# Signature
|
||||||
"""
|
"""
|
||||||
function chatbox(a::T1, input::T2) where {T1<:agent, T2<:AbstractString}
|
function userChatbox(a::T1, input::T2) where {T1<:agent, T2<:AbstractString}
|
||||||
error("--> chatbox")
|
error("--> userChatbox")
|
||||||
|
|
||||||
|
# put in model format
|
||||||
|
virtualWineCustomer = a.config[:externalservice][:virtualWineCustomer_1]
|
||||||
|
llminfo = virtualWineCustomer[:llminfo]
|
||||||
|
formattedinput =
|
||||||
|
if llminfo[:name] == "llama3instruct"
|
||||||
|
formatLLMtext_llama3instruct("assistant", input)
|
||||||
|
else
|
||||||
|
error("llm model name is not defied yet $(@__LINE__)")
|
||||||
|
end
|
||||||
|
|
||||||
|
# send formatted input to user using GeneralUtils.sendReceiveMqttMsg
|
||||||
|
|
||||||
|
|
||||||
|
# return response
|
||||||
|
|
||||||
|
end
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Arguments
|
||||||
|
|
||||||
|
# Return
|
||||||
|
|
||||||
|
# Example
|
||||||
|
```jldoctest
|
||||||
|
julia>
|
||||||
|
```
|
||||||
|
|
||||||
|
# TODO
|
||||||
|
- [] update docstring
|
||||||
|
- [PENDING] implement the function
|
||||||
|
|
||||||
|
# Signature
|
||||||
|
"""
|
||||||
|
function userRecommendbox(a::T1, input::T2) where {T1<:agent, T2<:AbstractString}
|
||||||
|
error("--> userRecommendbox")
|
||||||
|
|
||||||
# put in model format
|
# put in model format
|
||||||
virtualWineCustomer = a.config[:externalservice][:virtualWineCustomer_1]
|
virtualWineCustomer = a.config[:externalservice][:virtualWineCustomer_1]
|
||||||
@@ -69,7 +107,7 @@ julia>
|
|||||||
|
|
||||||
# Signature
|
# Signature
|
||||||
"""
|
"""
|
||||||
function virtualWineCustomerReccommendbox(a::T1, input
|
function virtualWineUserRecommendbox(a::T1, input
|
||||||
)::Union{Tuple{String, Number, Number, Bool}, Tuple{String, Nothing, Number, Bool}} where {T1<:agent}
|
)::Union{Tuple{String, Number, Number, Bool}, Tuple{String, Nothing, Number, Bool}} where {T1<:agent}
|
||||||
|
|
||||||
# put in model format
|
# put in model format
|
||||||
@@ -85,7 +123,7 @@ function virtualWineCustomerReccommendbox(a::T1, input
|
|||||||
# send formatted input to user using GeneralUtils.sendReceiveMqttMsg
|
# send formatted input to user using GeneralUtils.sendReceiveMqttMsg
|
||||||
msgMeta = GeneralUtils.generate_msgMeta(
|
msgMeta = GeneralUtils.generate_msgMeta(
|
||||||
virtualWineCustomer[:mqtttopic],
|
virtualWineCustomer[:mqtttopic],
|
||||||
senderName= "virtualWineCustomerReccommendbox",
|
senderName= "virtualWineUserRecommendbox",
|
||||||
senderId= a.id,
|
senderId= a.id,
|
||||||
receiverName= "virtualWineCustomer",
|
receiverName= "virtualWineCustomer",
|
||||||
mqttBroker= a.config[:mqttServerInfo][:broker],
|
mqttBroker= a.config[:mqttServerInfo][:broker],
|
||||||
@@ -126,11 +164,10 @@ julia>
|
|||||||
|
|
||||||
# TODO
|
# TODO
|
||||||
- [] update docs
|
- [] update docs
|
||||||
- [] add to remove <<< user option select >>> and <<| reward |>>
|
|
||||||
|
|
||||||
# Signature
|
# Signature
|
||||||
"""
|
"""
|
||||||
function virtualWineCustomerChatbox(a::T1, input::T2
|
function virtualWineUserChatbox(a::T1, input::T2
|
||||||
)::Union{Tuple{String, Number, Number, Bool}, Tuple{String, Nothing, Number, Bool}} where {T1<:agent, T2<:AbstractString}
|
)::Union{Tuple{String, Number, Number, Bool}, Tuple{String, Nothing, Number, Bool}} where {T1<:agent, T2<:AbstractString}
|
||||||
|
|
||||||
# put in model format
|
# put in model format
|
||||||
@@ -146,7 +183,7 @@ function virtualWineCustomerChatbox(a::T1, input::T2
|
|||||||
# send formatted input to user using GeneralUtils.sendReceiveMqttMsg
|
# send formatted input to user using GeneralUtils.sendReceiveMqttMsg
|
||||||
msgMeta = GeneralUtils.generate_msgMeta(
|
msgMeta = GeneralUtils.generate_msgMeta(
|
||||||
virtualWineCustomer[:mqtttopic],
|
virtualWineCustomer[:mqtttopic],
|
||||||
senderName= "virtualWineCustomerChatbox",
|
senderName= "virtualWineUserChatbox",
|
||||||
senderId= a.id,
|
senderId= a.id,
|
||||||
receiverName= "virtualWineCustomer",
|
receiverName= "virtualWineCustomer",
|
||||||
mqttBroker= a.config[:mqttServerInfo][:broker],
|
mqttBroker= a.config[:mqttServerInfo][:broker],
|
||||||
@@ -178,7 +215,7 @@ function virtualWineCustomerChatbox(a::T1, input::T2
|
|||||||
println("")
|
println("")
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
error("virtualWineCustomerChatbox failed to get a response")
|
error("virtualWineUserChatbox failed to get a response")
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
259
src/mcts.jl
259
src/mcts.jl
@@ -5,7 +5,8 @@
|
|||||||
|
|
||||||
module mcts
|
module mcts
|
||||||
|
|
||||||
export MCTSNode, runMCTS, isleaf
|
export MCTSNode, runMCTS, isleaf, selectBestNextState, selectBestTrajectory, transition,
|
||||||
|
userChatbox
|
||||||
|
|
||||||
using Dates, UUIDs, DataStructures, JSON3, Random, PrettyPrinting
|
using Dates, UUIDs, DataStructures, JSON3, Random, PrettyPrinting
|
||||||
using GeneralUtils
|
using GeneralUtils
|
||||||
@@ -51,9 +52,9 @@ mutable struct MCTSNode{T1<:AbstractDict, T2<:AbstractString}
|
|||||||
nodekey::T2
|
nodekey::T2
|
||||||
state::T1
|
state::T1
|
||||||
visits::Integer
|
visits::Integer
|
||||||
progressvalue::Number
|
progressvalue::Number # estimate value by LLM's reasoning
|
||||||
statevalue::Number
|
statevalue::Number # store discounted commulative reward (gather from its child node)
|
||||||
reward::Number
|
reward::Number # this node's own reward
|
||||||
isterminal::Bool
|
isterminal::Bool
|
||||||
parent::Union{MCTSNode, Nothing}
|
parent::Union{MCTSNode, Nothing}
|
||||||
children::Dict{String, MCTSNode}
|
children::Dict{String, MCTSNode}
|
||||||
@@ -132,23 +133,24 @@ julia>
|
|||||||
# Signature
|
# Signature
|
||||||
"""
|
"""
|
||||||
function expand(a::T1, node::MCTSNode, decisionMaker::Function,
|
function expand(a::T1, node::MCTSNode, decisionMaker::Function,
|
||||||
evaluator::Function, reflector::Function; n::Integer=3) where {T1<:agent}
|
evaluator::Function, reflector::Function; totalsample::Integer=3
|
||||||
|
) where {T1<:agent}
|
||||||
|
|
||||||
nthSample = 0
|
nthSample = 0
|
||||||
while true
|
while true
|
||||||
nthSample += 1
|
nthSample += 1
|
||||||
if nthSample <= n
|
if nthSample <= totalsample
|
||||||
thoughtDict = decisionMaker(a, node.state)
|
thoughtDict = decisionMaker(a, node.state)
|
||||||
println("---> expand() sample $nthSample")
|
println("---> expand() sample $nthSample")
|
||||||
pprintln(node.state[:thoughtHistory])
|
pprintln(node.state[:thoughtHistory])
|
||||||
pprintln(thoughtDict)
|
pprintln(thoughtDict)
|
||||||
newNodeKey, newstate, reward, isterminalstate =
|
node.state[:thoughtDict] = thoughtDict
|
||||||
MCTStransition(a, node.state, thoughtDict)
|
newNodeKey, newstate = MCTStransition(a, node.state)
|
||||||
|
|
||||||
# add evaluator
|
# add evaluator
|
||||||
stateevaluation, progressvalue = evaluator(a, newstate)
|
stateevaluation, progressvalue = evaluator(a, newstate)
|
||||||
|
|
||||||
if reward < 0
|
if newstate[:reward] < 0
|
||||||
pprint(newstate[:thoughtHistory])
|
pprint(newstate[:thoughtHistory])
|
||||||
newstate[:evaluation] = stateevaluation
|
newstate[:evaluation] = stateevaluation
|
||||||
newstate[:lesson] = reflector(a, newstate)
|
newstate[:lesson] = reflector(a, newstate)
|
||||||
@@ -167,8 +169,9 @@ function expand(a::T1, node::MCTSNode, decisionMaker::Function,
|
|||||||
end
|
end
|
||||||
|
|
||||||
if newNodeKey ∉ keys(node.children)
|
if newNodeKey ∉ keys(node.children)
|
||||||
node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, progressvalue, 0,
|
node.children[newNodeKey] =
|
||||||
reward, isterminalstate, node, Dict{String, MCTSNode}())
|
MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward],
|
||||||
|
newstate[:isterminal], node, Dict{String, MCTSNode}())
|
||||||
end
|
end
|
||||||
else
|
else
|
||||||
break
|
break
|
||||||
@@ -196,24 +199,30 @@ end
|
|||||||
julia>
|
julia>
|
||||||
```
|
```
|
||||||
|
|
||||||
|
# TODO
|
||||||
|
- [] update docs
|
||||||
|
|
||||||
# Signature
|
# Signature
|
||||||
"""
|
"""
|
||||||
function simulate(a::T, node::MCTSNode, decisionMaker::Function, evaluator::Function,
|
function simulate(a::T, node::MCTSNode, decisionMaker::Function, evaluator::Function,
|
||||||
reflector::Function; maxDepth::Integer=3, n::Integer=3)::Number where {T<:agent}
|
reflector::Function; maxDepth::Integer=3, totalsample::Integer=3
|
||||||
|
)::Union{Tuple{Number, Dict{Symbol, <:Any}}, Tuple{Number, Nothing}} where {T<:agent}
|
||||||
|
|
||||||
simTrajectoryReward = 0.0
|
simTrajectoryReward = 0.0
|
||||||
|
terminalstate = nothing
|
||||||
|
|
||||||
for depth in 1:maxDepth
|
for depth in 1:maxDepth
|
||||||
simTrajectoryReward += node.reward
|
simTrajectoryReward += node.reward
|
||||||
if node.isterminal
|
if node.isterminal
|
||||||
|
terminalstate = node.state
|
||||||
break
|
break
|
||||||
else
|
else
|
||||||
expand(a, node, decisionMaker, evaluator, reflector; n=n)
|
expand(a, node, decisionMaker, evaluator, reflector; totalsample=totalsample)
|
||||||
node = selectChildNode(node)
|
node = selectChildNode(node)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
return simTrajectoryReward
|
return (simTrajectoryReward, terminalstate)
|
||||||
end
|
end
|
||||||
|
|
||||||
""" Backpropagate reward along the simulation chain
|
""" Backpropagate reward along the simulation chain
|
||||||
@@ -285,20 +294,21 @@ julia> thoughtDict = Dict(
|
|||||||
|
|
||||||
# Signature
|
# Signature
|
||||||
"""
|
"""
|
||||||
function MCTStransition(a::T1, state::T2, thoughtDict::T3
|
function MCTStransition(a::T1, state::T2
|
||||||
)::Tuple{String, Dict{Symbol, <:Any}, <:Number, Bool} where {T1<:agent, T2<:AbstractDict, T3<:AbstractDict}
|
)::Tuple{String, Dict{Symbol, <:Any}} where {T1<:agent, T2<:AbstractDict}
|
||||||
|
|
||||||
|
thoughtDict = state[:thoughtDict]
|
||||||
actionname = thoughtDict[:action][:name]
|
actionname = thoughtDict[:action][:name]
|
||||||
actioninput = thoughtDict[:action][:input]
|
actioninput = thoughtDict[:action][:input]
|
||||||
|
|
||||||
# map action and input() to llm function
|
# map action and input() to llm function
|
||||||
response, select, reward, isterminal =
|
response, select, reward, isterminal =
|
||||||
if actionname == "chatbox"
|
if actionname == "chatbox"
|
||||||
virtualWineCustomerChatbox(a, actioninput) # virtual customer
|
virtualWineUserChatbox(a, actioninput) # virtual customer
|
||||||
elseif actionname == "winestock"
|
elseif actionname == "winestock"
|
||||||
winestock(a, actioninput)
|
winestock(a, actioninput)
|
||||||
elseif actionname == "recommendbox"
|
elseif actionname == "recommendbox"
|
||||||
virtualWineCustomerReccommendbox(a, actioninput)
|
virtualWineUserRecommendbox(a, actioninput)
|
||||||
else
|
else
|
||||||
error("undefined LLM function. Requesting $actionname")
|
error("undefined LLM function. Requesting $actionname")
|
||||||
end
|
end
|
||||||
@@ -321,7 +331,85 @@ function MCTStransition(a::T1, state::T2, thoughtDict::T3
|
|||||||
|
|
||||||
newNodeKey = GeneralUtils.uuid4snakecase()
|
newNodeKey = GeneralUtils.uuid4snakecase()
|
||||||
|
|
||||||
return (newNodeKey, newstate, reward, isterminal)
|
return (newNodeKey, newstate)
|
||||||
|
end
|
||||||
|
|
||||||
|
|
||||||
|
""" Get a new state
|
||||||
|
|
||||||
|
# Arguments
|
||||||
|
- `a::T1`
|
||||||
|
one of YiemAgent's agent
|
||||||
|
- `state::T2`
|
||||||
|
current game state
|
||||||
|
- `thoughtDict::T3`
|
||||||
|
contain Thought, Action, Observation
|
||||||
|
- `isterminal::Function`
|
||||||
|
a function to determine terminal state
|
||||||
|
|
||||||
|
# Return
|
||||||
|
- `(newNodeKey, newstate, isterminalstate, reward)::Tuple{String, Dict{Symbol, <:Any}, Bool, <:Number}`
|
||||||
|
|
||||||
|
# Example
|
||||||
|
```jldoctest
|
||||||
|
julia> state = Dict{Symbol, Dict{Symbol, Any}}(
|
||||||
|
:thoughtHistory => Dict(:question => "Hello, I want to buy a bottle of wine."),
|
||||||
|
:storeinfo => Dict(),
|
||||||
|
:customerinfo => Dict()
|
||||||
|
)
|
||||||
|
julia> thoughtDict = Dict(
|
||||||
|
:question=> "I want to buy a bottle of wine.",
|
||||||
|
:thought_1=> "The customer wants to buy a bottle of wine.",
|
||||||
|
:action_1=> Dict{Symbol, Any}(
|
||||||
|
:name=>"Chatbox",
|
||||||
|
:input=>"What occasion are you buying the wine for?",
|
||||||
|
),
|
||||||
|
:observation_1 => ""
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
# TODO
|
||||||
|
- [x] add other actions
|
||||||
|
- [] add embedding of newstate and store in newstate[:embedding]
|
||||||
|
|
||||||
|
# Signature
|
||||||
|
"""
|
||||||
|
function transition(a::T1, state::T2
|
||||||
|
)::Dict{Symbol, <:Any} where {T1<:agent, T2<:AbstractDict}
|
||||||
|
|
||||||
|
thoughtDict = state[:thoughtDict]
|
||||||
|
actionname = thoughtDict[:action][:name]
|
||||||
|
actioninput = thoughtDict[:action][:input]
|
||||||
|
|
||||||
|
# map action and input() to llm function
|
||||||
|
response, select, reward, isterminal =
|
||||||
|
if actionname == "chatbox"
|
||||||
|
userChatbox(a, actioninput) # virtual customer
|
||||||
|
elseif actionname == "winestock"
|
||||||
|
winestock(a, actioninput)
|
||||||
|
elseif actionname == "recommendbox"
|
||||||
|
userRecommendbox(a, actioninput)
|
||||||
|
else
|
||||||
|
error("undefined LLM function. Requesting $actionname")
|
||||||
|
end
|
||||||
|
|
||||||
|
latestThoughtKey, latestThoughtIndice =
|
||||||
|
GeneralUtils.findHighestIndexKey(state[:thoughtHistory], "thought")
|
||||||
|
nextIndice = latestThoughtKey == :NA ? 1 : latestThoughtIndice + 1
|
||||||
|
latestThoughtKey = Symbol("thought_$nextIndice")
|
||||||
|
latestActionKey = Symbol("action_$nextIndice")
|
||||||
|
|
||||||
|
# add Thought, action, observation to thoughtHistory
|
||||||
|
newstate = deepcopy(state)
|
||||||
|
newstate[:thoughtHistory][latestThoughtKey] = thoughtDict[:thought]
|
||||||
|
newstate[:thoughtHistory][latestActionKey] = thoughtDict[:action]
|
||||||
|
newObservationKey = Symbol("observation_$(nextIndice)")
|
||||||
|
newstate[:thoughtHistory][newObservationKey] = response
|
||||||
|
newstate[:reward] = reward
|
||||||
|
newstate[:select] = select
|
||||||
|
newstate[:isterminal] = isterminal
|
||||||
|
|
||||||
|
return newstate
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
@@ -396,6 +484,90 @@ function selectChildNode(node::MCTSNode)::MCTSNode
|
|||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Arguments
|
||||||
|
- `node::MCTSNode`
|
||||||
|
node of a search tree
|
||||||
|
|
||||||
|
# Return
|
||||||
|
- `childNode::MCTSNode`
|
||||||
|
the highest value child node
|
||||||
|
|
||||||
|
# Example
|
||||||
|
```jldoctest
|
||||||
|
julia>
|
||||||
|
```
|
||||||
|
|
||||||
|
# TODO
|
||||||
|
- [] update docs
|
||||||
|
- [TESTING] implement the function
|
||||||
|
|
||||||
|
# Signature
|
||||||
|
"""
|
||||||
|
function selectBestNextState(node::MCTSNode)::MCTSNode
|
||||||
|
highestProgressValue = 0
|
||||||
|
nodekey = nothing
|
||||||
|
|
||||||
|
# if all childnode has statevalue == 0, use progressvalue + reward to select the best node
|
||||||
|
stateValueSum = sum([v.statevalue for (k, v) in node.children])
|
||||||
|
|
||||||
|
if stateValueSum != 0
|
||||||
|
for (k, childnode) in node.children
|
||||||
|
potential = childnode.statevalue / childnode.visits
|
||||||
|
|
||||||
|
if potential > highestProgressValue
|
||||||
|
highestProgressValue = potential
|
||||||
|
nodekey = childnode.nodekey
|
||||||
|
end
|
||||||
|
end
|
||||||
|
else
|
||||||
|
for (k, childnode) in node.children
|
||||||
|
potential = childnode.progressvalue + childnode.reward
|
||||||
|
|
||||||
|
if potential > highestProgressValue
|
||||||
|
highestProgressValue = potential
|
||||||
|
nodekey = childnode.nodekey
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
return node.children[nodekey]
|
||||||
|
end
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Arguments
|
||||||
|
- `node::MCTSNode`
|
||||||
|
node of a search tree
|
||||||
|
|
||||||
|
# Return
|
||||||
|
- `childNode::MCTSNode`
|
||||||
|
the highest value child node
|
||||||
|
|
||||||
|
# Example
|
||||||
|
```jldoctest
|
||||||
|
julia>
|
||||||
|
```
|
||||||
|
|
||||||
|
# TODO
|
||||||
|
- [] update docs
|
||||||
|
- [TESTING] implement the function
|
||||||
|
|
||||||
|
# Signature
|
||||||
|
"""
|
||||||
|
function selectBestTrajectory(node::MCTSNode)::MCTSNode
|
||||||
|
while !isleaf(node)
|
||||||
|
node = selectBestNextState(node)
|
||||||
|
end
|
||||||
|
|
||||||
|
return node
|
||||||
|
end
|
||||||
|
|
||||||
|
|
||||||
""" Determine wheter a given node is a root node
|
""" Determine wheter a given node is a root node
|
||||||
|
|
||||||
# Arguments
|
# Arguments
|
||||||
@@ -451,7 +623,7 @@ julia>
|
|||||||
|
|
||||||
# TODO
|
# TODO
|
||||||
[] update docstring
|
[] update docstring
|
||||||
[PENDING] return best plan
|
[x] return best action
|
||||||
|
|
||||||
# Signature
|
# Signature
|
||||||
"""
|
"""
|
||||||
@@ -460,50 +632,49 @@ function runMCTS(
|
|||||||
initialState,
|
initialState,
|
||||||
decisionMaker::Function,
|
decisionMaker::Function,
|
||||||
evaluator::Function,
|
evaluator::Function,
|
||||||
reflector::Function,
|
reflector::Function;
|
||||||
n::Integer,
|
totalsample::Integer=3,
|
||||||
maxDepth::Integer,
|
maxDepth::Integer=3,
|
||||||
maxIterations::Integer,
|
maxiterations::Integer=10,
|
||||||
w::Float64
|
explorationweight::Number=1.0,
|
||||||
) where {T1<:agent}
|
) where {T1<:agent}
|
||||||
|
|
||||||
root = MCTSNode("root", initialState, 0, 0, 0, 0, false, nothing, Dict{String, MCTSNode}())
|
root = MCTSNode("root", initialState, 0, 0, 0, 0, false, nothing, Dict{String, MCTSNode}())
|
||||||
|
|
||||||
for nth in 1:maxIterations
|
for nth in 1:maxiterations
|
||||||
node = root
|
node = root
|
||||||
node.visits += 1
|
node.visits += 1
|
||||||
|
|
||||||
while !isleaf(node)
|
while !isleaf(node)
|
||||||
node = UCTselect(node, w)
|
node = UCTselect(node, explorationweight)
|
||||||
end
|
end
|
||||||
if node.isterminal
|
if node.isterminal
|
||||||
# MCTS arrive at the leaf node that is also a terminal state,
|
# MCTS arrive at the leaf node that is also a terminal state,
|
||||||
# do nothing then go directly to backpropagation
|
# do nothing then go directly to backpropagation
|
||||||
backpropagate(leafNode, node.reward)
|
backpropagate(leafNode, node.reward)
|
||||||
else
|
else
|
||||||
expand(a, node, decisionMaker, evaluator, reflector; n=n)
|
expand(a, node, decisionMaker, evaluator, reflector; totalsample=totalsample)
|
||||||
leafNode = selectChildNode(node)
|
leafNode = selectChildNode(node)
|
||||||
simTrajectoryReward = simulate(a, leafNode, decisionMaker, evaluator,
|
simTrajectoryReward, terminalstate = simulate(a, leafNode, decisionMaker, evaluator,
|
||||||
reflector; maxDepth=maxDepth, n=n)
|
reflector; maxDepth=maxDepth, totalsample=totalsample)
|
||||||
|
if terminalstate !== nothing
|
||||||
|
terminalstate[:totalTrajectoryReward] = simTrajectoryReward
|
||||||
|
end
|
||||||
|
|
||||||
|
#[] write best state to file if it has higher simTrajectoryReward. Use to improve evaluation
|
||||||
|
# open("trajectory.json", "w") do io
|
||||||
|
# JSON3.pretty(io, terminalstate)
|
||||||
|
# end
|
||||||
|
|
||||||
backpropagate(leafNode, simTrajectoryReward)
|
backpropagate(leafNode, simTrajectoryReward)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
avgStateValue = 0
|
bestNextState = selectBestNextState(root)
|
||||||
selectedChildKey = nothing
|
besttrajectory = selectBestTrajectory(root)
|
||||||
for (k, v) in root.children
|
|
||||||
k_avgStateValue = v.statevalue / v.visits
|
return (bestNextState.state, besttrajectory.state)
|
||||||
if k_avgStateValue > avgStateValue
|
|
||||||
avgStateValue = k_avgStateValue
|
|
||||||
selectedChildKey = k
|
|
||||||
end
|
end
|
||||||
end
|
|
||||||
|
|
||||||
return root.children[selectedChildKey]
|
|
||||||
end
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -134,7 +134,7 @@ result = GeneralUtils.sendMqttMsg(outgoingMsg)
|
|||||||
outgoingMsg = Dict(
|
outgoingMsg = Dict(
|
||||||
:msgMeta=> msgMeta,
|
:msgMeta=> msgMeta,
|
||||||
:payload=> Dict(
|
:payload=> Dict(
|
||||||
:text=> "I like it dry.",
|
:text=> "I like dry wine with fruity flavors.",
|
||||||
:select=> nothing,
|
:select=> nothing,
|
||||||
:reward=> 0,
|
:reward=> 0,
|
||||||
:isterminal=> false,
|
:isterminal=> false,
|
||||||
|
|||||||
Reference in New Issue
Block a user