This commit is contained in:
narawat lamaiin
2024-05-15 13:35:26 +07:00
parent 62c6ce90ed
commit e9c91fdb4d
6 changed files with 179 additions and 77 deletions

View File

@@ -39,7 +39,8 @@ macro executeStringFunction(functionStr, args...)
func_expr = Meta.parse(functionStr) func_expr = Meta.parse(functionStr)
# Create a new function with the parsed expression # Create a new function with the parsed expression
function_to_call = eval(Expr(:function, Expr(:call, func_expr, args...), func_expr.args[2:end]...)) function_to_call = eval(Expr(:function, Expr(:call, func_expr, args...),
func_expr.args[2:end]...))
# Call the newly created function with the provided arguments # Call the newly created function with the provided arguments
function_to_call(args...) function_to_call(args...)
@@ -744,47 +745,61 @@ julia> response = ChatAgent.conversation(newAgent, "Hi! how are you?")
# Signature # Signature
""" """
function conversation(a::T, userinput::Dict) where {T<:agent} function conversation(a::T, userinput::Dict) where {T<:agent}
# "newtopic" command to delete chat history
if userinput[:text] == "newtopic" if userinput[:text] == "newtopic"
clearhistory(a) clearhistory(a)
return "Okay. What shall we talk about?" return "Okay. What shall we talk about?"
else else
# add usermsg to a.chathistory # add usermsg to a.chathistory
addNewMessage(a, "user", userinput[:text]) addNewMessage(a, "user", userinput[:text])
currentstate =
if isempty(a.plan[:currenttrajectory]) if isempty(a.plan[:currenttrajectory])
# set up initial state a.plan[:currenttrajectory] = Dict{Symbol, Any}(
Dict{Symbol, Any}( # deepcopy the info to prevent modifying the info unintentionally during MCTS planning
# deepcopy the info to prevent modifying the info unintentionally during MCTS planning :customerinfo=> deepcopy(a.keywordinfo[:customerinfo]),
:customerinfo=> deepcopy(a.keywordinfo[:customerinfo]), :storeinfo=> deepcopy(a.keywordinfo[:storeinfo]),
:storeinfo=> deepcopy(a.keywordinfo[:storeinfo]), :userselect=> nothing,
:userselect=> nothing, :reward=> 0,
:reward=> 0, :isterminal=> false,
:isterminal=> false, :evaluation=> nothing,
:evaluation=> nothing, :lesson=> nothing,
:lesson=> nothing,
:thoughtDict=> nothing,
:totalTrajectoryReward=> nothing,
:thoughtHistory=> OrderedDict{Symbol, Any}( # contain question, thought_1, action_1, observation_1, thought_2, ...
# :recap=>,
:question=> userinput[:text],
)
)
else
a.plan[:currenttrajectory]
end
bestNextState, besttrajectory = runMCTS(a, currentstate, decisionMaker, evaluator, reflector, :totalTrajectoryReward=> nothing,
totalsample=3, maxDepth=2, maxiterations=1, explorationweight=1.0)
# contain question, thought_1, action_1, observation_1, thought_2, ...
:thoughtHistory=> OrderedDict{Symbol, Any}(
#[] :recap=>,
:question=> userinput[:text],
)
)
else
_, a.plan[:currenttrajectory] = makeNewState(a.plan[:currenttrajectory],
a.plan[:activeplan][:thoughtHistory], userinput[:text], userinput[:select],
userinput[:reward], userinput[:isterminal])
end
end
while true
bestNextState, besttrajectory = runMCTS(a, a.plan[:currenttrajectory], decisionMaker,
evaluator, reflector, totalsample=2, maxDepth=2, maxiterations=1, explorationweight=1.0)
a.plan[:activeplan] = bestNextState
latestActionKey, latestActionIndice =
GeneralUtils.findHighestIndexKey(bestNextState[:thoughtHistory], "action")
actionname = bestNextState[:thoughtHistory][latestActionKey][:name]
actioninput = bestNextState[:thoughtHistory][latestActionKey][:input]
# transition # transition
newstate = transition(a, bestNextState) if actionname == "chatbox"
a.plan[:currenttrajectory] = newstate # add usermsg to a.chathistory
addNewMessage(a, "assistant", actioninput)
return actioninput
elseif actionname == "recommendbox"
# add usermsg to a.chathistory
addNewMessage(a, "assistant", actioninput)
return actioninput
else
_, a.plan[:currenttrajectory] = transition(a, a.plan[:activeplan])
end
end end
end end
@@ -797,6 +812,62 @@ end
# function conversation(a::T, userinput::Dict) where {T<:agent}
# # get new user msg from a.receiveUserMsgChannel
# # "newtopic" command to delete chat history
# if userinput[:text] == "newtopic"
# clearhistory(a)
# return "Okay. What shall we talk about?"
# else
# # add usermsg to a.chathistory
# addNewMessage(a, "user", userinput[:text])
# currentstate =
# if isempty(a.plan[:currenttrajectory])
# # set up initial state
# Dict{Symbol, Any}(
# # deepcopy the info to prevent modifying the info unintentionally during MCTS planning
# :customerinfo=> deepcopy(a.keywordinfo[:customerinfo]),
# :storeinfo=> deepcopy(a.keywordinfo[:storeinfo]),
# :userselect=> nothing,
# :reward=> 0,
# :isterminal=> false,
# :evaluation=> nothing,
# :lesson=> nothing,
# :thoughtDict=> nothing,
# :totalTrajectoryReward=> nothing,
# :thoughtHistory=> OrderedDict{Symbol, Any}( # contain question, thought_1, action_1, observation_1, thought_2, ...
# # :recap=>,
# :question=> userinput[:text],
# )
# )
# else
# a.plan[:currenttrajectory]
# end
# bestNextState, besttrajectory = runMCTS(a, currentstate, decisionMaker, evaluator, reflector,
# totalsample=3, maxDepth=2, maxiterations=1, explorationweight=1.0)
# # transition
# newstate = transition(a, bestNextState)
# a.plan[:currenttrajectory] = newstate
# end
# end

View File

@@ -22,7 +22,7 @@ julia>
# TODO # TODO
- [] update docstring - [] update docstring
- [PENDING] implement the function - [WORKING] implement the function
# Signature # Signature
""" """
@@ -293,7 +293,7 @@ function jsoncorrection(a::T1, input::T2,
correctjson = incorrectjson correctjson = incorrectjson
return correctjson return correctjson
catch e catch e
@warn "Attempting correct JSON string. Attempt $attempt" @warn "Attempting to correct JSON string. Attempt $attempt"
e = """$e""" e = """$e"""
if occursin("EOF", e) if occursin("EOF", e)
e = split(e, "EOF")[1] * "EOF" e = split(e, "EOF")[1] * "EOF"

View File

@@ -6,7 +6,7 @@
module mcts module mcts
export MCTSNode, runMCTS, isleaf, selectBestNextState, selectBestTrajectory, transition, export MCTSNode, runMCTS, isleaf, selectBestNextState, selectBestTrajectory, transition,
userChatbox userChatbox, makeNewState
using Dates, UUIDs, DataStructures, JSON3, Random, PrettyPrinting using Dates, UUIDs, DataStructures, JSON3, Random, PrettyPrinting
using GeneralUtils using GeneralUtils
@@ -144,10 +144,9 @@ function expand(a::T1, node::MCTSNode, decisionMaker::Function,
println("---> expand() sample $nthSample") println("---> expand() sample $nthSample")
pprintln(node.state[:thoughtHistory]) pprintln(node.state[:thoughtHistory])
pprintln(thoughtDict) pprintln(thoughtDict)
node.state[:thoughtDict] = thoughtDict newNodeKey, newstate = MCTStransition(a, node.state, thoughtDict)
newNodeKey, newstate = MCTStransition(a, node.state)
# add evaluator
stateevaluation, progressvalue = evaluator(a, newstate) stateevaluation, progressvalue = evaluator(a, newstate)
if newstate[:reward] < 0 if newstate[:reward] < 0
@@ -294,10 +293,9 @@ julia> thoughtDict = Dict(
# Signature # Signature
""" """
function MCTStransition(a::T1, state::T2 function MCTStransition(a::T1, state::T2, thoughtDict::T2
)::Tuple{String, Dict{Symbol, <:Any}} where {T1<:agent, T2<:AbstractDict} )::Tuple{String, Dict{Symbol, <:Any}} where {T1<:agent, T2<:AbstractDict}
thoughtDict = state[:thoughtDict]
actionname = thoughtDict[:action][:name] actionname = thoughtDict[:action][:name]
actioninput = thoughtDict[:action][:input] actioninput = thoughtDict[:action][:input]
@@ -313,25 +311,7 @@ function MCTStransition(a::T1, state::T2
error("undefined LLM function. Requesting $actionname") error("undefined LLM function. Requesting $actionname")
end end
latestThoughtKey, latestThoughtIndice = GeneralUtils.findHighestIndexKey(state[:thoughtHistory], return makeNewState(state, thoughtDict, response, select, reward, isterminal)
"thought")
nextIndice = latestThoughtKey == :NA ? 1 : latestThoughtIndice + 1
latestThoughtKey = Symbol("thought_$nextIndice")
latestActionKey = Symbol("action_$nextIndice")
# add Thought, action, observation to thoughtHistory
newstate = deepcopy(state)
newstate[:thoughtHistory][latestThoughtKey] = thoughtDict[:thought]
newstate[:thoughtHistory][latestActionKey] = thoughtDict[:action]
newObservationKey = Symbol("observation_$(nextIndice)")
newstate[:thoughtHistory][newObservationKey] = response
newstate[:reward] = reward
newstate[:select] = select
newstate[:isterminal] = isterminal
newNodeKey = GeneralUtils.uuid4snakecase()
return (newNodeKey, newstate)
end end
@@ -374,7 +354,7 @@ julia> thoughtDict = Dict(
# Signature # Signature
""" """
function transition(a::T1, state::T2 function transition(a::T1, state::T2, thoughtDict::T2
)::Dict{Symbol, <:Any} where {T1<:agent, T2<:AbstractDict} )::Dict{Symbol, <:Any} where {T1<:agent, T2<:AbstractDict}
thoughtDict = state[:thoughtDict] thoughtDict = state[:thoughtDict]
@@ -383,36 +363,74 @@ function transition(a::T1, state::T2
# map action and input() to llm function # map action and input() to llm function
response, select, reward, isterminal = response, select, reward, isterminal =
if actionname == "chatbox" if actionname == "winestock"
userChatbox(a, actioninput) # virtual customer
elseif actionname == "winestock"
winestock(a, actioninput) winestock(a, actioninput)
elseif actionname == "recommendbox"
userRecommendbox(a, actioninput)
else else
error("undefined LLM function. Requesting $actionname") error("undefined LLM function. Requesting $actionname")
end end
latestThoughtKey, latestThoughtIndice = return makeNewState(state, thoughtDict, response, select, reward, isterminal)
GeneralUtils.findHighestIndexKey(state[:thoughtHistory], "thought") end
nextIndice = latestThoughtKey == :NA ? 1 : latestThoughtIndice + 1
latestThoughtKey = Symbol("thought_$nextIndice")
latestActionKey = Symbol("action_$nextIndice") """
# Arguments
# Return
# Example
```jldoctest
julia>
```
# TODO
- [] update docstring
- [TESTING] implement the function
# Signature
"""
function makeNewState(currentstate::T1, thoughtDict::T4, response::T2, select::Union{T3, Nothing},
reward::T3, isterminal::Bool
)::Tuple{String, Dict{Symbol, <:Any}} where {T1<:AbstractDict, T2<:AbstractString, T3<:Number, T4<:AbstractDict}
currentstate_latestThoughtKey, currentstate_latestThoughtIndice =
GeneralUtils.findHighestIndexKey(currentstate[:thoughtHistory], "thought")
currentstate_nextIndice = currentstate_latestThoughtKey == :NA ? 1 : currentstate_latestThoughtIndice + 1
currentstate_latestThoughtKey = Symbol("thought_$currentstate_nextIndice")
latestActionKey = Symbol("action_$currentstate_nextIndice")
_, thoughtDict_latestThoughtIndice =
GeneralUtils.findHighestIndexKey(thoughtDict, "thought")
thoughtDict_latestThoughtKey, thoughtDict_latestActionKey =
if thoughtDict_latestThoughtIndice == -1
(:thought, :action)
else
(
Symbol("thought_$thoughtDict_latestThoughtIndice"),
Symbol("action_$thoughtDict_latestThoughtIndice"),
)
end
# add Thought, action, observation to thoughtHistory # add Thought, action, observation to thoughtHistory
newstate = deepcopy(state) newstate = deepcopy(currentstate)
newstate[:thoughtHistory][latestThoughtKey] = thoughtDict[:thought] newstate[:thoughtHistory][currentstate_latestThoughtKey] =
newstate[:thoughtHistory][latestActionKey] = thoughtDict[:action] thoughtDict[thoughtDict_latestThoughtKey]
newObservationKey = Symbol("observation_$(nextIndice)") newstate[:thoughtHistory][latestActionKey] = thoughtDict[thoughtDict_latestActionKey]
newObservationKey = Symbol("observation_$(currentstate_nextIndice)")
newstate[:thoughtHistory][newObservationKey] = response newstate[:thoughtHistory][newObservationKey] = response
newstate[:reward] = reward newstate[:reward] = reward
newstate[:select] = select newstate[:select] = select
newstate[:isterminal] = isterminal newstate[:isterminal] = isterminal
return newstate newNodeKey = GeneralUtils.uuid4snakecase()
return (newNodeKey, newstate)
end end
""" Determine whether a node is a leaf node of a search tree. """ Determine whether a node is a leaf node of a search tree.
# Arguments # Arguments

View File

@@ -111,8 +111,8 @@ julia> agent = YiemAgent.bsommelier(
# each plan is in [historyPoint_1, historyPoint_2, ...] format # each plan is in [historyPoint_1, historyPoint_2, ...] format
:existingplan => Vector(), :existingplan => Vector(),
:activeplan => Vector{Dict{Symbol, Any}}(), # current using plan :activeplan => Dict{Symbol, Any}(), # current using plan
:currenttrajectory=> Vector{Dict{Symbol, Any}}(), # store :currenttrajectory=> Dict{Symbol, Any}(), # store question, thought, action, observation, ...
) )
# put incoming message here. waiting for further processing # put incoming message here. waiting for further processing

View File

@@ -59,10 +59,21 @@ tools=Dict( # update input format
tools=tools, tools=tools,
) )
response = YiemAgent.conversation(a, Dict(:text=> "Hello, I would like a get a bottle of wine",) ) # response = YiemAgent.conversation(a, Dict(:text=> "newtopic",) )
response = YiemAgent.conversation(a, Dict(:text=> "Hello, I would like a get a bottle of wine",
:select=> nothing,
:reward=> 0,
:isterminal=> false,
) )
println("---> YiemAgent: ", response)
response = YiemAgent.conversation(a, Dict(:text=> "I'm having a graduation party this evening",
:select=> nothing,
:reward=> 0,
:isterminal=> false,
) )
println("---> YiemAgent: ", response)

View File

@@ -1,6 +1,8 @@
using Revise using Revise
using YiemAgent, GeneralUtils, JSON3, DataStructures using YiemAgent, GeneralUtils, JSON3, DataStructures
# ---------------------------------------------- 100 --------------------------------------------- #
msgMeta = Dict(:requestResponse => nothing, msgMeta = Dict(:requestResponse => nothing,
:msgPurpose => nothing, :msgPurpose => nothing,
:receiverId => nothing, :receiverId => nothing,