This commit is contained in:
narawat lamaiin
2024-07-08 17:16:02 +07:00
parent 0167a3bde7
commit 695cd6a2b9
2 changed files with 120 additions and 170 deletions

View File

@@ -8,17 +8,28 @@ using LLMMCTS, GeneralUtils
# ---------------------------------------------- 100 --------------------------------------------- #
""" Attempt to correct JSON string.
""" Attempt to correct JSON string. This function use LLM to correct JSON string
# Arguments
- `config::Dict`
A config containing services and mqttbroker info
- `jsonstring::AbstractString`
A JSON string that needs to be corrected.
- `example::String`
An example of how JSON should be structured.
- `commfunction::Function`
A function that handles communication to LLM.
# Keyword Arguments
- `maxattempt::Integer`
A number to limit how many attempt will be made.
# Return
- `(response, errormsg, success)::NamedTuple`
if success=`true`, response is a valid JSON string. otherwise, check errormsg.
# Example
```jldoctest
julia> config = Dict(
julia> function commfunction(prompt::String)
config = Dict(
:mqttServerInfo => Dict(
:description => "mqtt server info",
:port => 1883,
@@ -32,6 +43,36 @@ julia> config = Dict(
),
)
)
# apply LLM specific instruct format
externalService = config[:externalservice][:text2textinstruct]
msgMeta = GeneralUtils.generate_msgMeta(
externalService[:mqtttopic],
senderName= "reflector",
senderId= string(uuid4()),
receiverName= "text2textinstruct",
mqttBroker= config[:mqttServerInfo][:broker],
mqttBrokerPort= config[:mqttServerInfo][:port],
)
outgoingMsg = Dict(
:msgMeta=> msgMeta,
:payload=> Dict(
:text=> prompt,
:kwargs=> Dict(
:max_tokens=> 512,
:stop=> ["<|eot_id|>"],
:temperature=>0.2,
)
)
)
response = GeneralUtils.sendReceiveMqttMsg(outgoingMsg)
return response
end
julia> incorrectjson = "{\n \"explain\": \"The user is asking for information about wines that can be paired with lamb. This seems like a straightforward question, but it's not clear which database tables will contain the relevant information. I'll start by listing all available tables to see if any of them have columns related to wine and food pairing.\",\n \"plan\":\n \"1) List all tables in the database using 'listalltables',\n 2) Check each table for columns related to wine and food pairing\",\n \"action\": {\"name\": \"listalltables\", \"input\": \"\"},\n \"expectation\": \"A list of all available tables in the database, which will help me identify potential tables that might contain relevant information.\",\n \"observation\": \"\"\n}"
julia> expectJSON =
"
Here is an expected JSON format:
@@ -43,15 +84,15 @@ julia> expectJSON =
observation: ...
}
"
julia> result = FormatCorrector.jsoncorrection(incorrectjson, expectJSON, commfunction)
julia> println(result)
" \n{\n \"explain\": \"The user is asking for information about wines that can be paired with lamb. This seems like a straightforward question, but it's not clear which database tables will contain the relevant information. I'll start by listing all available tables to see if any of them have columns related to wine and food pairing.\",\n \"plan\": [\n \"1) List all tables in the database using 'listalltables',\",\n \"2) Check each table for columns related to wine and food pairing\"\n ],\n \"action\": {\"name\": \"listalltables\", \"input\": \"\"},\n \"expectation\": \"A list of all available tables in the database, which will help me identify potential tables that might contain relevant information.\",\n \"observation\": \"\"\n}"
```
# TODO
- [WORKING] update docstring
- [x] implement the function
# Signature
"""
function jsoncorrection(config::Dict, jsonstring::String, example; maxattempt=3)
function jsoncorrection(jsonstring::T, example::String, commfunction::Function;
maxattempt=3)::NamedTuple where {T<:AbstractString}
initialstate = Dict{Symbol, Any}(
:reward=> 0,
@@ -66,8 +107,10 @@ function jsoncorrection(config::Dict, jsonstring::String, example; maxattempt=3)
)
transitionargs = (
config=config,
decisionMaker=decisionMaker,
evaluator=evaluator,
context=example,
commfunction=commfunction,
)
_, result = LLMMCTS.runMCTS(initialstate, transition, transitionargs;
totalsample=1, maxdepth=(maxattempt-1), maxiterations=1, explorationweight=1.0)
@@ -80,31 +123,13 @@ function jsoncorrection(config::Dict, jsonstring::String, example; maxattempt=3)
end
"""
# Arguments
`v::Integer`
dummy variable
# Return
# Example
```jldoctest
julia>
```
# TODO
- [] update docstring
- [PENDING] implement the function. This function will hace the same structure as evaluation()
and it act as iterative code generation function
# Signature
"""
function transition(state::T1, args::NamedTuple
) where {T1<:AbstractDict}
config::AbstractDict = args[:config]
decisionMakerF::Function = args[:decisionMaker]
evaluatorF::Function = args[:evaluator]
context = args[:context]
commf = args[:commfunction]
# only for the 1st transition
if state[:errormsg] === nothing
@@ -129,7 +154,7 @@ function transition(state::T1, args::NamedTuple
println("Attempting to correct JSON. ", @__FILE__, " ", @__LINE__)
explain, jsonstr = decisionMaker(state, config, context)
explain, jsonstr = decisionMakerF(state, context, commf)
response, errormsg, reward, isterminal = executeJSON(jsonstr)
# make new state
@@ -149,26 +174,7 @@ function transition(state::T1, args::NamedTuple
end
"""
# Arguments
`v::Integer`
dummy variable
# Return
# Example
```jldoctest
julia>
```
# TODO
- [] update docstring
- [WORKING] implement the function
# Signature
"""
function decisionMaker(state, config, context)
function decisionMaker(state, context, commfunction::Function)
systemmsg =
"""
You are a helpful assistant that corrects the user's incorrect JSON string while preserve its content.
@@ -225,36 +231,10 @@ function decisionMaker(state, config, context)
<|start_header_id|>assistant<|end_header_id|>
"""
externalService = config[:externalservice][:text2textinstruct]
# apply LLM specific instruct format
externalService = config[:externalservice][:text2textinstruct]
msgMeta = GeneralUtils.generate_msgMeta(
externalService[:mqtttopic],
senderName= "reflector",
senderId= string(uuid4()),
receiverName= "text2textinstruct",
mqttBroker= config[:mqttServerInfo][:broker],
mqttBrokerPort= config[:mqttServerInfo][:port],
)
outgoingMsg = Dict(
:msgMeta=> msgMeta,
:payload=> Dict(
:text=> prompt,
:kwargs=> Dict(
:max_tokens=> 512,
:stop=> ["<|eot_id|>"],
:temperature=>0.2,
)
)
)
for i in 1:5
response = GeneralUtils.sendReceiveMqttMsg(outgoingMsg)
_responseJsonStr = response[:response][:text]
d = split(_responseJsonStr, "Code")
response = commfunction(prompt)
responseJsonStr = response[:response][:text]
d = split(responseJsonStr, "Code")
_code = d[2]
index = findfirst(":", _code)[end]
code = _code[index+1:end]
@@ -264,74 +244,11 @@ function decisionMaker(state, config, context)
end
"""
# Arguments
`v::Integer`
dummy variable
# Return
# Example
```jldoctest
julia>
```
# TODO
- [] update docstring
- [PENDING] implement the function
# Signature
"""
function evaluator(newstate, config)
return ("None", 0)
end
"""
# Arguments
`v::Integer`
dummy variable
# Return
# Example
```jldoctest
julia>
```
# TODO
- [] update docstring
- [PENDING] implement the function
# Signature
"""
function reflector()
error("reflector")
end
"""
# Arguments
`v::Integer`
dummy variable
# Return
# Example
```jldoctest
julia>
```
# TODO
- [] update docstring
- [WORKING] implement the function
# Signature
"""
function executeJSON(jsonstring)
try
d = copy(JSON3.read(jsonstring))

View File

@@ -1,4 +1,5 @@
using FormatCorrector
using FormatCorrector, GeneralUtils, UUIDs
incorrectjson = "{\n \"explain\": \"The user is asking for information about wines that can be paired with lamb. This seems like a straightforward question, but it's not clear which database tables will contain the relevant information. I'll start by listing all available tables to see if any of them have columns related to wine and food pairing.\",\n \"plan\":\n \"1) List all tables in the database using 'listalltables',\n 2) Check each table for columns related to wine and food pairing\",\n \"action\": {\"name\": \"listalltables\", \"input\": \"\"},\n \"expectation\": \"A list of all available tables in the database, which will help me identify potential tables that might contain relevant information.\",\n \"observation\": \"\"\n}"
@@ -16,6 +17,9 @@ expectJSON =
}
"""
function commfunction(prompt::String)
config = Dict(
:mqttServerInfo => Dict(
:description => "mqtt server info",
@@ -31,7 +35,36 @@ config = Dict(
)
)
result = FormatCorrector.jsoncorrection(config, incorrectjson, expectJSON)
# apply LLM specific instruct format
externalService = config[:externalservice][:text2textinstruct]
msgMeta = GeneralUtils.generate_msgMeta(
externalService[:mqtttopic],
senderName= "reflector",
senderId= string(uuid4()),
receiverName= "text2textinstruct",
mqttBroker= config[:mqttServerInfo][:broker],
mqttBrokerPort= config[:mqttServerInfo][:port],
)
outgoingMsg = Dict(
:msgMeta=> msgMeta,
:payload=> Dict(
:text=> prompt,
:kwargs=> Dict(
:max_tokens=> 512,
:stop=> ["<|eot_id|>"],
:temperature=>0.2,
)
)
)
response = GeneralUtils.sendReceiveMqttMsg(outgoingMsg)
return response
end
result = FormatCorrector.jsoncorrection(incorrectjson, expectJSON, commfunction)
println(result)