This commit is contained in:
narawat lamaiin
2024-06-29 20:01:42 +07:00
parent 69c42beeca
commit 0a7c2e15bf

View File

@@ -25,7 +25,7 @@ julia>
# Signature
"""
function jsoncorrection(config::Dict, jsonstring::String, context)
function jsoncorrection(config::Dict, jsonstring::String, context; maxattempt=3)
initialstate = Dict{Symbol, Any}(
:reward=> 0,
@@ -41,18 +41,15 @@ function jsoncorrection(config::Dict, jsonstring::String, context)
transitionargs = (
config=config,
decisionMaker=decisionMaker,
evaluator=evaluator,
reflector=reflector,
context=context,
)
_, result = LLMMCTS.runMCTS(initialstate, transition, transitionargs;
totalsample=1, maxdepth=3, maxiterations=1, explorationweight=1.0)
totalsample=1, maxdepth=(maxattempt-1), maxiterations=1, explorationweight=1.0)
if result[:response] !== nothing
return response=result[:code]
if result[:isterminal] == true
return (response=result[:code], errormsg=nothing, success=true)
else
return response=result[:errorexplain]
return (response=nothing, errormsg=result[:errorexplain], success=false)
end
end
@@ -81,9 +78,6 @@ function transition(state::T1, args::NamedTuple
) where {T1<:AbstractDict}
config::AbstractDict = args[:config]
decisionMaker::Function = args[:decisionMaker]
evaluator::Function = args[:evaluator]
reflector::Function = args[:reflector]
context = args[:context]
# only for the 1st transition
@@ -107,6 +101,8 @@ function transition(state::T1, args::NamedTuple
end
end
println("Attempting to correct JSON. ", @__FILE__, " ", @__LINE__)
explain, jsonstr = decisionMaker(state, config, context)
response, errormsg, reward, isterminal = executeJSON(jsonstr)
@@ -120,7 +116,8 @@ function transition(state::T1, args::NamedTuple
newstate[:reward] = reward
newstate[:isterminal] = isterminal
stateevaluation, progressvalue = evaluator(newstate, config)
# stateevaluation, progressvalue = evaluator(newstate, config)
stateevaluation, progressvalue = ("None", 0) # vscode debug error @ evaluator(newstate, config)
return (newNodeKey=newNodeKey, newstate=newstate, progressvalue=progressvalue)
end
@@ -157,12 +154,16 @@ function decisionMaker(state, config, context)
Execution error: Code from the last round error
You should then respond to the user with:
- Why: Why couldn't the JSON string be loaded? Are there any steps missing in your plan? What does the execution error imply?
- Why: Are there any steps missing in your plan? What does the execution error imply?
- Plan:
1) What to do next to complete the task from the current situation.
2) Be specific.
- Code:
1) Write the correct JSON string version.
1) Write a correct JSON string version of the code by improving upon the last round.
You should only respond in format as described below and nothing more:
"Why": ...,
"Plan": ...,
"Code": ...
Let's begin!
@@ -196,7 +197,6 @@ function decisionMaker(state, config, context)
prompt *=
"""
<|start_header_id|>assistant<|end_header_id|>
{
"""
externalService = config[:externalservice][:text2textinstruct]
@@ -220,6 +220,7 @@ function decisionMaker(state, config, context)
:kwargs=> Dict(
:max_tokens=> 512,
:stop=> ["<|eot_id|>"],
:temperature=>0.2,
)
)
)
@@ -227,22 +228,11 @@ function decisionMaker(state, config, context)
for i in 1:5
response = GeneralUtils.sendReceiveMqttMsg(outgoingMsg)
_responseJsonStr = response[:response][:text]
if length(split(_responseJsonStr, "Code:")) == 2
x = split(_responseJsonStr, "Code:")
if occursin("Why", x[2])
error("88")
end
return x
elseif length(split(_responseJsonStr, "Code\":")) == 2
x = split(_responseJsonStr, "Code\":")
if occursin("Why", x[2])
error("88")
end
return x
else
println("Trying to get Code part")
end
d = split(_responseJsonStr, "Code")
_code = d[2]
index = findfirst(":", _code)[end]
code = _code[index+1:end]
return (explain=d[1], code=code)
end
error("Failed to get Code part.")
end
@@ -268,9 +258,7 @@ julia>
# Signature
"""
function evaluator(newstate, config)
evaluation="None"
score=0
return (evaluation=evaluation, score=score)
return ("None", 0)
end
@@ -320,15 +308,15 @@ julia>
"""
function executeJSON(jsonstring)
try
jsonobj = JSON3.read(jsonstring)
d = copy(JSON3.read(jsonstring))
if typeof(jsonobj) <: JSON3.Object
if length(JSON3.write(jsonobj)) < 0.7 * length(jsonstring)
if typeof(d) <: AbstractDict
if length(JSON3.write(d)) < 0.7 * length(jsonstring)
# sometime the original jsonstr is long but the resulting response is short which means
# a lot of string was lost. This is not OK.
return (response=nothing, errormsg="JSON string lost too much content compared to the original when executed. Probably due to some brackets are missing or missplaced.", reward=0, isterminal=false)
else
return (response=copy(jsonobj), errormsg=nothing, reward=1, isterminal=true)
return (response=jsonstring, errormsg=nothing, reward=1, isterminal=true)
end
else
return (response=nothing, errormsg="JSON string parsing failed. Probably due to some brackets are missing or missplaced.", reward=0, isterminal=false)