diff --git a/Manifest.toml b/Manifest.toml index 40d6d1f..4954cb4 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -2,7 +2,7 @@ julia_version = "1.10.3" manifest_format = "2.0" -project_hash = "c6233f8bf690740dd830d1f0927bd3afed93b8d2" +project_hash = "d5182042dab089bafa4f01ef385efd46c01a0396" [[deps.AliasTables]] deps = ["PtrArrays", "Random"] @@ -96,9 +96,9 @@ uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" [[deps.Distributions]] deps = ["AliasTables", "FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SpecialFunctions", "Statistics", "StatsAPI", "StatsBase", "StatsFuns"] -git-tree-sha1 = "22c595ca4146c07b16bcf9c8bea86f731f7109d2" +git-tree-sha1 = "9c405847cc7ecda2dc921ccf18b47ca150d7317e" uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" -version = "0.25.108" +version = "0.25.109" [deps.Distributions.extensions] DistributionsChainRulesCoreExt = "ChainRulesCore" @@ -204,6 +204,12 @@ git-tree-sha1 = "e9648d90370e2d0317f9518c9c6e0841db54a90b" uuid = "aa1ae85d-cabe-5617-a682-6adf51b2e16a" version = "0.9.31" +[[deps.LLMMCTS]] +deps = ["JSON3"] +path = "/appfolder/app/privatejuliapkg/LLMMCTS" +uuid = "d76c5a4d-449e-4835-8cc4-dd86ec44f241" +version = "0.1.0" + [[deps.LazyArtifacts]] deps = ["Artifacts", "Pkg"] uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3" @@ -272,9 +278,9 @@ version = "2.4.6" [[deps.MQTTClient]] deps = ["Distributed", "Random", "Sockets"] -git-tree-sha1 = "c58ba9d6ae121f58494fa1e5164213f5b4e3e2c7" +git-tree-sha1 = "f2597b290d4bf17b577346153cd2ddf9accb5c26" uuid = "985f35cc-2c3d-4943-b8c1-f0931d5f0959" -version = "0.3.0" +version = "0.3.1" weakdeps = ["PrecompileTools"] [deps.MQTTClient.extensions] diff --git a/Project.toml b/Project.toml index b3951f5..55e0cad 100644 --- a/Project.toml +++ b/Project.toml @@ -10,6 +10,7 @@ Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" GeneralUtils = "c6c72f09-b708-4ac8-ac7c-2084d70108fe" HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3" JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" +LLMMCTS = "d76c5a4d-449e-4835-8cc4-dd86ec44f241" MQTTClient = "985f35cc-2c3d-4943-b8c1-f0931d5f0959" PrettyPrinting = "54e16d92-306c-5ea0-a30b-337be88ac337" PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d" diff --git a/previousversion/0.1/src/type.jl b/previousversion/0.1/src/type.jl index f55ad1a..f6b98e7 100644 --- a/previousversion/0.1/src/type.jl +++ b/previousversion/0.1/src/type.jl @@ -102,6 +102,7 @@ mutable struct sommelier <: agent # 1-historyPoint is in Dict{Symbol, Any} and compose of: # state, statevalue, thought, action, observation plan::Dict{Symbol, Any} + mctsWorkDict::Dict{Symbol, Any} end function sommelier( @@ -149,6 +150,7 @@ function sommelier( :activeplan => Dict{Symbol, Any}(), # current using plan :currenttrajectory=> Dict{Symbol, Any}(), # store question, thought, action, observation, ... ) + mctsWorkDict::Dict{Symbol, Any} = Dict{Symbol, Any}() ) #[NEXTVERSION] publish to a.config[:configtopic] to get a config. @@ -167,6 +169,7 @@ function sommelier( chathistory, keywordinfo, plan, + mctsWorkDict, ) return newAgent diff --git a/src/interface.jl b/src/interface.jl index 242e314..0a8d62a 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -79,7 +79,7 @@ julia> output_thoughtDict = Dict( # Signature """ -function decisionMaker(a::T1, state::T2)::Dict{Symbol, Any} where {T1<:agent, T2<:AbstractDict} +function decisionMaker(config::T1, state::T2)::Dict{Symbol, Any} where {T1<:AbstractDict, T2<:AbstractDict} customerinfo = """ I will give you the following information about customer: @@ -169,7 +169,7 @@ function decisionMaker(a::T1, state::T2)::Dict{Symbol, Any} where {T1<:agent, T2 """ # apply LLM specific instruct format - externalService = a.config[:externalservice][:text2textinstruct] + externalService = config[:externalservice][:text2textinstruct] llminfo = externalService[:llminfo] prompt = if llminfo[:name] == "llama3instruct" @@ -181,10 +181,10 @@ function decisionMaker(a::T1, state::T2)::Dict{Symbol, Any} where {T1<:agent, T2 msgMeta = GeneralUtils.generate_msgMeta( externalService[:mqtttopic], senderName= "decisionMaker", - senderId= a.id, + senderId= string(uuid4()), receiverName= "text2textinstruct", - mqttBroker= a.config[:mqttServerInfo][:broker], - mqttBrokerPort= a.config[:mqttServerInfo][:port], + mqttBroker= config[:mqttServerInfo][:broker], + mqttBrokerPort= config[:mqttServerInfo][:port], ) outgoingMsg = Dict( @@ -212,7 +212,7 @@ function decisionMaker(a::T1, state::T2)::Dict{Symbol, Any} where {T1<:agent, T2 "observation": "..." } """ - responseJsonStr = jsoncorrection(a, _responseJsonStr, expectedJsonExample) + responseJsonStr = jsoncorrection(config, _responseJsonStr, expectedJsonExample) thoughtDict = copy(JSON3.read(responseJsonStr)) # check if dict has all required value @@ -786,6 +786,11 @@ function reflector(a::T1, state::T2)::String where {T1<:agent, T2<:AbstractDict} end +function transition() + error("--> transition") +end + + # """ Determine whether the state is a terminal state # # Arguments @@ -882,6 +887,7 @@ julia> response = ChatAgent.conversation(newAgent, "Hi! how are you?") # Signature """ function conversation(a::T, userinput::Dict) where {T<:agent} + config = deepcopy(a.config) if userinput[:text] == "newtopic" clearhistory(a) return "Okay. What shall we talk about?" @@ -920,8 +926,9 @@ function conversation(a::T, userinput::Dict) where {T<:agent} end while true - bestNextState, besttrajectory = runMCTS(a, a.plan[:currenttrajectory], decisionMaker, - evaluator, reflector, totalsample=2, maxDepth=3, maxiterations=3, explorationweight=1.0) + bestNextState, besttrajectory = LLMMCTS.runMCTS(config, a.plan[:currenttrajectory], + decisionMaker, evaluator, reflector, transition; + totalsample=2, maxDepth=3, maxiterations=3, explorationweight=1.0) a.plan[:activeplan] = bestNextState latestActionKey, latestActionIndice = diff --git a/src/llmfunction.jl b/src/llmfunction.jl index 2e25ad9..c169929 100644 --- a/src/llmfunction.jl +++ b/src/llmfunction.jl @@ -728,9 +728,9 @@ julia> # Signature """ -function jsoncorrection(a::T1, input::T2, correctJsonExample::T3; +function jsoncorrection(config::T1, input::T2, correctJsonExample::T3; maxattempt::Integer=3 - ) where {T1<:agent, T2<:AbstractString, T3<:AbstractString} + ) where {T1<:AbstractDict, T2<:AbstractString, T3<:AbstractString} incorrectjson = deepcopy(input) correctjson = nothing @@ -762,7 +762,7 @@ function jsoncorrection(a::T1, input::T2, correctJsonExample::T3; """ # apply LLM specific instruct format - externalService = a.config[:externalservice][:text2textinstruct] + externalService = config[:externalservice][:text2textinstruct] llminfo = externalService[:llminfo] prompt = if llminfo[:name] == "llama3instruct" @@ -775,10 +775,10 @@ function jsoncorrection(a::T1, input::T2, correctJsonExample::T3; msgMeta = GeneralUtils.generate_msgMeta( externalService[:mqtttopic], senderName= "jsoncorrection", - senderId= a.id, + senderId= string(uuid4()), receiverName= "text2textinstruct", - mqttBroker= a.config[:mqttServerInfo][:broker], - mqttBrokerPort= a.config[:mqttServerInfo][:port], + mqttBroker= config[:mqttServerInfo][:broker], + mqttBrokerPort= config[:mqttServerInfo][:port], ) outgoingMsg = Dict(