diff --git a/src/mcts copy.jl b/src/mcts copy.jl index a2fb267..df65a5c 100644 --- a/src/mcts copy.jl +++ b/src/mcts copy.jl @@ -12,13 +12,21 @@ using GeneralUtils # ---------------------------------------------- 100 --------------------------------------------- # +""" + TODO\n + [] update docstring +""" struct MCTSNode{T} - state::T - visits::Int - total_reward::Float64 - children::Dict{T, MCTSNode} + state::T + visits::Int + total_reward::Float64 + children::Dict{T, MCTSNode} end +""" + TODO\n + [] update docstring +""" function select(node::MCTSNode, c::Float64) max_uct = -Inf selected_node = nothing @@ -35,6 +43,10 @@ function select(node::MCTSNode, c::Float64) return selected_node end +""" + TODO\n + [] update docstring +""" function expand(node::MCTSNode, state::T, actions::Vector{T}) for action in actions new_state = transition(node.state, action) # Implement your transition function @@ -44,6 +56,10 @@ function expand(node::MCTSNode, state::T, actions::Vector{T}) end end +""" + TODO\n + [] update docstring +""" function simulate(state::T, max_depth::Int) total_reward = 0.0 for _ in 1:max_depth @@ -54,6 +70,10 @@ function simulate(state::T, max_depth::Int) return total_reward end +""" + TODO\n + [] update docstring +""" function backpropagate(node::MCTSNode, reward::Float64) node.visits += 1 node.total_reward += reward @@ -63,16 +83,29 @@ function backpropagate(node::MCTSNode, reward::Float64) end end +""" + TODO\n + [] update docstring + [] implement transition() +""" +function transition(state, action) + +end + # ------------------------------------------------------------------------------------------------ # # Create a complete example using the defined MCTS functions # # ------------------------------------------------------------------------------------------------ # -function run_mcts(initial_state, actions, max_iterations::Int, max_depth::Int, c::Float64) +""" + TODO\n + [] update docstring +""" +function run_mcts(initial_state, actions, max_iterations::Int, max_depth::Int, w::Float64) root = MCTSNode(initial_state, 0, 0.0, Dict()) for _ in 1:max_iterations node = root while !is_leaf(node) - node = select(node, c) + node = select(node, w) end expand(node, node.state, actions) @@ -94,12 +127,6 @@ actions = [-1, 0, 1] best_action = run_mcts(initial_state, actions, 1000, 10, 1.0) println("Best action to take: ", best_action) -In this example, you define the MCTS algorithm with the UCT selection function and then create a complete example of using the MCTS algorithm to find the best action to take in a given state space with a set of actions. You can customize the transition function, action selection function, and parameters to suit your specific problem domain. - - - - - diff --git a/src/mcts.jl b/src/mcts.jl index a2fb267..54491d7 100644 --- a/src/mcts.jl +++ b/src/mcts.jl @@ -12,13 +12,21 @@ using GeneralUtils # ---------------------------------------------- 100 --------------------------------------------- # +""" + TODO\n + [] update docstring +""" struct MCTSNode{T} - state::T - visits::Int - total_reward::Float64 - children::Dict{T, MCTSNode} + state::T + visits::Int + total_reward::Float64 + children::Dict{T, MCTSNode} end +""" Traversing tree + TODO\n + [] update docstring +""" function select(node::MCTSNode, c::Float64) max_uct = -Inf selected_node = nothing @@ -35,6 +43,10 @@ function select(node::MCTSNode, c::Float64) return selected_node end +""" + TODO\n + [] update docstring +""" function expand(node::MCTSNode, state::T, actions::Vector{T}) for action in actions new_state = transition(node.state, action) # Implement your transition function @@ -44,6 +56,10 @@ function expand(node::MCTSNode, state::T, actions::Vector{T}) end end +""" + TODO\n + [] update docstring +""" function simulate(state::T, max_depth::Int) total_reward = 0.0 for _ in 1:max_depth @@ -54,6 +70,10 @@ function simulate(state::T, max_depth::Int) return total_reward end +""" + TODO\n + [] update docstring +""" function backpropagate(node::MCTSNode, reward::Float64) node.visits += 1 node.total_reward += reward @@ -63,16 +83,38 @@ function backpropagate(node::MCTSNode, reward::Float64) end end +""" + TODO\n + [] update docstring + [] implement transition() +""" +function transition(state, action) + +end + +""" + TODO\n + [] update docstring + [] implement isLeaf() +""" +function isLeaf(node::MCTSNode) + +end + # ------------------------------------------------------------------------------------------------ # # Create a complete example using the defined MCTS functions # # ------------------------------------------------------------------------------------------------ # -function run_mcts(initial_state, actions, max_iterations::Int, max_depth::Int, c::Float64) +""" + TODO\n + [] update docstring +""" +function run_mcts(initial_state, actions, max_iterations::Int, max_depth::Int, w::Float64) root = MCTSNode(initial_state, 0, 0.0, Dict()) for _ in 1:max_iterations node = root - while !is_leaf(node) - node = select(node, c) + while !isLeaf(node) + node = select(node, w) end expand(node, node.state, actions) @@ -94,12 +136,6 @@ actions = [-1, 0, 1] best_action = run_mcts(initial_state, actions, 1000, 10, 1.0) println("Best action to take: ", best_action) -In this example, you define the MCTS algorithm with the UCT selection function and then create a complete example of using the MCTS algorithm to find the best action to take in a given state space with a set of actions. You can customize the transition function, action selection function, and parameters to suit your specific problem domain. - - - - -