From 2e9c21f243233861cecba00d396884cfe2a7a5c3 Mon Sep 17 00:00:00 2001 From: narawat lamaiin Date: Sat, 11 May 2024 15:34:51 +0700 Subject: [PATCH] update --- src/MCTSexamplePrompt.py | 2 +- src/interface.jl | 252 ++++++++++++++++----------------------- src/llmfunction.jl | 147 +++++++++++------------ src/mcts.jl | 83 ++++++++----- test/test_1.jl | 8 +- 5 files changed, 226 insertions(+), 266 deletions(-) diff --git a/src/MCTSexamplePrompt.py b/src/MCTSexamplePrompt.py index a4bbc47..2a5f7ab 100644 --- a/src/MCTSexamplePrompt.py +++ b/src/MCTSexamplePrompt.py @@ -416,7 +416,7 @@ Action 3: Search[Mexican Grand Prix winners] Observation 3: Could not find Mexican Grand Prix winners. Similar: ['Mexican Grand Prix', 'List of Formula One Grand Prix winners', '1990 Mexican Grand Prix', '2018 Mexican Grand Prix', '2019 Mexican Grand Prix']. Thought 4: Given the difficulties in finding a direct list of Mexican Grand Prix winners, I should search for the Mexican Grand Prix to get a broader picture of the race's history. This might include winners. Action 4: Search[Mexican Grand Prix] -This trajectory is incorrect as my search should be related to Mexican Formula One race car drivers, not winners of the Mexican Grand Prix, a seperate event. A better search would have been for the List of Formula One Grand Prix winners, as suggested. +This trajectory is incorrect as my search should be related to Mexican Formula One race car drivers, not winners it is reasonable to checkof the Mexican Grand Prix, a seperate event. A better search would have been for the List of Formula One Grand Prix winners, as suggested. Thus the correctness score is 3 Question: Which magazine was started first Arthur's Magazine or First for Women? diff --git a/src/interface.jl b/src/interface.jl index 6f384fa..cebd781 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -126,7 +126,7 @@ function decisionMaker(a::T1, state::T2)::Dict{Symbol, Any} where {T1<:agent, T2 1) Get to know what occasion the user is buying wine for 2) Get to know what food the user will have with wine 3) Get to know how much the user willing to spend - 4) Get to know type of wine the user is looking for e.g. Red, White, Sparkling, Rose, Dessert, Fortified + 4) Get to know type of wine the user is looking for e.g. red, white, sparkling, rose, dessert, fortified 5) Get to know what characteristics of wine the user is looking for e.g. tannin, sweetness, intensity, acidity 6) Check your inventory for the best wine that match the user preference @@ -197,51 +197,49 @@ function decisionMaker(a::T1, state::T2)::Dict{Symbol, Any} where {T1<:agent, T2 ) ) @show outgoingMsg - attempt = 0 - while true - attempt += 1 - if attempt <= 5 - try - response = GeneralUtils.sendReceiveMqttMsg(outgoingMsg) - _responseJsonStr = response[:response][:text] - expectedJsonExample = - """ - Here is an expected JSON format: - { - "thought": "...", - "action": {"name": "...", "input": "..."}, - "observation": "..." - } - """ - responseJsonStr = jsoncorrection(a, _responseJsonStr, expectedJsonExample) - thoughtDict = copy(JSON3.read(responseJsonStr)) - # check if dict has all required value - dummya::AbstractString = thoughtDict[:thought] - actionname::AbstractString = thoughtDict[:action][:name] - actioninput::AbstractString = thoughtDict[:action][:input] + for attempt in 1:5 + try + response = GeneralUtils.sendReceiveMqttMsg(outgoingMsg) + _responseJsonStr = response[:response][:text] + expectedJsonExample = + """ + Here is an expected JSON format: + { + "thought": "...", + "action": {"name": "...", "input": "..."}, + "observation": "..." + } + """ + responseJsonStr = jsoncorrection(a, _responseJsonStr, expectedJsonExample) + thoughtDict = copy(JSON3.read(responseJsonStr)) - if actionname ∈ ["winestock", "chatbox", "recommendbox"] - # LLM use available function - else - error("DecisionMaker use wrong function") - end - - return thoughtDict - catch e - io = IOBuffer() - showerror(io, e) - errorMsg = String(take!(io)) - st = sprint((io, v) -> show(io, "text/plain", v), stacktrace(catch_backtrace())) - println("") - @warn "Error occurred: $errorMsg\n$st" - println("") + # check if dict has all required value + thought::AbstractString = thoughtDict[:thought] + actionname::AbstractString = thoughtDict[:action][:name] + actioninput::AbstractString = thoughtDict[:action][:input] + if actionname ∈ ["winestock", "chatbox", "recommendbox"] + # LLM use available function + elseif thought == "" + error("DecisionMaker has no thought") + elseif length(actioninput) == 0 + error("DecisionMaker has no actioninput") + else + error("DecisionMaker use wrong function") end - else - error("DecisionMaker failed to generate a thought") + return thoughtDict + catch e + io = IOBuffer() + showerror(io, e) + errorMsg = String(take!(io)) + st = sprint((io, v) -> show(io, "text/plain", v), stacktrace(catch_backtrace())) + println("") + @warn "Attempt $attempt. Error occurred: $errorMsg\n$st" + println("") end end + error("DecisionMaker failed to generate a thought") end @@ -297,7 +295,22 @@ function progressValueEstimator(a::T1, state::T2)::Tuple{String, Integer} where } {"evaluation": "This trajectory is correct as it is reasonable to check an inventory for info provided in the question. It is also better to have simple searches corresponding to a single entity, making this the best action.", - "score": 10} + "score": 10 + } + + { + "question": "Do you have an all-in-one pen with 4 colors and a pencil for sale?", + "thought_1": "Let me check our inventory first to see if I have it.", + "action_1": {"name": "inventory", "input": "pen with 4 color and a pencil."}, + "observation_1": "I found {1: "Pilot Dr. grip 4-in-1 pen", 2: "Rotting pencil"}", + "thought_2": "Ok, I have what the user is asking. Let's tell the user.", + "action_2": {"name": "chatbox", "input": "Yes, we do have a Pilot Dr. grip 4-in-1 pen and a Rotting pencil"}, + "observation_1": "This is not what I wanted." + } + {"evaluation": "This trajectory is incorrect as my search term should be related to a 4-colors pen with a pencil in it, + not a pen and a pencil seperately. A better search term should have been a 4-colors pen with a pencil, all-in-one.", + "score": 2 + } Let's begin!: $(JSON3.write(state[:thoughtHistory])) @@ -334,40 +347,34 @@ function progressValueEstimator(a::T1, state::T2)::Tuple{String, Integer} where ) ) - attempt = 0 - while true - attempt += 1 - if attempt <= 5 - try - response = GeneralUtils.sendReceiveMqttMsg(outgoingMsg) - _responseJsonStr = response[:response][:text] - expectedJsonExample = - """ - Here is an expected JSON format: - {"evaluation": "...", "score": "..."} - """ - responseJsonStr = jsoncorrection(a, _responseJsonStr, expectedJsonExample) - evaluationDict = copy(JSON3.read(responseJsonStr)) + for attempt in 1:5 + try + response = GeneralUtils.sendReceiveMqttMsg(outgoingMsg) + _responseJsonStr = response[:response][:text] + expectedJsonExample = + """ + Here is an expected JSON format: + {"evaluation": "...", "score": "..."} + """ + responseJsonStr = jsoncorrection(a, _responseJsonStr, expectedJsonExample) + evaluationDict = copy(JSON3.read(responseJsonStr)) - # check if dict has all required value - dummya::AbstractString = evaluationDict[:evaluation] - dummyb::Integer = evaluationDict[:score] + # check if dict has all required value + dummya::AbstractString = evaluationDict[:evaluation] + dummyb::Integer = evaluationDict[:score] - return (evaluationDict[:evaluation], evaluationDict[:score]) - catch e - io = IOBuffer() - showerror(io, e) - errorMsg = String(take!(io)) - st = sprint((io, v) -> show(io, "text/plain", v), stacktrace(catch_backtrace())) - println("") - @warn "Error occurred: $errorMsg\n$st" - println("") - end - - else - error("progressValueEstimator failed to generate an evaluation") + return (evaluationDict[:evaluation], evaluationDict[:score]) + catch e + io = IOBuffer() + showerror(io, e) + errorMsg = String(take!(io)) + st = sprint((io, v) -> show(io, "text/plain", v), stacktrace(catch_backtrace())) + println("") + @warn "Attempt $attempt. Error occurred: $errorMsg\n$st" + println("") end end + error("progressValueEstimator failed to generate an evaluation") end @@ -392,55 +399,6 @@ julia> function reflector(a::T1, state::T2)::String where {T1<:agent, T2<:AbstractDict} # https://github.com/andyz245/LanguageAgentTreeSearch/blob/main/hotpot/hotpot.py - # _prompt = - # """ - # You are an advanced reasoning agent that can improve based on self refection. - # You will be given the information related to the previous help you've done for a user. - # You were unsuccessful in helping the user either because you guessed the wrong answer with Finish[answer], or you used up your set number of reasoning steps. - # In a few sentences, Diagnose a possible reason for failure and devise a new, concise, high level plan that aims to mitigate the same failure. - # Use complete sentences. - - # You should only respond in JSON format as describe below: - # {"reflection": "your relection"} - - # Here are some examples: - # Previous Trial: - # { - # "question": "Hello, I would like a get a bottle of wine", - # "thought_1": "A customer wants to buy a bottle of wine. Before making a recommendation, I need to know more about their preferences.", - # "action_1": {"name": "chatbox", "input": "What is the occasion for which you're buying this wine?"}, - # "observation_1": "We are holding a wedding party", - - # "thought_2": "A wedding party, that's a great occasion! The customer might be looking for a celebratory drink. Let me ask some more questions to narrow down the options.", - # "action_2": {"name": "chatbox", "input": "What type of food will you be serving at the wedding?"}, - # "observation_2": "It will be Thai dishes.", - - # "thought_3": "With Thai food, I should recommend a wine that complements its spicy and savory flavors. And since it's a celebratory occasion, the customer might prefer a full-bodied wine.", - # "action_3": {"name": "chatbox", "input": "What is your budget for this bottle of wine?"}, - # "observation_3": "I would spend up to 50 bucks.", - - # "thought_4": "Now that I have some more information, it's time to narrow down the options.", - # "action_4": {"name": "winestock", "input": "red wine with full body, pairs well with spicy food, budget \$50"}, - # "observation_4": "I found the following wines in our stock: \n{\n 1: El Enemigo Cabernet Franc 2019\n2: Tantara Chardonnay 2017\n\n}\n", - - # "thought_5": "Now that I have a list of potential wines, I need to know more about the customer's taste preferences.", - # "action_5": {"name": "chatbox", "input": "What type of wine characteristics are you looking for? (e.g. t.e.g. tannin level, sweetness, intensity, acidity)"}, - # "observation_5": "I like full-bodied Red wine with low tannin.", - - # "thought_6": "Now that I have more information about the customer's preferences, it's time to make a recommendation.", - # "action_6": {"name": "recommendbox", "input": "El Enemigo Cabernet Franc 2019"}, - # "observation_6": "I don't like the one you recommend. I want dry wine." - # } - - # { - # "reflection": "I asked the user about the occasion, food type, and budget, and then searched for wine in the inventory right away. However, I should have asked the user for the specific wine type and their preferences in order to gather more information before making a recommendation." - # } - - # Previous trial: - # $(JSON3.write(state[:thoughtHistory])) - # {"reflection" - # """ - _prompt = """ You are a helpful sommelier working for a wine store. @@ -475,7 +433,7 @@ function reflector(a::T1, state::T2)::String where {T1<:agent, T2<:AbstractDict} "thought_5": "Now that I have a list of potential wines, I need to know more about the customer's taste preferences.", "action_5": {"name": "chatbox", "input": "What type of wine characteristics are you looking for? (e.g. t.e.g. tannin level, sweetness, intensity, acidity)"}, - "observation_5": "I like full-bodied Red wine with low tannin.", + "observation_5": "I like full-bodied red wine with low tannin.", "thought_6": "Now that I have more information about the customer's preferences, it's time to make a recommendation.", "action_6": {"name": "recommendbox", "input": "El Enemigo Cabernet Franc 2019"}, @@ -523,39 +481,33 @@ function reflector(a::T1, state::T2)::String where {T1<:agent, T2<:AbstractDict} ) ) - attempt = 0 - while true - attempt += 1 - if attempt <= 5 - try - response = GeneralUtils.sendReceiveMqttMsg(outgoingMsg) - _responseJsonStr = response[:response][:text] - expectedJsonExample = - """ - Here is an expected JSON format: - {"reflection": "..."} - """ - responseJsonStr = jsoncorrection(a, _responseJsonStr, expectedJsonExample) - reflectionDict = copy(JSON3.read(responseJsonStr)) + for attempt in 1:5 + try + response = GeneralUtils.sendReceiveMqttMsg(outgoingMsg) + _responseJsonStr = response[:response][:text] + expectedJsonExample = + """ + Here is an expected JSON format: + {"reflection": "..."} + """ + responseJsonStr = jsoncorrection(a, _responseJsonStr, expectedJsonExample) + reflectionDict = copy(JSON3.read(responseJsonStr)) - # check if dict has all required value - dummya::AbstractString = reflectionDict[:reflection] + # check if dict has all required value + dummya::AbstractString = reflectionDict[:reflection] - return reflectionDict[:reflection] - catch e - io = IOBuffer() - showerror(io, e) - errorMsg = String(take!(io)) - st = sprint((io, v) -> show(io, "text/plain", v), stacktrace(catch_backtrace())) - println("") - @warn "Error occurred: $errorMsg\n$st" - println("") - end - - else - error("reflector failed to generate a thought") + return reflectionDict[:reflection] + catch e + io = IOBuffer() + showerror(io, e) + errorMsg = String(take!(io)) + st = sprint((io, v) -> show(io, "text/plain", v), stacktrace(catch_backtrace())) + println("") + @warn "Attempt $attempt. Error occurred: $errorMsg\n$st" + println("") end end + error("reflector failed to generate a thought") end diff --git a/src/llmfunction.jl b/src/llmfunction.jl index 17c9f10..7b9bfd3 100644 --- a/src/llmfunction.jl +++ b/src/llmfunction.jl @@ -99,7 +99,7 @@ function virtualWineCustomerReccommendbox(a::T1, input :text=> prompt, ) ) - @show outgoingMsg + result = GeneralUtils.sendReceiveMqttMsg(outgoingMsg; timeout=120) response = result[:response] @@ -162,28 +162,23 @@ function virtualWineCustomerChatbox(a::T1, input::T2 ) attempt = 0 - while true - attempt += 1 - if attempt <= 5 - try - result = GeneralUtils.sendReceiveMqttMsg(outgoingMsg; timeout=120) - response = result[:response] + for attempt in 1:5 + try + result = GeneralUtils.sendReceiveMqttMsg(outgoingMsg; timeout=120) + response = result[:response] - return (response[:text], response[:select], response[:reward], response[:isterminal]) - catch e - io = IOBuffer() - showerror(io, e) - errorMsg = String(take!(io)) - st = sprint((io, v) -> show(io, "text/plain", v), stacktrace(catch_backtrace())) - println("") - @warn "Error occurred: $errorMsg\n$st" - println("") - end - - else - error("virtualWineCustomerChatbox failed to get a response") + return (response[:text], response[:select], response[:reward], response[:isterminal]) + catch e + io = IOBuffer() + showerror(io, e) + errorMsg = String(take!(io)) + st = sprint((io, v) -> show(io, "text/plain", v), stacktrace(catch_backtrace())) + println("") + @warn "Error occurred: $errorMsg\n$st" + println("") end end + error("virtualWineCustomerChatbox failed to get a response") end @@ -252,77 +247,69 @@ julia> function jsoncorrection(a::T1, input::T2, correctJsonExample::T3) where {T1<:agent, T2<:AbstractString, T3<:AbstractString} - attemptround = 0 incorrectjson = deepcopy(input) correctjson = nothing - while true - attemptround += 1 - if attemptround <= 5 - try - d = copy(JSON3.read(incorrectjson)) - correctjson = incorrectjson - break - catch e - @warn "Attempting correct JSON string. $attemptround" - e = """$e""" - if occursin("EOF", e) - e = split(e, "EOF")[1] * "EOF" - end - incorrectjson = deepcopy(input) - _prompt = - """ - Your goal are: - 1) Use the info why the given JSON string failed to load and provide a corrected version that can be loaded by Python's json.load function. - 2) The user need Corrected JSON string only. Do not provide any other info. - $correctJsonExample + for attempt in 1:5 + try + d = copy(JSON3.read(incorrectjson)) + correctjson = incorrectjson + return correctjson + catch e + @warn "Attempting correct JSON string. Attempt $attempt" + e = """$e""" + if occursin("EOF", e) + e = split(e, "EOF")[1] * "EOF" + end + incorrectjson = deepcopy(input) + _prompt = + """ + Your goal are: + 1) Use the info why the given JSON string failed to load and provide a corrected version that can be loaded by Python's json.load function. + 2) The user need Corrected JSON string only. Do not provide any other info. - Let's begin! - Given JSON string: $incorrectjson - The given JSON string failed to load previously because: $e - Corrected JSON string: - """ + $correctJsonExample - # apply LLM specific instruct format - externalService = a.config[:externalservice][:text2textinstruct] - llminfo = externalService[:llminfo] - prompt = - if llminfo[:name] == "llama3instruct" - formatLLMtext_llama3instruct("system", _prompt) - else - error("llm model name is not defied yet $(@__LINE__)") - end + Let's begin! + Given JSON string: $incorrectjson + The given JSON string failed to load previously because: $e + Corrected JSON string: + """ - # send formatted input to user using GeneralUtils.sendReceiveMqttMsg - msgMeta = GeneralUtils.generate_msgMeta( - externalService[:mqtttopic], - senderName= "jsoncorrection", - senderId= a.id, - receiverName= "text2textinstruct", - mqttBroker= a.config[:mqttServerInfo][:broker], - mqttBrokerPort= a.config[:mqttServerInfo][:port], - ) + # apply LLM specific instruct format + externalService = a.config[:externalservice][:text2textinstruct] + llminfo = externalService[:llminfo] + prompt = + if llminfo[:name] == "llama3instruct" + formatLLMtext_llama3instruct("system", _prompt) + else + error("llm model name is not defied yet $(@__LINE__)") + end - outgoingMsg = Dict( - :msgMeta=> msgMeta, - :payload=> Dict( - :text=> prompt, - :kwargs=> Dict( - :max_tokens=> 512, - :stop=> ["<|eot_id|>"], - ) + # send formatted input to user using GeneralUtils.sendReceiveMqttMsg + msgMeta = GeneralUtils.generate_msgMeta( + externalService[:mqtttopic], + senderName= "jsoncorrection", + senderId= a.id, + receiverName= "text2textinstruct", + mqttBroker= a.config[:mqttServerInfo][:broker], + mqttBrokerPort= a.config[:mqttServerInfo][:port], + ) + + outgoingMsg = Dict( + :msgMeta=> msgMeta, + :payload=> Dict( + :text=> prompt, + :kwargs=> Dict( + :max_tokens=> 512, + :stop=> ["<|eot_id|>"], ) ) - result = GeneralUtils.sendReceiveMqttMsg(outgoingMsg; timeout=120) - incorrectjson = result[:response][:text] - end - else - error("Can't fix JSON string") - break + ) + result = GeneralUtils.sendReceiveMqttMsg(outgoingMsg; timeout=120) + incorrectjson = result[:response][:text] end end - - return correctjson end diff --git a/src/mcts.jl b/src/mcts.jl index 95c1d8d..1eb44f5 100644 --- a/src/mcts.jl +++ b/src/mcts.jl @@ -51,6 +51,7 @@ mutable struct MCTSNode{T1<:AbstractDict, T2<:AbstractString} nodekey::T2 state::T1 visits::Integer + progressvalue::Number statevalue::Number reward::Number isterminal::Bool @@ -78,26 +79,47 @@ julia> # Signature """ function UCTselect(node::MCTSNode, w::T)::MCTSNode where {T<:AbstractFloat} - max_uct = -Inf - selectedNode = nothing + maxUCT = -Inf + selectedNode = nothing - for (childState, childNode) in node.children - weightedterm = - if node.visits == 0 || childNode.visits == 0 - 0 - else - w * sqrt(log(node.visits) / childNode.visits) - end - uctValue = childNode.statevalue + weightedterm - - if uctValue > max_uct - max_uct = uctValue - selectedNode = childNode - end + for (childState, childNode) in node.children + UCTvalue = + if childNode.visits != 0 + weightedterm = w * sqrt(log(node.visits) / childNode.visits) # explore term + childNode.statevalue + weightedterm + else # node.visits == 0 makes sqrt() in explore term error + childNode.progressvalue # exploit term end + + if UCTvalue > maxUCT + maxUCT = UCTvalue + selectedNode = childNode + end + end - return selectedNode + return selectedNode end +# function UCTselect(node::MCTSNode, w::T)::MCTSNode where {T<:AbstractFloat} +# max_uct = -Inf +# selectedNode = nothing + +# for (childState, childNode) in node.children +# weightedterm = +# if node.visits == 0 || childNode.visits == 0 # node.visits == 0 makes sqrt() error +# 0 +# else +# w * sqrt(log(node.visits) / childNode.visits) +# end +# uctValue = childNode.statevalue + weightedterm + +# if uctValue > max_uct +# max_uct = uctValue +# selectedNode = childNode +# end +# end + +# return selectedNode +# end """ Expand selected node @@ -139,14 +161,15 @@ function expand(a::T1, node::MCTSNode, decisionMaker::Function, while true nthSample += 1 if nthSample <= n - println("---> expand() sample $nthSample") thoughtDict = decisionMaker(a, node.state) - + println("---> expand() sample $nthSample") + pprintln(node.state[:thoughtHistory]) + pprintln(thoughtDict) newNodeKey, newstate, reward, isterminalstate = MCTStransition(a, node.state, thoughtDict) # add progressValueEstimator - stateevaluation, statevalue = progressValueEstimator(a, newstate) + stateevaluation, progressvalue = progressValueEstimator(a, newstate) if reward < 0 pprint(newstate[:thoughtHistory]) @@ -156,7 +179,7 @@ function expand(a::T1, node::MCTSNode, decisionMaker::Function, # store new lesson for later use lessonDict = copy(JSON3.read("lesson.json")) latestLessonKey, latestLessonIndice = - GeneralUtils.findHighestIndexKey(state[:thoughtHistory], "lesson") + GeneralUtils.findHighestIndexKey(lessonDict, "lesson") nextIndice = latestLessonKey == :NA ? 1 : latestLessonIndice + 1 newLessonKey = Symbol("lesson_$(nextIndice)") lessonDict[newLessonKey] = newstate @@ -167,7 +190,7 @@ function expand(a::T1, node::MCTSNode, decisionMaker::Function, end if newNodeKey ∉ keys(node.children) - node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, statevalue, + node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, reward, isterminalstate, node, Dict{String, MCTSNode}()) end else @@ -236,10 +259,6 @@ julia> """ function backpropagate(node::MCTSNode, simTrajectoryReward::T; discountRewardCoeff::AbstractFloat=0.9) where {T<:Number} - # [WORKING] store best trajectory - fullTrajectoryReward = 0 - isLeafNodeTerminalState = node.isterminal - terminalStateReward = node.reward while !isroot(node) # Update the statistics of the current node based on the result of the playout node.visits += 1 @@ -387,9 +406,9 @@ function selectChildNode(node::MCTSNode)::MCTSNode # loop thought node children dictionary to find the highest progress value for (k, childNode) in node.children - thisNodeProgressValue = childNode.statevalue + childNode.reward - if thisNodeProgressValue > highestProgressValue - highestProgressValue = thisNodeProgressValue + potential = childNode.progressvalue + childNode.reward + if potential > highestProgressValue + highestProgressValue = potential nodekey = childNode.nodekey end end @@ -468,10 +487,12 @@ function runMCTS( maxIterations::Integer, w::Float64) where {T1<:agent} - root = MCTSNode("root", initialState, 0, 0, 0, false, nothing, Dict{String, MCTSNode}()) + root = MCTSNode("root", initialState, 0, 0, 0, 0, false, nothing, Dict{String, MCTSNode}()) for nth in 1:maxIterations node = root + node.visits += 1 + while !isleaf(node) node = UCTselect(node, w) end @@ -481,14 +502,14 @@ function runMCTS( backpropagate(leafNode, node.reward) else expand(a, node, decisionMaker, progressValueEstimator, reflector; n=n) - leafNode = UCTselect(node, w) + leafNode = selectChildNode(node) simTrajectoryReward = simulate(a, leafNode, decisionMaker, progressValueEstimator, reflector; maxDepth=maxDepth, n=n) backpropagate(leafNode, simTrajectoryReward) end end - best_child_state = argmax([child.total_reward / child.visits for child in values(root.children)]) + best_child_state = argmax([child.statevalue / child.visits for child in values(root.children)]) error("---> runMCTS") return best_child_state end diff --git a/test/test_1.jl b/test/test_1.jl index b2430e9..887703e 100644 --- a/test/test_1.jl +++ b/test/test_1.jl @@ -66,7 +66,7 @@ result = GeneralUtils.sendMqttMsg(outgoingMsg) outgoingMsg = Dict( :msgMeta=> msgMeta, :payload=> Dict( - :text=> "I like full-bodied Red wine with low tannin.", + :text=> "I like full-bodied red wine with low tannin.", :select=> nothing, :reward=> 0, :isterminal=> false, @@ -134,7 +134,7 @@ result = GeneralUtils.sendMqttMsg(outgoingMsg) outgoingMsg = Dict( :msgMeta=> msgMeta, :payload=> Dict( - :text=> "Are there any other options?", + :text=> "I like it dry.", :select=> nothing, :reward=> 0, :isterminal=> false, @@ -162,10 +162,10 @@ result = GeneralUtils.sendMqttMsg(outgoingMsg) outgoingMsg = Dict( :msgMeta=> msgMeta, :payload=> Dict( - :text=> "What are you saying. I don't understand.", + :text=> "You didn't tell me wine name.", :select=> nothing, :reward=> -1, - :isterminal=> false, + :isterminal=> true, ) ) result = GeneralUtils.sendMqttMsg(outgoingMsg)