update
This commit is contained in:
@@ -416,7 +416,7 @@ Action 3: Search[Mexican Grand Prix winners]
|
||||
Observation 3: Could not find Mexican Grand Prix winners. Similar: ['Mexican Grand Prix', 'List of Formula One Grand Prix winners', '1990 Mexican Grand Prix', '2018 Mexican Grand Prix', '2019 Mexican Grand Prix'].
|
||||
Thought 4: Given the difficulties in finding a direct list of Mexican Grand Prix winners, I should search for the Mexican Grand Prix to get a broader picture of the race's history. This might include winners.
|
||||
Action 4: Search[Mexican Grand Prix]
|
||||
This trajectory is incorrect as my search should be related to Mexican Formula One race car drivers, not winners of the Mexican Grand Prix, a seperate event. A better search would have been for the List of Formula One Grand Prix winners, as suggested.
|
||||
This trajectory is incorrect as my search should be related to Mexican Formula One race car drivers, not winners it is reasonable to checkof the Mexican Grand Prix, a seperate event. A better search would have been for the List of Formula One Grand Prix winners, as suggested.
|
||||
Thus the correctness score is 3
|
||||
|
||||
Question: Which magazine was started first Arthur's Magazine or First for Women?
|
||||
|
||||
252
src/interface.jl
252
src/interface.jl
@@ -126,7 +126,7 @@ function decisionMaker(a::T1, state::T2)::Dict{Symbol, Any} where {T1<:agent, T2
|
||||
1) Get to know what occasion the user is buying wine for
|
||||
2) Get to know what food the user will have with wine
|
||||
3) Get to know how much the user willing to spend
|
||||
4) Get to know type of wine the user is looking for e.g. Red, White, Sparkling, Rose, Dessert, Fortified
|
||||
4) Get to know type of wine the user is looking for e.g. red, white, sparkling, rose, dessert, fortified
|
||||
5) Get to know what characteristics of wine the user is looking for
|
||||
e.g. tannin, sweetness, intensity, acidity
|
||||
6) Check your inventory for the best wine that match the user preference
|
||||
@@ -197,51 +197,49 @@ function decisionMaker(a::T1, state::T2)::Dict{Symbol, Any} where {T1<:agent, T2
|
||||
)
|
||||
)
|
||||
@show outgoingMsg
|
||||
attempt = 0
|
||||
while true
|
||||
attempt += 1
|
||||
if attempt <= 5
|
||||
try
|
||||
response = GeneralUtils.sendReceiveMqttMsg(outgoingMsg)
|
||||
_responseJsonStr = response[:response][:text]
|
||||
expectedJsonExample =
|
||||
"""
|
||||
Here is an expected JSON format:
|
||||
{
|
||||
"thought": "...",
|
||||
"action": {"name": "...", "input": "..."},
|
||||
"observation": "..."
|
||||
}
|
||||
"""
|
||||
responseJsonStr = jsoncorrection(a, _responseJsonStr, expectedJsonExample)
|
||||
thoughtDict = copy(JSON3.read(responseJsonStr))
|
||||
|
||||
# check if dict has all required value
|
||||
dummya::AbstractString = thoughtDict[:thought]
|
||||
actionname::AbstractString = thoughtDict[:action][:name]
|
||||
actioninput::AbstractString = thoughtDict[:action][:input]
|
||||
for attempt in 1:5
|
||||
try
|
||||
response = GeneralUtils.sendReceiveMqttMsg(outgoingMsg)
|
||||
_responseJsonStr = response[:response][:text]
|
||||
expectedJsonExample =
|
||||
"""
|
||||
Here is an expected JSON format:
|
||||
{
|
||||
"thought": "...",
|
||||
"action": {"name": "...", "input": "..."},
|
||||
"observation": "..."
|
||||
}
|
||||
"""
|
||||
responseJsonStr = jsoncorrection(a, _responseJsonStr, expectedJsonExample)
|
||||
thoughtDict = copy(JSON3.read(responseJsonStr))
|
||||
|
||||
if actionname ∈ ["winestock", "chatbox", "recommendbox"]
|
||||
# LLM use available function
|
||||
else
|
||||
error("DecisionMaker use wrong function")
|
||||
end
|
||||
|
||||
return thoughtDict
|
||||
catch e
|
||||
io = IOBuffer()
|
||||
showerror(io, e)
|
||||
errorMsg = String(take!(io))
|
||||
st = sprint((io, v) -> show(io, "text/plain", v), stacktrace(catch_backtrace()))
|
||||
println("")
|
||||
@warn "Error occurred: $errorMsg\n$st"
|
||||
println("")
|
||||
# check if dict has all required value
|
||||
thought::AbstractString = thoughtDict[:thought]
|
||||
actionname::AbstractString = thoughtDict[:action][:name]
|
||||
actioninput::AbstractString = thoughtDict[:action][:input]
|
||||
if actionname ∈ ["winestock", "chatbox", "recommendbox"]
|
||||
# LLM use available function
|
||||
elseif thought == ""
|
||||
error("DecisionMaker has no thought")
|
||||
elseif length(actioninput) == 0
|
||||
error("DecisionMaker has no actioninput")
|
||||
else
|
||||
error("DecisionMaker use wrong function")
|
||||
end
|
||||
|
||||
else
|
||||
error("DecisionMaker failed to generate a thought")
|
||||
return thoughtDict
|
||||
catch e
|
||||
io = IOBuffer()
|
||||
showerror(io, e)
|
||||
errorMsg = String(take!(io))
|
||||
st = sprint((io, v) -> show(io, "text/plain", v), stacktrace(catch_backtrace()))
|
||||
println("")
|
||||
@warn "Attempt $attempt. Error occurred: $errorMsg\n$st"
|
||||
println("")
|
||||
end
|
||||
end
|
||||
error("DecisionMaker failed to generate a thought")
|
||||
end
|
||||
|
||||
|
||||
@@ -297,7 +295,22 @@ function progressValueEstimator(a::T1, state::T2)::Tuple{String, Integer} where
|
||||
}
|
||||
{"evaluation": "This trajectory is correct as it is reasonable to check an inventory for info provided in the question.
|
||||
It is also better to have simple searches corresponding to a single entity, making this the best action.",
|
||||
"score": 10}
|
||||
"score": 10
|
||||
}
|
||||
|
||||
{
|
||||
"question": "Do you have an all-in-one pen with 4 colors and a pencil for sale?",
|
||||
"thought_1": "Let me check our inventory first to see if I have it.",
|
||||
"action_1": {"name": "inventory", "input": "pen with 4 color and a pencil."},
|
||||
"observation_1": "I found {1: "Pilot Dr. grip 4-in-1 pen", 2: "Rotting pencil"}",
|
||||
"thought_2": "Ok, I have what the user is asking. Let's tell the user.",
|
||||
"action_2": {"name": "chatbox", "input": "Yes, we do have a Pilot Dr. grip 4-in-1 pen and a Rotting pencil"},
|
||||
"observation_1": "This is not what I wanted."
|
||||
}
|
||||
{"evaluation": "This trajectory is incorrect as my search term should be related to a 4-colors pen with a pencil in it,
|
||||
not a pen and a pencil seperately. A better search term should have been a 4-colors pen with a pencil, all-in-one.",
|
||||
"score": 2
|
||||
}
|
||||
|
||||
Let's begin!:
|
||||
$(JSON3.write(state[:thoughtHistory]))
|
||||
@@ -334,40 +347,34 @@ function progressValueEstimator(a::T1, state::T2)::Tuple{String, Integer} where
|
||||
)
|
||||
)
|
||||
|
||||
attempt = 0
|
||||
while true
|
||||
attempt += 1
|
||||
if attempt <= 5
|
||||
try
|
||||
response = GeneralUtils.sendReceiveMqttMsg(outgoingMsg)
|
||||
_responseJsonStr = response[:response][:text]
|
||||
expectedJsonExample =
|
||||
"""
|
||||
Here is an expected JSON format:
|
||||
{"evaluation": "...", "score": "..."}
|
||||
"""
|
||||
responseJsonStr = jsoncorrection(a, _responseJsonStr, expectedJsonExample)
|
||||
evaluationDict = copy(JSON3.read(responseJsonStr))
|
||||
for attempt in 1:5
|
||||
try
|
||||
response = GeneralUtils.sendReceiveMqttMsg(outgoingMsg)
|
||||
_responseJsonStr = response[:response][:text]
|
||||
expectedJsonExample =
|
||||
"""
|
||||
Here is an expected JSON format:
|
||||
{"evaluation": "...", "score": "..."}
|
||||
"""
|
||||
responseJsonStr = jsoncorrection(a, _responseJsonStr, expectedJsonExample)
|
||||
evaluationDict = copy(JSON3.read(responseJsonStr))
|
||||
|
||||
# check if dict has all required value
|
||||
dummya::AbstractString = evaluationDict[:evaluation]
|
||||
dummyb::Integer = evaluationDict[:score]
|
||||
# check if dict has all required value
|
||||
dummya::AbstractString = evaluationDict[:evaluation]
|
||||
dummyb::Integer = evaluationDict[:score]
|
||||
|
||||
return (evaluationDict[:evaluation], evaluationDict[:score])
|
||||
catch e
|
||||
io = IOBuffer()
|
||||
showerror(io, e)
|
||||
errorMsg = String(take!(io))
|
||||
st = sprint((io, v) -> show(io, "text/plain", v), stacktrace(catch_backtrace()))
|
||||
println("")
|
||||
@warn "Error occurred: $errorMsg\n$st"
|
||||
println("")
|
||||
end
|
||||
|
||||
else
|
||||
error("progressValueEstimator failed to generate an evaluation")
|
||||
return (evaluationDict[:evaluation], evaluationDict[:score])
|
||||
catch e
|
||||
io = IOBuffer()
|
||||
showerror(io, e)
|
||||
errorMsg = String(take!(io))
|
||||
st = sprint((io, v) -> show(io, "text/plain", v), stacktrace(catch_backtrace()))
|
||||
println("")
|
||||
@warn "Attempt $attempt. Error occurred: $errorMsg\n$st"
|
||||
println("")
|
||||
end
|
||||
end
|
||||
error("progressValueEstimator failed to generate an evaluation")
|
||||
end
|
||||
|
||||
|
||||
@@ -392,55 +399,6 @@ julia>
|
||||
function reflector(a::T1, state::T2)::String where {T1<:agent, T2<:AbstractDict}
|
||||
# https://github.com/andyz245/LanguageAgentTreeSearch/blob/main/hotpot/hotpot.py
|
||||
|
||||
# _prompt =
|
||||
# """
|
||||
# You are an advanced reasoning agent that can improve based on self refection.
|
||||
# You will be given the information related to the previous help you've done for a user.
|
||||
# You were unsuccessful in helping the user either because you guessed the wrong answer with Finish[answer], or you used up your set number of reasoning steps.
|
||||
# In a few sentences, Diagnose a possible reason for failure and devise a new, concise, high level plan that aims to mitigate the same failure.
|
||||
# Use complete sentences.
|
||||
|
||||
# You should only respond in JSON format as describe below:
|
||||
# {"reflection": "your relection"}
|
||||
|
||||
# Here are some examples:
|
||||
# Previous Trial:
|
||||
# {
|
||||
# "question": "Hello, I would like a get a bottle of wine",
|
||||
# "thought_1": "A customer wants to buy a bottle of wine. Before making a recommendation, I need to know more about their preferences.",
|
||||
# "action_1": {"name": "chatbox", "input": "What is the occasion for which you're buying this wine?"},
|
||||
# "observation_1": "We are holding a wedding party",
|
||||
|
||||
# "thought_2": "A wedding party, that's a great occasion! The customer might be looking for a celebratory drink. Let me ask some more questions to narrow down the options.",
|
||||
# "action_2": {"name": "chatbox", "input": "What type of food will you be serving at the wedding?"},
|
||||
# "observation_2": "It will be Thai dishes.",
|
||||
|
||||
# "thought_3": "With Thai food, I should recommend a wine that complements its spicy and savory flavors. And since it's a celebratory occasion, the customer might prefer a full-bodied wine.",
|
||||
# "action_3": {"name": "chatbox", "input": "What is your budget for this bottle of wine?"},
|
||||
# "observation_3": "I would spend up to 50 bucks.",
|
||||
|
||||
# "thought_4": "Now that I have some more information, it's time to narrow down the options.",
|
||||
# "action_4": {"name": "winestock", "input": "red wine with full body, pairs well with spicy food, budget \$50"},
|
||||
# "observation_4": "I found the following wines in our stock: \n{\n 1: El Enemigo Cabernet Franc 2019\n2: Tantara Chardonnay 2017\n\n}\n",
|
||||
|
||||
# "thought_5": "Now that I have a list of potential wines, I need to know more about the customer's taste preferences.",
|
||||
# "action_5": {"name": "chatbox", "input": "What type of wine characteristics are you looking for? (e.g. t.e.g. tannin level, sweetness, intensity, acidity)"},
|
||||
# "observation_5": "I like full-bodied Red wine with low tannin.",
|
||||
|
||||
# "thought_6": "Now that I have more information about the customer's preferences, it's time to make a recommendation.",
|
||||
# "action_6": {"name": "recommendbox", "input": "El Enemigo Cabernet Franc 2019"},
|
||||
# "observation_6": "I don't like the one you recommend. I want dry wine."
|
||||
# }
|
||||
|
||||
# {
|
||||
# "reflection": "I asked the user about the occasion, food type, and budget, and then searched for wine in the inventory right away. However, I should have asked the user for the specific wine type and their preferences in order to gather more information before making a recommendation."
|
||||
# }
|
||||
|
||||
# Previous trial:
|
||||
# $(JSON3.write(state[:thoughtHistory]))
|
||||
# {"reflection"
|
||||
# """
|
||||
|
||||
_prompt =
|
||||
"""
|
||||
You are a helpful sommelier working for a wine store.
|
||||
@@ -475,7 +433,7 @@ function reflector(a::T1, state::T2)::String where {T1<:agent, T2<:AbstractDict}
|
||||
|
||||
"thought_5": "Now that I have a list of potential wines, I need to know more about the customer's taste preferences.",
|
||||
"action_5": {"name": "chatbox", "input": "What type of wine characteristics are you looking for? (e.g. t.e.g. tannin level, sweetness, intensity, acidity)"},
|
||||
"observation_5": "I like full-bodied Red wine with low tannin.",
|
||||
"observation_5": "I like full-bodied red wine with low tannin.",
|
||||
|
||||
"thought_6": "Now that I have more information about the customer's preferences, it's time to make a recommendation.",
|
||||
"action_6": {"name": "recommendbox", "input": "El Enemigo Cabernet Franc 2019"},
|
||||
@@ -523,39 +481,33 @@ function reflector(a::T1, state::T2)::String where {T1<:agent, T2<:AbstractDict}
|
||||
)
|
||||
)
|
||||
|
||||
attempt = 0
|
||||
while true
|
||||
attempt += 1
|
||||
if attempt <= 5
|
||||
try
|
||||
response = GeneralUtils.sendReceiveMqttMsg(outgoingMsg)
|
||||
_responseJsonStr = response[:response][:text]
|
||||
expectedJsonExample =
|
||||
"""
|
||||
Here is an expected JSON format:
|
||||
{"reflection": "..."}
|
||||
"""
|
||||
responseJsonStr = jsoncorrection(a, _responseJsonStr, expectedJsonExample)
|
||||
reflectionDict = copy(JSON3.read(responseJsonStr))
|
||||
for attempt in 1:5
|
||||
try
|
||||
response = GeneralUtils.sendReceiveMqttMsg(outgoingMsg)
|
||||
_responseJsonStr = response[:response][:text]
|
||||
expectedJsonExample =
|
||||
"""
|
||||
Here is an expected JSON format:
|
||||
{"reflection": "..."}
|
||||
"""
|
||||
responseJsonStr = jsoncorrection(a, _responseJsonStr, expectedJsonExample)
|
||||
reflectionDict = copy(JSON3.read(responseJsonStr))
|
||||
|
||||
# check if dict has all required value
|
||||
dummya::AbstractString = reflectionDict[:reflection]
|
||||
# check if dict has all required value
|
||||
dummya::AbstractString = reflectionDict[:reflection]
|
||||
|
||||
return reflectionDict[:reflection]
|
||||
catch e
|
||||
io = IOBuffer()
|
||||
showerror(io, e)
|
||||
errorMsg = String(take!(io))
|
||||
st = sprint((io, v) -> show(io, "text/plain", v), stacktrace(catch_backtrace()))
|
||||
println("")
|
||||
@warn "Error occurred: $errorMsg\n$st"
|
||||
println("")
|
||||
end
|
||||
|
||||
else
|
||||
error("reflector failed to generate a thought")
|
||||
return reflectionDict[:reflection]
|
||||
catch e
|
||||
io = IOBuffer()
|
||||
showerror(io, e)
|
||||
errorMsg = String(take!(io))
|
||||
st = sprint((io, v) -> show(io, "text/plain", v), stacktrace(catch_backtrace()))
|
||||
println("")
|
||||
@warn "Attempt $attempt. Error occurred: $errorMsg\n$st"
|
||||
println("")
|
||||
end
|
||||
end
|
||||
error("reflector failed to generate a thought")
|
||||
end
|
||||
|
||||
|
||||
|
||||
@@ -99,7 +99,7 @@ function virtualWineCustomerReccommendbox(a::T1, input
|
||||
:text=> prompt,
|
||||
)
|
||||
)
|
||||
@show outgoingMsg
|
||||
|
||||
result = GeneralUtils.sendReceiveMqttMsg(outgoingMsg; timeout=120)
|
||||
response = result[:response]
|
||||
|
||||
@@ -162,28 +162,23 @@ function virtualWineCustomerChatbox(a::T1, input::T2
|
||||
)
|
||||
|
||||
attempt = 0
|
||||
while true
|
||||
attempt += 1
|
||||
if attempt <= 5
|
||||
try
|
||||
result = GeneralUtils.sendReceiveMqttMsg(outgoingMsg; timeout=120)
|
||||
response = result[:response]
|
||||
for attempt in 1:5
|
||||
try
|
||||
result = GeneralUtils.sendReceiveMqttMsg(outgoingMsg; timeout=120)
|
||||
response = result[:response]
|
||||
|
||||
return (response[:text], response[:select], response[:reward], response[:isterminal])
|
||||
catch e
|
||||
io = IOBuffer()
|
||||
showerror(io, e)
|
||||
errorMsg = String(take!(io))
|
||||
st = sprint((io, v) -> show(io, "text/plain", v), stacktrace(catch_backtrace()))
|
||||
println("")
|
||||
@warn "Error occurred: $errorMsg\n$st"
|
||||
println("")
|
||||
end
|
||||
|
||||
else
|
||||
error("virtualWineCustomerChatbox failed to get a response")
|
||||
return (response[:text], response[:select], response[:reward], response[:isterminal])
|
||||
catch e
|
||||
io = IOBuffer()
|
||||
showerror(io, e)
|
||||
errorMsg = String(take!(io))
|
||||
st = sprint((io, v) -> show(io, "text/plain", v), stacktrace(catch_backtrace()))
|
||||
println("")
|
||||
@warn "Error occurred: $errorMsg\n$st"
|
||||
println("")
|
||||
end
|
||||
end
|
||||
error("virtualWineCustomerChatbox failed to get a response")
|
||||
end
|
||||
|
||||
|
||||
@@ -252,77 +247,69 @@ julia>
|
||||
function jsoncorrection(a::T1, input::T2,
|
||||
correctJsonExample::T3) where {T1<:agent, T2<:AbstractString, T3<:AbstractString}
|
||||
|
||||
attemptround = 0
|
||||
incorrectjson = deepcopy(input)
|
||||
correctjson = nothing
|
||||
while true
|
||||
attemptround += 1
|
||||
if attemptround <= 5
|
||||
try
|
||||
d = copy(JSON3.read(incorrectjson))
|
||||
correctjson = incorrectjson
|
||||
break
|
||||
catch e
|
||||
@warn "Attempting correct JSON string. $attemptround"
|
||||
e = """$e"""
|
||||
if occursin("EOF", e)
|
||||
e = split(e, "EOF")[1] * "EOF"
|
||||
end
|
||||
incorrectjson = deepcopy(input)
|
||||
_prompt =
|
||||
"""
|
||||
Your goal are:
|
||||
1) Use the info why the given JSON string failed to load and provide a corrected version that can be loaded by Python's json.load function.
|
||||
2) The user need Corrected JSON string only. Do not provide any other info.
|
||||
|
||||
$correctJsonExample
|
||||
for attempt in 1:5
|
||||
try
|
||||
d = copy(JSON3.read(incorrectjson))
|
||||
correctjson = incorrectjson
|
||||
return correctjson
|
||||
catch e
|
||||
@warn "Attempting correct JSON string. Attempt $attempt"
|
||||
e = """$e"""
|
||||
if occursin("EOF", e)
|
||||
e = split(e, "EOF")[1] * "EOF"
|
||||
end
|
||||
incorrectjson = deepcopy(input)
|
||||
_prompt =
|
||||
"""
|
||||
Your goal are:
|
||||
1) Use the info why the given JSON string failed to load and provide a corrected version that can be loaded by Python's json.load function.
|
||||
2) The user need Corrected JSON string only. Do not provide any other info.
|
||||
|
||||
Let's begin!
|
||||
Given JSON string: $incorrectjson
|
||||
The given JSON string failed to load previously because: $e
|
||||
Corrected JSON string:
|
||||
"""
|
||||
$correctJsonExample
|
||||
|
||||
# apply LLM specific instruct format
|
||||
externalService = a.config[:externalservice][:text2textinstruct]
|
||||
llminfo = externalService[:llminfo]
|
||||
prompt =
|
||||
if llminfo[:name] == "llama3instruct"
|
||||
formatLLMtext_llama3instruct("system", _prompt)
|
||||
else
|
||||
error("llm model name is not defied yet $(@__LINE__)")
|
||||
end
|
||||
Let's begin!
|
||||
Given JSON string: $incorrectjson
|
||||
The given JSON string failed to load previously because: $e
|
||||
Corrected JSON string:
|
||||
"""
|
||||
|
||||
# send formatted input to user using GeneralUtils.sendReceiveMqttMsg
|
||||
msgMeta = GeneralUtils.generate_msgMeta(
|
||||
externalService[:mqtttopic],
|
||||
senderName= "jsoncorrection",
|
||||
senderId= a.id,
|
||||
receiverName= "text2textinstruct",
|
||||
mqttBroker= a.config[:mqttServerInfo][:broker],
|
||||
mqttBrokerPort= a.config[:mqttServerInfo][:port],
|
||||
)
|
||||
# apply LLM specific instruct format
|
||||
externalService = a.config[:externalservice][:text2textinstruct]
|
||||
llminfo = externalService[:llminfo]
|
||||
prompt =
|
||||
if llminfo[:name] == "llama3instruct"
|
||||
formatLLMtext_llama3instruct("system", _prompt)
|
||||
else
|
||||
error("llm model name is not defied yet $(@__LINE__)")
|
||||
end
|
||||
|
||||
outgoingMsg = Dict(
|
||||
:msgMeta=> msgMeta,
|
||||
:payload=> Dict(
|
||||
:text=> prompt,
|
||||
:kwargs=> Dict(
|
||||
:max_tokens=> 512,
|
||||
:stop=> ["<|eot_id|>"],
|
||||
)
|
||||
# send formatted input to user using GeneralUtils.sendReceiveMqttMsg
|
||||
msgMeta = GeneralUtils.generate_msgMeta(
|
||||
externalService[:mqtttopic],
|
||||
senderName= "jsoncorrection",
|
||||
senderId= a.id,
|
||||
receiverName= "text2textinstruct",
|
||||
mqttBroker= a.config[:mqttServerInfo][:broker],
|
||||
mqttBrokerPort= a.config[:mqttServerInfo][:port],
|
||||
)
|
||||
|
||||
outgoingMsg = Dict(
|
||||
:msgMeta=> msgMeta,
|
||||
:payload=> Dict(
|
||||
:text=> prompt,
|
||||
:kwargs=> Dict(
|
||||
:max_tokens=> 512,
|
||||
:stop=> ["<|eot_id|>"],
|
||||
)
|
||||
)
|
||||
result = GeneralUtils.sendReceiveMqttMsg(outgoingMsg; timeout=120)
|
||||
incorrectjson = result[:response][:text]
|
||||
end
|
||||
else
|
||||
error("Can't fix JSON string")
|
||||
break
|
||||
)
|
||||
result = GeneralUtils.sendReceiveMqttMsg(outgoingMsg; timeout=120)
|
||||
incorrectjson = result[:response][:text]
|
||||
end
|
||||
end
|
||||
|
||||
return correctjson
|
||||
end
|
||||
|
||||
|
||||
|
||||
83
src/mcts.jl
83
src/mcts.jl
@@ -51,6 +51,7 @@ mutable struct MCTSNode{T1<:AbstractDict, T2<:AbstractString}
|
||||
nodekey::T2
|
||||
state::T1
|
||||
visits::Integer
|
||||
progressvalue::Number
|
||||
statevalue::Number
|
||||
reward::Number
|
||||
isterminal::Bool
|
||||
@@ -78,26 +79,47 @@ julia>
|
||||
# Signature
|
||||
"""
|
||||
function UCTselect(node::MCTSNode, w::T)::MCTSNode where {T<:AbstractFloat}
|
||||
max_uct = -Inf
|
||||
selectedNode = nothing
|
||||
maxUCT = -Inf
|
||||
selectedNode = nothing
|
||||
|
||||
for (childState, childNode) in node.children
|
||||
weightedterm =
|
||||
if node.visits == 0 || childNode.visits == 0
|
||||
0
|
||||
else
|
||||
w * sqrt(log(node.visits) / childNode.visits)
|
||||
end
|
||||
uctValue = childNode.statevalue + weightedterm
|
||||
|
||||
if uctValue > max_uct
|
||||
max_uct = uctValue
|
||||
selectedNode = childNode
|
||||
end
|
||||
for (childState, childNode) in node.children
|
||||
UCTvalue =
|
||||
if childNode.visits != 0
|
||||
weightedterm = w * sqrt(log(node.visits) / childNode.visits) # explore term
|
||||
childNode.statevalue + weightedterm
|
||||
else # node.visits == 0 makes sqrt() in explore term error
|
||||
childNode.progressvalue # exploit term
|
||||
end
|
||||
|
||||
return selectedNode
|
||||
if UCTvalue > maxUCT
|
||||
maxUCT = UCTvalue
|
||||
selectedNode = childNode
|
||||
end
|
||||
end
|
||||
|
||||
return selectedNode
|
||||
end
|
||||
# function UCTselect(node::MCTSNode, w::T)::MCTSNode where {T<:AbstractFloat}
|
||||
# max_uct = -Inf
|
||||
# selectedNode = nothing
|
||||
|
||||
# for (childState, childNode) in node.children
|
||||
# weightedterm =
|
||||
# if node.visits == 0 || childNode.visits == 0 # node.visits == 0 makes sqrt() error
|
||||
# 0
|
||||
# else
|
||||
# w * sqrt(log(node.visits) / childNode.visits)
|
||||
# end
|
||||
# uctValue = childNode.statevalue + weightedterm
|
||||
|
||||
# if uctValue > max_uct
|
||||
# max_uct = uctValue
|
||||
# selectedNode = childNode
|
||||
# end
|
||||
# end
|
||||
|
||||
# return selectedNode
|
||||
# end
|
||||
|
||||
|
||||
""" Expand selected node
|
||||
@@ -139,14 +161,15 @@ function expand(a::T1, node::MCTSNode, decisionMaker::Function,
|
||||
while true
|
||||
nthSample += 1
|
||||
if nthSample <= n
|
||||
println("---> expand() sample $nthSample")
|
||||
thoughtDict = decisionMaker(a, node.state)
|
||||
|
||||
println("---> expand() sample $nthSample")
|
||||
pprintln(node.state[:thoughtHistory])
|
||||
pprintln(thoughtDict)
|
||||
newNodeKey, newstate, reward, isterminalstate =
|
||||
MCTStransition(a, node.state, thoughtDict)
|
||||
|
||||
# add progressValueEstimator
|
||||
stateevaluation, statevalue = progressValueEstimator(a, newstate)
|
||||
stateevaluation, progressvalue = progressValueEstimator(a, newstate)
|
||||
|
||||
if reward < 0
|
||||
pprint(newstate[:thoughtHistory])
|
||||
@@ -156,7 +179,7 @@ function expand(a::T1, node::MCTSNode, decisionMaker::Function,
|
||||
# store new lesson for later use
|
||||
lessonDict = copy(JSON3.read("lesson.json"))
|
||||
latestLessonKey, latestLessonIndice =
|
||||
GeneralUtils.findHighestIndexKey(state[:thoughtHistory], "lesson")
|
||||
GeneralUtils.findHighestIndexKey(lessonDict, "lesson")
|
||||
nextIndice = latestLessonKey == :NA ? 1 : latestLessonIndice + 1
|
||||
newLessonKey = Symbol("lesson_$(nextIndice)")
|
||||
lessonDict[newLessonKey] = newstate
|
||||
@@ -167,7 +190,7 @@ function expand(a::T1, node::MCTSNode, decisionMaker::Function,
|
||||
end
|
||||
|
||||
if newNodeKey ∉ keys(node.children)
|
||||
node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, statevalue,
|
||||
node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, progressvalue, 0,
|
||||
reward, isterminalstate, node, Dict{String, MCTSNode}())
|
||||
end
|
||||
else
|
||||
@@ -236,10 +259,6 @@ julia>
|
||||
"""
|
||||
function backpropagate(node::MCTSNode, simTrajectoryReward::T;
|
||||
discountRewardCoeff::AbstractFloat=0.9) where {T<:Number}
|
||||
# [WORKING] store best trajectory
|
||||
fullTrajectoryReward = 0
|
||||
isLeafNodeTerminalState = node.isterminal
|
||||
terminalStateReward = node.reward
|
||||
while !isroot(node)
|
||||
# Update the statistics of the current node based on the result of the playout
|
||||
node.visits += 1
|
||||
@@ -387,9 +406,9 @@ function selectChildNode(node::MCTSNode)::MCTSNode
|
||||
|
||||
# loop thought node children dictionary to find the highest progress value
|
||||
for (k, childNode) in node.children
|
||||
thisNodeProgressValue = childNode.statevalue + childNode.reward
|
||||
if thisNodeProgressValue > highestProgressValue
|
||||
highestProgressValue = thisNodeProgressValue
|
||||
potential = childNode.progressvalue + childNode.reward
|
||||
if potential > highestProgressValue
|
||||
highestProgressValue = potential
|
||||
nodekey = childNode.nodekey
|
||||
end
|
||||
end
|
||||
@@ -468,10 +487,12 @@ function runMCTS(
|
||||
maxIterations::Integer,
|
||||
w::Float64) where {T1<:agent}
|
||||
|
||||
root = MCTSNode("root", initialState, 0, 0, 0, false, nothing, Dict{String, MCTSNode}())
|
||||
root = MCTSNode("root", initialState, 0, 0, 0, 0, false, nothing, Dict{String, MCTSNode}())
|
||||
|
||||
for nth in 1:maxIterations
|
||||
node = root
|
||||
node.visits += 1
|
||||
|
||||
while !isleaf(node)
|
||||
node = UCTselect(node, w)
|
||||
end
|
||||
@@ -481,14 +502,14 @@ function runMCTS(
|
||||
backpropagate(leafNode, node.reward)
|
||||
else
|
||||
expand(a, node, decisionMaker, progressValueEstimator, reflector; n=n)
|
||||
leafNode = UCTselect(node, w)
|
||||
leafNode = selectChildNode(node)
|
||||
simTrajectoryReward = simulate(a, leafNode, decisionMaker, progressValueEstimator,
|
||||
reflector; maxDepth=maxDepth, n=n)
|
||||
backpropagate(leafNode, simTrajectoryReward)
|
||||
end
|
||||
end
|
||||
|
||||
best_child_state = argmax([child.total_reward / child.visits for child in values(root.children)])
|
||||
best_child_state = argmax([child.statevalue / child.visits for child in values(root.children)])
|
||||
error("---> runMCTS")
|
||||
return best_child_state
|
||||
end
|
||||
|
||||
@@ -66,7 +66,7 @@ result = GeneralUtils.sendMqttMsg(outgoingMsg)
|
||||
outgoingMsg = Dict(
|
||||
:msgMeta=> msgMeta,
|
||||
:payload=> Dict(
|
||||
:text=> "I like full-bodied Red wine with low tannin.",
|
||||
:text=> "I like full-bodied red wine with low tannin.",
|
||||
:select=> nothing,
|
||||
:reward=> 0,
|
||||
:isterminal=> false,
|
||||
@@ -134,7 +134,7 @@ result = GeneralUtils.sendMqttMsg(outgoingMsg)
|
||||
outgoingMsg = Dict(
|
||||
:msgMeta=> msgMeta,
|
||||
:payload=> Dict(
|
||||
:text=> "Are there any other options?",
|
||||
:text=> "I like it dry.",
|
||||
:select=> nothing,
|
||||
:reward=> 0,
|
||||
:isterminal=> false,
|
||||
@@ -162,10 +162,10 @@ result = GeneralUtils.sendMqttMsg(outgoingMsg)
|
||||
outgoingMsg = Dict(
|
||||
:msgMeta=> msgMeta,
|
||||
:payload=> Dict(
|
||||
:text=> "What are you saying. I don't understand.",
|
||||
:text=> "You didn't tell me wine name.",
|
||||
:select=> nothing,
|
||||
:reward=> -1,
|
||||
:isterminal=> false,
|
||||
:isterminal=> true,
|
||||
)
|
||||
)
|
||||
result = GeneralUtils.sendMqttMsg(outgoingMsg)
|
||||
|
||||
Reference in New Issue
Block a user