update
This commit is contained in:
@@ -108,9 +108,12 @@ function conversation(a::T, userinput::Dict) where {T<:agent}
|
|||||||
|
|
||||||
|
|
||||||
else #[WORKING] new thinking
|
else #[WORKING] new thinking
|
||||||
|
initialState = Dict(
|
||||||
|
:info=> Dict(), # keyword info
|
||||||
initialState = 0
|
:thought=> nothing,
|
||||||
|
:action=> nothing,
|
||||||
|
:observation=> nothing,
|
||||||
|
)
|
||||||
bestplan = runMCTS(initialState, decisionMaker, stateValueEstimator, reflector,
|
bestplan = runMCTS(initialState, decisionMaker, stateValueEstimator, reflector,
|
||||||
3, 10, 1000, 1.0)
|
3, 10, 1000, 1.0)
|
||||||
error("---> bestplan")
|
error("---> bestplan")
|
||||||
|
|||||||
125
src/mcts.jl
125
src/mcts.jl
@@ -5,17 +5,22 @@
|
|||||||
|
|
||||||
module mcts
|
module mcts
|
||||||
|
|
||||||
export runMCTS
|
export MCTSNode, runMCTS, decisionMaker, stateValueEstimator, reflector
|
||||||
|
|
||||||
using Dates, UUIDs, DataStructures, JSON3, Random
|
using Dates, UUIDs, DataStructures, JSON3, Random
|
||||||
using GeneralUtils
|
using GeneralUtils
|
||||||
|
|
||||||
# ---------------------------------------------- 100 --------------------------------------------- #
|
# ---------------------------------------------- 100 --------------------------------------------- #
|
||||||
|
|
||||||
"""
|
""" a node for MCTS search tree
|
||||||
|
|
||||||
Arguments\n
|
Arguments\n
|
||||||
-----
|
-----
|
||||||
|
state::T
|
||||||
|
Represent a state of a game. Can be a Dict or something else.
|
||||||
|
visits::Integer
|
||||||
|
number of time the game visits this state
|
||||||
|
stateValue::Float64
|
||||||
|
|
||||||
Return\n
|
Return\n
|
||||||
-----
|
-----
|
||||||
@@ -29,15 +34,15 @@ using GeneralUtils
|
|||||||
TODO\n
|
TODO\n
|
||||||
-----
|
-----
|
||||||
[] update docstring
|
[] update docstring
|
||||||
[] implement the function
|
[DONE] implement the function
|
||||||
|
|
||||||
Signature\n
|
Signature\n
|
||||||
-----
|
-----
|
||||||
"""
|
"""
|
||||||
struct MCTSNode{T}
|
struct MCTSNode{T}
|
||||||
state::T
|
state::T
|
||||||
visits::Int
|
visits::Integer
|
||||||
stateValue::Float64
|
stateValue::AbstractFloat
|
||||||
children::Dict{T, MCTSNode}
|
children::Dict{T, MCTSNode}
|
||||||
end
|
end
|
||||||
|
|
||||||
@@ -107,23 +112,12 @@ end
|
|||||||
function expand(node::MCTSNode, state::T, decisionMaker::Function, stateValueEstimator::Function;
|
function expand(node::MCTSNode, state::T, decisionMaker::Function, stateValueEstimator::Function;
|
||||||
n::Integer=3) where {T<:Any}
|
n::Integer=3) where {T<:Any}
|
||||||
|
|
||||||
actions = []
|
|
||||||
|
|
||||||
# sampling action from decisionMaker
|
# sampling action from decisionMaker
|
||||||
# for nth in 1:n
|
for sample in 1:n
|
||||||
|
newState = transition(node.state, action) #[] Implement your transition function
|
||||||
|
if newState ∉ keys(node.children)
|
||||||
# end
|
node.children[newState] = MCTSNode(newState, 0, 0.0, Dict{T, MCTSNode}())
|
||||||
|
end
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
for action in actions
|
|
||||||
newState = transition(node.state, action) # Implement your transition function
|
|
||||||
if newState ∉ keys(node.children)
|
|
||||||
node.children[newState] = MCTSNode(newState, 0, 0.0, Dict{T, MCTSNode}())
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
@@ -145,6 +139,7 @@ end
|
|||||||
-----
|
-----
|
||||||
[] update docstring
|
[] update docstring
|
||||||
[] implement the function
|
[] implement the function
|
||||||
|
[] reward only comes at terminal state
|
||||||
|
|
||||||
Signature\n
|
Signature\n
|
||||||
-----
|
-----
|
||||||
@@ -152,9 +147,13 @@ end
|
|||||||
function simulate(state::T, max_depth::Int) where {T<:Any}
|
function simulate(state::T, max_depth::Int) where {T<:Any}
|
||||||
total_reward = 0.0
|
total_reward = 0.0
|
||||||
for _ in 1:max_depth
|
for _ in 1:max_depth
|
||||||
action = select_action(state) # Implement your action selection function
|
#[] Implement your action selection function based on highest stateValue
|
||||||
state, reward = transition(state, action) # Implement your transition function
|
action = select_action(state) # current state
|
||||||
total_reward += reward
|
state, reward = transition(state, action) # Implement transition function to a new state
|
||||||
|
|
||||||
|
#[] check for the terminal state
|
||||||
|
|
||||||
|
total_reward += reward
|
||||||
end
|
end
|
||||||
return total_reward
|
return total_reward
|
||||||
end
|
end
|
||||||
@@ -183,7 +182,9 @@ end
|
|||||||
"""
|
"""
|
||||||
function backpropagate(node::MCTSNode, reward::Float64)
|
function backpropagate(node::MCTSNode, reward::Float64)
|
||||||
node.visits += 1
|
node.visits += 1
|
||||||
node.total_reward += reward
|
|
||||||
|
# [] there is no total_reward in the paper, buy they use stateValue
|
||||||
|
node.total_reward += reward
|
||||||
if !isempty(node.children)
|
if !isempty(node.children)
|
||||||
best_child = argmax([child.total_reward / child.visits for child in values(node.children)])
|
best_child = argmax([child.total_reward / child.visits for child in values(node.children)])
|
||||||
backpropagate(node.children[best_child], -reward)
|
backpropagate(node.children[best_child], -reward)
|
||||||
@@ -216,25 +217,27 @@ function transition(state, action)
|
|||||||
|
|
||||||
end
|
end
|
||||||
|
|
||||||
""" Check whether a node is a leaf node
|
""" Check whether a node is a leaf node of a tree
|
||||||
|
|
||||||
Arguments\n
|
Arguments\n
|
||||||
-----
|
-----
|
||||||
|
node::MCTSNode
|
||||||
|
node of a tree
|
||||||
|
|
||||||
Return\n
|
Return\n
|
||||||
-----
|
-----
|
||||||
a task represent an agent
|
result::Bool
|
||||||
|
true if the node is a leaf node of a tree otherwise false
|
||||||
|
|
||||||
Example\n
|
Example\n
|
||||||
-----
|
-----
|
||||||
```jldoctest
|
```jldoctest
|
||||||
julia>
|
julia> using
|
||||||
```
|
```
|
||||||
|
|
||||||
TODO\n
|
TODO\n
|
||||||
-----
|
-----
|
||||||
[] update docstring
|
[] update docstring
|
||||||
[DONE] implement isLeaf()
|
|
||||||
|
|
||||||
Signature\n
|
Signature\n
|
||||||
-----
|
-----
|
||||||
@@ -320,6 +323,34 @@ function reflector()
|
|||||||
|
|
||||||
end
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
Arguments\n
|
||||||
|
-----
|
||||||
|
|
||||||
|
Return\n
|
||||||
|
-----
|
||||||
|
|
||||||
|
Example\n
|
||||||
|
-----
|
||||||
|
```jldoctest
|
||||||
|
julia>
|
||||||
|
```
|
||||||
|
|
||||||
|
TODO\n
|
||||||
|
-----
|
||||||
|
[] update docstring
|
||||||
|
[] implement the function
|
||||||
|
[] implement RAG to pull similar experience
|
||||||
|
|
||||||
|
Signature\n
|
||||||
|
-----
|
||||||
|
"""
|
||||||
|
function isTerminal()
|
||||||
|
|
||||||
|
end
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------------------------------------ #
|
# ------------------------------------------------------------------------------------------------ #
|
||||||
# Create a complete example using the defined MCTS functions #
|
# Create a complete example using the defined MCTS functions #
|
||||||
# ------------------------------------------------------------------------------------------------ #
|
# ------------------------------------------------------------------------------------------------ #
|
||||||
@@ -342,6 +373,8 @@ end
|
|||||||
|
|
||||||
Return\n
|
Return\n
|
||||||
-----
|
-----
|
||||||
|
plan::Vector{Dict}
|
||||||
|
best plan
|
||||||
|
|
||||||
Example\n
|
Example\n
|
||||||
-----
|
-----
|
||||||
@@ -357,26 +390,26 @@ end
|
|||||||
-----
|
-----
|
||||||
"""
|
"""
|
||||||
function runMCTS(initialState, decisionMaker::Function, stateValueEstimator::Function,
|
function runMCTS(initialState, decisionMaker::Function, stateValueEstimator::Function,
|
||||||
reflector::Function, totalActionSampled::Integer, maxDepth::Integer,
|
reflector::Function, n::Integer, maxDepth::Integer,
|
||||||
maxIterations::Integer, w::Float64)
|
maxIterations::Integer, w::Float64)
|
||||||
root = MCTSNode(initialState, 0, 0.0, Dict())
|
statetype = typeof(initialState)
|
||||||
|
root = MCTSNode(initialState, 0, 0.0, Dict{statetype, MCTSNode}())
|
||||||
for _ in 1:maxIterations
|
error("---> runMCTS")
|
||||||
node = root
|
for _ in 1:maxIterations
|
||||||
while !isLeaf(node)
|
node = root
|
||||||
node = select(node, w)
|
while !isLeaf(node)
|
||||||
end
|
node = select(node, w)
|
||||||
|
|
||||||
expand(node, node.state, decisionMaker, stateValueEstimator,
|
|
||||||
n=n)
|
|
||||||
|
|
||||||
leaf_node = node.children[node.state]
|
|
||||||
reward = simulate(leaf_node.state, maxDepth)
|
|
||||||
backpropagate(leaf_node, reward)
|
|
||||||
end
|
end
|
||||||
|
|
||||||
best_child_state = argmax([child.total_reward / child.visits for child in values(root.children)])
|
expand(node, node.state, decisionMaker, stateValueEstimator, n=n)
|
||||||
return best_child_state
|
|
||||||
|
leaf_node = node.children[node.state] # mark leaf node
|
||||||
|
reward = simulate(leaf_node.state, maxDepth)
|
||||||
|
backpropagate(leaf_node, reward)
|
||||||
|
end
|
||||||
|
|
||||||
|
best_child_state = argmax([child.total_reward / child.visits for child in values(root.children)])
|
||||||
|
return best_child_state
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -63,7 +63,7 @@ tools=Dict( # update input format
|
|||||||
tools=tools,
|
tools=tools,
|
||||||
)
|
)
|
||||||
|
|
||||||
response = YiemAgent.conversation(a, Dict(:text=> "newtopic", ) )
|
response = YiemAgent.conversation(a, Dict(:text=> "hello", ) )
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user