This commit is contained in:
narawat lamaiin
2024-05-10 11:59:38 +07:00
parent c7fd7bc40d
commit a4ba292fad
4 changed files with 107 additions and 44 deletions

View File

@@ -91,10 +91,17 @@ function decisionMaker(a::T1, state::T2)::Dict{Symbol, Any} where {T1<:agent, T2
$(JSON3.write(state[:storeinfo]))
"""
lessonDict = copy(JSON3.read("lesson.json"))
lesson =
if isempty(a.lesson)
if isempty(lessonDict)
""
else
lessons = Dict{Symbol, Any}()
for (k, v) in lessonDict
lessons[k] = lessonDict[k][:lesson]
end
"""
You have attempted to help the user before and failed, either because your reasoning for the
recommendation was incorrect or your response did not exactly match the user expectation.
@@ -102,7 +109,7 @@ function decisionMaker(a::T1, state::T2)::Dict{Symbol, Any} where {T1<:agent, T2
did previously. Use them to improve your strategy to help the user.
Here are some lessons:
$(JSON3.write(a.lesson[:lesson_1][:lesson]))
$(JSON3.write(lessons))
When providing the thought and action for the current trial, that into account these failed
trajectories and make sure not to repeat the same mistakes and incorrect answers.
@@ -211,8 +218,14 @@ function decisionMaker(a::T1, state::T2)::Dict{Symbol, Any} where {T1<:agent, T2
# check if dict has all required value
dummya::AbstractString = thoughtDict[:thought]
dummyb::AbstractString = thoughtDict[:action][:name]
dummyc::AbstractString = thoughtDict[:action][:input]
actionname::AbstractString = thoughtDict[:action][:name]
actioninput::AbstractString = thoughtDict[:action][:input]
if actionname ["winestock", "chatbox", "recommendbox"]
# LLM use available function
else
error("DecisionMaker use wrong function")
end
return thoughtDict
catch e

View File

@@ -136,31 +136,84 @@ function expand(a::T1, node::MCTSNode, decisionMaker::Function,
progressValueEstimator::Function, reflector::Function; n::Integer=3) where {T1<:agent}
nthSample = 0
while nthSample < n
thoughtDict = decisionMaker(a, node.state)
newNodeKey, newstate, reward, isterminalstate =
MCTStransition(a, node.state, thoughtDict)
# add progressValueEstimator
stateevaluation, statevalue = progressValueEstimator(a, newstate)
if reward < 0
pprint(newstate[:thoughtHistory])
newstate[:evaluation] = stateevaluation
newstate[:lesson] = reflector(a, newstate)
a.lesson[:lesson_1] = deepcopy(newstate)
print("---> reflector()")
end
if newNodeKey keys(node.children)
node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, statevalue,
reward, isterminalstate, node, Dict{String, MCTSNode}())
end
while true
nthSample += 1
if nthSample <= n
println("---> expand() sample $nthSample")
thoughtDict = decisionMaker(a, node.state)
newNodeKey, newstate, reward, isterminalstate =
MCTStransition(a, node.state, thoughtDict)
# add progressValueEstimator
stateevaluation, statevalue = progressValueEstimator(a, newstate)
if reward < 0
pprint(newstate[:thoughtHistory])
newstate[:evaluation] = stateevaluation
newstate[:lesson] = reflector(a, newstate)
# store new lesson for later use
lessonDict = copy(JSON3.read("lesson.json"))
latestLessonKey, latestLessonIndice =
GeneralUtils.findHighestIndexKey(state[:thoughtHistory], "lesson")
nextIndice = latestLessonKey == :NA ? 1 : latestLessonIndice + 1
newLessonKey = Symbol("lesson_$(nextIndice)")
lessonDict[newLessonKey] = newstate
open("lesson.json", "w") do io
JSON3.pretty(io, lessonDict)
end
print("---> reflector()")
end
if newNodeKey keys(node.children)
node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, statevalue,
reward, isterminalstate, node, Dict{String, MCTSNode}())
end
else
break
end
end
end
# function expand(a::T1, node::MCTSNode, decisionMaker::Function,
# progressValueEstimator::Function, reflector::Function; n::Integer=3) where {T1<:agent}
# nthSample = 0
# while nthSample <= n
# nthSample += 1
# println("---> expand() sample $nthSample")
# thoughtDict = decisionMaker(a, node.state)
# newNodeKey, newstate, reward, isterminalstate =
# MCTStransition(a, node.state, thoughtDict)
# # add progressValueEstimator
# stateevaluation, statevalue = progressValueEstimator(a, newstate)
# if reward < 0
# pprint(newstate[:thoughtHistory])
# newstate[:evaluation] = stateevaluation
# newstate[:lesson] = reflector(a, newstate)
# # store new lesson for later use
# lessonDict = copy(JSON3.read("lesson.json"))
# latestLessonKey, latestLessonIndice =
# GeneralUtils.findHighestIndexKey(state[:thoughtHistory], "lesson")
# nextIndice = latestLessonKey == :NA ? 1 : latestLessonIndice + 1
# newLessonKey = Symbol("lesson_$(nextIndice)")
# lessonDict[newLessonKey] = newstate
# open("lesson.json", "w") do io
# JSON3.pretty(io, lessonDict)
# end
# print("---> reflector()")
# end
# if newNodeKey ∉ keys(node.children)
# node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, statevalue,
# reward, isterminalstate, node, Dict{String, MCTSNode}())
# end
# end
# end
@@ -299,8 +352,8 @@ function MCTStransition(a::T1, state::T2, thoughtDict::T3
newstate = deepcopy(state)
newstate[:thoughtHistory][latestThoughtKey] = thoughtDict[:thought]
newstate[:thoughtHistory][latestActionKey] = thoughtDict[:action]
latestObservationKey = Symbol("observation_$(nextIndice)")
newstate[:thoughtHistory][latestObservationKey] = response
newObservationKey = Symbol("observation_$(nextIndice)")
newstate[:thoughtHistory][newObservationKey] = response
newstate[:reward] = reward
newstate[:select] = select
newstate[:isterminal] = isterminal

View File

@@ -101,7 +101,6 @@ julia> agent = YiemAgent.bsommelier(
:customerinfo => Dict{Symbol, Any}(),
:storeinfo => Dict{Symbol, Any}(),
)
lesson::Dict{Symbol, Any} = Dict{Symbol, Any}()
mctstree::Dict{Symbol, Any} = Dict{Symbol, Any}()
# 1-historyPoint is in Dict{Symbol, Any} and compose of:

View File

@@ -116,6 +116,7 @@ result = GeneralUtils.sendMqttMsg(outgoingMsg)
outgoingMsg = Dict(
:msgMeta=> msgMeta,
:payload=> Dict(
@@ -144,20 +145,6 @@ result = GeneralUtils.sendMqttMsg(outgoingMsg)
outgoingMsg = Dict(
:msgMeta=> msgMeta,
:payload=> Dict(
:text=> "I already told you I like Red wine. Why did you ask me about other wine type?",
:select=> nothing,
:reward=> -1,
:isterminal=> false,
)
)
result = GeneralUtils.sendMqttMsg(outgoingMsg)
outgoingMsg = Dict(
:msgMeta=> msgMeta,
:payload=> Dict(
@@ -175,7 +162,7 @@ result = GeneralUtils.sendMqttMsg(outgoingMsg)
outgoingMsg = Dict(
:msgMeta=> msgMeta,
:payload=> Dict(
:text=> "I didn't like the one you recommend. You like dry wine.",
:text=> "What are you saying. I don't understand.",
:select=> nothing,
:reward=> -1,
:isterminal=> false,
@@ -187,4 +174,15 @@ result = GeneralUtils.sendMqttMsg(outgoingMsg)
outgoingMsg = Dict(
:msgMeta=> msgMeta,
:payload=> Dict(
:text=> "I like dry wine with medium acidity.",
:select=> nothing,
:reward=> 0,
:isterminal=> false,
)
)
result = GeneralUtils.sendMqttMsg(outgoingMsg)