This commit is contained in:
narawat lamaiin
2024-04-24 08:02:01 +07:00
parent 1962035990
commit 445bb11592
3 changed files with 151 additions and 14 deletions

View File

@@ -110,9 +110,9 @@ function conversation(a::T, userinput::Dict) where {T<:agent}
else #[WORKING] new thinking else #[WORKING] new thinking
initialState = Dict( initialState = Dict(
:info=> Dict(), # keyword info :info=> Dict(), # keyword info
:thought=> nothing, :thoughtHistory=> Dict( # contain question, thought_1, action_1, observation_1, thought_2, ...
:action=> nothing, :question=> userinput[:text],
:observation=> nothing, )
) )
bestplan = runMCTS(initialState, decisionMaker, stateValueEstimator, reflector, bestplan = runMCTS(initialState, decisionMaker, stateValueEstimator, reflector,
3, 10, 1000, 1.0) 3, 10, 1000, 1.0)

View File

@@ -17,10 +17,23 @@ using GeneralUtils
Arguments\n Arguments\n
----- -----
state::T state::T
Represent a state of a game. Can be a Dict or something else. a state of a game. Can be a Dict or something else.
For example:
state = Dict(
:info=> Dict(), # keyword info
:thoughtHistory=> Dict(
:question=> _,
:thought_1=> _,
:action_1=> _,
:observation_1=> _,
:thought_2=> _,
...
)
)
visits::Integer visits::Integer
number of time the game visits this state number of time the game visits this state
stateValue::Float64 stateValue::Float64
state value
Return\n Return\n
----- -----
@@ -46,7 +59,7 @@ struct MCTSNode{T}
children::Dict{T, MCTSNode} children::Dict{T, MCTSNode}
end end
""" """ Select a node based on UCT score
Arguments\n Arguments\n
----- -----
@@ -87,10 +100,16 @@ function select(node::MCTSNode, w::Float64)
return selectedNode return selectedNode
end end
""" """ Expand selected node
Arguments\n Arguments\n
----- -----
node::MCTSNode
MCTS node
state::T
a state of a game. Can be a Dict or something else.
decisionMaker::Function
Return\n Return\n
----- -----
@@ -114,6 +133,7 @@ function expand(node::MCTSNode, state::T, decisionMaker::Function, stateValueEst
# sampling action from decisionMaker # sampling action from decisionMaker
for sample in 1:n for sample in 1:n
result = decisionMaker(state)
newState = transition(node.state, action) #[] Implement your transition function newState = transition(node.state, action) #[] Implement your transition function
if newState keys(node.children) if newState keys(node.children)
node.children[newState] = MCTSNode(newState, 0, 0.0, Dict{T, MCTSNode}()) node.children[newState] = MCTSNode(newState, 0, 0.0, Dict{T, MCTSNode}())
@@ -262,7 +282,7 @@ isLeaf(node::MCTSNode)::Bool = isempty(node.children)
----- -----
[] update docstring [] update docstring
[] implement the function [] implement the function
[] implement RAG to pull similar experience [WORKING] implement RAG to pull similar experience
Signature\n Signature\n
----- -----
@@ -350,6 +370,32 @@ function isTerminal()
end end
"""
Arguments\n
-----
Return\n
-----
Example\n
-----
```jldoctest
julia>
```
TODO\n
-----
[] update docstring
[WORKING] implement the function
Signature\n
-----
"""
function executeLLMFunction()
end
# ------------------------------------------------------------------------------------------------ # # ------------------------------------------------------------------------------------------------ #
# Create a complete example using the defined MCTS functions # # Create a complete example using the defined MCTS functions #

View File

@@ -1,6 +1,6 @@
module util module util
export clearhistory, addNewMessage export clearhistory, addNewMessage, formatLLMtext, formatLLMtext_llama3instruct
using UUIDs, Dates, DataStructures, HTTP, MQTTClient, JSON3 using UUIDs, Dates, DataStructures, HTTP, MQTTClient, JSON3
using GeneralUtils using GeneralUtils
@@ -101,22 +101,114 @@ end
Signature\n Signature\n
----- -----
""" """
function addNewMessage(a::T1, role::String, text::T2; function addNewMessage(a::T1, name::String, text::T2;
maximumMsg::Integer=20) where {T1<:agent, T2<:AbstractString} maximumMsg::Integer=20) where {T1<:agent, T2<:AbstractString}
if role ["system", "user", "assistant"] # guard against typo if name ["system", "user", "assistant"] # guard against typo
error("role is not in agent.availableRole $(@__LINE__)") error("name is not in agent.availableRole $(@__LINE__)")
end end
#[] summarize the oldest 10 message #[] summarize the oldest 10 message
if length(a.chathistory) > maximumMsg if length(a.chathistory) > maximumMsg
summarize(a.chathistory) summarize(a.chathistory)
else else
d = Dict(:role=> role, :text=> text, :timestamp=> Dates.now()) d = Dict(:name=> name, :text=> text, :timestamp=> Dates.now())
push!(a.chathistory, d) push!(a.chathistory, d)
end end
end end
""" Convert a chat dictionary into LLM model instruct format.
Arguments\n
-----
name::T
message owner name e.f. "system", "user" or "assistant"
text::T
Return\n
-----
formattedtext::String
text formatted to model format
Example\n
-----
```jldoctest
julia> using Revise
julia> d = Dict(:name=> "system",:text=> "You are a helpful, respectful and honest assistant.",)
julia> formattedtext = formatLLMtext_llama3instruct(d[:name], d[:text])
```
Signature\n
-----
"""
function formatLLMtext_llama3instruct(name::T, text::T) where {T<:AbstractString}
formattedtext =
if name == "system"
"""<|begin_of_text|>
<|start_header_id|>$name<|end_header_id|>
$text
<|eot_id|>
"""
else
"""
<|start_header_id|>$name<|end_header_id|>
$text
<|eot_id|>
"""
end
return formattedtext
end
""" Convert a chat messages in vector of dictionary into LLM model instruct format.
Arguments\n
-----
messages::Vector{Dict{Symbol, T}}
message owner name e.f. "system", "user" or "assistant"
formatname::T
format name to be used
Return\n
-----
formattedtext::String
text formatted to model format
Example\n
-----
```jldoctest
julia> using Revise
julia> chatmessage = [
Dict(:name=> "system",:text=> "You are a helpful, respectful and honest assistant.",),
Dict(:name=> "user",:text=> "list me all planets in our solar system.",),
]
julia> formattedtext = formatLLMtext(chatmessage, "llama3instruct")
```
Signature\n
-----
"""
function formatLLMtext(messages::Vector{Dict{Symbol, T}},
formatname::String="llama3instruct") where {T<:Any}
f = if formatname == "llama3instruct"
formatLLMtext_llama3instruct
elseif formatname == "mistral"
# not define yet
else
error("$formatname template not define yet")
end
str = ""
for t in messages
str *= f(t[:name], t[:text])
end
return str
end
@@ -168,7 +260,6 @@ end
end # module util end # module util