This commit is contained in:
narawat lamaiin
2024-05-15 13:35:26 +07:00
parent 62c6ce90ed
commit e9c91fdb4d
6 changed files with 179 additions and 77 deletions

View File

@@ -39,7 +39,8 @@ macro executeStringFunction(functionStr, args...)
func_expr = Meta.parse(functionStr)
# Create a new function with the parsed expression
function_to_call = eval(Expr(:function, Expr(:call, func_expr, args...), func_expr.args[2:end]...))
function_to_call = eval(Expr(:function, Expr(:call, func_expr, args...),
func_expr.args[2:end]...))
# Call the newly created function with the provided arguments
function_to_call(args...)
@@ -744,21 +745,15 @@ julia> response = ChatAgent.conversation(newAgent, "Hi! how are you?")
# Signature
"""
function conversation(a::T, userinput::Dict) where {T<:agent}
# "newtopic" command to delete chat history
if userinput[:text] == "newtopic"
clearhistory(a)
return "Okay. What shall we talk about?"
else
# add usermsg to a.chathistory
addNewMessage(a, "user", userinput[:text])
currentstate =
if isempty(a.plan[:currenttrajectory])
# set up initial state
Dict{Symbol, Any}(
a.plan[:currenttrajectory] = Dict{Symbol, Any}(
# deepcopy the info to prevent modifying the info unintentionally during MCTS planning
:customerinfo=> deepcopy(a.keywordinfo[:customerinfo]),
:storeinfo=> deepcopy(a.keywordinfo[:storeinfo]),
@@ -767,24 +762,44 @@ function conversation(a::T, userinput::Dict) where {T<:agent}
:isterminal=> false,
:evaluation=> nothing,
:lesson=> nothing,
:thoughtDict=> nothing,
:totalTrajectoryReward=> nothing,
:thoughtHistory=> OrderedDict{Symbol, Any}( # contain question, thought_1, action_1, observation_1, thought_2, ...
# :recap=>,
# contain question, thought_1, action_1, observation_1, thought_2, ...
:thoughtHistory=> OrderedDict{Symbol, Any}(
#[] :recap=>,
:question=> userinput[:text],
)
)
else
a.plan[:currenttrajectory]
_, a.plan[:currenttrajectory] = makeNewState(a.plan[:currenttrajectory],
a.plan[:activeplan][:thoughtHistory], userinput[:text], userinput[:select],
userinput[:reward], userinput[:isterminal])
end
end
bestNextState, besttrajectory = runMCTS(a, currentstate, decisionMaker, evaluator, reflector,
totalsample=3, maxDepth=2, maxiterations=1, explorationweight=1.0)
while true
bestNextState, besttrajectory = runMCTS(a, a.plan[:currenttrajectory], decisionMaker,
evaluator, reflector, totalsample=2, maxDepth=2, maxiterations=1, explorationweight=1.0)
a.plan[:activeplan] = bestNextState
latestActionKey, latestActionIndice =
GeneralUtils.findHighestIndexKey(bestNextState[:thoughtHistory], "action")
actionname = bestNextState[:thoughtHistory][latestActionKey][:name]
actioninput = bestNextState[:thoughtHistory][latestActionKey][:input]
# transition
newstate = transition(a, bestNextState)
a.plan[:currenttrajectory] = newstate
if actionname == "chatbox"
# add usermsg to a.chathistory
addNewMessage(a, "assistant", actioninput)
return actioninput
elseif actionname == "recommendbox"
# add usermsg to a.chathistory
addNewMessage(a, "assistant", actioninput)
return actioninput
else
_, a.plan[:currenttrajectory] = transition(a, a.plan[:activeplan])
end
end
end
@@ -797,6 +812,62 @@ end
# function conversation(a::T, userinput::Dict) where {T<:agent}
# # get new user msg from a.receiveUserMsgChannel
# # "newtopic" command to delete chat history
# if userinput[:text] == "newtopic"
# clearhistory(a)
# return "Okay. What shall we talk about?"
# else
# # add usermsg to a.chathistory
# addNewMessage(a, "user", userinput[:text])
# currentstate =
# if isempty(a.plan[:currenttrajectory])
# # set up initial state
# Dict{Symbol, Any}(
# # deepcopy the info to prevent modifying the info unintentionally during MCTS planning
# :customerinfo=> deepcopy(a.keywordinfo[:customerinfo]),
# :storeinfo=> deepcopy(a.keywordinfo[:storeinfo]),
# :userselect=> nothing,
# :reward=> 0,
# :isterminal=> false,
# :evaluation=> nothing,
# :lesson=> nothing,
# :thoughtDict=> nothing,
# :totalTrajectoryReward=> nothing,
# :thoughtHistory=> OrderedDict{Symbol, Any}( # contain question, thought_1, action_1, observation_1, thought_2, ...
# # :recap=>,
# :question=> userinput[:text],
# )
# )
# else
# a.plan[:currenttrajectory]
# end
# bestNextState, besttrajectory = runMCTS(a, currentstate, decisionMaker, evaluator, reflector,
# totalsample=3, maxDepth=2, maxiterations=1, explorationweight=1.0)
# # transition
# newstate = transition(a, bestNextState)
# a.plan[:currenttrajectory] = newstate
# end
# end

View File

@@ -22,7 +22,7 @@ julia>
# TODO
- [] update docstring
- [PENDING] implement the function
- [WORKING] implement the function
# Signature
"""
@@ -293,7 +293,7 @@ function jsoncorrection(a::T1, input::T2,
correctjson = incorrectjson
return correctjson
catch e
@warn "Attempting correct JSON string. Attempt $attempt"
@warn "Attempting to correct JSON string. Attempt $attempt"
e = """$e"""
if occursin("EOF", e)
e = split(e, "EOF")[1] * "EOF"

View File

@@ -6,7 +6,7 @@
module mcts
export MCTSNode, runMCTS, isleaf, selectBestNextState, selectBestTrajectory, transition,
userChatbox
userChatbox, makeNewState
using Dates, UUIDs, DataStructures, JSON3, Random, PrettyPrinting
using GeneralUtils
@@ -144,10 +144,9 @@ function expand(a::T1, node::MCTSNode, decisionMaker::Function,
println("---> expand() sample $nthSample")
pprintln(node.state[:thoughtHistory])
pprintln(thoughtDict)
node.state[:thoughtDict] = thoughtDict
newNodeKey, newstate = MCTStransition(a, node.state)
newNodeKey, newstate = MCTStransition(a, node.state, thoughtDict)
# add evaluator
stateevaluation, progressvalue = evaluator(a, newstate)
if newstate[:reward] < 0
@@ -294,10 +293,9 @@ julia> thoughtDict = Dict(
# Signature
"""
function MCTStransition(a::T1, state::T2
function MCTStransition(a::T1, state::T2, thoughtDict::T2
)::Tuple{String, Dict{Symbol, <:Any}} where {T1<:agent, T2<:AbstractDict}
thoughtDict = state[:thoughtDict]
actionname = thoughtDict[:action][:name]
actioninput = thoughtDict[:action][:input]
@@ -313,25 +311,7 @@ function MCTStransition(a::T1, state::T2
error("undefined LLM function. Requesting $actionname")
end
latestThoughtKey, latestThoughtIndice = GeneralUtils.findHighestIndexKey(state[:thoughtHistory],
"thought")
nextIndice = latestThoughtKey == :NA ? 1 : latestThoughtIndice + 1
latestThoughtKey = Symbol("thought_$nextIndice")
latestActionKey = Symbol("action_$nextIndice")
# add Thought, action, observation to thoughtHistory
newstate = deepcopy(state)
newstate[:thoughtHistory][latestThoughtKey] = thoughtDict[:thought]
newstate[:thoughtHistory][latestActionKey] = thoughtDict[:action]
newObservationKey = Symbol("observation_$(nextIndice)")
newstate[:thoughtHistory][newObservationKey] = response
newstate[:reward] = reward
newstate[:select] = select
newstate[:isterminal] = isterminal
newNodeKey = GeneralUtils.uuid4snakecase()
return (newNodeKey, newstate)
return makeNewState(state, thoughtDict, response, select, reward, isterminal)
end
@@ -374,7 +354,7 @@ julia> thoughtDict = Dict(
# Signature
"""
function transition(a::T1, state::T2
function transition(a::T1, state::T2, thoughtDict::T2
)::Dict{Symbol, <:Any} where {T1<:agent, T2<:AbstractDict}
thoughtDict = state[:thoughtDict]
@@ -383,36 +363,74 @@ function transition(a::T1, state::T2
# map action and input() to llm function
response, select, reward, isterminal =
if actionname == "chatbox"
userChatbox(a, actioninput) # virtual customer
elseif actionname == "winestock"
if actionname == "winestock"
winestock(a, actioninput)
elseif actionname == "recommendbox"
userRecommendbox(a, actioninput)
else
error("undefined LLM function. Requesting $actionname")
end
latestThoughtKey, latestThoughtIndice =
GeneralUtils.findHighestIndexKey(state[:thoughtHistory], "thought")
nextIndice = latestThoughtKey == :NA ? 1 : latestThoughtIndice + 1
latestThoughtKey = Symbol("thought_$nextIndice")
latestActionKey = Symbol("action_$nextIndice")
return makeNewState(state, thoughtDict, response, select, reward, isterminal)
end
"""
# Arguments
# Return
# Example
```jldoctest
julia>
```
# TODO
- [] update docstring
- [TESTING] implement the function
# Signature
"""
function makeNewState(currentstate::T1, thoughtDict::T4, response::T2, select::Union{T3, Nothing},
reward::T3, isterminal::Bool
)::Tuple{String, Dict{Symbol, <:Any}} where {T1<:AbstractDict, T2<:AbstractString, T3<:Number, T4<:AbstractDict}
currentstate_latestThoughtKey, currentstate_latestThoughtIndice =
GeneralUtils.findHighestIndexKey(currentstate[:thoughtHistory], "thought")
currentstate_nextIndice = currentstate_latestThoughtKey == :NA ? 1 : currentstate_latestThoughtIndice + 1
currentstate_latestThoughtKey = Symbol("thought_$currentstate_nextIndice")
latestActionKey = Symbol("action_$currentstate_nextIndice")
_, thoughtDict_latestThoughtIndice =
GeneralUtils.findHighestIndexKey(thoughtDict, "thought")
thoughtDict_latestThoughtKey, thoughtDict_latestActionKey =
if thoughtDict_latestThoughtIndice == -1
(:thought, :action)
else
(
Symbol("thought_$thoughtDict_latestThoughtIndice"),
Symbol("action_$thoughtDict_latestThoughtIndice"),
)
end
# add Thought, action, observation to thoughtHistory
newstate = deepcopy(state)
newstate[:thoughtHistory][latestThoughtKey] = thoughtDict[:thought]
newstate[:thoughtHistory][latestActionKey] = thoughtDict[:action]
newObservationKey = Symbol("observation_$(nextIndice)")
newstate = deepcopy(currentstate)
newstate[:thoughtHistory][currentstate_latestThoughtKey] =
thoughtDict[thoughtDict_latestThoughtKey]
newstate[:thoughtHistory][latestActionKey] = thoughtDict[thoughtDict_latestActionKey]
newObservationKey = Symbol("observation_$(currentstate_nextIndice)")
newstate[:thoughtHistory][newObservationKey] = response
newstate[:reward] = reward
newstate[:select] = select
newstate[:isterminal] = isterminal
return newstate
newNodeKey = GeneralUtils.uuid4snakecase()
return (newNodeKey, newstate)
end
""" Determine whether a node is a leaf node of a search tree.
# Arguments

View File

@@ -111,8 +111,8 @@ julia> agent = YiemAgent.bsommelier(
# each plan is in [historyPoint_1, historyPoint_2, ...] format
:existingplan => Vector(),
:activeplan => Vector{Dict{Symbol, Any}}(), # current using plan
:currenttrajectory=> Vector{Dict{Symbol, Any}}(), # store
:activeplan => Dict{Symbol, Any}(), # current using plan
:currenttrajectory=> Dict{Symbol, Any}(), # store question, thought, action, observation, ...
)
# put incoming message here. waiting for further processing

View File

@@ -59,10 +59,21 @@ tools=Dict( # update input format
tools=tools,
)
response = YiemAgent.conversation(a, Dict(:text=> "Hello, I would like a get a bottle of wine",) )
# response = YiemAgent.conversation(a, Dict(:text=> "newtopic",) )
response = YiemAgent.conversation(a, Dict(:text=> "Hello, I would like a get a bottle of wine",
:select=> nothing,
:reward=> 0,
:isterminal=> false,
) )
println("---> YiemAgent: ", response)
response = YiemAgent.conversation(a, Dict(:text=> "I'm having a graduation party this evening",
:select=> nothing,
:reward=> 0,
:isterminal=> false,
) )
println("---> YiemAgent: ", response)

View File

@@ -1,6 +1,8 @@
using Revise
using YiemAgent, GeneralUtils, JSON3, DataStructures
# ---------------------------------------------- 100 --------------------------------------------- #
msgMeta = Dict(:requestResponse => nothing,
:msgPurpose => nothing,
:receiverId => nothing,