update
This commit is contained in:
105
src/interface.jl
105
src/interface.jl
@@ -39,7 +39,8 @@ macro executeStringFunction(functionStr, args...)
|
|||||||
func_expr = Meta.parse(functionStr)
|
func_expr = Meta.parse(functionStr)
|
||||||
|
|
||||||
# Create a new function with the parsed expression
|
# Create a new function with the parsed expression
|
||||||
function_to_call = eval(Expr(:function, Expr(:call, func_expr, args...), func_expr.args[2:end]...))
|
function_to_call = eval(Expr(:function, Expr(:call, func_expr, args...),
|
||||||
|
func_expr.args[2:end]...))
|
||||||
|
|
||||||
# Call the newly created function with the provided arguments
|
# Call the newly created function with the provided arguments
|
||||||
function_to_call(args...)
|
function_to_call(args...)
|
||||||
@@ -744,21 +745,15 @@ julia> response = ChatAgent.conversation(newAgent, "Hi! how are you?")
|
|||||||
# Signature
|
# Signature
|
||||||
"""
|
"""
|
||||||
function conversation(a::T, userinput::Dict) where {T<:agent}
|
function conversation(a::T, userinput::Dict) where {T<:agent}
|
||||||
|
|
||||||
# "newtopic" command to delete chat history
|
|
||||||
if userinput[:text] == "newtopic"
|
if userinput[:text] == "newtopic"
|
||||||
clearhistory(a)
|
clearhistory(a)
|
||||||
|
|
||||||
return "Okay. What shall we talk about?"
|
return "Okay. What shall we talk about?"
|
||||||
|
|
||||||
else
|
else
|
||||||
# add usermsg to a.chathistory
|
# add usermsg to a.chathistory
|
||||||
addNewMessage(a, "user", userinput[:text])
|
addNewMessage(a, "user", userinput[:text])
|
||||||
|
|
||||||
currentstate =
|
|
||||||
if isempty(a.plan[:currenttrajectory])
|
if isempty(a.plan[:currenttrajectory])
|
||||||
# set up initial state
|
a.plan[:currenttrajectory] = Dict{Symbol, Any}(
|
||||||
Dict{Symbol, Any}(
|
|
||||||
# deepcopy the info to prevent modifying the info unintentionally during MCTS planning
|
# deepcopy the info to prevent modifying the info unintentionally during MCTS planning
|
||||||
:customerinfo=> deepcopy(a.keywordinfo[:customerinfo]),
|
:customerinfo=> deepcopy(a.keywordinfo[:customerinfo]),
|
||||||
:storeinfo=> deepcopy(a.keywordinfo[:storeinfo]),
|
:storeinfo=> deepcopy(a.keywordinfo[:storeinfo]),
|
||||||
@@ -767,24 +762,44 @@ function conversation(a::T, userinput::Dict) where {T<:agent}
|
|||||||
:isterminal=> false,
|
:isterminal=> false,
|
||||||
:evaluation=> nothing,
|
:evaluation=> nothing,
|
||||||
:lesson=> nothing,
|
:lesson=> nothing,
|
||||||
:thoughtDict=> nothing,
|
|
||||||
:totalTrajectoryReward=> nothing,
|
:totalTrajectoryReward=> nothing,
|
||||||
:thoughtHistory=> OrderedDict{Symbol, Any}( # contain question, thought_1, action_1, observation_1, thought_2, ...
|
|
||||||
# :recap=>,
|
# contain question, thought_1, action_1, observation_1, thought_2, ...
|
||||||
|
:thoughtHistory=> OrderedDict{Symbol, Any}(
|
||||||
|
#[] :recap=>,
|
||||||
:question=> userinput[:text],
|
:question=> userinput[:text],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else
|
else
|
||||||
a.plan[:currenttrajectory]
|
_, a.plan[:currenttrajectory] = makeNewState(a.plan[:currenttrajectory],
|
||||||
|
a.plan[:activeplan][:thoughtHistory], userinput[:text], userinput[:select],
|
||||||
|
userinput[:reward], userinput[:isterminal])
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
bestNextState, besttrajectory = runMCTS(a, currentstate, decisionMaker, evaluator, reflector,
|
while true
|
||||||
totalsample=3, maxDepth=2, maxiterations=1, explorationweight=1.0)
|
bestNextState, besttrajectory = runMCTS(a, a.plan[:currenttrajectory], decisionMaker,
|
||||||
|
evaluator, reflector, totalsample=2, maxDepth=2, maxiterations=1, explorationweight=1.0)
|
||||||
|
a.plan[:activeplan] = bestNextState
|
||||||
|
|
||||||
|
latestActionKey, latestActionIndice =
|
||||||
|
GeneralUtils.findHighestIndexKey(bestNextState[:thoughtHistory], "action")
|
||||||
|
actionname = bestNextState[:thoughtHistory][latestActionKey][:name]
|
||||||
|
actioninput = bestNextState[:thoughtHistory][latestActionKey][:input]
|
||||||
|
|
||||||
# transition
|
# transition
|
||||||
newstate = transition(a, bestNextState)
|
if actionname == "chatbox"
|
||||||
a.plan[:currenttrajectory] = newstate
|
# add usermsg to a.chathistory
|
||||||
|
addNewMessage(a, "assistant", actioninput)
|
||||||
|
return actioninput
|
||||||
|
elseif actionname == "recommendbox"
|
||||||
|
# add usermsg to a.chathistory
|
||||||
|
addNewMessage(a, "assistant", actioninput)
|
||||||
|
return actioninput
|
||||||
|
else
|
||||||
|
_, a.plan[:currenttrajectory] = transition(a, a.plan[:activeplan])
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
@@ -797,6 +812,62 @@ end
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# function conversation(a::T, userinput::Dict) where {T<:agent}
|
||||||
|
|
||||||
|
# # get new user msg from a.receiveUserMsgChannel
|
||||||
|
|
||||||
|
|
||||||
|
# # "newtopic" command to delete chat history
|
||||||
|
# if userinput[:text] == "newtopic"
|
||||||
|
# clearhistory(a)
|
||||||
|
|
||||||
|
# return "Okay. What shall we talk about?"
|
||||||
|
|
||||||
|
# else
|
||||||
|
# # add usermsg to a.chathistory
|
||||||
|
# addNewMessage(a, "user", userinput[:text])
|
||||||
|
|
||||||
|
# currentstate =
|
||||||
|
# if isempty(a.plan[:currenttrajectory])
|
||||||
|
# # set up initial state
|
||||||
|
# Dict{Symbol, Any}(
|
||||||
|
# # deepcopy the info to prevent modifying the info unintentionally during MCTS planning
|
||||||
|
# :customerinfo=> deepcopy(a.keywordinfo[:customerinfo]),
|
||||||
|
# :storeinfo=> deepcopy(a.keywordinfo[:storeinfo]),
|
||||||
|
# :userselect=> nothing,
|
||||||
|
# :reward=> 0,
|
||||||
|
# :isterminal=> false,
|
||||||
|
# :evaluation=> nothing,
|
||||||
|
# :lesson=> nothing,
|
||||||
|
# :thoughtDict=> nothing,
|
||||||
|
# :totalTrajectoryReward=> nothing,
|
||||||
|
# :thoughtHistory=> OrderedDict{Symbol, Any}( # contain question, thought_1, action_1, observation_1, thought_2, ...
|
||||||
|
# # :recap=>,
|
||||||
|
# :question=> userinput[:text],
|
||||||
|
# )
|
||||||
|
# )
|
||||||
|
# else
|
||||||
|
# a.plan[:currenttrajectory]
|
||||||
|
# end
|
||||||
|
|
||||||
|
# bestNextState, besttrajectory = runMCTS(a, currentstate, decisionMaker, evaluator, reflector,
|
||||||
|
# totalsample=3, maxDepth=2, maxiterations=1, explorationweight=1.0)
|
||||||
|
|
||||||
|
# # transition
|
||||||
|
# newstate = transition(a, bestNextState)
|
||||||
|
# a.plan[:currenttrajectory] = newstate
|
||||||
|
|
||||||
|
# end
|
||||||
|
# end
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ julia>
|
|||||||
|
|
||||||
# TODO
|
# TODO
|
||||||
- [] update docstring
|
- [] update docstring
|
||||||
- [PENDING] implement the function
|
- [WORKING] implement the function
|
||||||
|
|
||||||
# Signature
|
# Signature
|
||||||
"""
|
"""
|
||||||
@@ -293,7 +293,7 @@ function jsoncorrection(a::T1, input::T2,
|
|||||||
correctjson = incorrectjson
|
correctjson = incorrectjson
|
||||||
return correctjson
|
return correctjson
|
||||||
catch e
|
catch e
|
||||||
@warn "Attempting correct JSON string. Attempt $attempt"
|
@warn "Attempting to correct JSON string. Attempt $attempt"
|
||||||
e = """$e"""
|
e = """$e"""
|
||||||
if occursin("EOF", e)
|
if occursin("EOF", e)
|
||||||
e = split(e, "EOF")[1] * "EOF"
|
e = split(e, "EOF")[1] * "EOF"
|
||||||
|
|||||||
100
src/mcts.jl
100
src/mcts.jl
@@ -6,7 +6,7 @@
|
|||||||
module mcts
|
module mcts
|
||||||
|
|
||||||
export MCTSNode, runMCTS, isleaf, selectBestNextState, selectBestTrajectory, transition,
|
export MCTSNode, runMCTS, isleaf, selectBestNextState, selectBestTrajectory, transition,
|
||||||
userChatbox
|
userChatbox, makeNewState
|
||||||
|
|
||||||
using Dates, UUIDs, DataStructures, JSON3, Random, PrettyPrinting
|
using Dates, UUIDs, DataStructures, JSON3, Random, PrettyPrinting
|
||||||
using GeneralUtils
|
using GeneralUtils
|
||||||
@@ -144,10 +144,9 @@ function expand(a::T1, node::MCTSNode, decisionMaker::Function,
|
|||||||
println("---> expand() sample $nthSample")
|
println("---> expand() sample $nthSample")
|
||||||
pprintln(node.state[:thoughtHistory])
|
pprintln(node.state[:thoughtHistory])
|
||||||
pprintln(thoughtDict)
|
pprintln(thoughtDict)
|
||||||
node.state[:thoughtDict] = thoughtDict
|
newNodeKey, newstate = MCTStransition(a, node.state, thoughtDict)
|
||||||
newNodeKey, newstate = MCTStransition(a, node.state)
|
|
||||||
|
|
||||||
# add evaluator
|
|
||||||
stateevaluation, progressvalue = evaluator(a, newstate)
|
stateevaluation, progressvalue = evaluator(a, newstate)
|
||||||
|
|
||||||
if newstate[:reward] < 0
|
if newstate[:reward] < 0
|
||||||
@@ -294,10 +293,9 @@ julia> thoughtDict = Dict(
|
|||||||
|
|
||||||
# Signature
|
# Signature
|
||||||
"""
|
"""
|
||||||
function MCTStransition(a::T1, state::T2
|
function MCTStransition(a::T1, state::T2, thoughtDict::T2
|
||||||
)::Tuple{String, Dict{Symbol, <:Any}} where {T1<:agent, T2<:AbstractDict}
|
)::Tuple{String, Dict{Symbol, <:Any}} where {T1<:agent, T2<:AbstractDict}
|
||||||
|
|
||||||
thoughtDict = state[:thoughtDict]
|
|
||||||
actionname = thoughtDict[:action][:name]
|
actionname = thoughtDict[:action][:name]
|
||||||
actioninput = thoughtDict[:action][:input]
|
actioninput = thoughtDict[:action][:input]
|
||||||
|
|
||||||
@@ -313,25 +311,7 @@ function MCTStransition(a::T1, state::T2
|
|||||||
error("undefined LLM function. Requesting $actionname")
|
error("undefined LLM function. Requesting $actionname")
|
||||||
end
|
end
|
||||||
|
|
||||||
latestThoughtKey, latestThoughtIndice = GeneralUtils.findHighestIndexKey(state[:thoughtHistory],
|
return makeNewState(state, thoughtDict, response, select, reward, isterminal)
|
||||||
"thought")
|
|
||||||
nextIndice = latestThoughtKey == :NA ? 1 : latestThoughtIndice + 1
|
|
||||||
latestThoughtKey = Symbol("thought_$nextIndice")
|
|
||||||
latestActionKey = Symbol("action_$nextIndice")
|
|
||||||
|
|
||||||
# add Thought, action, observation to thoughtHistory
|
|
||||||
newstate = deepcopy(state)
|
|
||||||
newstate[:thoughtHistory][latestThoughtKey] = thoughtDict[:thought]
|
|
||||||
newstate[:thoughtHistory][latestActionKey] = thoughtDict[:action]
|
|
||||||
newObservationKey = Symbol("observation_$(nextIndice)")
|
|
||||||
newstate[:thoughtHistory][newObservationKey] = response
|
|
||||||
newstate[:reward] = reward
|
|
||||||
newstate[:select] = select
|
|
||||||
newstate[:isterminal] = isterminal
|
|
||||||
|
|
||||||
newNodeKey = GeneralUtils.uuid4snakecase()
|
|
||||||
|
|
||||||
return (newNodeKey, newstate)
|
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
@@ -374,7 +354,7 @@ julia> thoughtDict = Dict(
|
|||||||
|
|
||||||
# Signature
|
# Signature
|
||||||
"""
|
"""
|
||||||
function transition(a::T1, state::T2
|
function transition(a::T1, state::T2, thoughtDict::T2
|
||||||
)::Dict{Symbol, <:Any} where {T1<:agent, T2<:AbstractDict}
|
)::Dict{Symbol, <:Any} where {T1<:agent, T2<:AbstractDict}
|
||||||
|
|
||||||
thoughtDict = state[:thoughtDict]
|
thoughtDict = state[:thoughtDict]
|
||||||
@@ -383,36 +363,74 @@ function transition(a::T1, state::T2
|
|||||||
|
|
||||||
# map action and input() to llm function
|
# map action and input() to llm function
|
||||||
response, select, reward, isterminal =
|
response, select, reward, isterminal =
|
||||||
if actionname == "chatbox"
|
if actionname == "winestock"
|
||||||
userChatbox(a, actioninput) # virtual customer
|
|
||||||
elseif actionname == "winestock"
|
|
||||||
winestock(a, actioninput)
|
winestock(a, actioninput)
|
||||||
elseif actionname == "recommendbox"
|
|
||||||
userRecommendbox(a, actioninput)
|
|
||||||
else
|
else
|
||||||
error("undefined LLM function. Requesting $actionname")
|
error("undefined LLM function. Requesting $actionname")
|
||||||
end
|
end
|
||||||
|
|
||||||
latestThoughtKey, latestThoughtIndice =
|
return makeNewState(state, thoughtDict, response, select, reward, isterminal)
|
||||||
GeneralUtils.findHighestIndexKey(state[:thoughtHistory], "thought")
|
end
|
||||||
nextIndice = latestThoughtKey == :NA ? 1 : latestThoughtIndice + 1
|
|
||||||
latestThoughtKey = Symbol("thought_$nextIndice")
|
|
||||||
latestActionKey = Symbol("action_$nextIndice")
|
"""
|
||||||
|
|
||||||
|
# Arguments
|
||||||
|
|
||||||
|
# Return
|
||||||
|
|
||||||
|
# Example
|
||||||
|
```jldoctest
|
||||||
|
julia>
|
||||||
|
```
|
||||||
|
|
||||||
|
# TODO
|
||||||
|
- [] update docstring
|
||||||
|
- [TESTING] implement the function
|
||||||
|
|
||||||
|
# Signature
|
||||||
|
"""
|
||||||
|
function makeNewState(currentstate::T1, thoughtDict::T4, response::T2, select::Union{T3, Nothing},
|
||||||
|
reward::T3, isterminal::Bool
|
||||||
|
)::Tuple{String, Dict{Symbol, <:Any}} where {T1<:AbstractDict, T2<:AbstractString, T3<:Number, T4<:AbstractDict}
|
||||||
|
|
||||||
|
currentstate_latestThoughtKey, currentstate_latestThoughtIndice =
|
||||||
|
GeneralUtils.findHighestIndexKey(currentstate[:thoughtHistory], "thought")
|
||||||
|
currentstate_nextIndice = currentstate_latestThoughtKey == :NA ? 1 : currentstate_latestThoughtIndice + 1
|
||||||
|
currentstate_latestThoughtKey = Symbol("thought_$currentstate_nextIndice")
|
||||||
|
latestActionKey = Symbol("action_$currentstate_nextIndice")
|
||||||
|
|
||||||
|
_, thoughtDict_latestThoughtIndice =
|
||||||
|
GeneralUtils.findHighestIndexKey(thoughtDict, "thought")
|
||||||
|
|
||||||
|
thoughtDict_latestThoughtKey, thoughtDict_latestActionKey =
|
||||||
|
if thoughtDict_latestThoughtIndice == -1
|
||||||
|
(:thought, :action)
|
||||||
|
else
|
||||||
|
(
|
||||||
|
Symbol("thought_$thoughtDict_latestThoughtIndice"),
|
||||||
|
Symbol("action_$thoughtDict_latestThoughtIndice"),
|
||||||
|
)
|
||||||
|
end
|
||||||
|
|
||||||
# add Thought, action, observation to thoughtHistory
|
# add Thought, action, observation to thoughtHistory
|
||||||
newstate = deepcopy(state)
|
newstate = deepcopy(currentstate)
|
||||||
newstate[:thoughtHistory][latestThoughtKey] = thoughtDict[:thought]
|
newstate[:thoughtHistory][currentstate_latestThoughtKey] =
|
||||||
newstate[:thoughtHistory][latestActionKey] = thoughtDict[:action]
|
thoughtDict[thoughtDict_latestThoughtKey]
|
||||||
newObservationKey = Symbol("observation_$(nextIndice)")
|
newstate[:thoughtHistory][latestActionKey] = thoughtDict[thoughtDict_latestActionKey]
|
||||||
|
newObservationKey = Symbol("observation_$(currentstate_nextIndice)")
|
||||||
newstate[:thoughtHistory][newObservationKey] = response
|
newstate[:thoughtHistory][newObservationKey] = response
|
||||||
newstate[:reward] = reward
|
newstate[:reward] = reward
|
||||||
newstate[:select] = select
|
newstate[:select] = select
|
||||||
newstate[:isterminal] = isterminal
|
newstate[:isterminal] = isterminal
|
||||||
|
|
||||||
return newstate
|
newNodeKey = GeneralUtils.uuid4snakecase()
|
||||||
|
|
||||||
|
return (newNodeKey, newstate)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
""" Determine whether a node is a leaf node of a search tree.
|
""" Determine whether a node is a leaf node of a search tree.
|
||||||
|
|
||||||
# Arguments
|
# Arguments
|
||||||
|
|||||||
@@ -111,8 +111,8 @@ julia> agent = YiemAgent.bsommelier(
|
|||||||
# each plan is in [historyPoint_1, historyPoint_2, ...] format
|
# each plan is in [historyPoint_1, historyPoint_2, ...] format
|
||||||
:existingplan => Vector(),
|
:existingplan => Vector(),
|
||||||
|
|
||||||
:activeplan => Vector{Dict{Symbol, Any}}(), # current using plan
|
:activeplan => Dict{Symbol, Any}(), # current using plan
|
||||||
:currenttrajectory=> Vector{Dict{Symbol, Any}}(), # store
|
:currenttrajectory=> Dict{Symbol, Any}(), # store question, thought, action, observation, ...
|
||||||
)
|
)
|
||||||
|
|
||||||
# put incoming message here. waiting for further processing
|
# put incoming message here. waiting for further processing
|
||||||
|
|||||||
@@ -59,10 +59,21 @@ tools=Dict( # update input format
|
|||||||
tools=tools,
|
tools=tools,
|
||||||
)
|
)
|
||||||
|
|
||||||
response = YiemAgent.conversation(a, Dict(:text=> "Hello, I would like a get a bottle of wine",) )
|
# response = YiemAgent.conversation(a, Dict(:text=> "newtopic",) )
|
||||||
|
|
||||||
|
|
||||||
|
response = YiemAgent.conversation(a, Dict(:text=> "Hello, I would like a get a bottle of wine",
|
||||||
|
:select=> nothing,
|
||||||
|
:reward=> 0,
|
||||||
|
:isterminal=> false,
|
||||||
|
) )
|
||||||
|
println("---> YiemAgent: ", response)
|
||||||
|
|
||||||
|
response = YiemAgent.conversation(a, Dict(:text=> "I'm having a graduation party this evening",
|
||||||
|
:select=> nothing,
|
||||||
|
:reward=> 0,
|
||||||
|
:isterminal=> false,
|
||||||
|
) )
|
||||||
|
println("---> YiemAgent: ", response)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
using Revise
|
using Revise
|
||||||
using YiemAgent, GeneralUtils, JSON3, DataStructures
|
using YiemAgent, GeneralUtils, JSON3, DataStructures
|
||||||
|
|
||||||
|
# ---------------------------------------------- 100 --------------------------------------------- #
|
||||||
|
|
||||||
msgMeta = Dict(:requestResponse => nothing,
|
msgMeta = Dict(:requestResponse => nothing,
|
||||||
:msgPurpose => nothing,
|
:msgPurpose => nothing,
|
||||||
:receiverId => nothing,
|
:receiverId => nothing,
|
||||||
|
|||||||
Reference in New Issue
Block a user