This commit is contained in:
narawat lamaiin
2024-06-24 17:46:31 +07:00
commit 7dab20df41
7 changed files with 1368 additions and 0 deletions

17
src/FormatCorrector.jl Normal file
View File

@@ -0,0 +1,17 @@
module FormatCorrector
# export
""" Order by dependencies of each file. The 1st included file must not depend on any other
files and each file can only depend on the file included before it.
"""
include("interface.jl")
using .interface
# ---------------------------------------------- 100 --------------------------------------------- #
end # module FormatCorrector

403
src/interface.jl Normal file
View File

@@ -0,0 +1,403 @@
module interface
export jsoncorrection
using JSON3, DataStructures, Random, Dates, UUIDs, MQTTClient
using LLMMCTS, GeneralUtils
# ---------------------------------------------- 100 --------------------------------------------- #
"""
# Arguments
# Return
# Example
```jldoctest
julia>
```
# TODO
- [] update docstring
- [] implement the function
# Signature
"""
function jsoncorrection(config::Dict, jsonstring::String, context::Dict)
initialstate = Dict{Symbol, Any}(
:reward=> 0,
:isterminal=> false,
:evaluation=> nothing,
:errormsg=> nothing,
:errorexplain=> nothing,
:question=> jsonstring,
:code=> nothing,
:response=> nothing,
)
transitionargs = (
config=config,
decisionMaker=decisionMaker,
evaluator=evaluator,
reflector=reflector,
context=context,
)
result, _ = LLMMCTS.runMCTS(initialstate, transition, transitionargs;
totalsample=1, maxdepth=3, maxiterations=1, explorationweight=1.0)
if result[:response] !== nothing
return (response=result[:response], select=nothing, reward=0, isterminal=false)
else
return (response=result[:errorexplain], select=nothing, reward=0, isterminal=false)
end
end
"""
# Arguments
`v::Integer`
dummy variable
# Return
# Example
```jldoctest
julia>
```
# TODO
- [] update docstring
- [PENDING] implement the function. This function will hace the same structure as evaluation()
and it act as iterative code generation function
# Signature
"""
function transition(state::T1, args::NamedTuple
) where {T1<:AbstractDict}
config::AbstractDict = args[:config]
decisionMaker::Function = args[:decisionMaker]
evaluator::Function = args[:evaluator]
reflector::Function = args[:reflector]
context::Union{AbstractDict, Nothing} = args[:context]
# only for the 1st transition
if state[:errormsg] === nothing
response, errormsg, reward, isterminal = executeJSON(state[:question])
state[:response] = response
state[:errormsg] = errormsg
end
explain, jsonstr = decisionMaker(state, config, context)
response, errormsg, reward, isterminal = executeJSON(jsonstr)
# make new state
newNodeKey = GeneralUtils.uuid4snakecase()
newstate = deepcopy(state)
newstate[:code] = jsonstr
newstate[:response] = response
newstate[:errormsg] = errormsg
newstate[:errorexplain] = explain
newstate[:reward] = reward
newstate[:isterminal] = isterminal
stateevaluation, progressvalue = evaluator(newstate, config)
return (newNodeKey=newNodeKey, newstate=newstate, progressvalue=progressvalue)
end
"""
# Arguments
`v::Integer`
dummy variable
# Return
# Example
```jldoctest
julia>
```
# TODO
- [] update docstring
- [WORKING] implement the function
# Signature
"""
function decisionMaker(state, config, context)
systemmsg =
"""
You are a helpful assistant that corrects the user's incorrect JSON string while preserve its content.
At each round of conversation, the user will give you:
Context: ....
Query: The user's original incorrect JSON string
Code from the last round: Your revised JSON string
Execution error: Your JSON string's parsing error
You should then respond to the user with:
- Why: Why couldn't the JSON string be loaded? Are there any steps missing in your plan? What does the execution error imply?
- Code:
1) Write the correct JSON string version.
You should only respond in format as described below and nothing more:
"Why": ...,
"Code": ...
Let's begin!
"""
usermsg =
if state[:code] !== nothing
"""
Context: $(JSON3.write(context[:expectedJsonExample]))
Query: $(state[:question])
Code from the last round: $(state[:code])
Execution error: $(state[:errormsg])
"""
else
"""
Context: $(JSON3.write(context[:expectedJsonExample]))
Query: $(state[:question])
Code from the last round: None
Execution error: $(state[:errormsg])
"""
end
chathistory =
[
Dict(:name=> "system", :text=> systemmsg),
Dict(:name=> "user", :text=> usermsg)
]
# put in model format
prompt = GeneralUtils.formatLLMtext(chathistory, "llama3instruct")
prompt *=
"""
<|start_header_id|>assistant<|end_header_id|>
{
"""
externalService = config[:externalservice][:text2textinstruct]
# apply LLM specific instruct format
externalService = config[:externalservice][:text2textinstruct]
msgMeta = GeneralUtils.generate_msgMeta(
externalService[:mqtttopic],
senderName= "reflector",
senderId= string(uuid4()),
receiverName= "text2textinstruct",
mqttBroker= config[:mqttServerInfo][:broker],
mqttBrokerPort= config[:mqttServerInfo][:port],
)
outgoingMsg = Dict(
:msgMeta=> msgMeta,
:payload=> Dict(
:text=> prompt,
:kwargs=> Dict(
:max_tokens=> 512,
:stop=> ["<|eot_id|>"],
)
)
)
response = GeneralUtils.sendReceiveMqttMsg(outgoingMsg)
_responseJsonStr = response[:response][:text]
code =
if length(split(_responseJsonStr, "Code:")) == 2
split(_responseJsonStr, "Code:")
elseif length(split(_responseJsonStr, "Code\":")) == 2
split(_responseJsonStr, "Code\":")
else
error("failed to get Code part")
end
return code
end
"""
# Arguments
`v::Integer`
dummy variable
# Return
# Example
```jldoctest
julia>
```
# TODO
- [] update docstring
- [PENDING] implement the function
# Signature
"""
function evaluator(newstate, config)
evaluation="None"
score=0
return (evaluation=evaluation, score=score)
end
"""
# Arguments
`v::Integer`
dummy variable
# Return
# Example
```jldoctest
julia>
```
# TODO
- [] update docstring
- [PENDING] implement the function
# Signature
"""
function reflector()
error("reflector")
end
"""
# Arguments
`v::Integer`
dummy variable
# Return
# Example
```jldoctest
julia>
```
# TODO
- [] update docstring
- [WORKING] implement the function
# Signature
"""
function executeJSON(jsonstring)
try
jsonobj = JSON3.read(jsonstring)
if typeof(jsonobj) <: JSON3.Object
if length(JSON3.write(jsonobj)) < 0.7 * length(jsonstring)
# sometime the original jsonstr is long but the resulting response is short which means
# a lot of string was lost. This is not OK.
return (response=nothing, errormsg="JSON string lost too much content compared to the original when executed. Probably due to some brackets are missing or missplaced.", reward=0, isterminal=false)
else
return (response=copy(jsonobj), errormsg=nothing, reward=1, isterminal=true)
end
else
return (response=nothing, errormsg="JSON string parsing failed. Probably due to some brackets are missing or missplaced.", reward=0, isterminal=false)
end
catch e
io = IOBuffer()
showerror(io, e)
errorMsg = String(take!(io))
st = sprint((io, v) -> show(io, "text/plain", v), stacktrace(catch_backtrace()))
println("")
println("")
return (response=nothing, errormsg=errorMsg, reward=0, isterminal=false)
end
end
end # module interface