update
This commit is contained in:
@@ -12,6 +12,10 @@ using GeneralUtils
|
|||||||
|
|
||||||
# ---------------------------------------------- 100 --------------------------------------------- #
|
# ---------------------------------------------- 100 --------------------------------------------- #
|
||||||
|
|
||||||
|
"""
|
||||||
|
TODO\n
|
||||||
|
[] update docstring
|
||||||
|
"""
|
||||||
struct MCTSNode{T}
|
struct MCTSNode{T}
|
||||||
state::T
|
state::T
|
||||||
visits::Int
|
visits::Int
|
||||||
@@ -19,6 +23,10 @@ struct MCTSNode{T}
|
|||||||
children::Dict{T, MCTSNode}
|
children::Dict{T, MCTSNode}
|
||||||
end
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
TODO\n
|
||||||
|
[] update docstring
|
||||||
|
"""
|
||||||
function select(node::MCTSNode, c::Float64)
|
function select(node::MCTSNode, c::Float64)
|
||||||
max_uct = -Inf
|
max_uct = -Inf
|
||||||
selected_node = nothing
|
selected_node = nothing
|
||||||
@@ -35,6 +43,10 @@ function select(node::MCTSNode, c::Float64)
|
|||||||
return selected_node
|
return selected_node
|
||||||
end
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
TODO\n
|
||||||
|
[] update docstring
|
||||||
|
"""
|
||||||
function expand(node::MCTSNode, state::T, actions::Vector{T})
|
function expand(node::MCTSNode, state::T, actions::Vector{T})
|
||||||
for action in actions
|
for action in actions
|
||||||
new_state = transition(node.state, action) # Implement your transition function
|
new_state = transition(node.state, action) # Implement your transition function
|
||||||
@@ -44,6 +56,10 @@ function expand(node::MCTSNode, state::T, actions::Vector{T})
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
TODO\n
|
||||||
|
[] update docstring
|
||||||
|
"""
|
||||||
function simulate(state::T, max_depth::Int)
|
function simulate(state::T, max_depth::Int)
|
||||||
total_reward = 0.0
|
total_reward = 0.0
|
||||||
for _ in 1:max_depth
|
for _ in 1:max_depth
|
||||||
@@ -54,6 +70,10 @@ function simulate(state::T, max_depth::Int)
|
|||||||
return total_reward
|
return total_reward
|
||||||
end
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
TODO\n
|
||||||
|
[] update docstring
|
||||||
|
"""
|
||||||
function backpropagate(node::MCTSNode, reward::Float64)
|
function backpropagate(node::MCTSNode, reward::Float64)
|
||||||
node.visits += 1
|
node.visits += 1
|
||||||
node.total_reward += reward
|
node.total_reward += reward
|
||||||
@@ -63,16 +83,29 @@ function backpropagate(node::MCTSNode, reward::Float64)
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
TODO\n
|
||||||
|
[] update docstring
|
||||||
|
[] implement transition()
|
||||||
|
"""
|
||||||
|
function transition(state, action)
|
||||||
|
|
||||||
|
end
|
||||||
|
|
||||||
# ------------------------------------------------------------------------------------------------ #
|
# ------------------------------------------------------------------------------------------------ #
|
||||||
# Create a complete example using the defined MCTS functions #
|
# Create a complete example using the defined MCTS functions #
|
||||||
# ------------------------------------------------------------------------------------------------ #
|
# ------------------------------------------------------------------------------------------------ #
|
||||||
function run_mcts(initial_state, actions, max_iterations::Int, max_depth::Int, c::Float64)
|
"""
|
||||||
|
TODO\n
|
||||||
|
[] update docstring
|
||||||
|
"""
|
||||||
|
function run_mcts(initial_state, actions, max_iterations::Int, max_depth::Int, w::Float64)
|
||||||
root = MCTSNode(initial_state, 0, 0.0, Dict())
|
root = MCTSNode(initial_state, 0, 0.0, Dict())
|
||||||
|
|
||||||
for _ in 1:max_iterations
|
for _ in 1:max_iterations
|
||||||
node = root
|
node = root
|
||||||
while !is_leaf(node)
|
while !is_leaf(node)
|
||||||
node = select(node, c)
|
node = select(node, w)
|
||||||
end
|
end
|
||||||
|
|
||||||
expand(node, node.state, actions)
|
expand(node, node.state, actions)
|
||||||
@@ -94,12 +127,6 @@ actions = [-1, 0, 1]
|
|||||||
best_action = run_mcts(initial_state, actions, 1000, 10, 1.0)
|
best_action = run_mcts(initial_state, actions, 1000, 10, 1.0)
|
||||||
println("Best action to take: ", best_action)
|
println("Best action to take: ", best_action)
|
||||||
|
|
||||||
In this example, you define the MCTS algorithm with the UCT selection function and then create a complete example of using the MCTS algorithm to find the best action to take in a given state space with a set of actions. You can customize the transition function, action selection function, and parameters to suit your specific problem domain.
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
54
src/mcts.jl
54
src/mcts.jl
@@ -12,6 +12,10 @@ using GeneralUtils
|
|||||||
|
|
||||||
# ---------------------------------------------- 100 --------------------------------------------- #
|
# ---------------------------------------------- 100 --------------------------------------------- #
|
||||||
|
|
||||||
|
"""
|
||||||
|
TODO\n
|
||||||
|
[] update docstring
|
||||||
|
"""
|
||||||
struct MCTSNode{T}
|
struct MCTSNode{T}
|
||||||
state::T
|
state::T
|
||||||
visits::Int
|
visits::Int
|
||||||
@@ -19,6 +23,10 @@ struct MCTSNode{T}
|
|||||||
children::Dict{T, MCTSNode}
|
children::Dict{T, MCTSNode}
|
||||||
end
|
end
|
||||||
|
|
||||||
|
""" Traversing tree
|
||||||
|
TODO\n
|
||||||
|
[] update docstring
|
||||||
|
"""
|
||||||
function select(node::MCTSNode, c::Float64)
|
function select(node::MCTSNode, c::Float64)
|
||||||
max_uct = -Inf
|
max_uct = -Inf
|
||||||
selected_node = nothing
|
selected_node = nothing
|
||||||
@@ -35,6 +43,10 @@ function select(node::MCTSNode, c::Float64)
|
|||||||
return selected_node
|
return selected_node
|
||||||
end
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
TODO\n
|
||||||
|
[] update docstring
|
||||||
|
"""
|
||||||
function expand(node::MCTSNode, state::T, actions::Vector{T})
|
function expand(node::MCTSNode, state::T, actions::Vector{T})
|
||||||
for action in actions
|
for action in actions
|
||||||
new_state = transition(node.state, action) # Implement your transition function
|
new_state = transition(node.state, action) # Implement your transition function
|
||||||
@@ -44,6 +56,10 @@ function expand(node::MCTSNode, state::T, actions::Vector{T})
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
TODO\n
|
||||||
|
[] update docstring
|
||||||
|
"""
|
||||||
function simulate(state::T, max_depth::Int)
|
function simulate(state::T, max_depth::Int)
|
||||||
total_reward = 0.0
|
total_reward = 0.0
|
||||||
for _ in 1:max_depth
|
for _ in 1:max_depth
|
||||||
@@ -54,6 +70,10 @@ function simulate(state::T, max_depth::Int)
|
|||||||
return total_reward
|
return total_reward
|
||||||
end
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
TODO\n
|
||||||
|
[] update docstring
|
||||||
|
"""
|
||||||
function backpropagate(node::MCTSNode, reward::Float64)
|
function backpropagate(node::MCTSNode, reward::Float64)
|
||||||
node.visits += 1
|
node.visits += 1
|
||||||
node.total_reward += reward
|
node.total_reward += reward
|
||||||
@@ -63,16 +83,38 @@ function backpropagate(node::MCTSNode, reward::Float64)
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
TODO\n
|
||||||
|
[] update docstring
|
||||||
|
[] implement transition()
|
||||||
|
"""
|
||||||
|
function transition(state, action)
|
||||||
|
|
||||||
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
TODO\n
|
||||||
|
[] update docstring
|
||||||
|
[] implement isLeaf()
|
||||||
|
"""
|
||||||
|
function isLeaf(node::MCTSNode)
|
||||||
|
|
||||||
|
end
|
||||||
|
|
||||||
# ------------------------------------------------------------------------------------------------ #
|
# ------------------------------------------------------------------------------------------------ #
|
||||||
# Create a complete example using the defined MCTS functions #
|
# Create a complete example using the defined MCTS functions #
|
||||||
# ------------------------------------------------------------------------------------------------ #
|
# ------------------------------------------------------------------------------------------------ #
|
||||||
function run_mcts(initial_state, actions, max_iterations::Int, max_depth::Int, c::Float64)
|
"""
|
||||||
|
TODO\n
|
||||||
|
[] update docstring
|
||||||
|
"""
|
||||||
|
function run_mcts(initial_state, actions, max_iterations::Int, max_depth::Int, w::Float64)
|
||||||
root = MCTSNode(initial_state, 0, 0.0, Dict())
|
root = MCTSNode(initial_state, 0, 0.0, Dict())
|
||||||
|
|
||||||
for _ in 1:max_iterations
|
for _ in 1:max_iterations
|
||||||
node = root
|
node = root
|
||||||
while !is_leaf(node)
|
while !isLeaf(node)
|
||||||
node = select(node, c)
|
node = select(node, w)
|
||||||
end
|
end
|
||||||
|
|
||||||
expand(node, node.state, actions)
|
expand(node, node.state, actions)
|
||||||
@@ -94,12 +136,6 @@ actions = [-1, 0, 1]
|
|||||||
best_action = run_mcts(initial_state, actions, 1000, 10, 1.0)
|
best_action = run_mcts(initial_state, actions, 1000, 10, 1.0)
|
||||||
println("Best action to take: ", best_action)
|
println("Best action to take: ", best_action)
|
||||||
|
|
||||||
In this example, you define the MCTS algorithm with the UCT selection function and then create a complete example of using the MCTS algorithm to find the best action to take in a given state space with a set of actions. You can customize the transition function, action selection function, and parameters to suit your specific problem domain.
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user