This commit is contained in:
narawat lamaiin
2024-05-13 17:37:44 +07:00
parent 8431258f1c
commit 62c6ce90ed
4 changed files with 432 additions and 86 deletions

View File

@@ -295,7 +295,7 @@ function evaluator(a::T1, state::T2)::Tuple{String, Integer} where {T1<:agent, T
} }
{"evaluation": "This trajectory is correct as it is reasonable to check an inventory for info provided in the question. {"evaluation": "This trajectory is correct as it is reasonable to check an inventory for info provided in the question.
It is also better to have simple searches corresponding to a single entity, making this the best action.", It is also better to have simple searches corresponding to a single entity, making this the best action.",
"score": 7 "score": 10
} }
{ {
@@ -309,7 +309,7 @@ function evaluator(a::T1, state::T2)::Tuple{String, Integer} where {T1<:agent, T
} }
{"evaluation": "This trajectory is incorrect as my search term should be related to a 4-colors pen with a pencil in it, {"evaluation": "This trajectory is incorrect as my search term should be related to a 4-colors pen with a pencil in it,
not a pen and a pencil seperately. A better search term should have been a 4-colors pen with a pencil, all-in-one.", not a pen and a pencil seperately. A better search term should have been a 4-colors pen with a pencil, all-in-one.",
"score": 3 "score": 0
} }
Let's begin!: Let's begin!:
@@ -378,6 +378,143 @@ function evaluator(a::T1, state::T2)::Tuple{String, Integer} where {T1<:agent, T
end end
# """
# # Arguments
# - `a::T1`
# one of Yiem's agent
# - `state::T2`
# a game state
# # Return
# - `evaluation::Tuple{String, Integer}`
# evaluation and score
# # Example
# ```jldoctest
# julia>
# ```
# # TODO
# - [] update docs
# - [] implement the function
# # Signature
# """
# function comparer(a::T1, state::T2)::Tuple{String, Integer} where {T1<:agent, T2<:AbstractDict}
# _prompt =
# """
# Analyze the trajectories of a solution to a question answering task. The trajectories are
# labeled by environmental observations about the situation, thoughts that can reason about
# the current situation and actions that can be three types:
# 1) winestock[query], which you can use to find wine in your inventory.
# 2) chatbox[text], which you can use to interact with the user.
# 3) recommendbox[answer], which returns your wine recommendation to the user.
# Given a question and a trajectory, evaluate its correctness and provide your reasoning and
# analysis in detail. Focus on the latest thought, action, and observation. Incomplete trajectories
# can be correct if the thoughts and actions so far are correct, even if the answer is not found
# yet. Do not generate additional thoughts or actions. Then ending with the correctness score s
# where s is an integer from 0 to 10.
# You should only respond in JSON format as describe below:
# {"evaluation": "your evaluation", "score": "your evaluation score"}
# Here are some examples:
# {
# "question": "I'm looking for a sedan with an automatic driving feature.",
# "thought_1": "I have many types of sedans in my inventory, each with diverse features.",
# "thought_2": "But there is only 1 model that has the feature customer wanted.",
# "thought_3": "I should check our inventory first to see if we have it.",
# "action_1": {"name": "inventory", "input": "Yiem model A"},
# "observation_1": "Yiem model A is in stock."
# }
# {"evaluation": "This trajectory is correct as it is reasonable to check an inventory for info provided in the question.
# It is also better to have simple searches corresponding to a single entity, making this the best action.",
# "score": 10
# }
# {
# "question": "Do you have an all-in-one pen with 4 colors and a pencil for sale?",
# "thought_1": "Let me check our inventory first to see if I have it.",
# "action_1": {"name": "inventory", "input": "pen with 4 color and a pencil."},
# "observation_1": "I found {1: "Pilot Dr. grip 4-in-1 pen", 2: "Rotting pencil"}",
# "thought_2": "Ok, I have what the user is asking. Let's tell the user.",
# "action_2": {"name": "chatbox", "input": "Yes, we do have a Pilot Dr. grip 4-in-1 pen and a Rotting pencil"},
# "observation_1": "This is not what I wanted."
# }
# {"evaluation": "This trajectory is incorrect as my search term should be related to a 4-colors pen with a pencil in it,
# not a pen and a pencil seperately. A better search term should have been a 4-colors pen with a pencil, all-in-one.",
# "score": 0
# }
# Let's begin!:
# $(JSON3.write(state[:thoughtHistory]))
# {"evaluation"
# """
# # apply LLM specific instruct format
# externalService = a.config[:externalservice][:text2textinstruct]
# llminfo = externalService[:llminfo]
# prompt =
# if llminfo[:name] == "llama3instruct"
# formatLLMtext_llama3instruct("system", _prompt)
# else
# error("llm model name is not defied yet $(@__LINE__)")
# end
# msgMeta = GeneralUtils.generate_msgMeta(
# a.config[:externalservice][:text2textinstruct][:mqtttopic],
# senderName= "evaluator",
# senderId= a.id,
# receiverName= "text2textinstruct",
# mqttBroker= a.config[:mqttServerInfo][:broker],
# mqttBrokerPort= a.config[:mqttServerInfo][:port],
# )
# outgoingMsg = Dict(
# :msgMeta=> msgMeta,
# :payload=> Dict(
# :text=> prompt,
# :kwargs=> Dict(
# :max_tokens=> 512,
# :stop=> ["<|eot_id|>"],
# )
# )
# )
# for attempt in 1:5
# try
# response = GeneralUtils.sendReceiveMqttMsg(outgoingMsg)
# _responseJsonStr = response[:response][:text]
# expectedJsonExample =
# """
# Here is an expected JSON format:
# {"evaluation": "...", "score": "..."}
# """
# responseJsonStr = jsoncorrection(a, _responseJsonStr, expectedJsonExample)
# evaluationDict = copy(JSON3.read(responseJsonStr))
# # check if dict has all required value
# dummya::AbstractString = evaluationDict[:evaluation]
# dummyb::Integer = evaluationDict[:score]
# return (evaluationDict[:evaluation], evaluationDict[:score])
# catch e
# io = IOBuffer()
# showerror(io, e)
# errorMsg = String(take!(io))
# st = sprint((io, v) -> show(io, "text/plain", v), stacktrace(catch_backtrace()))
# println("")
# @warn "Attempt $attempt. Error occurred: $errorMsg\n$st"
# println("")
# end
# end
# error("evaluator failed to generate an evaluation")
# end
""" """
# Arguments # Arguments
@@ -600,8 +737,9 @@ julia> response = ChatAgent.conversation(newAgent, "Hi! how are you?")
# TODO # TODO
- [] update docstring - [] update docstring
- [WORKING] MCTS() for planning - [x] MCTS() for planning
- [] add recap to initialState for earlier completed question - [] add recap to initialState for earlier completed question
- [WORKING] conversation loop
# Signature # Signature
""" """
@@ -617,36 +755,36 @@ function conversation(a::T, userinput::Dict) where {T<:agent}
# add usermsg to a.chathistory # add usermsg to a.chathistory
addNewMessage(a, "user", userinput[:text]) addNewMessage(a, "user", userinput[:text])
#[] if the last used tool is a chatbox, put usermsg -> observation and continue actor loop as planned currentstate =
if !isempty(a.plan[:currenttrajectory]) && if isempty(a.plan[:currenttrajectory])
a.plan[:currenttrajectory][end][:action] == "chatbox" # set up initial state
Dict{Symbol, Any}(
# deepcopy the info to prevent modifying the info unintentionally during MCTS planning
:customerinfo=> deepcopy(a.keywordinfo[:customerinfo]),
:storeinfo=> deepcopy(a.keywordinfo[:storeinfo]),
:userselect=> nothing,
:reward=> 0,
:isterminal=> false,
:evaluation=> nothing,
:lesson=> nothing,
:thoughtDict=> nothing,
:totalTrajectoryReward=> nothing,
:thoughtHistory=> OrderedDict{Symbol, Any}( # contain question, thought_1, action_1, observation_1, thought_2, ...
# :recap=>,
:question=> userinput[:text],
)
)
else else
initialState = Dict{Symbol, Any}( a.plan[:currenttrajectory]
end
# deepcopy the info to prevent modifying the info unintentionally during MCTS planning bestNextState, besttrajectory = runMCTS(a, currentstate, decisionMaker, evaluator, reflector,
:customerinfo=> deepcopy(a.keywordinfo[:customerinfo]), totalsample=3, maxDepth=2, maxiterations=1, explorationweight=1.0)
:storeinfo=> deepcopy(a.keywordinfo[:storeinfo]),
:select=> nothing, # transition
:reward=> 0, newstate = transition(a, bestNextState)
:isterminal=> false, a.plan[:currenttrajectory] = newstate
:evaluation=> nothing,
:lesson=> nothing,
:thoughtHistory=> OrderedDict{Symbol, Any}( # contain question, thought_1, action_1, observation_1, thought_2, ...
# :recap=>,
:question=> userinput[:text],
)
)
bestplan = runMCTS(a, initialState, decisionMaker, evaluator, reflector,
2, 3, 4, 1.0)
error("---> bestplan")
# actor loop(bestplan)
end
end end
end end

View File

@@ -1,7 +1,7 @@
module llmfunction module llmfunction
export virtualWineCustomerChatbox, jsoncorrection, winestock, export virtualWineUserChatbox, jsoncorrection, winestock,
virtualWineCustomerReccommendbox virtualWineUserRecommendbox, userChatbox, userRecommendbox
using HTTP, JSON3, URIs, Random using HTTP, JSON3, URIs, Random
using GeneralUtils using GeneralUtils
@@ -26,8 +26,46 @@ julia>
# Signature # Signature
""" """
function chatbox(a::T1, input::T2) where {T1<:agent, T2<:AbstractString} function userChatbox(a::T1, input::T2) where {T1<:agent, T2<:AbstractString}
error("--> chatbox") error("--> userChatbox")
# put in model format
virtualWineCustomer = a.config[:externalservice][:virtualWineCustomer_1]
llminfo = virtualWineCustomer[:llminfo]
formattedinput =
if llminfo[:name] == "llama3instruct"
formatLLMtext_llama3instruct("assistant", input)
else
error("llm model name is not defied yet $(@__LINE__)")
end
# send formatted input to user using GeneralUtils.sendReceiveMqttMsg
# return response
end
"""
# Arguments
# Return
# Example
```jldoctest
julia>
```
# TODO
- [] update docstring
- [PENDING] implement the function
# Signature
"""
function userRecommendbox(a::T1, input::T2) where {T1<:agent, T2<:AbstractString}
error("--> userRecommendbox")
# put in model format # put in model format
virtualWineCustomer = a.config[:externalservice][:virtualWineCustomer_1] virtualWineCustomer = a.config[:externalservice][:virtualWineCustomer_1]
@@ -69,7 +107,7 @@ julia>
# Signature # Signature
""" """
function virtualWineCustomerReccommendbox(a::T1, input function virtualWineUserRecommendbox(a::T1, input
)::Union{Tuple{String, Number, Number, Bool}, Tuple{String, Nothing, Number, Bool}} where {T1<:agent} )::Union{Tuple{String, Number, Number, Bool}, Tuple{String, Nothing, Number, Bool}} where {T1<:agent}
# put in model format # put in model format
@@ -85,7 +123,7 @@ function virtualWineCustomerReccommendbox(a::T1, input
# send formatted input to user using GeneralUtils.sendReceiveMqttMsg # send formatted input to user using GeneralUtils.sendReceiveMqttMsg
msgMeta = GeneralUtils.generate_msgMeta( msgMeta = GeneralUtils.generate_msgMeta(
virtualWineCustomer[:mqtttopic], virtualWineCustomer[:mqtttopic],
senderName= "virtualWineCustomerReccommendbox", senderName= "virtualWineUserRecommendbox",
senderId= a.id, senderId= a.id,
receiverName= "virtualWineCustomer", receiverName= "virtualWineCustomer",
mqttBroker= a.config[:mqttServerInfo][:broker], mqttBroker= a.config[:mqttServerInfo][:broker],
@@ -125,12 +163,11 @@ julia>
``` ```
# TODO # TODO
- [] update docs - [] update docs
- [] add to remove <<< user option select >>> and <<| reward |>>
# Signature # Signature
""" """
function virtualWineCustomerChatbox(a::T1, input::T2 function virtualWineUserChatbox(a::T1, input::T2
)::Union{Tuple{String, Number, Number, Bool}, Tuple{String, Nothing, Number, Bool}} where {T1<:agent, T2<:AbstractString} )::Union{Tuple{String, Number, Number, Bool}, Tuple{String, Nothing, Number, Bool}} where {T1<:agent, T2<:AbstractString}
# put in model format # put in model format
@@ -146,7 +183,7 @@ function virtualWineCustomerChatbox(a::T1, input::T2
# send formatted input to user using GeneralUtils.sendReceiveMqttMsg # send formatted input to user using GeneralUtils.sendReceiveMqttMsg
msgMeta = GeneralUtils.generate_msgMeta( msgMeta = GeneralUtils.generate_msgMeta(
virtualWineCustomer[:mqtttopic], virtualWineCustomer[:mqtttopic],
senderName= "virtualWineCustomerChatbox", senderName= "virtualWineUserChatbox",
senderId= a.id, senderId= a.id,
receiverName= "virtualWineCustomer", receiverName= "virtualWineCustomer",
mqttBroker= a.config[:mqttServerInfo][:broker], mqttBroker= a.config[:mqttServerInfo][:broker],
@@ -178,7 +215,7 @@ function virtualWineCustomerChatbox(a::T1, input::T2
println("") println("")
end end
end end
error("virtualWineCustomerChatbox failed to get a response") error("virtualWineUserChatbox failed to get a response")
end end

View File

@@ -5,7 +5,8 @@
module mcts module mcts
export MCTSNode, runMCTS, isleaf export MCTSNode, runMCTS, isleaf, selectBestNextState, selectBestTrajectory, transition,
userChatbox
using Dates, UUIDs, DataStructures, JSON3, Random, PrettyPrinting using Dates, UUIDs, DataStructures, JSON3, Random, PrettyPrinting
using GeneralUtils using GeneralUtils
@@ -51,9 +52,9 @@ mutable struct MCTSNode{T1<:AbstractDict, T2<:AbstractString}
nodekey::T2 nodekey::T2
state::T1 state::T1
visits::Integer visits::Integer
progressvalue::Number progressvalue::Number # estimate value by LLM's reasoning
statevalue::Number statevalue::Number # store discounted commulative reward (gather from its child node)
reward::Number reward::Number # this node's own reward
isterminal::Bool isterminal::Bool
parent::Union{MCTSNode, Nothing} parent::Union{MCTSNode, Nothing}
children::Dict{String, MCTSNode} children::Dict{String, MCTSNode}
@@ -132,23 +133,24 @@ julia>
# Signature # Signature
""" """
function expand(a::T1, node::MCTSNode, decisionMaker::Function, function expand(a::T1, node::MCTSNode, decisionMaker::Function,
evaluator::Function, reflector::Function; n::Integer=3) where {T1<:agent} evaluator::Function, reflector::Function; totalsample::Integer=3
) where {T1<:agent}
nthSample = 0 nthSample = 0
while true while true
nthSample += 1 nthSample += 1
if nthSample <= n if nthSample <= totalsample
thoughtDict = decisionMaker(a, node.state) thoughtDict = decisionMaker(a, node.state)
println("---> expand() sample $nthSample") println("---> expand() sample $nthSample")
pprintln(node.state[:thoughtHistory]) pprintln(node.state[:thoughtHistory])
pprintln(thoughtDict) pprintln(thoughtDict)
newNodeKey, newstate, reward, isterminalstate = node.state[:thoughtDict] = thoughtDict
MCTStransition(a, node.state, thoughtDict) newNodeKey, newstate = MCTStransition(a, node.state)
# add evaluator # add evaluator
stateevaluation, progressvalue = evaluator(a, newstate) stateevaluation, progressvalue = evaluator(a, newstate)
if reward < 0 if newstate[:reward] < 0
pprint(newstate[:thoughtHistory]) pprint(newstate[:thoughtHistory])
newstate[:evaluation] = stateevaluation newstate[:evaluation] = stateevaluation
newstate[:lesson] = reflector(a, newstate) newstate[:lesson] = reflector(a, newstate)
@@ -167,8 +169,9 @@ function expand(a::T1, node::MCTSNode, decisionMaker::Function,
end end
if newNodeKey keys(node.children) if newNodeKey keys(node.children)
node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, node.children[newNodeKey] =
reward, isterminalstate, node, Dict{String, MCTSNode}()) MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward],
newstate[:isterminal], node, Dict{String, MCTSNode}())
end end
else else
break break
@@ -196,24 +199,30 @@ end
julia> julia>
``` ```
# TODO
- [] update docs
# Signature # Signature
""" """
function simulate(a::T, node::MCTSNode, decisionMaker::Function, evaluator::Function, function simulate(a::T, node::MCTSNode, decisionMaker::Function, evaluator::Function,
reflector::Function; maxDepth::Integer=3, n::Integer=3)::Number where {T<:agent} reflector::Function; maxDepth::Integer=3, totalsample::Integer=3
)::Union{Tuple{Number, Dict{Symbol, <:Any}}, Tuple{Number, Nothing}} where {T<:agent}
simTrajectoryReward = 0.0 simTrajectoryReward = 0.0
terminalstate = nothing
for depth in 1:maxDepth for depth in 1:maxDepth
simTrajectoryReward += node.reward simTrajectoryReward += node.reward
if node.isterminal if node.isterminal
terminalstate = node.state
break break
else else
expand(a, node, decisionMaker, evaluator, reflector; n=n) expand(a, node, decisionMaker, evaluator, reflector; totalsample=totalsample)
node = selectChildNode(node) node = selectChildNode(node)
end end
end end
return simTrajectoryReward return (simTrajectoryReward, terminalstate)
end end
""" Backpropagate reward along the simulation chain """ Backpropagate reward along the simulation chain
@@ -285,20 +294,21 @@ julia> thoughtDict = Dict(
# Signature # Signature
""" """
function MCTStransition(a::T1, state::T2, thoughtDict::T3 function MCTStransition(a::T1, state::T2
)::Tuple{String, Dict{Symbol, <:Any}, <:Number, Bool} where {T1<:agent, T2<:AbstractDict, T3<:AbstractDict} )::Tuple{String, Dict{Symbol, <:Any}} where {T1<:agent, T2<:AbstractDict}
thoughtDict = state[:thoughtDict]
actionname = thoughtDict[:action][:name] actionname = thoughtDict[:action][:name]
actioninput = thoughtDict[:action][:input] actioninput = thoughtDict[:action][:input]
# map action and input() to llm function # map action and input() to llm function
response, select, reward, isterminal = response, select, reward, isterminal =
if actionname == "chatbox" if actionname == "chatbox"
virtualWineCustomerChatbox(a, actioninput) # virtual customer virtualWineUserChatbox(a, actioninput) # virtual customer
elseif actionname == "winestock" elseif actionname == "winestock"
winestock(a, actioninput) winestock(a, actioninput)
elseif actionname == "recommendbox" elseif actionname == "recommendbox"
virtualWineCustomerReccommendbox(a, actioninput) virtualWineUserRecommendbox(a, actioninput)
else else
error("undefined LLM function. Requesting $actionname") error("undefined LLM function. Requesting $actionname")
end end
@@ -321,7 +331,85 @@ function MCTStransition(a::T1, state::T2, thoughtDict::T3
newNodeKey = GeneralUtils.uuid4snakecase() newNodeKey = GeneralUtils.uuid4snakecase()
return (newNodeKey, newstate, reward, isterminal) return (newNodeKey, newstate)
end
""" Get a new state
# Arguments
- `a::T1`
one of YiemAgent's agent
- `state::T2`
current game state
- `thoughtDict::T3`
contain Thought, Action, Observation
- `isterminal::Function`
a function to determine terminal state
# Return
- `(newNodeKey, newstate, isterminalstate, reward)::Tuple{String, Dict{Symbol, <:Any}, Bool, <:Number}`
# Example
```jldoctest
julia> state = Dict{Symbol, Dict{Symbol, Any}}(
:thoughtHistory => Dict(:question => "Hello, I want to buy a bottle of wine."),
:storeinfo => Dict(),
:customerinfo => Dict()
)
julia> thoughtDict = Dict(
:question=> "I want to buy a bottle of wine.",
:thought_1=> "The customer wants to buy a bottle of wine.",
:action_1=> Dict{Symbol, Any}(
:name=>"Chatbox",
:input=>"What occasion are you buying the wine for?",
),
:observation_1 => ""
)
```
# TODO
- [x] add other actions
- [] add embedding of newstate and store in newstate[:embedding]
# Signature
"""
function transition(a::T1, state::T2
)::Dict{Symbol, <:Any} where {T1<:agent, T2<:AbstractDict}
thoughtDict = state[:thoughtDict]
actionname = thoughtDict[:action][:name]
actioninput = thoughtDict[:action][:input]
# map action and input() to llm function
response, select, reward, isterminal =
if actionname == "chatbox"
userChatbox(a, actioninput) # virtual customer
elseif actionname == "winestock"
winestock(a, actioninput)
elseif actionname == "recommendbox"
userRecommendbox(a, actioninput)
else
error("undefined LLM function. Requesting $actionname")
end
latestThoughtKey, latestThoughtIndice =
GeneralUtils.findHighestIndexKey(state[:thoughtHistory], "thought")
nextIndice = latestThoughtKey == :NA ? 1 : latestThoughtIndice + 1
latestThoughtKey = Symbol("thought_$nextIndice")
latestActionKey = Symbol("action_$nextIndice")
# add Thought, action, observation to thoughtHistory
newstate = deepcopy(state)
newstate[:thoughtHistory][latestThoughtKey] = thoughtDict[:thought]
newstate[:thoughtHistory][latestActionKey] = thoughtDict[:action]
newObservationKey = Symbol("observation_$(nextIndice)")
newstate[:thoughtHistory][newObservationKey] = response
newstate[:reward] = reward
newstate[:select] = select
newstate[:isterminal] = isterminal
return newstate
end end
@@ -396,6 +484,90 @@ function selectChildNode(node::MCTSNode)::MCTSNode
end end
"""
# Arguments
- `node::MCTSNode`
node of a search tree
# Return
- `childNode::MCTSNode`
the highest value child node
# Example
```jldoctest
julia>
```
# TODO
- [] update docs
- [TESTING] implement the function
# Signature
"""
function selectBestNextState(node::MCTSNode)::MCTSNode
highestProgressValue = 0
nodekey = nothing
# if all childnode has statevalue == 0, use progressvalue + reward to select the best node
stateValueSum = sum([v.statevalue for (k, v) in node.children])
if stateValueSum != 0
for (k, childnode) in node.children
potential = childnode.statevalue / childnode.visits
if potential > highestProgressValue
highestProgressValue = potential
nodekey = childnode.nodekey
end
end
else
for (k, childnode) in node.children
potential = childnode.progressvalue + childnode.reward
if potential > highestProgressValue
highestProgressValue = potential
nodekey = childnode.nodekey
end
end
end
return node.children[nodekey]
end
"""
# Arguments
- `node::MCTSNode`
node of a search tree
# Return
- `childNode::MCTSNode`
the highest value child node
# Example
```jldoctest
julia>
```
# TODO
- [] update docs
- [TESTING] implement the function
# Signature
"""
function selectBestTrajectory(node::MCTSNode)::MCTSNode
while !isleaf(node)
node = selectBestNextState(node)
end
return node
end
""" Determine wheter a given node is a root node """ Determine wheter a given node is a root node
# Arguments # Arguments
@@ -451,7 +623,7 @@ julia>
# TODO # TODO
[] update docstring [] update docstring
[PENDING] return best plan [x] return best action
# Signature # Signature
""" """
@@ -460,46 +632,48 @@ function runMCTS(
initialState, initialState,
decisionMaker::Function, decisionMaker::Function,
evaluator::Function, evaluator::Function,
reflector::Function, reflector::Function;
n::Integer, totalsample::Integer=3,
maxDepth::Integer, maxDepth::Integer=3,
maxIterations::Integer, maxiterations::Integer=10,
w::Float64 explorationweight::Number=1.0,
) where {T1<:agent} ) where {T1<:agent}
root = MCTSNode("root", initialState, 0, 0, 0, 0, false, nothing, Dict{String, MCTSNode}()) root = MCTSNode("root", initialState, 0, 0, 0, 0, false, nothing, Dict{String, MCTSNode}())
for nth in 1:maxIterations for nth in 1:maxiterations
node = root node = root
node.visits += 1 node.visits += 1
while !isleaf(node) while !isleaf(node)
node = UCTselect(node, w) node = UCTselect(node, explorationweight)
end end
if node.isterminal if node.isterminal
# MCTS arrive at the leaf node that is also a terminal state, # MCTS arrive at the leaf node that is also a terminal state,
# do nothing then go directly to backpropagation # do nothing then go directly to backpropagation
backpropagate(leafNode, node.reward) backpropagate(leafNode, node.reward)
else else
expand(a, node, decisionMaker, evaluator, reflector; n=n) expand(a, node, decisionMaker, evaluator, reflector; totalsample=totalsample)
leafNode = selectChildNode(node) leafNode = selectChildNode(node)
simTrajectoryReward = simulate(a, leafNode, decisionMaker, evaluator, simTrajectoryReward, terminalstate = simulate(a, leafNode, decisionMaker, evaluator,
reflector; maxDepth=maxDepth, n=n) reflector; maxDepth=maxDepth, totalsample=totalsample)
if terminalstate !== nothing
terminalstate[:totalTrajectoryReward] = simTrajectoryReward
end
#[] write best state to file if it has higher simTrajectoryReward. Use to improve evaluation
# open("trajectory.json", "w") do io
# JSON3.pretty(io, terminalstate)
# end
backpropagate(leafNode, simTrajectoryReward) backpropagate(leafNode, simTrajectoryReward)
end end
end end
avgStateValue = 0 bestNextState = selectBestNextState(root)
selectedChildKey = nothing besttrajectory = selectBestTrajectory(root)
for (k, v) in root.children
k_avgStateValue = v.statevalue / v.visits
if k_avgStateValue > avgStateValue
avgStateValue = k_avgStateValue
selectedChildKey = k
end
end
return root.children[selectedChildKey] return (bestNextState.state, besttrajectory.state)
end end
@@ -525,9 +699,6 @@ end

View File

@@ -134,7 +134,7 @@ result = GeneralUtils.sendMqttMsg(outgoingMsg)
outgoingMsg = Dict( outgoingMsg = Dict(
:msgMeta=> msgMeta, :msgMeta=> msgMeta,
:payload=> Dict( :payload=> Dict(
:text=> "I like it dry.", :text=> "I like dry wine with fruity flavors.",
:select=> nothing, :select=> nothing,
:reward=> 0, :reward=> 0,
:isterminal=> false, :isterminal=> false,