This commit is contained in:
narawat lamaiin
2024-05-07 06:30:24 +07:00
parent 8cc5606ae8
commit 43e7ba3991
4 changed files with 59 additions and 44 deletions

View File

@@ -318,46 +318,46 @@ function reflector()
end end
""" Determine whether the state is a terminal state # """ Determine whether the state is a terminal state
# Arguments # # Arguments
- `state::T` # - `state::T`
a game state # a game state
# Return # # Return
- `(isterminalstate, reward)::Tuple{Bool, <:Number}` # - `(isterminalstate, reward)::Tuple{Bool, <:Number}`
# Example # # Example
```jldoctest # ```jldoctest
julia> # julia>
``` # ```
# TODO # # TODO
[PENDING] add Reflect() # [PENDING] add Reflect()
# Signature # # Signature
""" # """
function isterminal(state::T)::Tuple{Bool, <:Number} where {T<:AbstractDict} # function isterminal(state::T)::Tuple{Bool, <:Number} where {T<:AbstractDict}
latestObservationKey, _ = GeneralUtils.findHighestIndexKey(state[:thoughtHistory], "observation") # latestObservationKey, _ = GeneralUtils.findHighestIndexKey(state[:thoughtHistory], "observation")
latestObservation = state[:thoughtHistory][latestObservationKey] # latestObservation = state[:thoughtHistory][latestObservationKey]
if latestObservation !== nothing # if latestObservation !== nothing
# terminal condition is when the user select wine by putting <<winename>> in latest observation # # terminal condition is when the user select wine by putting <<winename>> in latest observation
if occursin("<<", latestObservation) && occursin(">>", latestObservation) # if occursin("<<", latestObservation) && occursin(">>", latestObservation)
isterminalstate = true # isterminalstate = true
reward = 1 # reward = 1
else # else
isterminalstate = false # isterminalstate = false
reward = 0 # reward = 0
end # end
else # else
isterminalstate = false # isterminalstate = false
reward = 0 # reward = 0
end # end
return (isterminalstate, reward) # return (isterminalstate, reward)
end # end
""" Chat with llm. """ Chat with llm.
@@ -437,6 +437,9 @@ function conversation(a::T, userinput::Dict) where {T<:agent}
# deepcopy the info to prevent modifying the info unintentionally during MCTS planning # deepcopy the info to prevent modifying the info unintentionally during MCTS planning
:customerinfo=> deepcopy(a.keywordinfo[:customerinfo]), :customerinfo=> deepcopy(a.keywordinfo[:customerinfo]),
:storeinfo=> deepcopy(a.keywordinfo[:storeinfo]), :storeinfo=> deepcopy(a.keywordinfo[:storeinfo]),
:select=> nothing,
:reward=> 0,
:isterminal=> false,
:thoughtHistory=> OrderedDict{Symbol, Any}( # contain question, thought_1, action_1, observation_1, thought_2, ... :thoughtHistory=> OrderedDict{Symbol, Any}( # contain question, thought_1, action_1, observation_1, thought_2, ...
:question=> userinput[:text], :question=> userinput[:text],
) )

View File

@@ -125,6 +125,10 @@ end
julia> julia>
``` ```
# TODO
- [] update docs
- [] add to remove <<< user option select >>> and <<| reward |>>
# Signature # Signature
""" """
function virtualWineCustomerChatbox(a::T1, input::T2)::String where {T1<:agent, T2<:AbstractString} function virtualWineCustomerChatbox(a::T1, input::T2)::String where {T1<:agent, T2<:AbstractString}
@@ -158,9 +162,9 @@ function virtualWineCustomerChatbox(a::T1, input::T2)::String where {T1<:agent,
) )
@show outgoingMsg @show outgoingMsg
result = GeneralUtils.sendReceiveMqttMsg(outgoingMsg; timeout=120) result = GeneralUtils.sendReceiveMqttMsg(outgoingMsg; timeout=120)
response = result[:response][:text] response = result[:response]
return response return (response[:text], response[:select], response[:reward], response[:isterminal])
end end

View File

@@ -137,7 +137,7 @@ function expand(a::T1, node::MCTSNode, decisionMaker::Function,
try try
thoughtDict = decisionMaker(a, node.state) thoughtDict = decisionMaker(a, node.state)
newNodeKey, newstate, isterminalstate, reward = newNodeKey, newstate, reward, isterminalstate =
MCTStransition(a, node.state, thoughtDict, isterminal) MCTStransition(a, node.state, thoughtDict, isterminal)
# add progressValueEstimator # add progressValueEstimator
@@ -181,11 +181,10 @@ function simulate(a, node::MCTSNode, decisionMaker::Function, progressValueEstim
simTrajectoryReward = 0.0 simTrajectoryReward = 0.0
for depth in 1:maxDepth for depth in 1:maxDepth
if node.isterminal simTrajectoryReward += node.reward
simTrajectoryReward += node.reward if node.isterminalrd
break break
else else
simTrajectoryReward += node.reward
expand(a, node, decisionMaker, progressValueEstimator, isterminal, n=n) expand(a, node, decisionMaker, progressValueEstimator, isterminal, n=n)
node = selectChildNode(node) node = selectChildNode(node)
end end
@@ -268,7 +267,7 @@ function MCTStransition(a::T1, state::T2, thoughtDict::T3, isterminal::Function
actioninput = thoughtDict[:action][:input] actioninput = thoughtDict[:action][:input]
# map action and input() to llm function # map action and input() to llm function
response = response, select, reward, isterminal =
if actionname == "chatbox" if actionname == "chatbox"
virtualWineCustomerChatbox(a, actioninput) # virtual customer virtualWineCustomerChatbox(a, actioninput) # virtual customer
elseif actionname == "winestock" elseif actionname == "winestock"
@@ -291,11 +290,13 @@ function MCTStransition(a::T1, state::T2, thoughtDict::T3, isterminal::Function
newstate[:thoughtHistory][latestActionKey] = thoughtDict[:action] newstate[:thoughtHistory][latestActionKey] = thoughtDict[:action]
latestObservationKey = Symbol("observation_$(nextIndice)") latestObservationKey = Symbol("observation_$(nextIndice)")
newstate[:thoughtHistory][latestObservationKey] = response newstate[:thoughtHistory][latestObservationKey] = response
newstate[:reward] = reward
newstate[:select] = select
newstate[:isterminal] = isterminal
newNodeKey = GeneralUtils.uuid4snakecase() newNodeKey = GeneralUtils.uuid4snakecase()
isterminalstate, reward = isterminal(newstate)
return (newNodeKey, newstate, isterminalstate, reward) return (newNodeKey, newstate, reward, isterminal)
end end

View File

@@ -28,6 +28,9 @@ outgoingMsg = Dict(
:msgMeta=> msgMeta, :msgMeta=> msgMeta,
:payload=> Dict( :payload=> Dict(
:text=> "We are holding a wedding party", :text=> "We are holding a wedding party",
:select=> nothing,
:reward=> 0,
:isterminal=> false,
) )
) )
result = GeneralUtils.sendMqttMsg(outgoingMsg) result = GeneralUtils.sendMqttMsg(outgoingMsg)
@@ -45,6 +48,7 @@ result = GeneralUtils.sendMqttMsg(outgoingMsg)
outgoingMsg = Dict( outgoingMsg = Dict(
:msgMeta=> msgMeta, :msgMeta=> msgMeta,
:payload=> Dict( :payload=> Dict(
@@ -56,7 +60,6 @@ result = GeneralUtils.sendMqttMsg(outgoingMsg)
outgoingMsg = Dict( outgoingMsg = Dict(
:msgMeta=> msgMeta, :msgMeta=> msgMeta,
:payload=> Dict( :payload=> Dict(
@@ -76,16 +79,19 @@ outgoingMsg = Dict(
result = GeneralUtils.sendMqttMsg(outgoingMsg) result = GeneralUtils.sendMqttMsg(outgoingMsg)
outgoingMsg = Dict( outgoingMsg = Dict(
:msgMeta=> msgMeta, :msgMeta=> msgMeta,
:payload=> Dict( :payload=> Dict(
:text=> "<<OK, I'll take it.>>", :text=> "OK, I'll take it.",
) )
) )
result = GeneralUtils.sendMqttMsg(outgoingMsg) result = GeneralUtils.sendMqttMsg(outgoingMsg)
outgoingMsg = Dict( outgoingMsg = Dict(
:msgMeta=> msgMeta, :msgMeta=> msgMeta,
:payload=> Dict( :payload=> Dict(
@@ -96,6 +102,7 @@ result = GeneralUtils.sendMqttMsg(outgoingMsg)
outgoingMsg = Dict( outgoingMsg = Dict(
:msgMeta=> msgMeta, :msgMeta=> msgMeta,
:payload=> Dict( :payload=> Dict(