update
This commit is contained in:
158
src/mcts.jl
158
src/mcts.jl
@@ -13,8 +13,26 @@ using GeneralUtils
|
||||
# ---------------------------------------------- 100 --------------------------------------------- #
|
||||
|
||||
"""
|
||||
|
||||
Arguments\n
|
||||
-----
|
||||
|
||||
Return\n
|
||||
-----
|
||||
|
||||
Example\n
|
||||
-----
|
||||
```jldoctest
|
||||
julia>
|
||||
```
|
||||
|
||||
TODO\n
|
||||
-----
|
||||
[] update docstring
|
||||
[] implement the function
|
||||
|
||||
Signature\n
|
||||
-----
|
||||
"""
|
||||
struct MCTSNode{T}
|
||||
state::T
|
||||
@@ -23,9 +41,27 @@ struct MCTSNode{T}
|
||||
children::Dict{T, MCTSNode}
|
||||
end
|
||||
|
||||
""" Traversing tree
|
||||
"""
|
||||
|
||||
Arguments\n
|
||||
-----
|
||||
|
||||
Return\n
|
||||
-----
|
||||
|
||||
Example\n
|
||||
-----
|
||||
```jldoctest
|
||||
julia>
|
||||
```
|
||||
|
||||
TODO\n
|
||||
-----
|
||||
[] update docstring
|
||||
[] implement the function
|
||||
|
||||
Signature\n
|
||||
-----
|
||||
"""
|
||||
function select(node::MCTSNode, c::Float64)
|
||||
max_uct = -Inf
|
||||
@@ -44,8 +80,26 @@ function select(node::MCTSNode, c::Float64)
|
||||
end
|
||||
|
||||
"""
|
||||
|
||||
Arguments\n
|
||||
-----
|
||||
|
||||
Return\n
|
||||
-----
|
||||
|
||||
Example\n
|
||||
-----
|
||||
```jldoctest
|
||||
julia>
|
||||
```
|
||||
|
||||
TODO\n
|
||||
-----
|
||||
[] update docstring
|
||||
[] implement the function
|
||||
|
||||
Signature\n
|
||||
-----
|
||||
"""
|
||||
function expand(node::MCTSNode, state::T, actions::Vector{T})
|
||||
for action in actions
|
||||
@@ -57,8 +111,26 @@ function expand(node::MCTSNode, state::T, actions::Vector{T})
|
||||
end
|
||||
|
||||
"""
|
||||
|
||||
Arguments\n
|
||||
-----
|
||||
|
||||
Return\n
|
||||
-----
|
||||
|
||||
Example\n
|
||||
-----
|
||||
```jldoctest
|
||||
julia>
|
||||
```
|
||||
|
||||
TODO\n
|
||||
-----
|
||||
[] update docstring
|
||||
[] implement the function
|
||||
|
||||
Signature\n
|
||||
-----
|
||||
"""
|
||||
function simulate(state::T, max_depth::Int)
|
||||
total_reward = 0.0
|
||||
@@ -71,8 +143,26 @@ function simulate(state::T, max_depth::Int)
|
||||
end
|
||||
|
||||
"""
|
||||
|
||||
Arguments\n
|
||||
-----
|
||||
|
||||
Return\n
|
||||
-----
|
||||
|
||||
Example\n
|
||||
-----
|
||||
```jldoctest
|
||||
julia>
|
||||
```
|
||||
|
||||
TODO\n
|
||||
-----
|
||||
[] update docstring
|
||||
[] implement the function
|
||||
|
||||
Signature\n
|
||||
-----
|
||||
"""
|
||||
function backpropagate(node::MCTSNode, reward::Float64)
|
||||
node.visits += 1
|
||||
@@ -84,29 +174,79 @@ function backpropagate(node::MCTSNode, reward::Float64)
|
||||
end
|
||||
|
||||
"""
|
||||
|
||||
Arguments\n
|
||||
-----
|
||||
|
||||
Return\n
|
||||
-----
|
||||
|
||||
Example\n
|
||||
-----
|
||||
```jldoctest
|
||||
julia>
|
||||
```
|
||||
|
||||
TODO\n
|
||||
-----
|
||||
[] update docstring
|
||||
[] implement transition()
|
||||
[] implement the function
|
||||
|
||||
Signature\n
|
||||
-----
|
||||
"""
|
||||
function transition(state, action)
|
||||
|
||||
end
|
||||
|
||||
"""
|
||||
TODO\n
|
||||
[] update docstring
|
||||
[] implement isLeaf()
|
||||
"""
|
||||
function isLeaf(node::MCTSNode)
|
||||
""" Check whether a node is a leaf node
|
||||
|
||||
end
|
||||
Arguments\n
|
||||
-----
|
||||
|
||||
Return\n
|
||||
-----
|
||||
a task represent an agent
|
||||
|
||||
Example\n
|
||||
-----
|
||||
```jldoctest
|
||||
julia>
|
||||
```
|
||||
|
||||
TODO\n
|
||||
-----
|
||||
[] update docstring
|
||||
[DONE] implement isLeaf()
|
||||
|
||||
Signature\n
|
||||
-----
|
||||
"""
|
||||
isLeaf(node::MCTSNode)::Bool = isempty(node.children)
|
||||
|
||||
# ------------------------------------------------------------------------------------------------ #
|
||||
# Create a complete example using the defined MCTS functions #
|
||||
# ------------------------------------------------------------------------------------------------ #
|
||||
"""
|
||||
|
||||
Arguments\n
|
||||
-----
|
||||
|
||||
Return\n
|
||||
-----
|
||||
|
||||
Example\n
|
||||
-----
|
||||
```jldoctest
|
||||
julia>
|
||||
```
|
||||
|
||||
TODO\n
|
||||
-----
|
||||
[] update docstring
|
||||
|
||||
Signature\n
|
||||
-----
|
||||
"""
|
||||
function run_mcts(initial_state, actions, max_iterations::Int, max_depth::Int, w::Float64)
|
||||
root = MCTSNode(initial_state, 0, 0.0, Dict())
|
||||
|
||||
Reference in New Issue
Block a user