update
This commit is contained in:
@@ -16,6 +16,9 @@ module YiemAgent
|
|||||||
include("llmfunction.jl")
|
include("llmfunction.jl")
|
||||||
using .llmfunction
|
using .llmfunction
|
||||||
|
|
||||||
|
include("mcts.jl")
|
||||||
|
using .mcts
|
||||||
|
|
||||||
include("interface.jl")
|
include("interface.jl")
|
||||||
using .interface
|
using .interface
|
||||||
|
|
||||||
|
|||||||
111
src/mcts copy.jl
Normal file
111
src/mcts copy.jl
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
""" To implement a Monte Carlo Tree Search (MCTS) algorithm in Julia with the UCT (Upper Confidence
|
||||||
|
Bound for Trees) selection function, you can follow the steps below: Define the necessary types
|
||||||
|
and functions for the MCTS algorithm:
|
||||||
|
"""
|
||||||
|
|
||||||
|
module MCTS
|
||||||
|
|
||||||
|
# export
|
||||||
|
|
||||||
|
using Dates, UUIDs, DataStructures, JSON3, Random
|
||||||
|
using GeneralUtils
|
||||||
|
|
||||||
|
# ---------------------------------------------- 100 --------------------------------------------- #
|
||||||
|
|
||||||
|
struct MCTSNode{T}
|
||||||
|
state::T
|
||||||
|
visits::Int
|
||||||
|
total_reward::Float64
|
||||||
|
children::Dict{T, MCTSNode}
|
||||||
|
end
|
||||||
|
|
||||||
|
function select(node::MCTSNode, c::Float64)
|
||||||
|
max_uct = -Inf
|
||||||
|
selected_node = nothing
|
||||||
|
|
||||||
|
for (child_state, child_node) in node.children
|
||||||
|
uct_value = child_node.total_reward / child_node.visits +
|
||||||
|
c * sqrt(log(node.visits) / child_node.visits)
|
||||||
|
if uct_value > max_uct
|
||||||
|
max_uct = uct_value
|
||||||
|
selected_node = child_node
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
return selected_node
|
||||||
|
end
|
||||||
|
|
||||||
|
function expand(node::MCTSNode, state::T, actions::Vector{T})
|
||||||
|
for action in actions
|
||||||
|
new_state = transition(node.state, action) # Implement your transition function
|
||||||
|
if new_state ∉ keys(node.children)
|
||||||
|
node.children[new_state] = MCTSNode(new_state, 0, 0.0, Dict{T, MCTSNode}())
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
function simulate(state::T, max_depth::Int)
|
||||||
|
total_reward = 0.0
|
||||||
|
for _ in 1:max_depth
|
||||||
|
action = select_action(state) # Implement your action selection function
|
||||||
|
state, reward = transition(state, action) # Implement your transition function
|
||||||
|
total_reward += reward
|
||||||
|
end
|
||||||
|
return total_reward
|
||||||
|
end
|
||||||
|
|
||||||
|
function backpropagate(node::MCTSNode, reward::Float64)
|
||||||
|
node.visits += 1
|
||||||
|
node.total_reward += reward
|
||||||
|
if !isempty(node.children)
|
||||||
|
best_child = argmax([child.total_reward / child.visits for child in values(node.children)])
|
||||||
|
backpropagate(node.children[best_child], -reward)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------------------------------------ #
|
||||||
|
# Create a complete example using the defined MCTS functions #
|
||||||
|
# ------------------------------------------------------------------------------------------------ #
|
||||||
|
function run_mcts(initial_state, actions, max_iterations::Int, max_depth::Int, c::Float64)
|
||||||
|
root = MCTSNode(initial_state, 0, 0.0, Dict())
|
||||||
|
|
||||||
|
for _ in 1:max_iterations
|
||||||
|
node = root
|
||||||
|
while !is_leaf(node)
|
||||||
|
node = select(node, c)
|
||||||
|
end
|
||||||
|
|
||||||
|
expand(node, node.state, actions)
|
||||||
|
|
||||||
|
leaf_node = node.children[node.state]
|
||||||
|
reward = simulate(leaf_node.state, max_depth)
|
||||||
|
backpropagate(leaf_node, reward)
|
||||||
|
end
|
||||||
|
|
||||||
|
best_child_state = argmax([child.total_reward / child.visits for child in values(root.children)])
|
||||||
|
return best_child_state
|
||||||
|
end
|
||||||
|
|
||||||
|
# Define your transition function and action selection function here
|
||||||
|
|
||||||
|
# Example usage
|
||||||
|
initial_state = 0
|
||||||
|
actions = [-1, 0, 1]
|
||||||
|
best_action = run_mcts(initial_state, actions, 1000, 10, 1.0)
|
||||||
|
println("Best action to take: ", best_action)
|
||||||
|
|
||||||
|
In this example, you define the MCTS algorithm with the UCT selection function and then create a complete example of using the MCTS algorithm to find the best action to take in a given state space with a set of actions. You can customize the transition function, action selection function, and parameters to suit your specific problem domain.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
end
|
||||||
111
src/mcts.jl
Normal file
111
src/mcts.jl
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
""" To implement a Monte Carlo Tree Search (MCTS) algorithm in Julia with the UCT (Upper Confidence
|
||||||
|
Bound for Trees) selection function, you can follow the steps below: Define the necessary types
|
||||||
|
and functions for the MCTS algorithm:
|
||||||
|
"""
|
||||||
|
|
||||||
|
module MCTS
|
||||||
|
|
||||||
|
# export
|
||||||
|
|
||||||
|
using Dates, UUIDs, DataStructures, JSON3, Random
|
||||||
|
using GeneralUtils
|
||||||
|
|
||||||
|
# ---------------------------------------------- 100 --------------------------------------------- #
|
||||||
|
|
||||||
|
struct MCTSNode{T}
|
||||||
|
state::T
|
||||||
|
visits::Int
|
||||||
|
total_reward::Float64
|
||||||
|
children::Dict{T, MCTSNode}
|
||||||
|
end
|
||||||
|
|
||||||
|
function select(node::MCTSNode, c::Float64)
|
||||||
|
max_uct = -Inf
|
||||||
|
selected_node = nothing
|
||||||
|
|
||||||
|
for (child_state, child_node) in node.children
|
||||||
|
uct_value = child_node.total_reward / child_node.visits +
|
||||||
|
c * sqrt(log(node.visits) / child_node.visits)
|
||||||
|
if uct_value > max_uct
|
||||||
|
max_uct = uct_value
|
||||||
|
selected_node = child_node
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
return selected_node
|
||||||
|
end
|
||||||
|
|
||||||
|
function expand(node::MCTSNode, state::T, actions::Vector{T})
|
||||||
|
for action in actions
|
||||||
|
new_state = transition(node.state, action) # Implement your transition function
|
||||||
|
if new_state ∉ keys(node.children)
|
||||||
|
node.children[new_state] = MCTSNode(new_state, 0, 0.0, Dict{T, MCTSNode}())
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
function simulate(state::T, max_depth::Int)
|
||||||
|
total_reward = 0.0
|
||||||
|
for _ in 1:max_depth
|
||||||
|
action = select_action(state) # Implement your action selection function
|
||||||
|
state, reward = transition(state, action) # Implement your transition function
|
||||||
|
total_reward += reward
|
||||||
|
end
|
||||||
|
return total_reward
|
||||||
|
end
|
||||||
|
|
||||||
|
function backpropagate(node::MCTSNode, reward::Float64)
|
||||||
|
node.visits += 1
|
||||||
|
node.total_reward += reward
|
||||||
|
if !isempty(node.children)
|
||||||
|
best_child = argmax([child.total_reward / child.visits for child in values(node.children)])
|
||||||
|
backpropagate(node.children[best_child], -reward)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------------------------------------ #
|
||||||
|
# Create a complete example using the defined MCTS functions #
|
||||||
|
# ------------------------------------------------------------------------------------------------ #
|
||||||
|
function run_mcts(initial_state, actions, max_iterations::Int, max_depth::Int, c::Float64)
|
||||||
|
root = MCTSNode(initial_state, 0, 0.0, Dict())
|
||||||
|
|
||||||
|
for _ in 1:max_iterations
|
||||||
|
node = root
|
||||||
|
while !is_leaf(node)
|
||||||
|
node = select(node, c)
|
||||||
|
end
|
||||||
|
|
||||||
|
expand(node, node.state, actions)
|
||||||
|
|
||||||
|
leaf_node = node.children[node.state]
|
||||||
|
reward = simulate(leaf_node.state, max_depth)
|
||||||
|
backpropagate(leaf_node, reward)
|
||||||
|
end
|
||||||
|
|
||||||
|
best_child_state = argmax([child.total_reward / child.visits for child in values(root.children)])
|
||||||
|
return best_child_state
|
||||||
|
end
|
||||||
|
|
||||||
|
# Define your transition function and action selection function here
|
||||||
|
|
||||||
|
# Example usage
|
||||||
|
initial_state = 0
|
||||||
|
actions = [-1, 0, 1]
|
||||||
|
best_action = run_mcts(initial_state, actions, 1000, 10, 1.0)
|
||||||
|
println("Best action to take: ", best_action)
|
||||||
|
|
||||||
|
In this example, you define the MCTS algorithm with the UCT selection function and then create a complete example of using the MCTS algorithm to find the best action to take in a given state space with a set of actions. You can customize the transition function, action selection function, and parameters to suit your specific problem domain.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
end
|
||||||
Reference in New Issue
Block a user