update
This commit is contained in:
@@ -91,10 +91,17 @@ function decisionMaker(a::T1, state::T2)::Dict{Symbol, Any} where {T1<:agent, T2
|
||||
$(JSON3.write(state[:storeinfo]))
|
||||
"""
|
||||
|
||||
lessonDict = copy(JSON3.read("lesson.json"))
|
||||
|
||||
lesson =
|
||||
if isempty(a.lesson)
|
||||
if isempty(lessonDict)
|
||||
""
|
||||
else
|
||||
lessons = Dict{Symbol, Any}()
|
||||
for (k, v) in lessonDict
|
||||
lessons[k] = lessonDict[k][:lesson]
|
||||
end
|
||||
|
||||
"""
|
||||
You have attempted to help the user before and failed, either because your reasoning for the
|
||||
recommendation was incorrect or your response did not exactly match the user expectation.
|
||||
@@ -102,7 +109,7 @@ function decisionMaker(a::T1, state::T2)::Dict{Symbol, Any} where {T1<:agent, T2
|
||||
did previously. Use them to improve your strategy to help the user.
|
||||
|
||||
Here are some lessons:
|
||||
$(JSON3.write(a.lesson[:lesson_1][:lesson]))
|
||||
$(JSON3.write(lessons))
|
||||
|
||||
When providing the thought and action for the current trial, that into account these failed
|
||||
trajectories and make sure not to repeat the same mistakes and incorrect answers.
|
||||
@@ -211,8 +218,14 @@ function decisionMaker(a::T1, state::T2)::Dict{Symbol, Any} where {T1<:agent, T2
|
||||
|
||||
# check if dict has all required value
|
||||
dummya::AbstractString = thoughtDict[:thought]
|
||||
dummyb::AbstractString = thoughtDict[:action][:name]
|
||||
dummyc::AbstractString = thoughtDict[:action][:input]
|
||||
actionname::AbstractString = thoughtDict[:action][:name]
|
||||
actioninput::AbstractString = thoughtDict[:action][:input]
|
||||
|
||||
if actionname ∈ ["winestock", "chatbox", "recommendbox"]
|
||||
# LLM use available function
|
||||
else
|
||||
error("DecisionMaker use wrong function")
|
||||
end
|
||||
|
||||
return thoughtDict
|
||||
catch e
|
||||
|
||||
101
src/mcts.jl
101
src/mcts.jl
@@ -136,31 +136,84 @@ function expand(a::T1, node::MCTSNode, decisionMaker::Function,
|
||||
progressValueEstimator::Function, reflector::Function; n::Integer=3) where {T1<:agent}
|
||||
|
||||
nthSample = 0
|
||||
while nthSample < n
|
||||
|
||||
thoughtDict = decisionMaker(a, node.state)
|
||||
|
||||
newNodeKey, newstate, reward, isterminalstate =
|
||||
MCTStransition(a, node.state, thoughtDict)
|
||||
|
||||
# add progressValueEstimator
|
||||
stateevaluation, statevalue = progressValueEstimator(a, newstate)
|
||||
|
||||
if reward < 0
|
||||
pprint(newstate[:thoughtHistory])
|
||||
newstate[:evaluation] = stateevaluation
|
||||
newstate[:lesson] = reflector(a, newstate)
|
||||
a.lesson[:lesson_1] = deepcopy(newstate)
|
||||
print("---> reflector()")
|
||||
end
|
||||
|
||||
if newNodeKey ∉ keys(node.children)
|
||||
node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, statevalue,
|
||||
reward, isterminalstate, node, Dict{String, MCTSNode}())
|
||||
end
|
||||
while true
|
||||
nthSample += 1
|
||||
if nthSample <= n
|
||||
println("---> expand() sample $nthSample")
|
||||
thoughtDict = decisionMaker(a, node.state)
|
||||
|
||||
newNodeKey, newstate, reward, isterminalstate =
|
||||
MCTStransition(a, node.state, thoughtDict)
|
||||
|
||||
# add progressValueEstimator
|
||||
stateevaluation, statevalue = progressValueEstimator(a, newstate)
|
||||
|
||||
if reward < 0
|
||||
pprint(newstate[:thoughtHistory])
|
||||
newstate[:evaluation] = stateevaluation
|
||||
newstate[:lesson] = reflector(a, newstate)
|
||||
|
||||
# store new lesson for later use
|
||||
lessonDict = copy(JSON3.read("lesson.json"))
|
||||
latestLessonKey, latestLessonIndice =
|
||||
GeneralUtils.findHighestIndexKey(state[:thoughtHistory], "lesson")
|
||||
nextIndice = latestLessonKey == :NA ? 1 : latestLessonIndice + 1
|
||||
newLessonKey = Symbol("lesson_$(nextIndice)")
|
||||
lessonDict[newLessonKey] = newstate
|
||||
open("lesson.json", "w") do io
|
||||
JSON3.pretty(io, lessonDict)
|
||||
end
|
||||
print("---> reflector()")
|
||||
end
|
||||
|
||||
if newNodeKey ∉ keys(node.children)
|
||||
node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, statevalue,
|
||||
reward, isterminalstate, node, Dict{String, MCTSNode}())
|
||||
end
|
||||
else
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
# function expand(a::T1, node::MCTSNode, decisionMaker::Function,
|
||||
# progressValueEstimator::Function, reflector::Function; n::Integer=3) where {T1<:agent}
|
||||
|
||||
# nthSample = 0
|
||||
# while nthSample <= n
|
||||
# nthSample += 1
|
||||
# println("---> expand() sample $nthSample")
|
||||
# thoughtDict = decisionMaker(a, node.state)
|
||||
|
||||
# newNodeKey, newstate, reward, isterminalstate =
|
||||
# MCTStransition(a, node.state, thoughtDict)
|
||||
|
||||
# # add progressValueEstimator
|
||||
# stateevaluation, statevalue = progressValueEstimator(a, newstate)
|
||||
|
||||
# if reward < 0
|
||||
# pprint(newstate[:thoughtHistory])
|
||||
# newstate[:evaluation] = stateevaluation
|
||||
# newstate[:lesson] = reflector(a, newstate)
|
||||
|
||||
# # store new lesson for later use
|
||||
# lessonDict = copy(JSON3.read("lesson.json"))
|
||||
# latestLessonKey, latestLessonIndice =
|
||||
# GeneralUtils.findHighestIndexKey(state[:thoughtHistory], "lesson")
|
||||
# nextIndice = latestLessonKey == :NA ? 1 : latestLessonIndice + 1
|
||||
# newLessonKey = Symbol("lesson_$(nextIndice)")
|
||||
# lessonDict[newLessonKey] = newstate
|
||||
# open("lesson.json", "w") do io
|
||||
# JSON3.pretty(io, lessonDict)
|
||||
# end
|
||||
# print("---> reflector()")
|
||||
# end
|
||||
|
||||
# if newNodeKey ∉ keys(node.children)
|
||||
# node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, statevalue,
|
||||
# reward, isterminalstate, node, Dict{String, MCTSNode}())
|
||||
# end
|
||||
# end
|
||||
# end
|
||||
|
||||
|
||||
|
||||
@@ -299,8 +352,8 @@ function MCTStransition(a::T1, state::T2, thoughtDict::T3
|
||||
newstate = deepcopy(state)
|
||||
newstate[:thoughtHistory][latestThoughtKey] = thoughtDict[:thought]
|
||||
newstate[:thoughtHistory][latestActionKey] = thoughtDict[:action]
|
||||
latestObservationKey = Symbol("observation_$(nextIndice)")
|
||||
newstate[:thoughtHistory][latestObservationKey] = response
|
||||
newObservationKey = Symbol("observation_$(nextIndice)")
|
||||
newstate[:thoughtHistory][newObservationKey] = response
|
||||
newstate[:reward] = reward
|
||||
newstate[:select] = select
|
||||
newstate[:isterminal] = isterminal
|
||||
|
||||
@@ -101,7 +101,6 @@ julia> agent = YiemAgent.bsommelier(
|
||||
:customerinfo => Dict{Symbol, Any}(),
|
||||
:storeinfo => Dict{Symbol, Any}(),
|
||||
)
|
||||
lesson::Dict{Symbol, Any} = Dict{Symbol, Any}()
|
||||
mctstree::Dict{Symbol, Any} = Dict{Symbol, Any}()
|
||||
|
||||
# 1-historyPoint is in Dict{Symbol, Any} and compose of:
|
||||
|
||||
@@ -116,6 +116,7 @@ result = GeneralUtils.sendMqttMsg(outgoingMsg)
|
||||
|
||||
|
||||
|
||||
|
||||
outgoingMsg = Dict(
|
||||
:msgMeta=> msgMeta,
|
||||
:payload=> Dict(
|
||||
@@ -144,20 +145,6 @@ result = GeneralUtils.sendMqttMsg(outgoingMsg)
|
||||
|
||||
|
||||
|
||||
outgoingMsg = Dict(
|
||||
:msgMeta=> msgMeta,
|
||||
:payload=> Dict(
|
||||
:text=> "I already told you I like Red wine. Why did you ask me about other wine type?",
|
||||
:select=> nothing,
|
||||
:reward=> -1,
|
||||
:isterminal=> false,
|
||||
)
|
||||
)
|
||||
result = GeneralUtils.sendMqttMsg(outgoingMsg)
|
||||
|
||||
|
||||
|
||||
|
||||
outgoingMsg = Dict(
|
||||
:msgMeta=> msgMeta,
|
||||
:payload=> Dict(
|
||||
@@ -175,7 +162,7 @@ result = GeneralUtils.sendMqttMsg(outgoingMsg)
|
||||
outgoingMsg = Dict(
|
||||
:msgMeta=> msgMeta,
|
||||
:payload=> Dict(
|
||||
:text=> "I didn't like the one you recommend. You like dry wine.",
|
||||
:text=> "What are you saying. I don't understand.",
|
||||
:select=> nothing,
|
||||
:reward=> -1,
|
||||
:isterminal=> false,
|
||||
@@ -187,4 +174,15 @@ result = GeneralUtils.sendMqttMsg(outgoingMsg)
|
||||
|
||||
|
||||
|
||||
outgoingMsg = Dict(
|
||||
:msgMeta=> msgMeta,
|
||||
:payload=> Dict(
|
||||
:text=> "I like dry wine with medium acidity.",
|
||||
:select=> nothing,
|
||||
:reward=> 0,
|
||||
:isterminal=> false,
|
||||
)
|
||||
)
|
||||
result = GeneralUtils.sendMqttMsg(outgoingMsg)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user