update
This commit is contained in:
@@ -110,9 +110,9 @@ function conversation(a::T, userinput::Dict) where {T<:agent}
|
|||||||
else #[WORKING] new thinking
|
else #[WORKING] new thinking
|
||||||
initialState = Dict(
|
initialState = Dict(
|
||||||
:info=> Dict(), # keyword info
|
:info=> Dict(), # keyword info
|
||||||
:thought=> nothing,
|
:thoughtHistory=> Dict( # contain question, thought_1, action_1, observation_1, thought_2, ...
|
||||||
:action=> nothing,
|
:question=> userinput[:text],
|
||||||
:observation=> nothing,
|
)
|
||||||
)
|
)
|
||||||
bestplan = runMCTS(initialState, decisionMaker, stateValueEstimator, reflector,
|
bestplan = runMCTS(initialState, decisionMaker, stateValueEstimator, reflector,
|
||||||
3, 10, 1000, 1.0)
|
3, 10, 1000, 1.0)
|
||||||
|
|||||||
54
src/mcts.jl
54
src/mcts.jl
@@ -17,10 +17,23 @@ using GeneralUtils
|
|||||||
Arguments\n
|
Arguments\n
|
||||||
-----
|
-----
|
||||||
state::T
|
state::T
|
||||||
Represent a state of a game. Can be a Dict or something else.
|
a state of a game. Can be a Dict or something else.
|
||||||
|
For example:
|
||||||
|
state = Dict(
|
||||||
|
:info=> Dict(), # keyword info
|
||||||
|
:thoughtHistory=> Dict(
|
||||||
|
:question=> _,
|
||||||
|
:thought_1=> _,
|
||||||
|
:action_1=> _,
|
||||||
|
:observation_1=> _,
|
||||||
|
:thought_2=> _,
|
||||||
|
...
|
||||||
|
)
|
||||||
|
)
|
||||||
visits::Integer
|
visits::Integer
|
||||||
number of time the game visits this state
|
number of time the game visits this state
|
||||||
stateValue::Float64
|
stateValue::Float64
|
||||||
|
state value
|
||||||
|
|
||||||
Return\n
|
Return\n
|
||||||
-----
|
-----
|
||||||
@@ -46,7 +59,7 @@ struct MCTSNode{T}
|
|||||||
children::Dict{T, MCTSNode}
|
children::Dict{T, MCTSNode}
|
||||||
end
|
end
|
||||||
|
|
||||||
"""
|
""" Select a node based on UCT score
|
||||||
|
|
||||||
Arguments\n
|
Arguments\n
|
||||||
-----
|
-----
|
||||||
@@ -87,10 +100,16 @@ function select(node::MCTSNode, w::Float64)
|
|||||||
return selectedNode
|
return selectedNode
|
||||||
end
|
end
|
||||||
|
|
||||||
"""
|
""" Expand selected node
|
||||||
|
|
||||||
Arguments\n
|
Arguments\n
|
||||||
-----
|
-----
|
||||||
|
node::MCTSNode
|
||||||
|
MCTS node
|
||||||
|
state::T
|
||||||
|
a state of a game. Can be a Dict or something else.
|
||||||
|
decisionMaker::Function
|
||||||
|
|
||||||
|
|
||||||
Return\n
|
Return\n
|
||||||
-----
|
-----
|
||||||
@@ -114,6 +133,7 @@ function expand(node::MCTSNode, state::T, decisionMaker::Function, stateValueEst
|
|||||||
|
|
||||||
# sampling action from decisionMaker
|
# sampling action from decisionMaker
|
||||||
for sample in 1:n
|
for sample in 1:n
|
||||||
|
result = decisionMaker(state)
|
||||||
newState = transition(node.state, action) #[] Implement your transition function
|
newState = transition(node.state, action) #[] Implement your transition function
|
||||||
if newState ∉ keys(node.children)
|
if newState ∉ keys(node.children)
|
||||||
node.children[newState] = MCTSNode(newState, 0, 0.0, Dict{T, MCTSNode}())
|
node.children[newState] = MCTSNode(newState, 0, 0.0, Dict{T, MCTSNode}())
|
||||||
@@ -262,7 +282,7 @@ isLeaf(node::MCTSNode)::Bool = isempty(node.children)
|
|||||||
-----
|
-----
|
||||||
[] update docstring
|
[] update docstring
|
||||||
[] implement the function
|
[] implement the function
|
||||||
[] implement RAG to pull similar experience
|
[WORKING] implement RAG to pull similar experience
|
||||||
|
|
||||||
Signature\n
|
Signature\n
|
||||||
-----
|
-----
|
||||||
@@ -350,6 +370,32 @@ function isTerminal()
|
|||||||
|
|
||||||
end
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
Arguments\n
|
||||||
|
-----
|
||||||
|
|
||||||
|
Return\n
|
||||||
|
-----
|
||||||
|
|
||||||
|
Example\n
|
||||||
|
-----
|
||||||
|
```jldoctest
|
||||||
|
julia>
|
||||||
|
```
|
||||||
|
|
||||||
|
TODO\n
|
||||||
|
-----
|
||||||
|
[] update docstring
|
||||||
|
[WORKING] implement the function
|
||||||
|
|
||||||
|
Signature\n
|
||||||
|
-----
|
||||||
|
"""
|
||||||
|
function executeLLMFunction()
|
||||||
|
|
||||||
|
end
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------------------------------------ #
|
# ------------------------------------------------------------------------------------------------ #
|
||||||
# Create a complete example using the defined MCTS functions #
|
# Create a complete example using the defined MCTS functions #
|
||||||
|
|||||||
101
src/util.jl
101
src/util.jl
@@ -1,6 +1,6 @@
|
|||||||
module util
|
module util
|
||||||
|
|
||||||
export clearhistory, addNewMessage
|
export clearhistory, addNewMessage, formatLLMtext, formatLLMtext_llama3instruct
|
||||||
|
|
||||||
using UUIDs, Dates, DataStructures, HTTP, MQTTClient, JSON3
|
using UUIDs, Dates, DataStructures, HTTP, MQTTClient, JSON3
|
||||||
using GeneralUtils
|
using GeneralUtils
|
||||||
@@ -101,22 +101,113 @@ end
|
|||||||
Signature\n
|
Signature\n
|
||||||
-----
|
-----
|
||||||
"""
|
"""
|
||||||
function addNewMessage(a::T1, role::String, text::T2;
|
function addNewMessage(a::T1, name::String, text::T2;
|
||||||
maximumMsg::Integer=20) where {T1<:agent, T2<:AbstractString}
|
maximumMsg::Integer=20) where {T1<:agent, T2<:AbstractString}
|
||||||
if role ∉ ["system", "user", "assistant"] # guard against typo
|
if name ∉ ["system", "user", "assistant"] # guard against typo
|
||||||
error("role is not in agent.availableRole $(@__LINE__)")
|
error("name is not in agent.availableRole $(@__LINE__)")
|
||||||
end
|
end
|
||||||
|
|
||||||
#[] summarize the oldest 10 message
|
#[] summarize the oldest 10 message
|
||||||
if length(a.chathistory) > maximumMsg
|
if length(a.chathistory) > maximumMsg
|
||||||
summarize(a.chathistory)
|
summarize(a.chathistory)
|
||||||
else
|
else
|
||||||
d = Dict(:role=> role, :text=> text, :timestamp=> Dates.now())
|
d = Dict(:name=> name, :text=> text, :timestamp=> Dates.now())
|
||||||
push!(a.chathistory, d)
|
push!(a.chathistory, d)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
|
""" Convert a chat dictionary into LLM model instruct format.
|
||||||
|
|
||||||
|
Arguments\n
|
||||||
|
-----
|
||||||
|
name::T
|
||||||
|
message owner name e.f. "system", "user" or "assistant"
|
||||||
|
text::T
|
||||||
|
|
||||||
|
Return\n
|
||||||
|
-----
|
||||||
|
formattedtext::String
|
||||||
|
text formatted to model format
|
||||||
|
|
||||||
|
Example\n
|
||||||
|
-----
|
||||||
|
```jldoctest
|
||||||
|
julia> using Revise
|
||||||
|
julia> d = Dict(:name=> "system",:text=> "You are a helpful, respectful and honest assistant.",)
|
||||||
|
julia> formattedtext = formatLLMtext_llama3instruct(d[:name], d[:text])
|
||||||
|
```
|
||||||
|
|
||||||
|
Signature\n
|
||||||
|
-----
|
||||||
|
"""
|
||||||
|
function formatLLMtext_llama3instruct(name::T, text::T) where {T<:AbstractString}
|
||||||
|
formattedtext =
|
||||||
|
if name == "system"
|
||||||
|
"""<|begin_of_text|>
|
||||||
|
<|start_header_id|>$name<|end_header_id|>
|
||||||
|
$text
|
||||||
|
<|eot_id|>
|
||||||
|
"""
|
||||||
|
else
|
||||||
|
"""
|
||||||
|
<|start_header_id|>$name<|end_header_id|>
|
||||||
|
$text
|
||||||
|
<|eot_id|>
|
||||||
|
"""
|
||||||
|
end
|
||||||
|
|
||||||
|
return formattedtext
|
||||||
|
end
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
""" Convert a chat messages in vector of dictionary into LLM model instruct format.
|
||||||
|
|
||||||
|
Arguments\n
|
||||||
|
-----
|
||||||
|
messages::Vector{Dict{Symbol, T}}
|
||||||
|
message owner name e.f. "system", "user" or "assistant"
|
||||||
|
formatname::T
|
||||||
|
format name to be used
|
||||||
|
|
||||||
|
Return\n
|
||||||
|
-----
|
||||||
|
formattedtext::String
|
||||||
|
text formatted to model format
|
||||||
|
|
||||||
|
Example\n
|
||||||
|
-----
|
||||||
|
```jldoctest
|
||||||
|
julia> using Revise
|
||||||
|
julia> chatmessage = [
|
||||||
|
Dict(:name=> "system",:text=> "You are a helpful, respectful and honest assistant.",),
|
||||||
|
Dict(:name=> "user",:text=> "list me all planets in our solar system.",),
|
||||||
|
]
|
||||||
|
julia> formattedtext = formatLLMtext(chatmessage, "llama3instruct")
|
||||||
|
```
|
||||||
|
|
||||||
|
Signature\n
|
||||||
|
-----
|
||||||
|
"""
|
||||||
|
function formatLLMtext(messages::Vector{Dict{Symbol, T}},
|
||||||
|
formatname::String="llama3instruct") where {T<:Any}
|
||||||
|
f = if formatname == "llama3instruct"
|
||||||
|
formatLLMtext_llama3instruct
|
||||||
|
elseif formatname == "mistral"
|
||||||
|
# not define yet
|
||||||
|
else
|
||||||
|
error("$formatname template not define yet")
|
||||||
|
end
|
||||||
|
|
||||||
|
str = ""
|
||||||
|
for t in messages
|
||||||
|
str *= f(t[:name], t[:text])
|
||||||
|
end
|
||||||
|
|
||||||
|
return str
|
||||||
|
end
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user