update
This commit is contained in:
@@ -110,9 +110,9 @@ function conversation(a::T, userinput::Dict) where {T<:agent}
|
||||
else #[WORKING] new thinking
|
||||
initialState = Dict(
|
||||
:info=> Dict(), # keyword info
|
||||
:thought=> nothing,
|
||||
:action=> nothing,
|
||||
:observation=> nothing,
|
||||
:thoughtHistory=> Dict( # contain question, thought_1, action_1, observation_1, thought_2, ...
|
||||
:question=> userinput[:text],
|
||||
)
|
||||
)
|
||||
bestplan = runMCTS(initialState, decisionMaker, stateValueEstimator, reflector,
|
||||
3, 10, 1000, 1.0)
|
||||
|
||||
54
src/mcts.jl
54
src/mcts.jl
@@ -17,10 +17,23 @@ using GeneralUtils
|
||||
Arguments\n
|
||||
-----
|
||||
state::T
|
||||
Represent a state of a game. Can be a Dict or something else.
|
||||
a state of a game. Can be a Dict or something else.
|
||||
For example:
|
||||
state = Dict(
|
||||
:info=> Dict(), # keyword info
|
||||
:thoughtHistory=> Dict(
|
||||
:question=> _,
|
||||
:thought_1=> _,
|
||||
:action_1=> _,
|
||||
:observation_1=> _,
|
||||
:thought_2=> _,
|
||||
...
|
||||
)
|
||||
)
|
||||
visits::Integer
|
||||
number of time the game visits this state
|
||||
stateValue::Float64
|
||||
state value
|
||||
|
||||
Return\n
|
||||
-----
|
||||
@@ -46,7 +59,7 @@ struct MCTSNode{T}
|
||||
children::Dict{T, MCTSNode}
|
||||
end
|
||||
|
||||
"""
|
||||
""" Select a node based on UCT score
|
||||
|
||||
Arguments\n
|
||||
-----
|
||||
@@ -87,10 +100,16 @@ function select(node::MCTSNode, w::Float64)
|
||||
return selectedNode
|
||||
end
|
||||
|
||||
"""
|
||||
""" Expand selected node
|
||||
|
||||
Arguments\n
|
||||
-----
|
||||
node::MCTSNode
|
||||
MCTS node
|
||||
state::T
|
||||
a state of a game. Can be a Dict or something else.
|
||||
decisionMaker::Function
|
||||
|
||||
|
||||
Return\n
|
||||
-----
|
||||
@@ -114,6 +133,7 @@ function expand(node::MCTSNode, state::T, decisionMaker::Function, stateValueEst
|
||||
|
||||
# sampling action from decisionMaker
|
||||
for sample in 1:n
|
||||
result = decisionMaker(state)
|
||||
newState = transition(node.state, action) #[] Implement your transition function
|
||||
if newState ∉ keys(node.children)
|
||||
node.children[newState] = MCTSNode(newState, 0, 0.0, Dict{T, MCTSNode}())
|
||||
@@ -262,7 +282,7 @@ isLeaf(node::MCTSNode)::Bool = isempty(node.children)
|
||||
-----
|
||||
[] update docstring
|
||||
[] implement the function
|
||||
[] implement RAG to pull similar experience
|
||||
[WORKING] implement RAG to pull similar experience
|
||||
|
||||
Signature\n
|
||||
-----
|
||||
@@ -350,6 +370,32 @@ function isTerminal()
|
||||
|
||||
end
|
||||
|
||||
"""
|
||||
|
||||
Arguments\n
|
||||
-----
|
||||
|
||||
Return\n
|
||||
-----
|
||||
|
||||
Example\n
|
||||
-----
|
||||
```jldoctest
|
||||
julia>
|
||||
```
|
||||
|
||||
TODO\n
|
||||
-----
|
||||
[] update docstring
|
||||
[WORKING] implement the function
|
||||
|
||||
Signature\n
|
||||
-----
|
||||
"""
|
||||
function executeLLMFunction()
|
||||
|
||||
end
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------------------------ #
|
||||
# Create a complete example using the defined MCTS functions #
|
||||
|
||||
103
src/util.jl
103
src/util.jl
@@ -1,6 +1,6 @@
|
||||
module util
|
||||
|
||||
export clearhistory, addNewMessage
|
||||
export clearhistory, addNewMessage, formatLLMtext, formatLLMtext_llama3instruct
|
||||
|
||||
using UUIDs, Dates, DataStructures, HTTP, MQTTClient, JSON3
|
||||
using GeneralUtils
|
||||
@@ -101,22 +101,114 @@ end
|
||||
Signature\n
|
||||
-----
|
||||
"""
|
||||
function addNewMessage(a::T1, role::String, text::T2;
|
||||
function addNewMessage(a::T1, name::String, text::T2;
|
||||
maximumMsg::Integer=20) where {T1<:agent, T2<:AbstractString}
|
||||
if role ∉ ["system", "user", "assistant"] # guard against typo
|
||||
error("role is not in agent.availableRole $(@__LINE__)")
|
||||
if name ∉ ["system", "user", "assistant"] # guard against typo
|
||||
error("name is not in agent.availableRole $(@__LINE__)")
|
||||
end
|
||||
|
||||
#[] summarize the oldest 10 message
|
||||
if length(a.chathistory) > maximumMsg
|
||||
summarize(a.chathistory)
|
||||
else
|
||||
d = Dict(:role=> role, :text=> text, :timestamp=> Dates.now())
|
||||
d = Dict(:name=> name, :text=> text, :timestamp=> Dates.now())
|
||||
push!(a.chathistory, d)
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
""" Convert a chat dictionary into LLM model instruct format.
|
||||
|
||||
Arguments\n
|
||||
-----
|
||||
name::T
|
||||
message owner name e.f. "system", "user" or "assistant"
|
||||
text::T
|
||||
|
||||
Return\n
|
||||
-----
|
||||
formattedtext::String
|
||||
text formatted to model format
|
||||
|
||||
Example\n
|
||||
-----
|
||||
```jldoctest
|
||||
julia> using Revise
|
||||
julia> d = Dict(:name=> "system",:text=> "You are a helpful, respectful and honest assistant.",)
|
||||
julia> formattedtext = formatLLMtext_llama3instruct(d[:name], d[:text])
|
||||
```
|
||||
|
||||
Signature\n
|
||||
-----
|
||||
"""
|
||||
function formatLLMtext_llama3instruct(name::T, text::T) where {T<:AbstractString}
|
||||
formattedtext =
|
||||
if name == "system"
|
||||
"""<|begin_of_text|>
|
||||
<|start_header_id|>$name<|end_header_id|>
|
||||
$text
|
||||
<|eot_id|>
|
||||
"""
|
||||
else
|
||||
"""
|
||||
<|start_header_id|>$name<|end_header_id|>
|
||||
$text
|
||||
<|eot_id|>
|
||||
"""
|
||||
end
|
||||
|
||||
return formattedtext
|
||||
end
|
||||
|
||||
|
||||
|
||||
""" Convert a chat messages in vector of dictionary into LLM model instruct format.
|
||||
|
||||
Arguments\n
|
||||
-----
|
||||
messages::Vector{Dict{Symbol, T}}
|
||||
message owner name e.f. "system", "user" or "assistant"
|
||||
formatname::T
|
||||
format name to be used
|
||||
|
||||
Return\n
|
||||
-----
|
||||
formattedtext::String
|
||||
text formatted to model format
|
||||
|
||||
Example\n
|
||||
-----
|
||||
```jldoctest
|
||||
julia> using Revise
|
||||
julia> chatmessage = [
|
||||
Dict(:name=> "system",:text=> "You are a helpful, respectful and honest assistant.",),
|
||||
Dict(:name=> "user",:text=> "list me all planets in our solar system.",),
|
||||
]
|
||||
julia> formattedtext = formatLLMtext(chatmessage, "llama3instruct")
|
||||
```
|
||||
|
||||
Signature\n
|
||||
-----
|
||||
"""
|
||||
function formatLLMtext(messages::Vector{Dict{Symbol, T}},
|
||||
formatname::String="llama3instruct") where {T<:Any}
|
||||
f = if formatname == "llama3instruct"
|
||||
formatLLMtext_llama3instruct
|
||||
elseif formatname == "mistral"
|
||||
# not define yet
|
||||
else
|
||||
error("$formatname template not define yet")
|
||||
end
|
||||
|
||||
str = ""
|
||||
for t in messages
|
||||
str *= f(t[:name], t[:text])
|
||||
end
|
||||
|
||||
return str
|
||||
end
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -168,7 +260,6 @@ end
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
end # module util
|
||||
Reference in New Issue
Block a user