From a4ba292fad340294f232713465f546fd4ecb083d Mon Sep 17 00:00:00 2001 From: narawat lamaiin Date: Fri, 10 May 2024 11:59:38 +0700 Subject: [PATCH] update --- src/interface.jl | 21 ++++++++-- src/mcts.jl | 101 ++++++++++++++++++++++++++++++++++++----------- src/type.jl | 1 - test/test_1.jl | 28 ++++++------- 4 files changed, 107 insertions(+), 44 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 7da35dd..6f384fa 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -91,10 +91,17 @@ function decisionMaker(a::T1, state::T2)::Dict{Symbol, Any} where {T1<:agent, T2 $(JSON3.write(state[:storeinfo])) """ + lessonDict = copy(JSON3.read("lesson.json")) + lesson = - if isempty(a.lesson) + if isempty(lessonDict) "" else + lessons = Dict{Symbol, Any}() + for (k, v) in lessonDict + lessons[k] = lessonDict[k][:lesson] + end + """ You have attempted to help the user before and failed, either because your reasoning for the recommendation was incorrect or your response did not exactly match the user expectation. @@ -102,7 +109,7 @@ function decisionMaker(a::T1, state::T2)::Dict{Symbol, Any} where {T1<:agent, T2 did previously. Use them to improve your strategy to help the user. Here are some lessons: - $(JSON3.write(a.lesson[:lesson_1][:lesson])) + $(JSON3.write(lessons)) When providing the thought and action for the current trial, that into account these failed trajectories and make sure not to repeat the same mistakes and incorrect answers. @@ -211,8 +218,14 @@ function decisionMaker(a::T1, state::T2)::Dict{Symbol, Any} where {T1<:agent, T2 # check if dict has all required value dummya::AbstractString = thoughtDict[:thought] - dummyb::AbstractString = thoughtDict[:action][:name] - dummyc::AbstractString = thoughtDict[:action][:input] + actionname::AbstractString = thoughtDict[:action][:name] + actioninput::AbstractString = thoughtDict[:action][:input] + + if actionname ∈ ["winestock", "chatbox", "recommendbox"] + # LLM use available function + else + error("DecisionMaker use wrong function") + end return thoughtDict catch e diff --git a/src/mcts.jl b/src/mcts.jl index 4b72aad..7d30568 100644 --- a/src/mcts.jl +++ b/src/mcts.jl @@ -136,31 +136,84 @@ function expand(a::T1, node::MCTSNode, decisionMaker::Function, progressValueEstimator::Function, reflector::Function; n::Integer=3) where {T1<:agent} nthSample = 0 - while nthSample < n - - thoughtDict = decisionMaker(a, node.state) - - newNodeKey, newstate, reward, isterminalstate = - MCTStransition(a, node.state, thoughtDict) - - # add progressValueEstimator - stateevaluation, statevalue = progressValueEstimator(a, newstate) - - if reward < 0 - pprint(newstate[:thoughtHistory]) - newstate[:evaluation] = stateevaluation - newstate[:lesson] = reflector(a, newstate) - a.lesson[:lesson_1] = deepcopy(newstate) - print("---> reflector()") - end - - if newNodeKey ∉ keys(node.children) - node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, statevalue, - reward, isterminalstate, node, Dict{String, MCTSNode}()) - end + while true nthSample += 1 + if nthSample <= n + println("---> expand() sample $nthSample") + thoughtDict = decisionMaker(a, node.state) + + newNodeKey, newstate, reward, isterminalstate = + MCTStransition(a, node.state, thoughtDict) + + # add progressValueEstimator + stateevaluation, statevalue = progressValueEstimator(a, newstate) + + if reward < 0 + pprint(newstate[:thoughtHistory]) + newstate[:evaluation] = stateevaluation + newstate[:lesson] = reflector(a, newstate) + + # store new lesson for later use + lessonDict = copy(JSON3.read("lesson.json")) + latestLessonKey, latestLessonIndice = + GeneralUtils.findHighestIndexKey(state[:thoughtHistory], "lesson") + nextIndice = latestLessonKey == :NA ? 1 : latestLessonIndice + 1 + newLessonKey = Symbol("lesson_$(nextIndice)") + lessonDict[newLessonKey] = newstate + open("lesson.json", "w") do io + JSON3.pretty(io, lessonDict) + end + print("---> reflector()") + end + + if newNodeKey ∉ keys(node.children) + node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, statevalue, + reward, isterminalstate, node, Dict{String, MCTSNode}()) + end + else + break + end end end +# function expand(a::T1, node::MCTSNode, decisionMaker::Function, +# progressValueEstimator::Function, reflector::Function; n::Integer=3) where {T1<:agent} + +# nthSample = 0 +# while nthSample <= n +# nthSample += 1 +# println("---> expand() sample $nthSample") +# thoughtDict = decisionMaker(a, node.state) + +# newNodeKey, newstate, reward, isterminalstate = +# MCTStransition(a, node.state, thoughtDict) + +# # add progressValueEstimator +# stateevaluation, statevalue = progressValueEstimator(a, newstate) + +# if reward < 0 +# pprint(newstate[:thoughtHistory]) +# newstate[:evaluation] = stateevaluation +# newstate[:lesson] = reflector(a, newstate) + +# # store new lesson for later use +# lessonDict = copy(JSON3.read("lesson.json")) +# latestLessonKey, latestLessonIndice = +# GeneralUtils.findHighestIndexKey(state[:thoughtHistory], "lesson") +# nextIndice = latestLessonKey == :NA ? 1 : latestLessonIndice + 1 +# newLessonKey = Symbol("lesson_$(nextIndice)") +# lessonDict[newLessonKey] = newstate +# open("lesson.json", "w") do io +# JSON3.pretty(io, lessonDict) +# end +# print("---> reflector()") +# end + +# if newNodeKey ∉ keys(node.children) +# node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, statevalue, +# reward, isterminalstate, node, Dict{String, MCTSNode}()) +# end +# end +# end @@ -299,8 +352,8 @@ function MCTStransition(a::T1, state::T2, thoughtDict::T3 newstate = deepcopy(state) newstate[:thoughtHistory][latestThoughtKey] = thoughtDict[:thought] newstate[:thoughtHistory][latestActionKey] = thoughtDict[:action] - latestObservationKey = Symbol("observation_$(nextIndice)") - newstate[:thoughtHistory][latestObservationKey] = response + newObservationKey = Symbol("observation_$(nextIndice)") + newstate[:thoughtHistory][newObservationKey] = response newstate[:reward] = reward newstate[:select] = select newstate[:isterminal] = isterminal diff --git a/src/type.jl b/src/type.jl index 3ecc0da..3d4f317 100644 --- a/src/type.jl +++ b/src/type.jl @@ -101,7 +101,6 @@ julia> agent = YiemAgent.bsommelier( :customerinfo => Dict{Symbol, Any}(), :storeinfo => Dict{Symbol, Any}(), ) - lesson::Dict{Symbol, Any} = Dict{Symbol, Any}() mctstree::Dict{Symbol, Any} = Dict{Symbol, Any}() # 1-historyPoint is in Dict{Symbol, Any} and compose of: diff --git a/test/test_1.jl b/test/test_1.jl index 4d10d77..b2430e9 100644 --- a/test/test_1.jl +++ b/test/test_1.jl @@ -116,6 +116,7 @@ result = GeneralUtils.sendMqttMsg(outgoingMsg) + outgoingMsg = Dict( :msgMeta=> msgMeta, :payload=> Dict( @@ -144,20 +145,6 @@ result = GeneralUtils.sendMqttMsg(outgoingMsg) -outgoingMsg = Dict( - :msgMeta=> msgMeta, - :payload=> Dict( - :text=> "I already told you I like Red wine. Why did you ask me about other wine type?", - :select=> nothing, - :reward=> -1, - :isterminal=> false, - ) -) -result = GeneralUtils.sendMqttMsg(outgoingMsg) - - - - outgoingMsg = Dict( :msgMeta=> msgMeta, :payload=> Dict( @@ -175,7 +162,7 @@ result = GeneralUtils.sendMqttMsg(outgoingMsg) outgoingMsg = Dict( :msgMeta=> msgMeta, :payload=> Dict( - :text=> "I didn't like the one you recommend. You like dry wine.", + :text=> "What are you saying. I don't understand.", :select=> nothing, :reward=> -1, :isterminal=> false, @@ -187,4 +174,15 @@ result = GeneralUtils.sendMqttMsg(outgoingMsg) +outgoingMsg = Dict( + :msgMeta=> msgMeta, + :payload=> Dict( + :text=> "I like dry wine with medium acidity.", + :select=> nothing, + :reward=> 0, + :isterminal=> false, + ) +) +result = GeneralUtils.sendMqttMsg(outgoingMsg) +