update
This commit is contained in:
@@ -318,46 +318,46 @@ function reflector()
|
|||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
""" Determine whether the state is a terminal state
|
# """ Determine whether the state is a terminal state
|
||||||
|
|
||||||
# Arguments
|
# # Arguments
|
||||||
- `state::T`
|
# - `state::T`
|
||||||
a game state
|
# a game state
|
||||||
|
|
||||||
# Return
|
# # Return
|
||||||
- `(isterminalstate, reward)::Tuple{Bool, <:Number}`
|
# - `(isterminalstate, reward)::Tuple{Bool, <:Number}`
|
||||||
|
|
||||||
# Example
|
# # Example
|
||||||
```jldoctest
|
# ```jldoctest
|
||||||
julia>
|
# julia>
|
||||||
```
|
# ```
|
||||||
|
|
||||||
# TODO
|
# # TODO
|
||||||
[PENDING] add Reflect()
|
# [PENDING] add Reflect()
|
||||||
|
|
||||||
# Signature
|
# # Signature
|
||||||
"""
|
# """
|
||||||
function isterminal(state::T)::Tuple{Bool, <:Number} where {T<:AbstractDict}
|
# function isterminal(state::T)::Tuple{Bool, <:Number} where {T<:AbstractDict}
|
||||||
latestObservationKey, _ = GeneralUtils.findHighestIndexKey(state[:thoughtHistory], "observation")
|
# latestObservationKey, _ = GeneralUtils.findHighestIndexKey(state[:thoughtHistory], "observation")
|
||||||
latestObservation = state[:thoughtHistory][latestObservationKey]
|
# latestObservation = state[:thoughtHistory][latestObservationKey]
|
||||||
|
|
||||||
if latestObservation !== nothing
|
# if latestObservation !== nothing
|
||||||
|
|
||||||
# terminal condition is when the user select wine by putting <<winename>> in latest observation
|
# # terminal condition is when the user select wine by putting <<winename>> in latest observation
|
||||||
if occursin("<<", latestObservation) && occursin(">>", latestObservation)
|
# if occursin("<<", latestObservation) && occursin(">>", latestObservation)
|
||||||
isterminalstate = true
|
# isterminalstate = true
|
||||||
reward = 1
|
# reward = 1
|
||||||
else
|
# else
|
||||||
isterminalstate = false
|
# isterminalstate = false
|
||||||
reward = 0
|
# reward = 0
|
||||||
end
|
# end
|
||||||
else
|
# else
|
||||||
isterminalstate = false
|
# isterminalstate = false
|
||||||
reward = 0
|
# reward = 0
|
||||||
end
|
# end
|
||||||
|
|
||||||
return (isterminalstate, reward)
|
# return (isterminalstate, reward)
|
||||||
end
|
# end
|
||||||
|
|
||||||
|
|
||||||
""" Chat with llm.
|
""" Chat with llm.
|
||||||
@@ -436,7 +436,10 @@ function conversation(a::T, userinput::Dict) where {T<:agent}
|
|||||||
|
|
||||||
# deepcopy the info to prevent modifying the info unintentionally during MCTS planning
|
# deepcopy the info to prevent modifying the info unintentionally during MCTS planning
|
||||||
:customerinfo=> deepcopy(a.keywordinfo[:customerinfo]),
|
:customerinfo=> deepcopy(a.keywordinfo[:customerinfo]),
|
||||||
:storeinfo=> deepcopy(a.keywordinfo[:storeinfo]),
|
:storeinfo=> deepcopy(a.keywordinfo[:storeinfo]),
|
||||||
|
:select=> nothing,
|
||||||
|
:reward=> 0,
|
||||||
|
:isterminal=> false,
|
||||||
:thoughtHistory=> OrderedDict{Symbol, Any}( # contain question, thought_1, action_1, observation_1, thought_2, ...
|
:thoughtHistory=> OrderedDict{Symbol, Any}( # contain question, thought_1, action_1, observation_1, thought_2, ...
|
||||||
:question=> userinput[:text],
|
:question=> userinput[:text],
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -125,6 +125,10 @@ end
|
|||||||
julia>
|
julia>
|
||||||
```
|
```
|
||||||
|
|
||||||
|
# TODO
|
||||||
|
- [] update docs
|
||||||
|
- [] add to remove <<< user option select >>> and <<| reward |>>
|
||||||
|
|
||||||
# Signature
|
# Signature
|
||||||
"""
|
"""
|
||||||
function virtualWineCustomerChatbox(a::T1, input::T2)::String where {T1<:agent, T2<:AbstractString}
|
function virtualWineCustomerChatbox(a::T1, input::T2)::String where {T1<:agent, T2<:AbstractString}
|
||||||
@@ -158,9 +162,9 @@ function virtualWineCustomerChatbox(a::T1, input::T2)::String where {T1<:agent,
|
|||||||
)
|
)
|
||||||
@show outgoingMsg
|
@show outgoingMsg
|
||||||
result = GeneralUtils.sendReceiveMqttMsg(outgoingMsg; timeout=120)
|
result = GeneralUtils.sendReceiveMqttMsg(outgoingMsg; timeout=120)
|
||||||
response = result[:response][:text]
|
response = result[:response]
|
||||||
|
|
||||||
return response
|
return (response[:text], response[:select], response[:reward], response[:isterminal])
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
15
src/mcts.jl
15
src/mcts.jl
@@ -137,7 +137,7 @@ function expand(a::T1, node::MCTSNode, decisionMaker::Function,
|
|||||||
try
|
try
|
||||||
thoughtDict = decisionMaker(a, node.state)
|
thoughtDict = decisionMaker(a, node.state)
|
||||||
|
|
||||||
newNodeKey, newstate, isterminalstate, reward =
|
newNodeKey, newstate, reward, isterminalstate =
|
||||||
MCTStransition(a, node.state, thoughtDict, isterminal)
|
MCTStransition(a, node.state, thoughtDict, isterminal)
|
||||||
|
|
||||||
# add progressValueEstimator
|
# add progressValueEstimator
|
||||||
@@ -181,11 +181,10 @@ function simulate(a, node::MCTSNode, decisionMaker::Function, progressValueEstim
|
|||||||
simTrajectoryReward = 0.0
|
simTrajectoryReward = 0.0
|
||||||
|
|
||||||
for depth in 1:maxDepth
|
for depth in 1:maxDepth
|
||||||
if node.isterminal
|
simTrajectoryReward += node.reward
|
||||||
simTrajectoryReward += node.reward
|
if node.isterminalrd
|
||||||
break
|
break
|
||||||
else
|
else
|
||||||
simTrajectoryReward += node.reward
|
|
||||||
expand(a, node, decisionMaker, progressValueEstimator, isterminal, n=n)
|
expand(a, node, decisionMaker, progressValueEstimator, isterminal, n=n)
|
||||||
node = selectChildNode(node)
|
node = selectChildNode(node)
|
||||||
end
|
end
|
||||||
@@ -268,7 +267,7 @@ function MCTStransition(a::T1, state::T2, thoughtDict::T3, isterminal::Function
|
|||||||
actioninput = thoughtDict[:action][:input]
|
actioninput = thoughtDict[:action][:input]
|
||||||
|
|
||||||
# map action and input() to llm function
|
# map action and input() to llm function
|
||||||
response =
|
response, select, reward, isterminal =
|
||||||
if actionname == "chatbox"
|
if actionname == "chatbox"
|
||||||
virtualWineCustomerChatbox(a, actioninput) # virtual customer
|
virtualWineCustomerChatbox(a, actioninput) # virtual customer
|
||||||
elseif actionname == "winestock"
|
elseif actionname == "winestock"
|
||||||
@@ -291,11 +290,13 @@ function MCTStransition(a::T1, state::T2, thoughtDict::T3, isterminal::Function
|
|||||||
newstate[:thoughtHistory][latestActionKey] = thoughtDict[:action]
|
newstate[:thoughtHistory][latestActionKey] = thoughtDict[:action]
|
||||||
latestObservationKey = Symbol("observation_$(nextIndice)")
|
latestObservationKey = Symbol("observation_$(nextIndice)")
|
||||||
newstate[:thoughtHistory][latestObservationKey] = response
|
newstate[:thoughtHistory][latestObservationKey] = response
|
||||||
|
newstate[:reward] = reward
|
||||||
|
newstate[:select] = select
|
||||||
|
newstate[:isterminal] = isterminal
|
||||||
|
|
||||||
newNodeKey = GeneralUtils.uuid4snakecase()
|
newNodeKey = GeneralUtils.uuid4snakecase()
|
||||||
isterminalstate, reward = isterminal(newstate)
|
|
||||||
|
|
||||||
return (newNodeKey, newstate, isterminalstate, reward)
|
return (newNodeKey, newstate, reward, isterminal)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -28,6 +28,9 @@ outgoingMsg = Dict(
|
|||||||
:msgMeta=> msgMeta,
|
:msgMeta=> msgMeta,
|
||||||
:payload=> Dict(
|
:payload=> Dict(
|
||||||
:text=> "We are holding a wedding party",
|
:text=> "We are holding a wedding party",
|
||||||
|
:select=> nothing,
|
||||||
|
:reward=> 0,
|
||||||
|
:isterminal=> false,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
result = GeneralUtils.sendMqttMsg(outgoingMsg)
|
result = GeneralUtils.sendMqttMsg(outgoingMsg)
|
||||||
@@ -45,6 +48,7 @@ result = GeneralUtils.sendMqttMsg(outgoingMsg)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
outgoingMsg = Dict(
|
outgoingMsg = Dict(
|
||||||
:msgMeta=> msgMeta,
|
:msgMeta=> msgMeta,
|
||||||
:payload=> Dict(
|
:payload=> Dict(
|
||||||
@@ -56,7 +60,6 @@ result = GeneralUtils.sendMqttMsg(outgoingMsg)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
outgoingMsg = Dict(
|
outgoingMsg = Dict(
|
||||||
:msgMeta=> msgMeta,
|
:msgMeta=> msgMeta,
|
||||||
:payload=> Dict(
|
:payload=> Dict(
|
||||||
@@ -76,16 +79,19 @@ outgoingMsg = Dict(
|
|||||||
result = GeneralUtils.sendMqttMsg(outgoingMsg)
|
result = GeneralUtils.sendMqttMsg(outgoingMsg)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
outgoingMsg = Dict(
|
outgoingMsg = Dict(
|
||||||
:msgMeta=> msgMeta,
|
:msgMeta=> msgMeta,
|
||||||
:payload=> Dict(
|
:payload=> Dict(
|
||||||
:text=> "<<OK, I'll take it.>>",
|
:text=> "OK, I'll take it.",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
result = GeneralUtils.sendMqttMsg(outgoingMsg)
|
result = GeneralUtils.sendMqttMsg(outgoingMsg)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
outgoingMsg = Dict(
|
outgoingMsg = Dict(
|
||||||
:msgMeta=> msgMeta,
|
:msgMeta=> msgMeta,
|
||||||
:payload=> Dict(
|
:payload=> Dict(
|
||||||
@@ -96,6 +102,7 @@ result = GeneralUtils.sendMqttMsg(outgoingMsg)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
outgoingMsg = Dict(
|
outgoingMsg = Dict(
|
||||||
:msgMeta=> msgMeta,
|
:msgMeta=> msgMeta,
|
||||||
:payload=> Dict(
|
:payload=> Dict(
|
||||||
|
|||||||
Reference in New Issue
Block a user