This commit is contained in:
narawat lamaiin
2024-06-01 00:37:25 +07:00
parent 3196842296
commit 97c566a9d5
5 changed files with 36 additions and 19 deletions

View File

@@ -2,7 +2,7 @@
julia_version = "1.10.3"
manifest_format = "2.0"
project_hash = "c6233f8bf690740dd830d1f0927bd3afed93b8d2"
project_hash = "d5182042dab089bafa4f01ef385efd46c01a0396"
[[deps.AliasTables]]
deps = ["PtrArrays", "Random"]
@@ -96,9 +96,9 @@ uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
[[deps.Distributions]]
deps = ["AliasTables", "FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SpecialFunctions", "Statistics", "StatsAPI", "StatsBase", "StatsFuns"]
git-tree-sha1 = "22c595ca4146c07b16bcf9c8bea86f731f7109d2"
git-tree-sha1 = "9c405847cc7ecda2dc921ccf18b47ca150d7317e"
uuid = "31c24e10-a181-5473-b8eb-7969acd0382f"
version = "0.25.108"
version = "0.25.109"
[deps.Distributions.extensions]
DistributionsChainRulesCoreExt = "ChainRulesCore"
@@ -204,6 +204,12 @@ git-tree-sha1 = "e9648d90370e2d0317f9518c9c6e0841db54a90b"
uuid = "aa1ae85d-cabe-5617-a682-6adf51b2e16a"
version = "0.9.31"
[[deps.LLMMCTS]]
deps = ["JSON3"]
path = "/appfolder/app/privatejuliapkg/LLMMCTS"
uuid = "d76c5a4d-449e-4835-8cc4-dd86ec44f241"
version = "0.1.0"
[[deps.LazyArtifacts]]
deps = ["Artifacts", "Pkg"]
uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3"
@@ -272,9 +278,9 @@ version = "2.4.6"
[[deps.MQTTClient]]
deps = ["Distributed", "Random", "Sockets"]
git-tree-sha1 = "c58ba9d6ae121f58494fa1e5164213f5b4e3e2c7"
git-tree-sha1 = "f2597b290d4bf17b577346153cd2ddf9accb5c26"
uuid = "985f35cc-2c3d-4943-b8c1-f0931d5f0959"
version = "0.3.0"
version = "0.3.1"
weakdeps = ["PrecompileTools"]
[deps.MQTTClient.extensions]

View File

@@ -10,6 +10,7 @@ Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
GeneralUtils = "c6c72f09-b708-4ac8-ac7c-2084d70108fe"
HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3"
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
LLMMCTS = "d76c5a4d-449e-4835-8cc4-dd86ec44f241"
MQTTClient = "985f35cc-2c3d-4943-b8c1-f0931d5f0959"
PrettyPrinting = "54e16d92-306c-5ea0-a30b-337be88ac337"
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"

View File

@@ -102,6 +102,7 @@ mutable struct sommelier <: agent
# 1-historyPoint is in Dict{Symbol, Any} and compose of:
# state, statevalue, thought, action, observation
plan::Dict{Symbol, Any}
mctsWorkDict::Dict{Symbol, Any}
end
function sommelier(
@@ -149,6 +150,7 @@ function sommelier(
:activeplan => Dict{Symbol, Any}(), # current using plan
:currenttrajectory=> Dict{Symbol, Any}(), # store question, thought, action, observation, ...
)
mctsWorkDict::Dict{Symbol, Any} = Dict{Symbol, Any}()
)
#[NEXTVERSION] publish to a.config[:configtopic] to get a config.
@@ -167,6 +169,7 @@ function sommelier(
chathistory,
keywordinfo,
plan,
mctsWorkDict,
)
return newAgent

View File

@@ -79,7 +79,7 @@ julia> output_thoughtDict = Dict(
# Signature
"""
function decisionMaker(a::T1, state::T2)::Dict{Symbol, Any} where {T1<:agent, T2<:AbstractDict}
function decisionMaker(config::T1, state::T2)::Dict{Symbol, Any} where {T1<:AbstractDict, T2<:AbstractDict}
customerinfo =
"""
I will give you the following information about customer:
@@ -169,7 +169,7 @@ function decisionMaker(a::T1, state::T2)::Dict{Symbol, Any} where {T1<:agent, T2
"""
# apply LLM specific instruct format
externalService = a.config[:externalservice][:text2textinstruct]
externalService = config[:externalservice][:text2textinstruct]
llminfo = externalService[:llminfo]
prompt =
if llminfo[:name] == "llama3instruct"
@@ -181,10 +181,10 @@ function decisionMaker(a::T1, state::T2)::Dict{Symbol, Any} where {T1<:agent, T2
msgMeta = GeneralUtils.generate_msgMeta(
externalService[:mqtttopic],
senderName= "decisionMaker",
senderId= a.id,
senderId= string(uuid4()),
receiverName= "text2textinstruct",
mqttBroker= a.config[:mqttServerInfo][:broker],
mqttBrokerPort= a.config[:mqttServerInfo][:port],
mqttBroker= config[:mqttServerInfo][:broker],
mqttBrokerPort= config[:mqttServerInfo][:port],
)
outgoingMsg = Dict(
@@ -212,7 +212,7 @@ function decisionMaker(a::T1, state::T2)::Dict{Symbol, Any} where {T1<:agent, T2
"observation": "..."
}
"""
responseJsonStr = jsoncorrection(a, _responseJsonStr, expectedJsonExample)
responseJsonStr = jsoncorrection(config, _responseJsonStr, expectedJsonExample)
thoughtDict = copy(JSON3.read(responseJsonStr))
# check if dict has all required value
@@ -786,6 +786,11 @@ function reflector(a::T1, state::T2)::String where {T1<:agent, T2<:AbstractDict}
end
function transition()
error("--> transition")
end
# """ Determine whether the state is a terminal state
# # Arguments
@@ -882,6 +887,7 @@ julia> response = ChatAgent.conversation(newAgent, "Hi! how are you?")
# Signature
"""
function conversation(a::T, userinput::Dict) where {T<:agent}
config = deepcopy(a.config)
if userinput[:text] == "newtopic"
clearhistory(a)
return "Okay. What shall we talk about?"
@@ -920,8 +926,9 @@ function conversation(a::T, userinput::Dict) where {T<:agent}
end
while true
bestNextState, besttrajectory = runMCTS(a, a.plan[:currenttrajectory], decisionMaker,
evaluator, reflector, totalsample=2, maxDepth=3, maxiterations=3, explorationweight=1.0)
bestNextState, besttrajectory = LLMMCTS.runMCTS(config, a.plan[:currenttrajectory],
decisionMaker, evaluator, reflector, transition;
totalsample=2, maxDepth=3, maxiterations=3, explorationweight=1.0)
a.plan[:activeplan] = bestNextState
latestActionKey, latestActionIndice =

View File

@@ -728,9 +728,9 @@ julia>
# Signature
"""
function jsoncorrection(a::T1, input::T2, correctJsonExample::T3;
function jsoncorrection(config::T1, input::T2, correctJsonExample::T3;
maxattempt::Integer=3
) where {T1<:agent, T2<:AbstractString, T3<:AbstractString}
) where {T1<:AbstractDict, T2<:AbstractString, T3<:AbstractString}
incorrectjson = deepcopy(input)
correctjson = nothing
@@ -762,7 +762,7 @@ function jsoncorrection(a::T1, input::T2, correctJsonExample::T3;
"""
# apply LLM specific instruct format
externalService = a.config[:externalservice][:text2textinstruct]
externalService = config[:externalservice][:text2textinstruct]
llminfo = externalService[:llminfo]
prompt =
if llminfo[:name] == "llama3instruct"
@@ -775,10 +775,10 @@ function jsoncorrection(a::T1, input::T2, correctJsonExample::T3;
msgMeta = GeneralUtils.generate_msgMeta(
externalService[:mqtttopic],
senderName= "jsoncorrection",
senderId= a.id,
senderId= string(uuid4()),
receiverName= "text2textinstruct",
mqttBroker= a.config[:mqttServerInfo][:broker],
mqttBrokerPort= a.config[:mqttServerInfo][:port],
mqttBroker= config[:mqttServerInfo][:broker],
mqttBrokerPort= config[:mqttServerInfo][:port],
)
outgoingMsg = Dict(