From ebacc1692283f3592c1200603aa260d49633a196 Mon Sep 17 00:00:00 2001 From: narawat lamaiin Date: Mon, 6 May 2024 16:13:01 +0700 Subject: [PATCH] update --- src/interface.jl | 2 +- src/llmfunction.jl | 6 +++--- src/mcts.jl | 11 ++++++----- test/test_1.jl | 19 ++++++++++++++++++- 4 files changed, 28 insertions(+), 10 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 6682a5c..6543e89 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -442,7 +442,7 @@ function conversation(a::T, userinput::Dict) where {T<:agent} ) ) bestplan = runMCTS(a, initialState, decisionMaker, progressValueEstimator, reflector, - isterminal, 2, 3, 2, 1.0) + isterminal, 2, 3, 3, 1.0) error("---> bestplan") # actor loop(bestplan) diff --git a/src/llmfunction.jl b/src/llmfunction.jl index 6ad0542..3745505 100644 --- a/src/llmfunction.jl +++ b/src/llmfunction.jl @@ -101,7 +101,7 @@ function virtualWineCustomerReccommendbox(a::T1, input::T2)::String where {T1<:a ) ) @show outgoingMsg - result = GeneralUtils.sendReceiveMqttMsg(outgoingMsg) + result = GeneralUtils.sendReceiveMqttMsg(outgoingMsg; timeout=120) response = result[:response][:text] return response @@ -157,7 +157,7 @@ function virtualWineCustomerChatbox(a::T1, input::T2)::String where {T1<:agent, ) ) @show outgoingMsg - result = GeneralUtils.sendReceiveMqttMsg(outgoingMsg) + result = GeneralUtils.sendReceiveMqttMsg(outgoingMsg; timeout=120) response = result[:response][:text] return response @@ -291,7 +291,7 @@ function jsoncorrection(a::T1, input::T2, ) ) ) - result = GeneralUtils.sendReceiveMqttMsg(outgoingMsg) + result = GeneralUtils.sendReceiveMqttMsg(outgoingMsg; timeout=120) incorrectjson = result[:response][:text] end else diff --git a/src/mcts.jl b/src/mcts.jl index 680d87c..6f68fa1 100644 --- a/src/mcts.jl +++ b/src/mcts.jl @@ -348,8 +348,9 @@ function selectChildNode(node::MCTSNode)::MCTSNode # loop thought node children dictionary to find the highest progress value for (k, childNode) in node.children + thisNodeProgressValue = childNode.statevalue + childNode.reward if childNode.statevalue > highestProgressValue - highestProgressValue = childNode.statevalue + childNode.reward + highestProgressValue = thisNodeProgressValue nodekey = childNode.nodekey end end @@ -443,10 +444,10 @@ function runMCTS( expand(a, node, decisionMaker, progressValueEstimator, isterminal, n=n) - leaf_node = UCTselect(node, w) - simTrajectoryReward = simulate(a, leaf_node, decisionMaker, progressValueEstimator, - isterminal, maxDepth, n=n) - backpropagate(leaf_node, simTrajectoryReward) + leafNode = UCTselect(node, w) + simTrajectoryReward = simulate(a, leafNode, decisionMaker, progressValueEstimator, + isterminal, maxDepth, n=n) + backpropagate(leafNode, simTrajectoryReward) end best_child_state = argmax([child.total_reward / child.visits for child in values(root.children)]) diff --git a/test/test_1.jl b/test/test_1.jl index 413bb74..c8165bb 100644 --- a/test/test_1.jl +++ b/test/test_1.jl @@ -96,7 +96,24 @@ result = GeneralUtils.sendMqttMsg(outgoingMsg) - +outgoingMsg = Dict( + :msgMeta=> msgMeta, + :payload=> Dict( + :text=> "You did not gave me any choice.", + ) +) +result = GeneralUtils.sendMqttMsg(outgoingMsg) + + + + +outgoingMsg = Dict( + :msgMeta=> msgMeta, + :payload=> Dict( + :text=> "Yes.", + ) +) +result = GeneralUtils.sendMqttMsg(outgoingMsg)