This commit is contained in:
narawat lamaiin
2024-05-10 11:59:38 +07:00
parent c7fd7bc40d
commit a4ba292fad
4 changed files with 107 additions and 44 deletions

View File

@@ -91,10 +91,17 @@ function decisionMaker(a::T1, state::T2)::Dict{Symbol, Any} where {T1<:agent, T2
$(JSON3.write(state[:storeinfo])) $(JSON3.write(state[:storeinfo]))
""" """
lessonDict = copy(JSON3.read("lesson.json"))
lesson = lesson =
if isempty(a.lesson) if isempty(lessonDict)
"" ""
else else
lessons = Dict{Symbol, Any}()
for (k, v) in lessonDict
lessons[k] = lessonDict[k][:lesson]
end
""" """
You have attempted to help the user before and failed, either because your reasoning for the You have attempted to help the user before and failed, either because your reasoning for the
recommendation was incorrect or your response did not exactly match the user expectation. recommendation was incorrect or your response did not exactly match the user expectation.
@@ -102,7 +109,7 @@ function decisionMaker(a::T1, state::T2)::Dict{Symbol, Any} where {T1<:agent, T2
did previously. Use them to improve your strategy to help the user. did previously. Use them to improve your strategy to help the user.
Here are some lessons: Here are some lessons:
$(JSON3.write(a.lesson[:lesson_1][:lesson])) $(JSON3.write(lessons))
When providing the thought and action for the current trial, that into account these failed When providing the thought and action for the current trial, that into account these failed
trajectories and make sure not to repeat the same mistakes and incorrect answers. trajectories and make sure not to repeat the same mistakes and incorrect answers.
@@ -211,8 +218,14 @@ function decisionMaker(a::T1, state::T2)::Dict{Symbol, Any} where {T1<:agent, T2
# check if dict has all required value # check if dict has all required value
dummya::AbstractString = thoughtDict[:thought] dummya::AbstractString = thoughtDict[:thought]
dummyb::AbstractString = thoughtDict[:action][:name] actionname::AbstractString = thoughtDict[:action][:name]
dummyc::AbstractString = thoughtDict[:action][:input] actioninput::AbstractString = thoughtDict[:action][:input]
if actionname ["winestock", "chatbox", "recommendbox"]
# LLM use available function
else
error("DecisionMaker use wrong function")
end
return thoughtDict return thoughtDict
catch e catch e

View File

@@ -136,31 +136,84 @@ function expand(a::T1, node::MCTSNode, decisionMaker::Function,
progressValueEstimator::Function, reflector::Function; n::Integer=3) where {T1<:agent} progressValueEstimator::Function, reflector::Function; n::Integer=3) where {T1<:agent}
nthSample = 0 nthSample = 0
while nthSample < n while true
thoughtDict = decisionMaker(a, node.state)
newNodeKey, newstate, reward, isterminalstate =
MCTStransition(a, node.state, thoughtDict)
# add progressValueEstimator
stateevaluation, statevalue = progressValueEstimator(a, newstate)
if reward < 0
pprint(newstate[:thoughtHistory])
newstate[:evaluation] = stateevaluation
newstate[:lesson] = reflector(a, newstate)
a.lesson[:lesson_1] = deepcopy(newstate)
print("---> reflector()")
end
if newNodeKey keys(node.children)
node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, statevalue,
reward, isterminalstate, node, Dict{String, MCTSNode}())
end
nthSample += 1 nthSample += 1
if nthSample <= n
println("---> expand() sample $nthSample")
thoughtDict = decisionMaker(a, node.state)
newNodeKey, newstate, reward, isterminalstate =
MCTStransition(a, node.state, thoughtDict)
# add progressValueEstimator
stateevaluation, statevalue = progressValueEstimator(a, newstate)
if reward < 0
pprint(newstate[:thoughtHistory])
newstate[:evaluation] = stateevaluation
newstate[:lesson] = reflector(a, newstate)
# store new lesson for later use
lessonDict = copy(JSON3.read("lesson.json"))
latestLessonKey, latestLessonIndice =
GeneralUtils.findHighestIndexKey(state[:thoughtHistory], "lesson")
nextIndice = latestLessonKey == :NA ? 1 : latestLessonIndice + 1
newLessonKey = Symbol("lesson_$(nextIndice)")
lessonDict[newLessonKey] = newstate
open("lesson.json", "w") do io
JSON3.pretty(io, lessonDict)
end
print("---> reflector()")
end
if newNodeKey keys(node.children)
node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, statevalue,
reward, isterminalstate, node, Dict{String, MCTSNode}())
end
else
break
end
end end
end end
# function expand(a::T1, node::MCTSNode, decisionMaker::Function,
# progressValueEstimator::Function, reflector::Function; n::Integer=3) where {T1<:agent}
# nthSample = 0
# while nthSample <= n
# nthSample += 1
# println("---> expand() sample $nthSample")
# thoughtDict = decisionMaker(a, node.state)
# newNodeKey, newstate, reward, isterminalstate =
# MCTStransition(a, node.state, thoughtDict)
# # add progressValueEstimator
# stateevaluation, statevalue = progressValueEstimator(a, newstate)
# if reward < 0
# pprint(newstate[:thoughtHistory])
# newstate[:evaluation] = stateevaluation
# newstate[:lesson] = reflector(a, newstate)
# # store new lesson for later use
# lessonDict = copy(JSON3.read("lesson.json"))
# latestLessonKey, latestLessonIndice =
# GeneralUtils.findHighestIndexKey(state[:thoughtHistory], "lesson")
# nextIndice = latestLessonKey == :NA ? 1 : latestLessonIndice + 1
# newLessonKey = Symbol("lesson_$(nextIndice)")
# lessonDict[newLessonKey] = newstate
# open("lesson.json", "w") do io
# JSON3.pretty(io, lessonDict)
# end
# print("---> reflector()")
# end
# if newNodeKey ∉ keys(node.children)
# node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, statevalue,
# reward, isterminalstate, node, Dict{String, MCTSNode}())
# end
# end
# end
@@ -299,8 +352,8 @@ function MCTStransition(a::T1, state::T2, thoughtDict::T3
newstate = deepcopy(state) newstate = deepcopy(state)
newstate[:thoughtHistory][latestThoughtKey] = thoughtDict[:thought] newstate[:thoughtHistory][latestThoughtKey] = thoughtDict[:thought]
newstate[:thoughtHistory][latestActionKey] = thoughtDict[:action] newstate[:thoughtHistory][latestActionKey] = thoughtDict[:action]
latestObservationKey = Symbol("observation_$(nextIndice)") newObservationKey = Symbol("observation_$(nextIndice)")
newstate[:thoughtHistory][latestObservationKey] = response newstate[:thoughtHistory][newObservationKey] = response
newstate[:reward] = reward newstate[:reward] = reward
newstate[:select] = select newstate[:select] = select
newstate[:isterminal] = isterminal newstate[:isterminal] = isterminal

View File

@@ -101,7 +101,6 @@ julia> agent = YiemAgent.bsommelier(
:customerinfo => Dict{Symbol, Any}(), :customerinfo => Dict{Symbol, Any}(),
:storeinfo => Dict{Symbol, Any}(), :storeinfo => Dict{Symbol, Any}(),
) )
lesson::Dict{Symbol, Any} = Dict{Symbol, Any}()
mctstree::Dict{Symbol, Any} = Dict{Symbol, Any}() mctstree::Dict{Symbol, Any} = Dict{Symbol, Any}()
# 1-historyPoint is in Dict{Symbol, Any} and compose of: # 1-historyPoint is in Dict{Symbol, Any} and compose of:

View File

@@ -116,6 +116,7 @@ result = GeneralUtils.sendMqttMsg(outgoingMsg)
outgoingMsg = Dict( outgoingMsg = Dict(
:msgMeta=> msgMeta, :msgMeta=> msgMeta,
:payload=> Dict( :payload=> Dict(
@@ -144,20 +145,6 @@ result = GeneralUtils.sendMqttMsg(outgoingMsg)
outgoingMsg = Dict(
:msgMeta=> msgMeta,
:payload=> Dict(
:text=> "I already told you I like Red wine. Why did you ask me about other wine type?",
:select=> nothing,
:reward=> -1,
:isterminal=> false,
)
)
result = GeneralUtils.sendMqttMsg(outgoingMsg)
outgoingMsg = Dict( outgoingMsg = Dict(
:msgMeta=> msgMeta, :msgMeta=> msgMeta,
:payload=> Dict( :payload=> Dict(
@@ -175,7 +162,7 @@ result = GeneralUtils.sendMqttMsg(outgoingMsg)
outgoingMsg = Dict( outgoingMsg = Dict(
:msgMeta=> msgMeta, :msgMeta=> msgMeta,
:payload=> Dict( :payload=> Dict(
:text=> "I didn't like the one you recommend. You like dry wine.", :text=> "What are you saying. I don't understand.",
:select=> nothing, :select=> nothing,
:reward=> -1, :reward=> -1,
:isterminal=> false, :isterminal=> false,
@@ -187,4 +174,15 @@ result = GeneralUtils.sendMqttMsg(outgoingMsg)
outgoingMsg = Dict(
:msgMeta=> msgMeta,
:payload=> Dict(
:text=> "I like dry wine with medium acidity.",
:select=> nothing,
:reward=> 0,
:isterminal=> false,
)
)
result = GeneralUtils.sendMqttMsg(outgoingMsg)