181 lines
3.1 KiB
Julia
181 lines
3.1 KiB
Julia
module interface
|
|
|
|
|
|
export agent, addNewMessage, clearMessage, removeLatestMsg, generatePrompt
|
|
|
|
using JSON3, DataStructures
|
|
|
|
#------------------------------------------------------------------------------------------------100
|
|
|
|
|
|
|
|
@kwdef mutable struct agent
|
|
availableRole=["system", "user", "assistant"]
|
|
sessionId::Int= 1
|
|
maxUserMsg::Int= 5
|
|
|
|
""" Dict(Role=> Content) ; Role can be system, user, assistant
|
|
Example:
|
|
messages=[
|
|
Dict(:role=>"system", :content=> "You are a helpful assistant."),
|
|
Dict(:role=>"assistant", :content=> "How may I help you"),
|
|
Dict(:role=>"user", :content=> "Hello, how are you"),
|
|
]
|
|
"""
|
|
# Ref: https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGML/discussions/3
|
|
#
|
|
messages=[Dict(:role=>"system", :content=> "You are a helpful assistant."),]
|
|
end
|
|
|
|
"""
|
|
add new message to agent
|
|
# Example
|
|
```jldoctest
|
|
julia> addNewMessage(agent1, "user", "Where should I go to buy snacks")
|
|
````
|
|
"""
|
|
function addNewMessage(a::agent, role::String, content::String)
|
|
if role ∉ a.availableRole # guard against typo
|
|
error("role is not in agent.availableRole")
|
|
end
|
|
|
|
# check whether user messages exceed limit
|
|
userMsg = 0
|
|
for i in a.messages
|
|
if i[:role] == "user"
|
|
userMsg += 1
|
|
end
|
|
end
|
|
messageleft = 0
|
|
|
|
if userMsg > a.maxUserMsg # delete all conversation
|
|
clearMessage(a)
|
|
messageleft = a.maxUserMsg
|
|
else
|
|
userMsg += 1
|
|
d = Dict(:role=> role, :content=> content)
|
|
push!(a.messages, d)
|
|
messageleft = a.maxUserMsg - userMsg
|
|
end
|
|
|
|
return messageleft
|
|
end
|
|
|
|
|
|
function clearMessage(a::agent)
|
|
for i in eachindex(a.messages)
|
|
if length(a.messages) > 1 # system instruction will NOT be deleted
|
|
pop!(a.messages)
|
|
else
|
|
break
|
|
end
|
|
end
|
|
end
|
|
|
|
function removeLatestMsg(a::agent)
|
|
if length(a.messages) > 1
|
|
pop!(a.messages)
|
|
end
|
|
end
|
|
|
|
function generatePrompt_tokenSuffix(a::agent;
|
|
userToken::String="[/INST]", assistantToken="[INST]",
|
|
systemToken="[INST]<<SYS>> content <</SYS>>")
|
|
prompt = nothing
|
|
for msg in a.messages
|
|
role = msg[:role]
|
|
content = msg[:content]
|
|
|
|
if role == "system"
|
|
prompt = replace(systemToken, "content" => content)
|
|
elseif role == "user"
|
|
prompt *= " " * content * userToken
|
|
elseif role == "assistant"
|
|
prompt *= " " * content * assistantToken
|
|
else
|
|
error("undefied condition role = $role")
|
|
end
|
|
end
|
|
|
|
return prompt
|
|
end
|
|
|
|
function generatePrompt_tokenPrefix(a::agent;
|
|
userToken::String=" [/INST]", assistantToken=" [INST]",
|
|
systemToken="[INST]<<SYS>> content <</SYS>>")
|
|
prompt = nothing
|
|
for msg in a.messages
|
|
role = msg[:role]
|
|
content = msg[:content]
|
|
|
|
if role == "system"
|
|
prompt = replace(systemToken, "content" => content)
|
|
elseif role == "user"
|
|
prompt *= userToken * content * " "
|
|
elseif role == "assistant"
|
|
prompt *= assistantToken * content * " "
|
|
else
|
|
error("undefied condition role = $role")
|
|
end
|
|
end
|
|
|
|
return prompt
|
|
end
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
end # module |