update
This commit is contained in:
@@ -12,13 +12,21 @@ using GeneralUtils
|
||||
|
||||
# ---------------------------------------------- 100 --------------------------------------------- #
|
||||
|
||||
"""
|
||||
TODO\n
|
||||
[] update docstring
|
||||
"""
|
||||
struct MCTSNode{T}
|
||||
state::T
|
||||
visits::Int
|
||||
total_reward::Float64
|
||||
children::Dict{T, MCTSNode}
|
||||
state::T
|
||||
visits::Int
|
||||
total_reward::Float64
|
||||
children::Dict{T, MCTSNode}
|
||||
end
|
||||
|
||||
"""
|
||||
TODO\n
|
||||
[] update docstring
|
||||
"""
|
||||
function select(node::MCTSNode, c::Float64)
|
||||
max_uct = -Inf
|
||||
selected_node = nothing
|
||||
@@ -35,6 +43,10 @@ function select(node::MCTSNode, c::Float64)
|
||||
return selected_node
|
||||
end
|
||||
|
||||
"""
|
||||
TODO\n
|
||||
[] update docstring
|
||||
"""
|
||||
function expand(node::MCTSNode, state::T, actions::Vector{T})
|
||||
for action in actions
|
||||
new_state = transition(node.state, action) # Implement your transition function
|
||||
@@ -44,6 +56,10 @@ function expand(node::MCTSNode, state::T, actions::Vector{T})
|
||||
end
|
||||
end
|
||||
|
||||
"""
|
||||
TODO\n
|
||||
[] update docstring
|
||||
"""
|
||||
function simulate(state::T, max_depth::Int)
|
||||
total_reward = 0.0
|
||||
for _ in 1:max_depth
|
||||
@@ -54,6 +70,10 @@ function simulate(state::T, max_depth::Int)
|
||||
return total_reward
|
||||
end
|
||||
|
||||
"""
|
||||
TODO\n
|
||||
[] update docstring
|
||||
"""
|
||||
function backpropagate(node::MCTSNode, reward::Float64)
|
||||
node.visits += 1
|
||||
node.total_reward += reward
|
||||
@@ -63,16 +83,29 @@ function backpropagate(node::MCTSNode, reward::Float64)
|
||||
end
|
||||
end
|
||||
|
||||
"""
|
||||
TODO\n
|
||||
[] update docstring
|
||||
[] implement transition()
|
||||
"""
|
||||
function transition(state, action)
|
||||
|
||||
end
|
||||
|
||||
# ------------------------------------------------------------------------------------------------ #
|
||||
# Create a complete example using the defined MCTS functions #
|
||||
# ------------------------------------------------------------------------------------------------ #
|
||||
function run_mcts(initial_state, actions, max_iterations::Int, max_depth::Int, c::Float64)
|
||||
"""
|
||||
TODO\n
|
||||
[] update docstring
|
||||
"""
|
||||
function run_mcts(initial_state, actions, max_iterations::Int, max_depth::Int, w::Float64)
|
||||
root = MCTSNode(initial_state, 0, 0.0, Dict())
|
||||
|
||||
for _ in 1:max_iterations
|
||||
node = root
|
||||
while !is_leaf(node)
|
||||
node = select(node, c)
|
||||
node = select(node, w)
|
||||
end
|
||||
|
||||
expand(node, node.state, actions)
|
||||
@@ -94,12 +127,6 @@ actions = [-1, 0, 1]
|
||||
best_action = run_mcts(initial_state, actions, 1000, 10, 1.0)
|
||||
println("Best action to take: ", best_action)
|
||||
|
||||
In this example, you define the MCTS algorithm with the UCT selection function and then create a complete example of using the MCTS algorithm to find the best action to take in a given state space with a set of actions. You can customize the transition function, action selection function, and parameters to suit your specific problem domain.
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
62
src/mcts.jl
62
src/mcts.jl
@@ -12,13 +12,21 @@ using GeneralUtils
|
||||
|
||||
# ---------------------------------------------- 100 --------------------------------------------- #
|
||||
|
||||
"""
|
||||
TODO\n
|
||||
[] update docstring
|
||||
"""
|
||||
struct MCTSNode{T}
|
||||
state::T
|
||||
visits::Int
|
||||
total_reward::Float64
|
||||
children::Dict{T, MCTSNode}
|
||||
state::T
|
||||
visits::Int
|
||||
total_reward::Float64
|
||||
children::Dict{T, MCTSNode}
|
||||
end
|
||||
|
||||
""" Traversing tree
|
||||
TODO\n
|
||||
[] update docstring
|
||||
"""
|
||||
function select(node::MCTSNode, c::Float64)
|
||||
max_uct = -Inf
|
||||
selected_node = nothing
|
||||
@@ -35,6 +43,10 @@ function select(node::MCTSNode, c::Float64)
|
||||
return selected_node
|
||||
end
|
||||
|
||||
"""
|
||||
TODO\n
|
||||
[] update docstring
|
||||
"""
|
||||
function expand(node::MCTSNode, state::T, actions::Vector{T})
|
||||
for action in actions
|
||||
new_state = transition(node.state, action) # Implement your transition function
|
||||
@@ -44,6 +56,10 @@ function expand(node::MCTSNode, state::T, actions::Vector{T})
|
||||
end
|
||||
end
|
||||
|
||||
"""
|
||||
TODO\n
|
||||
[] update docstring
|
||||
"""
|
||||
function simulate(state::T, max_depth::Int)
|
||||
total_reward = 0.0
|
||||
for _ in 1:max_depth
|
||||
@@ -54,6 +70,10 @@ function simulate(state::T, max_depth::Int)
|
||||
return total_reward
|
||||
end
|
||||
|
||||
"""
|
||||
TODO\n
|
||||
[] update docstring
|
||||
"""
|
||||
function backpropagate(node::MCTSNode, reward::Float64)
|
||||
node.visits += 1
|
||||
node.total_reward += reward
|
||||
@@ -63,16 +83,38 @@ function backpropagate(node::MCTSNode, reward::Float64)
|
||||
end
|
||||
end
|
||||
|
||||
"""
|
||||
TODO\n
|
||||
[] update docstring
|
||||
[] implement transition()
|
||||
"""
|
||||
function transition(state, action)
|
||||
|
||||
end
|
||||
|
||||
"""
|
||||
TODO\n
|
||||
[] update docstring
|
||||
[] implement isLeaf()
|
||||
"""
|
||||
function isLeaf(node::MCTSNode)
|
||||
|
||||
end
|
||||
|
||||
# ------------------------------------------------------------------------------------------------ #
|
||||
# Create a complete example using the defined MCTS functions #
|
||||
# ------------------------------------------------------------------------------------------------ #
|
||||
function run_mcts(initial_state, actions, max_iterations::Int, max_depth::Int, c::Float64)
|
||||
"""
|
||||
TODO\n
|
||||
[] update docstring
|
||||
"""
|
||||
function run_mcts(initial_state, actions, max_iterations::Int, max_depth::Int, w::Float64)
|
||||
root = MCTSNode(initial_state, 0, 0.0, Dict())
|
||||
|
||||
for _ in 1:max_iterations
|
||||
node = root
|
||||
while !is_leaf(node)
|
||||
node = select(node, c)
|
||||
while !isLeaf(node)
|
||||
node = select(node, w)
|
||||
end
|
||||
|
||||
expand(node, node.state, actions)
|
||||
@@ -94,12 +136,6 @@ actions = [-1, 0, 1]
|
||||
best_action = run_mcts(initial_state, actions, 1000, 10, 1.0)
|
||||
println("Best action to take: ", best_action)
|
||||
|
||||
In this example, you define the MCTS algorithm with the UCT selection function and then create a complete example of using the MCTS algorithm to find the best action to take in a given state space with a set of actions. You can customize the transition function, action selection function, and parameters to suit your specific problem domain.
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user