update
This commit is contained in:
131
src/interface.jl
131
src/interface.jl
@@ -39,7 +39,8 @@ macro executeStringFunction(functionStr, args...)
|
||||
func_expr = Meta.parse(functionStr)
|
||||
|
||||
# Create a new function with the parsed expression
|
||||
function_to_call = eval(Expr(:function, Expr(:call, func_expr, args...), func_expr.args[2:end]...))
|
||||
function_to_call = eval(Expr(:function, Expr(:call, func_expr, args...),
|
||||
func_expr.args[2:end]...))
|
||||
|
||||
# Call the newly created function with the provided arguments
|
||||
function_to_call(args...)
|
||||
@@ -744,47 +745,61 @@ julia> response = ChatAgent.conversation(newAgent, "Hi! how are you?")
|
||||
# Signature
|
||||
"""
|
||||
function conversation(a::T, userinput::Dict) where {T<:agent}
|
||||
|
||||
# "newtopic" command to delete chat history
|
||||
if userinput[:text] == "newtopic"
|
||||
clearhistory(a)
|
||||
|
||||
return "Okay. What shall we talk about?"
|
||||
|
||||
else
|
||||
# add usermsg to a.chathistory
|
||||
addNewMessage(a, "user", userinput[:text])
|
||||
|
||||
currentstate =
|
||||
if isempty(a.plan[:currenttrajectory])
|
||||
# set up initial state
|
||||
Dict{Symbol, Any}(
|
||||
# deepcopy the info to prevent modifying the info unintentionally during MCTS planning
|
||||
:customerinfo=> deepcopy(a.keywordinfo[:customerinfo]),
|
||||
:storeinfo=> deepcopy(a.keywordinfo[:storeinfo]),
|
||||
:userselect=> nothing,
|
||||
:reward=> 0,
|
||||
:isterminal=> false,
|
||||
:evaluation=> nothing,
|
||||
:lesson=> nothing,
|
||||
:thoughtDict=> nothing,
|
||||
:totalTrajectoryReward=> nothing,
|
||||
:thoughtHistory=> OrderedDict{Symbol, Any}( # contain question, thought_1, action_1, observation_1, thought_2, ...
|
||||
# :recap=>,
|
||||
:question=> userinput[:text],
|
||||
)
|
||||
)
|
||||
a.plan[:currenttrajectory] = Dict{Symbol, Any}(
|
||||
# deepcopy the info to prevent modifying the info unintentionally during MCTS planning
|
||||
:customerinfo=> deepcopy(a.keywordinfo[:customerinfo]),
|
||||
:storeinfo=> deepcopy(a.keywordinfo[:storeinfo]),
|
||||
:userselect=> nothing,
|
||||
:reward=> 0,
|
||||
:isterminal=> false,
|
||||
:evaluation=> nothing,
|
||||
:lesson=> nothing,
|
||||
|
||||
:totalTrajectoryReward=> nothing,
|
||||
|
||||
# contain question, thought_1, action_1, observation_1, thought_2, ...
|
||||
:thoughtHistory=> OrderedDict{Symbol, Any}(
|
||||
#[] :recap=>,
|
||||
:question=> userinput[:text],
|
||||
)
|
||||
)
|
||||
else
|
||||
a.plan[:currenttrajectory]
|
||||
_, a.plan[:currenttrajectory] = makeNewState(a.plan[:currenttrajectory],
|
||||
a.plan[:activeplan][:thoughtHistory], userinput[:text], userinput[:select],
|
||||
userinput[:reward], userinput[:isterminal])
|
||||
end
|
||||
end
|
||||
|
||||
while true
|
||||
bestNextState, besttrajectory = runMCTS(a, a.plan[:currenttrajectory], decisionMaker,
|
||||
evaluator, reflector, totalsample=2, maxDepth=2, maxiterations=1, explorationweight=1.0)
|
||||
a.plan[:activeplan] = bestNextState
|
||||
|
||||
latestActionKey, latestActionIndice =
|
||||
GeneralUtils.findHighestIndexKey(bestNextState[:thoughtHistory], "action")
|
||||
actionname = bestNextState[:thoughtHistory][latestActionKey][:name]
|
||||
actioninput = bestNextState[:thoughtHistory][latestActionKey][:input]
|
||||
|
||||
bestNextState, besttrajectory = runMCTS(a, currentstate, decisionMaker, evaluator, reflector,
|
||||
totalsample=3, maxDepth=2, maxiterations=1, explorationweight=1.0)
|
||||
|
||||
# transition
|
||||
newstate = transition(a, bestNextState)
|
||||
a.plan[:currenttrajectory] = newstate
|
||||
|
||||
if actionname == "chatbox"
|
||||
# add usermsg to a.chathistory
|
||||
addNewMessage(a, "assistant", actioninput)
|
||||
return actioninput
|
||||
elseif actionname == "recommendbox"
|
||||
# add usermsg to a.chathistory
|
||||
addNewMessage(a, "assistant", actioninput)
|
||||
return actioninput
|
||||
else
|
||||
_, a.plan[:currenttrajectory] = transition(a, a.plan[:activeplan])
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
@@ -797,6 +812,62 @@ end
|
||||
|
||||
|
||||
|
||||
# function conversation(a::T, userinput::Dict) where {T<:agent}
|
||||
|
||||
# # get new user msg from a.receiveUserMsgChannel
|
||||
|
||||
|
||||
# # "newtopic" command to delete chat history
|
||||
# if userinput[:text] == "newtopic"
|
||||
# clearhistory(a)
|
||||
|
||||
# return "Okay. What shall we talk about?"
|
||||
|
||||
# else
|
||||
# # add usermsg to a.chathistory
|
||||
# addNewMessage(a, "user", userinput[:text])
|
||||
|
||||
# currentstate =
|
||||
# if isempty(a.plan[:currenttrajectory])
|
||||
# # set up initial state
|
||||
# Dict{Symbol, Any}(
|
||||
# # deepcopy the info to prevent modifying the info unintentionally during MCTS planning
|
||||
# :customerinfo=> deepcopy(a.keywordinfo[:customerinfo]),
|
||||
# :storeinfo=> deepcopy(a.keywordinfo[:storeinfo]),
|
||||
# :userselect=> nothing,
|
||||
# :reward=> 0,
|
||||
# :isterminal=> false,
|
||||
# :evaluation=> nothing,
|
||||
# :lesson=> nothing,
|
||||
# :thoughtDict=> nothing,
|
||||
# :totalTrajectoryReward=> nothing,
|
||||
# :thoughtHistory=> OrderedDict{Symbol, Any}( # contain question, thought_1, action_1, observation_1, thought_2, ...
|
||||
# # :recap=>,
|
||||
# :question=> userinput[:text],
|
||||
# )
|
||||
# )
|
||||
# else
|
||||
# a.plan[:currenttrajectory]
|
||||
# end
|
||||
|
||||
# bestNextState, besttrajectory = runMCTS(a, currentstate, decisionMaker, evaluator, reflector,
|
||||
# totalsample=3, maxDepth=2, maxiterations=1, explorationweight=1.0)
|
||||
|
||||
# # transition
|
||||
# newstate = transition(a, bestNextState)
|
||||
# a.plan[:currenttrajectory] = newstate
|
||||
|
||||
# end
|
||||
# end
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ julia>
|
||||
|
||||
# TODO
|
||||
- [] update docstring
|
||||
- [PENDING] implement the function
|
||||
- [WORKING] implement the function
|
||||
|
||||
# Signature
|
||||
"""
|
||||
@@ -293,7 +293,7 @@ function jsoncorrection(a::T1, input::T2,
|
||||
correctjson = incorrectjson
|
||||
return correctjson
|
||||
catch e
|
||||
@warn "Attempting correct JSON string. Attempt $attempt"
|
||||
@warn "Attempting to correct JSON string. Attempt $attempt"
|
||||
e = """$e"""
|
||||
if occursin("EOF", e)
|
||||
e = split(e, "EOF")[1] * "EOF"
|
||||
|
||||
100
src/mcts.jl
100
src/mcts.jl
@@ -6,7 +6,7 @@
|
||||
module mcts
|
||||
|
||||
export MCTSNode, runMCTS, isleaf, selectBestNextState, selectBestTrajectory, transition,
|
||||
userChatbox
|
||||
userChatbox, makeNewState
|
||||
|
||||
using Dates, UUIDs, DataStructures, JSON3, Random, PrettyPrinting
|
||||
using GeneralUtils
|
||||
@@ -144,10 +144,9 @@ function expand(a::T1, node::MCTSNode, decisionMaker::Function,
|
||||
println("---> expand() sample $nthSample")
|
||||
pprintln(node.state[:thoughtHistory])
|
||||
pprintln(thoughtDict)
|
||||
node.state[:thoughtDict] = thoughtDict
|
||||
newNodeKey, newstate = MCTStransition(a, node.state)
|
||||
newNodeKey, newstate = MCTStransition(a, node.state, thoughtDict)
|
||||
|
||||
# add evaluator
|
||||
|
||||
stateevaluation, progressvalue = evaluator(a, newstate)
|
||||
|
||||
if newstate[:reward] < 0
|
||||
@@ -294,10 +293,9 @@ julia> thoughtDict = Dict(
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function MCTStransition(a::T1, state::T2
|
||||
function MCTStransition(a::T1, state::T2, thoughtDict::T2
|
||||
)::Tuple{String, Dict{Symbol, <:Any}} where {T1<:agent, T2<:AbstractDict}
|
||||
|
||||
thoughtDict = state[:thoughtDict]
|
||||
actionname = thoughtDict[:action][:name]
|
||||
actioninput = thoughtDict[:action][:input]
|
||||
|
||||
@@ -313,25 +311,7 @@ function MCTStransition(a::T1, state::T2
|
||||
error("undefined LLM function. Requesting $actionname")
|
||||
end
|
||||
|
||||
latestThoughtKey, latestThoughtIndice = GeneralUtils.findHighestIndexKey(state[:thoughtHistory],
|
||||
"thought")
|
||||
nextIndice = latestThoughtKey == :NA ? 1 : latestThoughtIndice + 1
|
||||
latestThoughtKey = Symbol("thought_$nextIndice")
|
||||
latestActionKey = Symbol("action_$nextIndice")
|
||||
|
||||
# add Thought, action, observation to thoughtHistory
|
||||
newstate = deepcopy(state)
|
||||
newstate[:thoughtHistory][latestThoughtKey] = thoughtDict[:thought]
|
||||
newstate[:thoughtHistory][latestActionKey] = thoughtDict[:action]
|
||||
newObservationKey = Symbol("observation_$(nextIndice)")
|
||||
newstate[:thoughtHistory][newObservationKey] = response
|
||||
newstate[:reward] = reward
|
||||
newstate[:select] = select
|
||||
newstate[:isterminal] = isterminal
|
||||
|
||||
newNodeKey = GeneralUtils.uuid4snakecase()
|
||||
|
||||
return (newNodeKey, newstate)
|
||||
return makeNewState(state, thoughtDict, response, select, reward, isterminal)
|
||||
end
|
||||
|
||||
|
||||
@@ -374,7 +354,7 @@ julia> thoughtDict = Dict(
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function transition(a::T1, state::T2
|
||||
function transition(a::T1, state::T2, thoughtDict::T2
|
||||
)::Dict{Symbol, <:Any} where {T1<:agent, T2<:AbstractDict}
|
||||
|
||||
thoughtDict = state[:thoughtDict]
|
||||
@@ -383,36 +363,74 @@ function transition(a::T1, state::T2
|
||||
|
||||
# map action and input() to llm function
|
||||
response, select, reward, isterminal =
|
||||
if actionname == "chatbox"
|
||||
userChatbox(a, actioninput) # virtual customer
|
||||
elseif actionname == "winestock"
|
||||
if actionname == "winestock"
|
||||
winestock(a, actioninput)
|
||||
elseif actionname == "recommendbox"
|
||||
userRecommendbox(a, actioninput)
|
||||
else
|
||||
error("undefined LLM function. Requesting $actionname")
|
||||
end
|
||||
|
||||
latestThoughtKey, latestThoughtIndice =
|
||||
GeneralUtils.findHighestIndexKey(state[:thoughtHistory], "thought")
|
||||
nextIndice = latestThoughtKey == :NA ? 1 : latestThoughtIndice + 1
|
||||
latestThoughtKey = Symbol("thought_$nextIndice")
|
||||
latestActionKey = Symbol("action_$nextIndice")
|
||||
return makeNewState(state, thoughtDict, response, select, reward, isterminal)
|
||||
end
|
||||
|
||||
|
||||
"""
|
||||
|
||||
# Arguments
|
||||
|
||||
# Return
|
||||
|
||||
# Example
|
||||
```jldoctest
|
||||
julia>
|
||||
```
|
||||
|
||||
# TODO
|
||||
- [] update docstring
|
||||
- [TESTING] implement the function
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function makeNewState(currentstate::T1, thoughtDict::T4, response::T2, select::Union{T3, Nothing},
|
||||
reward::T3, isterminal::Bool
|
||||
)::Tuple{String, Dict{Symbol, <:Any}} where {T1<:AbstractDict, T2<:AbstractString, T3<:Number, T4<:AbstractDict}
|
||||
|
||||
currentstate_latestThoughtKey, currentstate_latestThoughtIndice =
|
||||
GeneralUtils.findHighestIndexKey(currentstate[:thoughtHistory], "thought")
|
||||
currentstate_nextIndice = currentstate_latestThoughtKey == :NA ? 1 : currentstate_latestThoughtIndice + 1
|
||||
currentstate_latestThoughtKey = Symbol("thought_$currentstate_nextIndice")
|
||||
latestActionKey = Symbol("action_$currentstate_nextIndice")
|
||||
|
||||
_, thoughtDict_latestThoughtIndice =
|
||||
GeneralUtils.findHighestIndexKey(thoughtDict, "thought")
|
||||
|
||||
thoughtDict_latestThoughtKey, thoughtDict_latestActionKey =
|
||||
if thoughtDict_latestThoughtIndice == -1
|
||||
(:thought, :action)
|
||||
else
|
||||
(
|
||||
Symbol("thought_$thoughtDict_latestThoughtIndice"),
|
||||
Symbol("action_$thoughtDict_latestThoughtIndice"),
|
||||
)
|
||||
end
|
||||
|
||||
# add Thought, action, observation to thoughtHistory
|
||||
newstate = deepcopy(state)
|
||||
newstate[:thoughtHistory][latestThoughtKey] = thoughtDict[:thought]
|
||||
newstate[:thoughtHistory][latestActionKey] = thoughtDict[:action]
|
||||
newObservationKey = Symbol("observation_$(nextIndice)")
|
||||
newstate = deepcopy(currentstate)
|
||||
newstate[:thoughtHistory][currentstate_latestThoughtKey] =
|
||||
thoughtDict[thoughtDict_latestThoughtKey]
|
||||
newstate[:thoughtHistory][latestActionKey] = thoughtDict[thoughtDict_latestActionKey]
|
||||
newObservationKey = Symbol("observation_$(currentstate_nextIndice)")
|
||||
newstate[:thoughtHistory][newObservationKey] = response
|
||||
newstate[:reward] = reward
|
||||
newstate[:select] = select
|
||||
newstate[:isterminal] = isterminal
|
||||
|
||||
return newstate
|
||||
newNodeKey = GeneralUtils.uuid4snakecase()
|
||||
|
||||
return (newNodeKey, newstate)
|
||||
end
|
||||
|
||||
|
||||
|
||||
""" Determine whether a node is a leaf node of a search tree.
|
||||
|
||||
# Arguments
|
||||
|
||||
@@ -111,8 +111,8 @@ julia> agent = YiemAgent.bsommelier(
|
||||
# each plan is in [historyPoint_1, historyPoint_2, ...] format
|
||||
:existingplan => Vector(),
|
||||
|
||||
:activeplan => Vector{Dict{Symbol, Any}}(), # current using plan
|
||||
:currenttrajectory=> Vector{Dict{Symbol, Any}}(), # store
|
||||
:activeplan => Dict{Symbol, Any}(), # current using plan
|
||||
:currenttrajectory=> Dict{Symbol, Any}(), # store question, thought, action, observation, ...
|
||||
)
|
||||
|
||||
# put incoming message here. waiting for further processing
|
||||
|
||||
@@ -59,10 +59,21 @@ tools=Dict( # update input format
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
response = YiemAgent.conversation(a, Dict(:text=> "Hello, I would like a get a bottle of wine",) )
|
||||
|
||||
# response = YiemAgent.conversation(a, Dict(:text=> "newtopic",) )
|
||||
|
||||
response = YiemAgent.conversation(a, Dict(:text=> "Hello, I would like a get a bottle of wine",
|
||||
:select=> nothing,
|
||||
:reward=> 0,
|
||||
:isterminal=> false,
|
||||
) )
|
||||
println("---> YiemAgent: ", response)
|
||||
|
||||
response = YiemAgent.conversation(a, Dict(:text=> "I'm having a graduation party this evening",
|
||||
:select=> nothing,
|
||||
:reward=> 0,
|
||||
:isterminal=> false,
|
||||
) )
|
||||
println("---> YiemAgent: ", response)
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
using Revise
|
||||
using YiemAgent, GeneralUtils, JSON3, DataStructures
|
||||
|
||||
# ---------------------------------------------- 100 --------------------------------------------- #
|
||||
|
||||
msgMeta = Dict(:requestResponse => nothing,
|
||||
:msgPurpose => nothing,
|
||||
:receiverId => nothing,
|
||||
|
||||
Reference in New Issue
Block a user