From 43e7ba399189eb17412553f2035f0cceff4d43b4 Mon Sep 17 00:00:00 2001 From: narawat lamaiin Date: Tue, 7 May 2024 06:30:24 +0700 Subject: [PATCH] update --- src/interface.jl | 69 ++++++++++++++++++++++++---------------------- src/llmfunction.jl | 8 ++++-- src/mcts.jl | 15 +++++----- test/test_1.jl | 11 ++++++-- 4 files changed, 59 insertions(+), 44 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 6543e89..ed7579d 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -318,46 +318,46 @@ function reflector() end -""" Determine whether the state is a terminal state +# """ Determine whether the state is a terminal state -# Arguments - - `state::T` - a game state +# # Arguments +# - `state::T` +# a game state -# Return - - `(isterminalstate, reward)::Tuple{Bool, <:Number}` +# # Return +# - `(isterminalstate, reward)::Tuple{Bool, <:Number}` -# Example -```jldoctest -julia> -``` +# # Example +# ```jldoctest +# julia> +# ``` -# TODO - [PENDING] add Reflect() +# # TODO +# [PENDING] add Reflect() -# Signature -""" -function isterminal(state::T)::Tuple{Bool, <:Number} where {T<:AbstractDict} - latestObservationKey, _ = GeneralUtils.findHighestIndexKey(state[:thoughtHistory], "observation") - latestObservation = state[:thoughtHistory][latestObservationKey] +# # Signature +# """ +# function isterminal(state::T)::Tuple{Bool, <:Number} where {T<:AbstractDict} +# latestObservationKey, _ = GeneralUtils.findHighestIndexKey(state[:thoughtHistory], "observation") +# latestObservation = state[:thoughtHistory][latestObservationKey] - if latestObservation !== nothing +# if latestObservation !== nothing - # terminal condition is when the user select wine by putting <> in latest observation - if occursin("<<", latestObservation) && occursin(">>", latestObservation) - isterminalstate = true - reward = 1 - else - isterminalstate = false - reward = 0 - end - else - isterminalstate = false - reward = 0 - end +# # terminal condition is when the user select wine by putting <> in latest observation +# if occursin("<<", latestObservation) && occursin(">>", latestObservation) +# isterminalstate = true +# reward = 1 +# else +# isterminalstate = false +# reward = 0 +# end +# else +# isterminalstate = false +# reward = 0 +# end - return (isterminalstate, reward) -end +# return (isterminalstate, reward) +# end """ Chat with llm. @@ -436,7 +436,10 @@ function conversation(a::T, userinput::Dict) where {T<:agent} # deepcopy the info to prevent modifying the info unintentionally during MCTS planning :customerinfo=> deepcopy(a.keywordinfo[:customerinfo]), - :storeinfo=> deepcopy(a.keywordinfo[:storeinfo]), + :storeinfo=> deepcopy(a.keywordinfo[:storeinfo]), + :select=> nothing, + :reward=> 0, + :isterminal=> false, :thoughtHistory=> OrderedDict{Symbol, Any}( # contain question, thought_1, action_1, observation_1, thought_2, ... :question=> userinput[:text], ) diff --git a/src/llmfunction.jl b/src/llmfunction.jl index 3745505..349e08a 100644 --- a/src/llmfunction.jl +++ b/src/llmfunction.jl @@ -125,6 +125,10 @@ end julia> ``` +# TODO + - [] update docs + - [] add to remove <<< user option select >>> and <<| reward |>> + # Signature """ function virtualWineCustomerChatbox(a::T1, input::T2)::String where {T1<:agent, T2<:AbstractString} @@ -158,9 +162,9 @@ function virtualWineCustomerChatbox(a::T1, input::T2)::String where {T1<:agent, ) @show outgoingMsg result = GeneralUtils.sendReceiveMqttMsg(outgoingMsg; timeout=120) - response = result[:response][:text] + response = result[:response] - return response + return (response[:text], response[:select], response[:reward], response[:isterminal]) end diff --git a/src/mcts.jl b/src/mcts.jl index 3648ce9..1d3d618 100644 --- a/src/mcts.jl +++ b/src/mcts.jl @@ -137,7 +137,7 @@ function expand(a::T1, node::MCTSNode, decisionMaker::Function, try thoughtDict = decisionMaker(a, node.state) - newNodeKey, newstate, isterminalstate, reward = + newNodeKey, newstate, reward, isterminalstate = MCTStransition(a, node.state, thoughtDict, isterminal) # add progressValueEstimator @@ -181,11 +181,10 @@ function simulate(a, node::MCTSNode, decisionMaker::Function, progressValueEstim simTrajectoryReward = 0.0 for depth in 1:maxDepth - if node.isterminal - simTrajectoryReward += node.reward + simTrajectoryReward += node.reward + if node.isterminalrd break else - simTrajectoryReward += node.reward expand(a, node, decisionMaker, progressValueEstimator, isterminal, n=n) node = selectChildNode(node) end @@ -268,7 +267,7 @@ function MCTStransition(a::T1, state::T2, thoughtDict::T3, isterminal::Function actioninput = thoughtDict[:action][:input] # map action and input() to llm function - response = + response, select, reward, isterminal = if actionname == "chatbox" virtualWineCustomerChatbox(a, actioninput) # virtual customer elseif actionname == "winestock" @@ -291,11 +290,13 @@ function MCTStransition(a::T1, state::T2, thoughtDict::T3, isterminal::Function newstate[:thoughtHistory][latestActionKey] = thoughtDict[:action] latestObservationKey = Symbol("observation_$(nextIndice)") newstate[:thoughtHistory][latestObservationKey] = response + newstate[:reward] = reward + newstate[:select] = select + newstate[:isterminal] = isterminal newNodeKey = GeneralUtils.uuid4snakecase() - isterminalstate, reward = isterminal(newstate) - return (newNodeKey, newstate, isterminalstate, reward) + return (newNodeKey, newstate, reward, isterminal) end diff --git a/test/test_1.jl b/test/test_1.jl index c8165bb..c4da5ec 100644 --- a/test/test_1.jl +++ b/test/test_1.jl @@ -28,6 +28,9 @@ outgoingMsg = Dict( :msgMeta=> msgMeta, :payload=> Dict( :text=> "We are holding a wedding party", + :select=> nothing, + :reward=> 0, + :isterminal=> false, ) ) result = GeneralUtils.sendMqttMsg(outgoingMsg) @@ -45,6 +48,7 @@ result = GeneralUtils.sendMqttMsg(outgoingMsg) + outgoingMsg = Dict( :msgMeta=> msgMeta, :payload=> Dict( @@ -56,7 +60,6 @@ result = GeneralUtils.sendMqttMsg(outgoingMsg) - outgoingMsg = Dict( :msgMeta=> msgMeta, :payload=> Dict( @@ -76,16 +79,19 @@ outgoingMsg = Dict( result = GeneralUtils.sendMqttMsg(outgoingMsg) + + outgoingMsg = Dict( :msgMeta=> msgMeta, :payload=> Dict( - :text=> "<>", + :text=> "OK, I'll take it.", ) ) result = GeneralUtils.sendMqttMsg(outgoingMsg) + outgoingMsg = Dict( :msgMeta=> msgMeta, :payload=> Dict( @@ -96,6 +102,7 @@ result = GeneralUtils.sendMqttMsg(outgoingMsg) + outgoingMsg = Dict( :msgMeta=> msgMeta, :payload=> Dict(