update
This commit is contained in:
172
src/mcts.jl
172
src/mcts.jl
@@ -13,8 +13,26 @@ using GeneralUtils
|
|||||||
# ---------------------------------------------- 100 --------------------------------------------- #
|
# ---------------------------------------------- 100 --------------------------------------------- #
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
Arguments\n
|
||||||
|
-----
|
||||||
|
|
||||||
|
Return\n
|
||||||
|
-----
|
||||||
|
|
||||||
|
Example\n
|
||||||
|
-----
|
||||||
|
```jldoctest
|
||||||
|
julia>
|
||||||
|
```
|
||||||
|
|
||||||
TODO\n
|
TODO\n
|
||||||
[] update docstring
|
-----
|
||||||
|
[] update docstring
|
||||||
|
[] implement the function
|
||||||
|
|
||||||
|
Signature\n
|
||||||
|
-----
|
||||||
"""
|
"""
|
||||||
struct MCTSNode{T}
|
struct MCTSNode{T}
|
||||||
state::T
|
state::T
|
||||||
@@ -23,9 +41,27 @@ struct MCTSNode{T}
|
|||||||
children::Dict{T, MCTSNode}
|
children::Dict{T, MCTSNode}
|
||||||
end
|
end
|
||||||
|
|
||||||
""" Traversing tree
|
"""
|
||||||
|
|
||||||
|
Arguments\n
|
||||||
|
-----
|
||||||
|
|
||||||
|
Return\n
|
||||||
|
-----
|
||||||
|
|
||||||
|
Example\n
|
||||||
|
-----
|
||||||
|
```jldoctest
|
||||||
|
julia>
|
||||||
|
```
|
||||||
|
|
||||||
TODO\n
|
TODO\n
|
||||||
[] update docstring
|
-----
|
||||||
|
[] update docstring
|
||||||
|
[] implement the function
|
||||||
|
|
||||||
|
Signature\n
|
||||||
|
-----
|
||||||
"""
|
"""
|
||||||
function select(node::MCTSNode, c::Float64)
|
function select(node::MCTSNode, c::Float64)
|
||||||
max_uct = -Inf
|
max_uct = -Inf
|
||||||
@@ -44,8 +80,26 @@ function select(node::MCTSNode, c::Float64)
|
|||||||
end
|
end
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
Arguments\n
|
||||||
|
-----
|
||||||
|
|
||||||
|
Return\n
|
||||||
|
-----
|
||||||
|
|
||||||
|
Example\n
|
||||||
|
-----
|
||||||
|
```jldoctest
|
||||||
|
julia>
|
||||||
|
```
|
||||||
|
|
||||||
TODO\n
|
TODO\n
|
||||||
[] update docstring
|
-----
|
||||||
|
[] update docstring
|
||||||
|
[] implement the function
|
||||||
|
|
||||||
|
Signature\n
|
||||||
|
-----
|
||||||
"""
|
"""
|
||||||
function expand(node::MCTSNode, state::T, actions::Vector{T})
|
function expand(node::MCTSNode, state::T, actions::Vector{T})
|
||||||
for action in actions
|
for action in actions
|
||||||
@@ -57,8 +111,26 @@ function expand(node::MCTSNode, state::T, actions::Vector{T})
|
|||||||
end
|
end
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
Arguments\n
|
||||||
|
-----
|
||||||
|
|
||||||
|
Return\n
|
||||||
|
-----
|
||||||
|
|
||||||
|
Example\n
|
||||||
|
-----
|
||||||
|
```jldoctest
|
||||||
|
julia>
|
||||||
|
```
|
||||||
|
|
||||||
TODO\n
|
TODO\n
|
||||||
[] update docstring
|
-----
|
||||||
|
[] update docstring
|
||||||
|
[] implement the function
|
||||||
|
|
||||||
|
Signature\n
|
||||||
|
-----
|
||||||
"""
|
"""
|
||||||
function simulate(state::T, max_depth::Int)
|
function simulate(state::T, max_depth::Int)
|
||||||
total_reward = 0.0
|
total_reward = 0.0
|
||||||
@@ -71,8 +143,26 @@ function simulate(state::T, max_depth::Int)
|
|||||||
end
|
end
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
Arguments\n
|
||||||
|
-----
|
||||||
|
|
||||||
|
Return\n
|
||||||
|
-----
|
||||||
|
|
||||||
|
Example\n
|
||||||
|
-----
|
||||||
|
```jldoctest
|
||||||
|
julia>
|
||||||
|
```
|
||||||
|
|
||||||
TODO\n
|
TODO\n
|
||||||
[] update docstring
|
-----
|
||||||
|
[] update docstring
|
||||||
|
[] implement the function
|
||||||
|
|
||||||
|
Signature\n
|
||||||
|
-----
|
||||||
"""
|
"""
|
||||||
function backpropagate(node::MCTSNode, reward::Float64)
|
function backpropagate(node::MCTSNode, reward::Float64)
|
||||||
node.visits += 1
|
node.visits += 1
|
||||||
@@ -84,29 +174,79 @@ function backpropagate(node::MCTSNode, reward::Float64)
|
|||||||
end
|
end
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
Arguments\n
|
||||||
|
-----
|
||||||
|
|
||||||
|
Return\n
|
||||||
|
-----
|
||||||
|
|
||||||
|
Example\n
|
||||||
|
-----
|
||||||
|
```jldoctest
|
||||||
|
julia>
|
||||||
|
```
|
||||||
|
|
||||||
TODO\n
|
TODO\n
|
||||||
[] update docstring
|
-----
|
||||||
[] implement transition()
|
[] update docstring
|
||||||
|
[] implement the function
|
||||||
|
|
||||||
|
Signature\n
|
||||||
|
-----
|
||||||
"""
|
"""
|
||||||
function transition(state, action)
|
function transition(state, action)
|
||||||
|
|
||||||
end
|
end
|
||||||
|
|
||||||
"""
|
""" Check whether a node is a leaf node
|
||||||
TODO\n
|
|
||||||
[] update docstring
|
|
||||||
[] implement isLeaf()
|
|
||||||
"""
|
|
||||||
function isLeaf(node::MCTSNode)
|
|
||||||
|
|
||||||
end
|
Arguments\n
|
||||||
|
-----
|
||||||
|
|
||||||
|
Return\n
|
||||||
|
-----
|
||||||
|
a task represent an agent
|
||||||
|
|
||||||
|
Example\n
|
||||||
|
-----
|
||||||
|
```jldoctest
|
||||||
|
julia>
|
||||||
|
```
|
||||||
|
|
||||||
|
TODO\n
|
||||||
|
-----
|
||||||
|
[] update docstring
|
||||||
|
[DONE] implement isLeaf()
|
||||||
|
|
||||||
|
Signature\n
|
||||||
|
-----
|
||||||
|
"""
|
||||||
|
isLeaf(node::MCTSNode)::Bool = isempty(node.children)
|
||||||
|
|
||||||
# ------------------------------------------------------------------------------------------------ #
|
# ------------------------------------------------------------------------------------------------ #
|
||||||
# Create a complete example using the defined MCTS functions #
|
# Create a complete example using the defined MCTS functions #
|
||||||
# ------------------------------------------------------------------------------------------------ #
|
# ------------------------------------------------------------------------------------------------ #
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
Arguments\n
|
||||||
|
-----
|
||||||
|
|
||||||
|
Return\n
|
||||||
|
-----
|
||||||
|
|
||||||
|
Example\n
|
||||||
|
-----
|
||||||
|
```jldoctest
|
||||||
|
julia>
|
||||||
|
```
|
||||||
|
|
||||||
TODO\n
|
TODO\n
|
||||||
[] update docstring
|
-----
|
||||||
|
[] update docstring
|
||||||
|
|
||||||
|
Signature\n
|
||||||
|
-----
|
||||||
"""
|
"""
|
||||||
function run_mcts(initial_state, actions, max_iterations::Int, max_depth::Int, w::Float64)
|
function run_mcts(initial_state, actions, max_iterations::Int, max_depth::Int, w::Float64)
|
||||||
root = MCTSNode(initial_state, 0, 0.0, Dict())
|
root = MCTSNode(initial_state, 0, 0.0, Dict())
|
||||||
|
|||||||
Reference in New Issue
Block a user