diff --git a/src/interface.jl b/src/interface.jl index 96cc99e..a1afdd5 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -78,8 +78,8 @@ function removeLatestMsg(a::agent) end end -function generatePrompt(a::agent; - userToken::String=" [/INST]", assistantToken=" [INST]", +function generatePrompt_tokenSuffix(a::agent; + userToken::String="[/INST]", assistantToken="[INST]", systemToken="[INST]<> content <>") prompt = nothing for msg in a.messages @@ -100,6 +100,28 @@ function generatePrompt(a::agent; return prompt end +function generatePrompt_tokenPrefix(a::agent; + userToken::String=" [/INST]", assistantToken=" [INST]", + systemToken="[INST]<> content <>") + prompt = nothing + for msg in a.messages + role = msg[:role] + content = msg[:content] + + if role == "system" + prompt = replace(systemToken, "content" => content) + elseif role == "user" + prompt *= userToken * content * " " + elseif role == "assistant" + prompt *= assistantToken * content * " " + else + error("undefied condition role = $role") + end + end + + return prompt +end + @@ -154,7 +176,6 @@ end - end # module \ No newline at end of file