This commit is contained in:
narawat lamaiin
2024-04-20 20:25:24 +07:00
parent ff8b20716d
commit b8d036e800

View File

@@ -13,8 +13,26 @@ using GeneralUtils
# ---------------------------------------------- 100 --------------------------------------------- # # ---------------------------------------------- 100 --------------------------------------------- #
""" """
Arguments\n
-----
Return\n
-----
Example\n
-----
```jldoctest
julia>
```
TODO\n TODO\n
[] update docstring -----
[] update docstring
[] implement the function
Signature\n
-----
""" """
struct MCTSNode{T} struct MCTSNode{T}
state::T state::T
@@ -23,9 +41,27 @@ struct MCTSNode{T}
children::Dict{T, MCTSNode} children::Dict{T, MCTSNode}
end end
""" Traversing tree """
Arguments\n
-----
Return\n
-----
Example\n
-----
```jldoctest
julia>
```
TODO\n TODO\n
[] update docstring -----
[] update docstring
[] implement the function
Signature\n
-----
""" """
function select(node::MCTSNode, c::Float64) function select(node::MCTSNode, c::Float64)
max_uct = -Inf max_uct = -Inf
@@ -44,8 +80,26 @@ function select(node::MCTSNode, c::Float64)
end end
""" """
Arguments\n
-----
Return\n
-----
Example\n
-----
```jldoctest
julia>
```
TODO\n TODO\n
[] update docstring -----
[] update docstring
[] implement the function
Signature\n
-----
""" """
function expand(node::MCTSNode, state::T, actions::Vector{T}) function expand(node::MCTSNode, state::T, actions::Vector{T})
for action in actions for action in actions
@@ -57,8 +111,26 @@ function expand(node::MCTSNode, state::T, actions::Vector{T})
end end
""" """
Arguments\n
-----
Return\n
-----
Example\n
-----
```jldoctest
julia>
```
TODO\n TODO\n
[] update docstring -----
[] update docstring
[] implement the function
Signature\n
-----
""" """
function simulate(state::T, max_depth::Int) function simulate(state::T, max_depth::Int)
total_reward = 0.0 total_reward = 0.0
@@ -71,8 +143,26 @@ function simulate(state::T, max_depth::Int)
end end
""" """
Arguments\n
-----
Return\n
-----
Example\n
-----
```jldoctest
julia>
```
TODO\n TODO\n
[] update docstring -----
[] update docstring
[] implement the function
Signature\n
-----
""" """
function backpropagate(node::MCTSNode, reward::Float64) function backpropagate(node::MCTSNode, reward::Float64)
node.visits += 1 node.visits += 1
@@ -84,29 +174,79 @@ function backpropagate(node::MCTSNode, reward::Float64)
end end
""" """
Arguments\n
-----
Return\n
-----
Example\n
-----
```jldoctest
julia>
```
TODO\n TODO\n
[] update docstring -----
[] implement transition() [] update docstring
[] implement the function
Signature\n
-----
""" """
function transition(state, action) function transition(state, action)
end end
""" """ Check whether a node is a leaf node
TODO\n
[] update docstring
[] implement isLeaf()
"""
function isLeaf(node::MCTSNode)
end Arguments\n
-----
Return\n
-----
a task represent an agent
Example\n
-----
```jldoctest
julia>
```
TODO\n
-----
[] update docstring
[DONE] implement isLeaf()
Signature\n
-----
"""
isLeaf(node::MCTSNode)::Bool = isempty(node.children)
# ------------------------------------------------------------------------------------------------ # # ------------------------------------------------------------------------------------------------ #
# Create a complete example using the defined MCTS functions # # Create a complete example using the defined MCTS functions #
# ------------------------------------------------------------------------------------------------ # # ------------------------------------------------------------------------------------------------ #
""" """
Arguments\n
-----
Return\n
-----
Example\n
-----
```jldoctest
julia>
```
TODO\n TODO\n
[] update docstring -----
[] update docstring
Signature\n
-----
""" """
function run_mcts(initial_state, actions, max_iterations::Int, max_depth::Int, w::Float64) function run_mcts(initial_state, actions, max_iterations::Int, max_depth::Int, w::Float64)
root = MCTSNode(initial_state, 0, 0.0, Dict()) root = MCTSNode(initial_state, 0, 0.0, Dict())