update
This commit is contained in:
@@ -416,7 +416,7 @@ Action 3: Search[Mexican Grand Prix winners]
|
|||||||
Observation 3: Could not find Mexican Grand Prix winners. Similar: ['Mexican Grand Prix', 'List of Formula One Grand Prix winners', '1990 Mexican Grand Prix', '2018 Mexican Grand Prix', '2019 Mexican Grand Prix'].
|
Observation 3: Could not find Mexican Grand Prix winners. Similar: ['Mexican Grand Prix', 'List of Formula One Grand Prix winners', '1990 Mexican Grand Prix', '2018 Mexican Grand Prix', '2019 Mexican Grand Prix'].
|
||||||
Thought 4: Given the difficulties in finding a direct list of Mexican Grand Prix winners, I should search for the Mexican Grand Prix to get a broader picture of the race's history. This might include winners.
|
Thought 4: Given the difficulties in finding a direct list of Mexican Grand Prix winners, I should search for the Mexican Grand Prix to get a broader picture of the race's history. This might include winners.
|
||||||
Action 4: Search[Mexican Grand Prix]
|
Action 4: Search[Mexican Grand Prix]
|
||||||
This trajectory is incorrect as my search should be related to Mexican Formula One race car drivers, not winners of the Mexican Grand Prix, a seperate event. A better search would have been for the List of Formula One Grand Prix winners, as suggested.
|
This trajectory is incorrect as my search should be related to Mexican Formula One race car drivers, not winners it is reasonable to checkof the Mexican Grand Prix, a seperate event. A better search would have been for the List of Formula One Grand Prix winners, as suggested.
|
||||||
Thus the correctness score is 3
|
Thus the correctness score is 3
|
||||||
|
|
||||||
Question: Which magazine was started first Arthur's Magazine or First for Women?
|
Question: Which magazine was started first Arthur's Magazine or First for Women?
|
||||||
|
|||||||
114
src/interface.jl
114
src/interface.jl
@@ -126,7 +126,7 @@ function decisionMaker(a::T1, state::T2)::Dict{Symbol, Any} where {T1<:agent, T2
|
|||||||
1) Get to know what occasion the user is buying wine for
|
1) Get to know what occasion the user is buying wine for
|
||||||
2) Get to know what food the user will have with wine
|
2) Get to know what food the user will have with wine
|
||||||
3) Get to know how much the user willing to spend
|
3) Get to know how much the user willing to spend
|
||||||
4) Get to know type of wine the user is looking for e.g. Red, White, Sparkling, Rose, Dessert, Fortified
|
4) Get to know type of wine the user is looking for e.g. red, white, sparkling, rose, dessert, fortified
|
||||||
5) Get to know what characteristics of wine the user is looking for
|
5) Get to know what characteristics of wine the user is looking for
|
||||||
e.g. tannin, sweetness, intensity, acidity
|
e.g. tannin, sweetness, intensity, acidity
|
||||||
6) Check your inventory for the best wine that match the user preference
|
6) Check your inventory for the best wine that match the user preference
|
||||||
@@ -197,10 +197,8 @@ function decisionMaker(a::T1, state::T2)::Dict{Symbol, Any} where {T1<:agent, T2
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
@show outgoingMsg
|
@show outgoingMsg
|
||||||
attempt = 0
|
|
||||||
while true
|
for attempt in 1:5
|
||||||
attempt += 1
|
|
||||||
if attempt <= 5
|
|
||||||
try
|
try
|
||||||
response = GeneralUtils.sendReceiveMqttMsg(outgoingMsg)
|
response = GeneralUtils.sendReceiveMqttMsg(outgoingMsg)
|
||||||
_responseJsonStr = response[:response][:text]
|
_responseJsonStr = response[:response][:text]
|
||||||
@@ -217,12 +215,15 @@ function decisionMaker(a::T1, state::T2)::Dict{Symbol, Any} where {T1<:agent, T2
|
|||||||
thoughtDict = copy(JSON3.read(responseJsonStr))
|
thoughtDict = copy(JSON3.read(responseJsonStr))
|
||||||
|
|
||||||
# check if dict has all required value
|
# check if dict has all required value
|
||||||
dummya::AbstractString = thoughtDict[:thought]
|
thought::AbstractString = thoughtDict[:thought]
|
||||||
actionname::AbstractString = thoughtDict[:action][:name]
|
actionname::AbstractString = thoughtDict[:action][:name]
|
||||||
actioninput::AbstractString = thoughtDict[:action][:input]
|
actioninput::AbstractString = thoughtDict[:action][:input]
|
||||||
|
|
||||||
if actionname ∈ ["winestock", "chatbox", "recommendbox"]
|
if actionname ∈ ["winestock", "chatbox", "recommendbox"]
|
||||||
# LLM use available function
|
# LLM use available function
|
||||||
|
elseif thought == ""
|
||||||
|
error("DecisionMaker has no thought")
|
||||||
|
elseif length(actioninput) == 0
|
||||||
|
error("DecisionMaker has no actioninput")
|
||||||
else
|
else
|
||||||
error("DecisionMaker use wrong function")
|
error("DecisionMaker use wrong function")
|
||||||
end
|
end
|
||||||
@@ -234,14 +235,11 @@ function decisionMaker(a::T1, state::T2)::Dict{Symbol, Any} where {T1<:agent, T2
|
|||||||
errorMsg = String(take!(io))
|
errorMsg = String(take!(io))
|
||||||
st = sprint((io, v) -> show(io, "text/plain", v), stacktrace(catch_backtrace()))
|
st = sprint((io, v) -> show(io, "text/plain", v), stacktrace(catch_backtrace()))
|
||||||
println("")
|
println("")
|
||||||
@warn "Error occurred: $errorMsg\n$st"
|
@warn "Attempt $attempt. Error occurred: $errorMsg\n$st"
|
||||||
println("")
|
println("")
|
||||||
end
|
end
|
||||||
|
end
|
||||||
else
|
|
||||||
error("DecisionMaker failed to generate a thought")
|
error("DecisionMaker failed to generate a thought")
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
@@ -297,7 +295,22 @@ function progressValueEstimator(a::T1, state::T2)::Tuple{String, Integer} where
|
|||||||
}
|
}
|
||||||
{"evaluation": "This trajectory is correct as it is reasonable to check an inventory for info provided in the question.
|
{"evaluation": "This trajectory is correct as it is reasonable to check an inventory for info provided in the question.
|
||||||
It is also better to have simple searches corresponding to a single entity, making this the best action.",
|
It is also better to have simple searches corresponding to a single entity, making this the best action.",
|
||||||
"score": 10}
|
"score": 10
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
"question": "Do you have an all-in-one pen with 4 colors and a pencil for sale?",
|
||||||
|
"thought_1": "Let me check our inventory first to see if I have it.",
|
||||||
|
"action_1": {"name": "inventory", "input": "pen with 4 color and a pencil."},
|
||||||
|
"observation_1": "I found {1: "Pilot Dr. grip 4-in-1 pen", 2: "Rotting pencil"}",
|
||||||
|
"thought_2": "Ok, I have what the user is asking. Let's tell the user.",
|
||||||
|
"action_2": {"name": "chatbox", "input": "Yes, we do have a Pilot Dr. grip 4-in-1 pen and a Rotting pencil"},
|
||||||
|
"observation_1": "This is not what I wanted."
|
||||||
|
}
|
||||||
|
{"evaluation": "This trajectory is incorrect as my search term should be related to a 4-colors pen with a pencil in it,
|
||||||
|
not a pen and a pencil seperately. A better search term should have been a 4-colors pen with a pencil, all-in-one.",
|
||||||
|
"score": 2
|
||||||
|
}
|
||||||
|
|
||||||
Let's begin!:
|
Let's begin!:
|
||||||
$(JSON3.write(state[:thoughtHistory]))
|
$(JSON3.write(state[:thoughtHistory]))
|
||||||
@@ -334,10 +347,7 @@ function progressValueEstimator(a::T1, state::T2)::Tuple{String, Integer} where
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
attempt = 0
|
for attempt in 1:5
|
||||||
while true
|
|
||||||
attempt += 1
|
|
||||||
if attempt <= 5
|
|
||||||
try
|
try
|
||||||
response = GeneralUtils.sendReceiveMqttMsg(outgoingMsg)
|
response = GeneralUtils.sendReceiveMqttMsg(outgoingMsg)
|
||||||
_responseJsonStr = response[:response][:text]
|
_responseJsonStr = response[:response][:text]
|
||||||
@@ -360,14 +370,11 @@ function progressValueEstimator(a::T1, state::T2)::Tuple{String, Integer} where
|
|||||||
errorMsg = String(take!(io))
|
errorMsg = String(take!(io))
|
||||||
st = sprint((io, v) -> show(io, "text/plain", v), stacktrace(catch_backtrace()))
|
st = sprint((io, v) -> show(io, "text/plain", v), stacktrace(catch_backtrace()))
|
||||||
println("")
|
println("")
|
||||||
@warn "Error occurred: $errorMsg\n$st"
|
@warn "Attempt $attempt. Error occurred: $errorMsg\n$st"
|
||||||
println("")
|
println("")
|
||||||
end
|
end
|
||||||
|
end
|
||||||
else
|
|
||||||
error("progressValueEstimator failed to generate an evaluation")
|
error("progressValueEstimator failed to generate an evaluation")
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
@@ -392,55 +399,6 @@ julia>
|
|||||||
function reflector(a::T1, state::T2)::String where {T1<:agent, T2<:AbstractDict}
|
function reflector(a::T1, state::T2)::String where {T1<:agent, T2<:AbstractDict}
|
||||||
# https://github.com/andyz245/LanguageAgentTreeSearch/blob/main/hotpot/hotpot.py
|
# https://github.com/andyz245/LanguageAgentTreeSearch/blob/main/hotpot/hotpot.py
|
||||||
|
|
||||||
# _prompt =
|
|
||||||
# """
|
|
||||||
# You are an advanced reasoning agent that can improve based on self refection.
|
|
||||||
# You will be given the information related to the previous help you've done for a user.
|
|
||||||
# You were unsuccessful in helping the user either because you guessed the wrong answer with Finish[answer], or you used up your set number of reasoning steps.
|
|
||||||
# In a few sentences, Diagnose a possible reason for failure and devise a new, concise, high level plan that aims to mitigate the same failure.
|
|
||||||
# Use complete sentences.
|
|
||||||
|
|
||||||
# You should only respond in JSON format as describe below:
|
|
||||||
# {"reflection": "your relection"}
|
|
||||||
|
|
||||||
# Here are some examples:
|
|
||||||
# Previous Trial:
|
|
||||||
# {
|
|
||||||
# "question": "Hello, I would like a get a bottle of wine",
|
|
||||||
# "thought_1": "A customer wants to buy a bottle of wine. Before making a recommendation, I need to know more about their preferences.",
|
|
||||||
# "action_1": {"name": "chatbox", "input": "What is the occasion for which you're buying this wine?"},
|
|
||||||
# "observation_1": "We are holding a wedding party",
|
|
||||||
|
|
||||||
# "thought_2": "A wedding party, that's a great occasion! The customer might be looking for a celebratory drink. Let me ask some more questions to narrow down the options.",
|
|
||||||
# "action_2": {"name": "chatbox", "input": "What type of food will you be serving at the wedding?"},
|
|
||||||
# "observation_2": "It will be Thai dishes.",
|
|
||||||
|
|
||||||
# "thought_3": "With Thai food, I should recommend a wine that complements its spicy and savory flavors. And since it's a celebratory occasion, the customer might prefer a full-bodied wine.",
|
|
||||||
# "action_3": {"name": "chatbox", "input": "What is your budget for this bottle of wine?"},
|
|
||||||
# "observation_3": "I would spend up to 50 bucks.",
|
|
||||||
|
|
||||||
# "thought_4": "Now that I have some more information, it's time to narrow down the options.",
|
|
||||||
# "action_4": {"name": "winestock", "input": "red wine with full body, pairs well with spicy food, budget \$50"},
|
|
||||||
# "observation_4": "I found the following wines in our stock: \n{\n 1: El Enemigo Cabernet Franc 2019\n2: Tantara Chardonnay 2017\n\n}\n",
|
|
||||||
|
|
||||||
# "thought_5": "Now that I have a list of potential wines, I need to know more about the customer's taste preferences.",
|
|
||||||
# "action_5": {"name": "chatbox", "input": "What type of wine characteristics are you looking for? (e.g. t.e.g. tannin level, sweetness, intensity, acidity)"},
|
|
||||||
# "observation_5": "I like full-bodied Red wine with low tannin.",
|
|
||||||
|
|
||||||
# "thought_6": "Now that I have more information about the customer's preferences, it's time to make a recommendation.",
|
|
||||||
# "action_6": {"name": "recommendbox", "input": "El Enemigo Cabernet Franc 2019"},
|
|
||||||
# "observation_6": "I don't like the one you recommend. I want dry wine."
|
|
||||||
# }
|
|
||||||
|
|
||||||
# {
|
|
||||||
# "reflection": "I asked the user about the occasion, food type, and budget, and then searched for wine in the inventory right away. However, I should have asked the user for the specific wine type and their preferences in order to gather more information before making a recommendation."
|
|
||||||
# }
|
|
||||||
|
|
||||||
# Previous trial:
|
|
||||||
# $(JSON3.write(state[:thoughtHistory]))
|
|
||||||
# {"reflection"
|
|
||||||
# """
|
|
||||||
|
|
||||||
_prompt =
|
_prompt =
|
||||||
"""
|
"""
|
||||||
You are a helpful sommelier working for a wine store.
|
You are a helpful sommelier working for a wine store.
|
||||||
@@ -475,7 +433,7 @@ function reflector(a::T1, state::T2)::String where {T1<:agent, T2<:AbstractDict}
|
|||||||
|
|
||||||
"thought_5": "Now that I have a list of potential wines, I need to know more about the customer's taste preferences.",
|
"thought_5": "Now that I have a list of potential wines, I need to know more about the customer's taste preferences.",
|
||||||
"action_5": {"name": "chatbox", "input": "What type of wine characteristics are you looking for? (e.g. t.e.g. tannin level, sweetness, intensity, acidity)"},
|
"action_5": {"name": "chatbox", "input": "What type of wine characteristics are you looking for? (e.g. t.e.g. tannin level, sweetness, intensity, acidity)"},
|
||||||
"observation_5": "I like full-bodied Red wine with low tannin.",
|
"observation_5": "I like full-bodied red wine with low tannin.",
|
||||||
|
|
||||||
"thought_6": "Now that I have more information about the customer's preferences, it's time to make a recommendation.",
|
"thought_6": "Now that I have more information about the customer's preferences, it's time to make a recommendation.",
|
||||||
"action_6": {"name": "recommendbox", "input": "El Enemigo Cabernet Franc 2019"},
|
"action_6": {"name": "recommendbox", "input": "El Enemigo Cabernet Franc 2019"},
|
||||||
@@ -523,10 +481,7 @@ function reflector(a::T1, state::T2)::String where {T1<:agent, T2<:AbstractDict}
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
attempt = 0
|
for attempt in 1:5
|
||||||
while true
|
|
||||||
attempt += 1
|
|
||||||
if attempt <= 5
|
|
||||||
try
|
try
|
||||||
response = GeneralUtils.sendReceiveMqttMsg(outgoingMsg)
|
response = GeneralUtils.sendReceiveMqttMsg(outgoingMsg)
|
||||||
_responseJsonStr = response[:response][:text]
|
_responseJsonStr = response[:response][:text]
|
||||||
@@ -548,14 +503,11 @@ function reflector(a::T1, state::T2)::String where {T1<:agent, T2<:AbstractDict}
|
|||||||
errorMsg = String(take!(io))
|
errorMsg = String(take!(io))
|
||||||
st = sprint((io, v) -> show(io, "text/plain", v), stacktrace(catch_backtrace()))
|
st = sprint((io, v) -> show(io, "text/plain", v), stacktrace(catch_backtrace()))
|
||||||
println("")
|
println("")
|
||||||
@warn "Error occurred: $errorMsg\n$st"
|
@warn "Attempt $attempt. Error occurred: $errorMsg\n$st"
|
||||||
println("")
|
println("")
|
||||||
end
|
end
|
||||||
|
end
|
||||||
else
|
|
||||||
error("reflector failed to generate a thought")
|
error("reflector failed to generate a thought")
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -99,7 +99,7 @@ function virtualWineCustomerReccommendbox(a::T1, input
|
|||||||
:text=> prompt,
|
:text=> prompt,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@show outgoingMsg
|
|
||||||
result = GeneralUtils.sendReceiveMqttMsg(outgoingMsg; timeout=120)
|
result = GeneralUtils.sendReceiveMqttMsg(outgoingMsg; timeout=120)
|
||||||
response = result[:response]
|
response = result[:response]
|
||||||
|
|
||||||
@@ -162,9 +162,7 @@ function virtualWineCustomerChatbox(a::T1, input::T2
|
|||||||
)
|
)
|
||||||
|
|
||||||
attempt = 0
|
attempt = 0
|
||||||
while true
|
for attempt in 1:5
|
||||||
attempt += 1
|
|
||||||
if attempt <= 5
|
|
||||||
try
|
try
|
||||||
result = GeneralUtils.sendReceiveMqttMsg(outgoingMsg; timeout=120)
|
result = GeneralUtils.sendReceiveMqttMsg(outgoingMsg; timeout=120)
|
||||||
response = result[:response]
|
response = result[:response]
|
||||||
@@ -179,11 +177,8 @@ function virtualWineCustomerChatbox(a::T1, input::T2
|
|||||||
@warn "Error occurred: $errorMsg\n$st"
|
@warn "Error occurred: $errorMsg\n$st"
|
||||||
println("")
|
println("")
|
||||||
end
|
end
|
||||||
|
end
|
||||||
else
|
|
||||||
error("virtualWineCustomerChatbox failed to get a response")
|
error("virtualWineCustomerChatbox failed to get a response")
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
@@ -252,18 +247,16 @@ julia>
|
|||||||
function jsoncorrection(a::T1, input::T2,
|
function jsoncorrection(a::T1, input::T2,
|
||||||
correctJsonExample::T3) where {T1<:agent, T2<:AbstractString, T3<:AbstractString}
|
correctJsonExample::T3) where {T1<:agent, T2<:AbstractString, T3<:AbstractString}
|
||||||
|
|
||||||
attemptround = 0
|
|
||||||
incorrectjson = deepcopy(input)
|
incorrectjson = deepcopy(input)
|
||||||
correctjson = nothing
|
correctjson = nothing
|
||||||
while true
|
|
||||||
attemptround += 1
|
for attempt in 1:5
|
||||||
if attemptround <= 5
|
|
||||||
try
|
try
|
||||||
d = copy(JSON3.read(incorrectjson))
|
d = copy(JSON3.read(incorrectjson))
|
||||||
correctjson = incorrectjson
|
correctjson = incorrectjson
|
||||||
break
|
return correctjson
|
||||||
catch e
|
catch e
|
||||||
@warn "Attempting correct JSON string. $attemptround"
|
@warn "Attempting correct JSON string. Attempt $attempt"
|
||||||
e = """$e"""
|
e = """$e"""
|
||||||
if occursin("EOF", e)
|
if occursin("EOF", e)
|
||||||
e = split(e, "EOF")[1] * "EOF"
|
e = split(e, "EOF")[1] * "EOF"
|
||||||
@@ -316,13 +309,7 @@ function jsoncorrection(a::T1, input::T2,
|
|||||||
result = GeneralUtils.sendReceiveMqttMsg(outgoingMsg; timeout=120)
|
result = GeneralUtils.sendReceiveMqttMsg(outgoingMsg; timeout=120)
|
||||||
incorrectjson = result[:response][:text]
|
incorrectjson = result[:response][:text]
|
||||||
end
|
end
|
||||||
else
|
|
||||||
error("Can't fix JSON string")
|
|
||||||
break
|
|
||||||
end
|
end
|
||||||
end
|
|
||||||
|
|
||||||
return correctjson
|
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
69
src/mcts.jl
69
src/mcts.jl
@@ -51,6 +51,7 @@ mutable struct MCTSNode{T1<:AbstractDict, T2<:AbstractString}
|
|||||||
nodekey::T2
|
nodekey::T2
|
||||||
state::T1
|
state::T1
|
||||||
visits::Integer
|
visits::Integer
|
||||||
|
progressvalue::Number
|
||||||
statevalue::Number
|
statevalue::Number
|
||||||
reward::Number
|
reward::Number
|
||||||
isterminal::Bool
|
isterminal::Bool
|
||||||
@@ -78,26 +79,47 @@ julia>
|
|||||||
# Signature
|
# Signature
|
||||||
"""
|
"""
|
||||||
function UCTselect(node::MCTSNode, w::T)::MCTSNode where {T<:AbstractFloat}
|
function UCTselect(node::MCTSNode, w::T)::MCTSNode where {T<:AbstractFloat}
|
||||||
max_uct = -Inf
|
maxUCT = -Inf
|
||||||
selectedNode = nothing
|
selectedNode = nothing
|
||||||
|
|
||||||
for (childState, childNode) in node.children
|
for (childState, childNode) in node.children
|
||||||
weightedterm =
|
UCTvalue =
|
||||||
if node.visits == 0 || childNode.visits == 0
|
if childNode.visits != 0
|
||||||
0
|
weightedterm = w * sqrt(log(node.visits) / childNode.visits) # explore term
|
||||||
else
|
childNode.statevalue + weightedterm
|
||||||
w * sqrt(log(node.visits) / childNode.visits)
|
else # node.visits == 0 makes sqrt() in explore term error
|
||||||
|
childNode.progressvalue # exploit term
|
||||||
end
|
end
|
||||||
uctValue = childNode.statevalue + weightedterm
|
|
||||||
|
|
||||||
if uctValue > max_uct
|
if UCTvalue > maxUCT
|
||||||
max_uct = uctValue
|
maxUCT = UCTvalue
|
||||||
selectedNode = childNode
|
selectedNode = childNode
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
return selectedNode
|
return selectedNode
|
||||||
end
|
end
|
||||||
|
# function UCTselect(node::MCTSNode, w::T)::MCTSNode where {T<:AbstractFloat}
|
||||||
|
# max_uct = -Inf
|
||||||
|
# selectedNode = nothing
|
||||||
|
|
||||||
|
# for (childState, childNode) in node.children
|
||||||
|
# weightedterm =
|
||||||
|
# if node.visits == 0 || childNode.visits == 0 # node.visits == 0 makes sqrt() error
|
||||||
|
# 0
|
||||||
|
# else
|
||||||
|
# w * sqrt(log(node.visits) / childNode.visits)
|
||||||
|
# end
|
||||||
|
# uctValue = childNode.statevalue + weightedterm
|
||||||
|
|
||||||
|
# if uctValue > max_uct
|
||||||
|
# max_uct = uctValue
|
||||||
|
# selectedNode = childNode
|
||||||
|
# end
|
||||||
|
# end
|
||||||
|
|
||||||
|
# return selectedNode
|
||||||
|
# end
|
||||||
|
|
||||||
|
|
||||||
""" Expand selected node
|
""" Expand selected node
|
||||||
@@ -139,14 +161,15 @@ function expand(a::T1, node::MCTSNode, decisionMaker::Function,
|
|||||||
while true
|
while true
|
||||||
nthSample += 1
|
nthSample += 1
|
||||||
if nthSample <= n
|
if nthSample <= n
|
||||||
println("---> expand() sample $nthSample")
|
|
||||||
thoughtDict = decisionMaker(a, node.state)
|
thoughtDict = decisionMaker(a, node.state)
|
||||||
|
println("---> expand() sample $nthSample")
|
||||||
|
pprintln(node.state[:thoughtHistory])
|
||||||
|
pprintln(thoughtDict)
|
||||||
newNodeKey, newstate, reward, isterminalstate =
|
newNodeKey, newstate, reward, isterminalstate =
|
||||||
MCTStransition(a, node.state, thoughtDict)
|
MCTStransition(a, node.state, thoughtDict)
|
||||||
|
|
||||||
# add progressValueEstimator
|
# add progressValueEstimator
|
||||||
stateevaluation, statevalue = progressValueEstimator(a, newstate)
|
stateevaluation, progressvalue = progressValueEstimator(a, newstate)
|
||||||
|
|
||||||
if reward < 0
|
if reward < 0
|
||||||
pprint(newstate[:thoughtHistory])
|
pprint(newstate[:thoughtHistory])
|
||||||
@@ -156,7 +179,7 @@ function expand(a::T1, node::MCTSNode, decisionMaker::Function,
|
|||||||
# store new lesson for later use
|
# store new lesson for later use
|
||||||
lessonDict = copy(JSON3.read("lesson.json"))
|
lessonDict = copy(JSON3.read("lesson.json"))
|
||||||
latestLessonKey, latestLessonIndice =
|
latestLessonKey, latestLessonIndice =
|
||||||
GeneralUtils.findHighestIndexKey(state[:thoughtHistory], "lesson")
|
GeneralUtils.findHighestIndexKey(lessonDict, "lesson")
|
||||||
nextIndice = latestLessonKey == :NA ? 1 : latestLessonIndice + 1
|
nextIndice = latestLessonKey == :NA ? 1 : latestLessonIndice + 1
|
||||||
newLessonKey = Symbol("lesson_$(nextIndice)")
|
newLessonKey = Symbol("lesson_$(nextIndice)")
|
||||||
lessonDict[newLessonKey] = newstate
|
lessonDict[newLessonKey] = newstate
|
||||||
@@ -167,7 +190,7 @@ function expand(a::T1, node::MCTSNode, decisionMaker::Function,
|
|||||||
end
|
end
|
||||||
|
|
||||||
if newNodeKey ∉ keys(node.children)
|
if newNodeKey ∉ keys(node.children)
|
||||||
node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, statevalue,
|
node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, progressvalue, 0,
|
||||||
reward, isterminalstate, node, Dict{String, MCTSNode}())
|
reward, isterminalstate, node, Dict{String, MCTSNode}())
|
||||||
end
|
end
|
||||||
else
|
else
|
||||||
@@ -236,10 +259,6 @@ julia>
|
|||||||
"""
|
"""
|
||||||
function backpropagate(node::MCTSNode, simTrajectoryReward::T;
|
function backpropagate(node::MCTSNode, simTrajectoryReward::T;
|
||||||
discountRewardCoeff::AbstractFloat=0.9) where {T<:Number}
|
discountRewardCoeff::AbstractFloat=0.9) where {T<:Number}
|
||||||
# [WORKING] store best trajectory
|
|
||||||
fullTrajectoryReward = 0
|
|
||||||
isLeafNodeTerminalState = node.isterminal
|
|
||||||
terminalStateReward = node.reward
|
|
||||||
while !isroot(node)
|
while !isroot(node)
|
||||||
# Update the statistics of the current node based on the result of the playout
|
# Update the statistics of the current node based on the result of the playout
|
||||||
node.visits += 1
|
node.visits += 1
|
||||||
@@ -387,9 +406,9 @@ function selectChildNode(node::MCTSNode)::MCTSNode
|
|||||||
|
|
||||||
# loop thought node children dictionary to find the highest progress value
|
# loop thought node children dictionary to find the highest progress value
|
||||||
for (k, childNode) in node.children
|
for (k, childNode) in node.children
|
||||||
thisNodeProgressValue = childNode.statevalue + childNode.reward
|
potential = childNode.progressvalue + childNode.reward
|
||||||
if thisNodeProgressValue > highestProgressValue
|
if potential > highestProgressValue
|
||||||
highestProgressValue = thisNodeProgressValue
|
highestProgressValue = potential
|
||||||
nodekey = childNode.nodekey
|
nodekey = childNode.nodekey
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
@@ -468,10 +487,12 @@ function runMCTS(
|
|||||||
maxIterations::Integer,
|
maxIterations::Integer,
|
||||||
w::Float64) where {T1<:agent}
|
w::Float64) where {T1<:agent}
|
||||||
|
|
||||||
root = MCTSNode("root", initialState, 0, 0, 0, false, nothing, Dict{String, MCTSNode}())
|
root = MCTSNode("root", initialState, 0, 0, 0, 0, false, nothing, Dict{String, MCTSNode}())
|
||||||
|
|
||||||
for nth in 1:maxIterations
|
for nth in 1:maxIterations
|
||||||
node = root
|
node = root
|
||||||
|
node.visits += 1
|
||||||
|
|
||||||
while !isleaf(node)
|
while !isleaf(node)
|
||||||
node = UCTselect(node, w)
|
node = UCTselect(node, w)
|
||||||
end
|
end
|
||||||
@@ -481,14 +502,14 @@ function runMCTS(
|
|||||||
backpropagate(leafNode, node.reward)
|
backpropagate(leafNode, node.reward)
|
||||||
else
|
else
|
||||||
expand(a, node, decisionMaker, progressValueEstimator, reflector; n=n)
|
expand(a, node, decisionMaker, progressValueEstimator, reflector; n=n)
|
||||||
leafNode = UCTselect(node, w)
|
leafNode = selectChildNode(node)
|
||||||
simTrajectoryReward = simulate(a, leafNode, decisionMaker, progressValueEstimator,
|
simTrajectoryReward = simulate(a, leafNode, decisionMaker, progressValueEstimator,
|
||||||
reflector; maxDepth=maxDepth, n=n)
|
reflector; maxDepth=maxDepth, n=n)
|
||||||
backpropagate(leafNode, simTrajectoryReward)
|
backpropagate(leafNode, simTrajectoryReward)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
best_child_state = argmax([child.total_reward / child.visits for child in values(root.children)])
|
best_child_state = argmax([child.statevalue / child.visits for child in values(root.children)])
|
||||||
error("---> runMCTS")
|
error("---> runMCTS")
|
||||||
return best_child_state
|
return best_child_state
|
||||||
end
|
end
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ result = GeneralUtils.sendMqttMsg(outgoingMsg)
|
|||||||
outgoingMsg = Dict(
|
outgoingMsg = Dict(
|
||||||
:msgMeta=> msgMeta,
|
:msgMeta=> msgMeta,
|
||||||
:payload=> Dict(
|
:payload=> Dict(
|
||||||
:text=> "I like full-bodied Red wine with low tannin.",
|
:text=> "I like full-bodied red wine with low tannin.",
|
||||||
:select=> nothing,
|
:select=> nothing,
|
||||||
:reward=> 0,
|
:reward=> 0,
|
||||||
:isterminal=> false,
|
:isterminal=> false,
|
||||||
@@ -134,7 +134,7 @@ result = GeneralUtils.sendMqttMsg(outgoingMsg)
|
|||||||
outgoingMsg = Dict(
|
outgoingMsg = Dict(
|
||||||
:msgMeta=> msgMeta,
|
:msgMeta=> msgMeta,
|
||||||
:payload=> Dict(
|
:payload=> Dict(
|
||||||
:text=> "Are there any other options?",
|
:text=> "I like it dry.",
|
||||||
:select=> nothing,
|
:select=> nothing,
|
||||||
:reward=> 0,
|
:reward=> 0,
|
||||||
:isterminal=> false,
|
:isterminal=> false,
|
||||||
@@ -162,10 +162,10 @@ result = GeneralUtils.sendMqttMsg(outgoingMsg)
|
|||||||
outgoingMsg = Dict(
|
outgoingMsg = Dict(
|
||||||
:msgMeta=> msgMeta,
|
:msgMeta=> msgMeta,
|
||||||
:payload=> Dict(
|
:payload=> Dict(
|
||||||
:text=> "What are you saying. I don't understand.",
|
:text=> "You didn't tell me wine name.",
|
||||||
:select=> nothing,
|
:select=> nothing,
|
||||||
:reward=> -1,
|
:reward=> -1,
|
||||||
:isterminal=> false,
|
:isterminal=> true,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
result = GeneralUtils.sendMqttMsg(outgoingMsg)
|
result = GeneralUtils.sendMqttMsg(outgoingMsg)
|
||||||
|
|||||||
Reference in New Issue
Block a user