update
This commit is contained in:
@@ -51,7 +51,7 @@ julia>
|
|||||||
|
|
||||||
# TODO
|
# TODO
|
||||||
[] update docstring
|
[] update docstring
|
||||||
[x] implement the function
|
[PENDING] implement the function
|
||||||
[] implement RAG to pull similar experience
|
[] implement RAG to pull similar experience
|
||||||
[] use iterative prompting to ensure JSON format
|
[] use iterative prompting to ensure JSON format
|
||||||
|
|
||||||
@@ -143,8 +143,8 @@ function decisionMaker(a::T1, state::T2) where {T1<:agent, T2<:AbstractDict}
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
thought = GeneralUtils.sendReceiveMqttMsg(outgoingMsg)
|
result = GeneralUtils.sendReceiveMqttMsg(outgoingMsg)
|
||||||
|
thought = result[:response][:text]
|
||||||
return thought
|
return thought
|
||||||
end
|
end
|
||||||
|
|
||||||
@@ -288,7 +288,7 @@ end
|
|||||||
function conversation(a::T, userinput::Dict) where {T<:agent}
|
function conversation(a::T, userinput::Dict) where {T<:agent}
|
||||||
"""
|
"""
|
||||||
[] update document
|
[] update document
|
||||||
[x] MCTS() for planning
|
[PENDING] MCTS() for planning
|
||||||
"""
|
"""
|
||||||
# "newtopic" command to delete chat history
|
# "newtopic" command to delete chat history
|
||||||
if userinput[:text] == "newtopic"
|
if userinput[:text] == "newtopic"
|
||||||
|
|||||||
332
src/mcts.jl
332
src/mcts.jl
@@ -15,43 +15,38 @@ using ..type
|
|||||||
|
|
||||||
""" a node for MCTS search tree
|
""" a node for MCTS search tree
|
||||||
|
|
||||||
Arguments\n
|
# Arguments
|
||||||
-----
|
- `state::T`
|
||||||
state::T
|
a state of a game. Can be a Dict or something else.
|
||||||
a state of a game. Can be a Dict or something else.
|
For example:
|
||||||
For example:
|
state = Dict(
|
||||||
state = Dict(
|
:info=> Dict(), # keyword info
|
||||||
:info=> Dict(), # keyword info
|
:thoughtHistory=> Dict(
|
||||||
:thoughtHistory=> Dict(
|
:question=> _,
|
||||||
:question=> _,
|
:thought_1=> _,
|
||||||
:thought_1=> _,
|
:action_1=> _,
|
||||||
:action_1=> _,
|
:observation_1=> _,
|
||||||
:observation_1=> _,
|
:thought_2=> _,
|
||||||
:thought_2=> _,
|
...
|
||||||
...
|
)
|
||||||
)
|
)
|
||||||
)
|
- `visits::Integer `
|
||||||
visits::Integer
|
number of time the game visits this state
|
||||||
number of time the game visits this state
|
- `stateValue::Float64`
|
||||||
stateValue::Float64
|
state value
|
||||||
state value
|
|
||||||
|
|
||||||
Return\n
|
# Return
|
||||||
-----
|
|
||||||
|
|
||||||
Example\n
|
# Example
|
||||||
-----
|
```jldoctest
|
||||||
```jldoctest
|
julia>
|
||||||
julia>
|
```
|
||||||
```
|
|
||||||
|
|
||||||
TODO\n
|
# TODO
|
||||||
-----
|
[] update docstring
|
||||||
[] update docstring
|
[x] implement the function
|
||||||
[DONE] implement the function
|
|
||||||
|
|
||||||
Signature\n
|
# Signature
|
||||||
-----
|
|
||||||
"""
|
"""
|
||||||
struct MCTSNode{T<:AbstractDict}
|
struct MCTSNode{T<:AbstractDict}
|
||||||
state::T
|
state::T
|
||||||
@@ -62,28 +57,23 @@ end
|
|||||||
|
|
||||||
""" Select a node based on UCT score
|
""" Select a node based on UCT score
|
||||||
|
|
||||||
Arguments\n
|
# Arguments
|
||||||
-----
|
- `node::MCTSNode`
|
||||||
node::MCTSNode
|
mcts node
|
||||||
mcts node
|
- `w::Float64`
|
||||||
w::Float64
|
exploration weight
|
||||||
exploration weight
|
# Return
|
||||||
Return\n
|
|
||||||
-----
|
|
||||||
|
|
||||||
Example\n
|
# Example
|
||||||
-----
|
```jldoctest
|
||||||
```jldoctest
|
julia>
|
||||||
julia>
|
```
|
||||||
```
|
|
||||||
|
|
||||||
TODO\n
|
# TODO
|
||||||
-----
|
[] update docstring
|
||||||
[] update docstring
|
[DONE] check childNode.total_reward w/ LATS paper. Which value total_reward representing
|
||||||
[DONE] check childNode.total_reward w/ LATS paper. Which value total_reward representing
|
|
||||||
|
|
||||||
Signature\n
|
# Signature
|
||||||
-----
|
|
||||||
"""
|
"""
|
||||||
function select(node::MCTSNode, w::Float64)
|
function select(node::MCTSNode, w::Float64)
|
||||||
max_uct = -Inf
|
max_uct = -Inf
|
||||||
@@ -103,31 +93,26 @@ end
|
|||||||
|
|
||||||
""" Expand selected node
|
""" Expand selected node
|
||||||
|
|
||||||
Arguments\n
|
# Arguments
|
||||||
-----
|
- `node::MCTSNode`
|
||||||
node::MCTSNode
|
MCTS node
|
||||||
MCTS node
|
- `state::T`
|
||||||
state::T
|
a state of a game. Can be a Dict or something else.
|
||||||
a state of a game. Can be a Dict or something else.
|
- `decisionMaker::Function`
|
||||||
decisionMaker::Function
|
|
||||||
|
|
||||||
|
|
||||||
Return\n
|
|
||||||
-----
|
# Return
|
||||||
|
|
||||||
Example\n
|
# Example
|
||||||
-----
|
```jldoctest
|
||||||
```jldoctest
|
julia>
|
||||||
julia>
|
```
|
||||||
```
|
|
||||||
|
|
||||||
TODO\n
|
# TODO
|
||||||
-----
|
- [] update docstring
|
||||||
[] update docstring
|
- [WORKING] implement the function
|
||||||
[x] implement the function
|
|
||||||
|
|
||||||
Signature\n
|
# Signature
|
||||||
-----
|
|
||||||
"""
|
"""
|
||||||
function expand(a::T1, node::MCTSNode, state::T2, decisionMaker::Function, stateValueEstimator::Function;
|
function expand(a::T1, node::MCTSNode, state::T2, decisionMaker::Function, stateValueEstimator::Function;
|
||||||
n::Integer=3) where {T1<:agent, T2<:AbstractDict}
|
n::Integer=3) where {T1<:agent, T2<:AbstractDict}
|
||||||
@@ -149,26 +134,21 @@ end
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
Arguments\n
|
# Arguments
|
||||||
-----
|
|
||||||
|
# Return
|
||||||
Return\n
|
|
||||||
-----
|
|
||||||
|
|
||||||
Example\n
|
# Example
|
||||||
-----
|
```jldoctest
|
||||||
```jldoctest
|
julia>
|
||||||
julia>
|
```
|
||||||
```
|
|
||||||
|
|
||||||
TODO\n
|
# TODO
|
||||||
-----
|
- [] update docstring
|
||||||
[] update docstring
|
- [] implement the function
|
||||||
[] implement the function
|
- [] reward only comes at terminal state
|
||||||
[] reward only comes at terminal state
|
|
||||||
|
|
||||||
Signature\n
|
# Signature
|
||||||
-----
|
|
||||||
"""
|
"""
|
||||||
function simulate(state::T, max_depth::Int) where {T<:AbstractDict}
|
function simulate(state::T, max_depth::Int) where {T<:AbstractDict}
|
||||||
total_reward = 0.0
|
total_reward = 0.0
|
||||||
@@ -186,25 +166,20 @@ end
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
Arguments\n
|
# Arguments
|
||||||
-----
|
|
||||||
|
# Return
|
||||||
Return\n
|
|
||||||
-----
|
|
||||||
|
|
||||||
Example\n
|
# Example
|
||||||
-----
|
```jldoctest
|
||||||
```jldoctest
|
julia>
|
||||||
julia>
|
```
|
||||||
```
|
|
||||||
|
|
||||||
TODO\n
|
# TODO
|
||||||
-----
|
- [] update docstring
|
||||||
[] update docstring
|
- [] implement the function
|
||||||
[] implement the function
|
|
||||||
|
|
||||||
Signature\n
|
# Signature
|
||||||
-----
|
|
||||||
"""
|
"""
|
||||||
function backpropagate(node::MCTSNode, reward::Float64)
|
function backpropagate(node::MCTSNode, reward::Float64)
|
||||||
node.visits += 1
|
node.visits += 1
|
||||||
@@ -219,79 +194,61 @@ end
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
Arguments\n
|
# Arguments
|
||||||
-----
|
|
||||||
|
# Return
|
||||||
Return\n
|
|
||||||
-----
|
|
||||||
|
|
||||||
Example\n
|
# Example
|
||||||
-----
|
```jldoctest
|
||||||
```jldoctest
|
julia>
|
||||||
julia>
|
```
|
||||||
```
|
|
||||||
|
|
||||||
TODO\n
|
# TODO
|
||||||
-----
|
- [] update docstring
|
||||||
[] update docstring
|
- [] implement the function
|
||||||
[] implement the function
|
|
||||||
|
|
||||||
Signature\n
|
# Signature
|
||||||
-----
|
|
||||||
"""
|
"""
|
||||||
function transition(state, action)
|
function transition(state, action)
|
||||||
|
|
||||||
end
|
end
|
||||||
|
|
||||||
""" Check whether a node is a leaf node of a tree
|
"""
|
||||||
|
|
||||||
Arguments\n
|
# Arguments
|
||||||
-----
|
|
||||||
node::MCTSNode
|
# Return
|
||||||
node of a tree
|
|
||||||
|
|
||||||
Return\n
|
|
||||||
-----
|
|
||||||
result::Bool
|
|
||||||
true if the node is a leaf node of a tree otherwise false
|
|
||||||
|
|
||||||
Example\n
|
# Example
|
||||||
-----
|
```jldoctest
|
||||||
```jldoctest
|
julia>
|
||||||
julia> using
|
```
|
||||||
```
|
|
||||||
|
|
||||||
TODO\n
|
# TODO
|
||||||
-----
|
- [] update docstring
|
||||||
[] update docstring
|
- [x] implement the function
|
||||||
|
|
||||||
Signature\n
|
# Signature
|
||||||
-----
|
|
||||||
"""
|
"""
|
||||||
isleaf(node::MCTSNode)::Bool = isempty(node.children)
|
isleaf(node::MCTSNode)::Bool = isempty(node.children)
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
Arguments\n
|
# Arguments
|
||||||
-----
|
|
||||||
|
# Return
|
||||||
Return\n
|
|
||||||
-----
|
|
||||||
|
|
||||||
Example\n
|
# Example
|
||||||
-----
|
```jldoctest
|
||||||
```jldoctest
|
julia>
|
||||||
julia>
|
```
|
||||||
```
|
|
||||||
|
|
||||||
TODO\n
|
# TODO
|
||||||
-----
|
- [] update docstring
|
||||||
[] update docstring
|
- [] implement the function
|
||||||
[] implement the function
|
|
||||||
|
|
||||||
Signature\n
|
# Signature
|
||||||
-----
|
|
||||||
"""
|
"""
|
||||||
function executeLLMFunction()
|
function executeLLMFunction()
|
||||||
|
|
||||||
@@ -303,42 +260,37 @@ end
|
|||||||
# ------------------------------------------------------------------------------------------------ #
|
# ------------------------------------------------------------------------------------------------ #
|
||||||
""" Search the best action to take for a given state and task
|
""" Search the best action to take for a given state and task
|
||||||
|
|
||||||
Arguments\n
|
# Arguments
|
||||||
-----
|
- `a::agent`
|
||||||
a::agent
|
one of Yiem's agents
|
||||||
one of Yiem's agents
|
- `initial state`
|
||||||
initial state
|
initial state
|
||||||
initial state
|
- `decisionMaker::Function`
|
||||||
decisionMaker::Function
|
decide what action to take
|
||||||
decide what action to take
|
- `stateValueEstimator::Function`
|
||||||
stateValueEstimator::Function
|
assess the value of the state
|
||||||
assess the value of the state
|
- `reflector::Function`
|
||||||
reflector::Function
|
generate lesson from trajectory and reward
|
||||||
generate lesson from trajectory and reward
|
- `isterminal::Function`
|
||||||
isterminal::Function
|
determine whether a given state is a terminal state
|
||||||
determine whether a given state is a terminal state
|
- `n::Integer`
|
||||||
n::Integer
|
how many times action will be sampled from decisionMaker
|
||||||
how many times action will be sampled from decisionMaker
|
- `w::Float64`
|
||||||
w::Float64
|
exploration weight
|
||||||
exploration weight
|
|
||||||
|
# Return
|
||||||
Return\n
|
- `plan::Vector{Dict}`
|
||||||
-----
|
best plan
|
||||||
plan::Vector{Dict}
|
|
||||||
best plan
|
|
||||||
|
|
||||||
Example\n
|
# Example
|
||||||
-----
|
```jldoctest
|
||||||
```jldoctest
|
julia>
|
||||||
julia>
|
```
|
||||||
```
|
|
||||||
|
|
||||||
TODO\n
|
# TODO
|
||||||
-----
|
[] update docstring
|
||||||
[] update docstring
|
|
||||||
|
|
||||||
Signature\n
|
# Signature
|
||||||
-----
|
|
||||||
"""
|
"""
|
||||||
function runMCTS(
|
function runMCTS(
|
||||||
a::T1,
|
a::T1,
|
||||||
|
|||||||
@@ -293,7 +293,7 @@ end
|
|||||||
TODO\n
|
TODO\n
|
||||||
-----
|
-----
|
||||||
[] update docstring
|
[] update docstring
|
||||||
[WORKING] implement the function
|
[PENDING] implement the function
|
||||||
|
|
||||||
Signature\n
|
Signature\n
|
||||||
-----
|
-----
|
||||||
@@ -332,7 +332,7 @@ end
|
|||||||
TODO\n
|
TODO\n
|
||||||
-----
|
-----
|
||||||
[] update docstring
|
[] update docstring
|
||||||
[WORKING] implement the function
|
[TESTING] implement the function
|
||||||
|
|
||||||
Signature\n
|
Signature\n
|
||||||
-----
|
-----
|
||||||
|
|||||||
@@ -100,10 +100,6 @@ response = YiemAgent.conversation(a, Dict(:text=> "Hello, I would like a get a b
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user