This commit is contained in:
narawat lamaiin
2024-04-21 16:19:32 +07:00
parent b8d036e800
commit ee1446b1e2
4 changed files with 651 additions and 147 deletions

View File

@@ -3,9 +3,9 @@
and functions for the MCTS algorithm:
"""
module MCTS
module mcts
# export
export runMCTS
using Dates, UUIDs, DataStructures, JSON3, Random
using GeneralUtils
@@ -37,7 +37,7 @@ using GeneralUtils
struct MCTSNode{T}
state::T
visits::Int
total_reward::Float64
stateValue::Float64
children::Dict{T, MCTSNode}
end
@@ -45,7 +45,10 @@ end
Arguments\n
-----
node::MCTSNode
mcts node
w::Float64
exploration weight
Return\n
-----
@@ -58,25 +61,70 @@ end
TODO\n
-----
[] update docstring
[] implement the function
[DONE] check childNode.total_reward w/ LATS paper. Which value total_reward representing
Signature\n
-----
"""
function select(node::MCTSNode, c::Float64)
function select(node::MCTSNode, w::Float64)
max_uct = -Inf
selected_node = nothing
selectedNode = nothing
for (child_state, child_node) in node.children
uct_value = child_node.total_reward / child_node.visits +
c * sqrt(log(node.visits) / child_node.visits)
if uct_value > max_uct
max_uct = uct_value
selected_node = child_node
for (childState, childNode) in node.children
uctValue = childNode.stateValue +
w * sqrt(log(node.visits) / childNode.visits)
if uctValue > max_uct
max_uct = uctValue
selectedNode = childNode
end
end
return selected_node
return selectedNode
end
"""
Arguments\n
-----
Return\n
-----
Example\n
-----
```jldoctest
julia>
```
TODO\n
-----
[] update docstring
[WORKING] implement the function
Signature\n
-----
"""
function expand(node::MCTSNode, state::T, decisionMaker::Function, stateValueEstimator::Function;
n::Integer=3) where {T<:Any}
actions = []
# sampling action from decisionMaker
# for nth in 1:n
# end
for action in actions
newState = transition(node.state, action) # Implement your transition function
if newState keys(node.children)
node.children[newState] = MCTSNode(newState, 0, 0.0, Dict{T, MCTSNode}())
end
end
end
"""
@@ -101,38 +149,7 @@ end
Signature\n
-----
"""
function expand(node::MCTSNode, state::T, actions::Vector{T})
for action in actions
new_state = transition(node.state, action) # Implement your transition function
if new_state keys(node.children)
node.children[new_state] = MCTSNode(new_state, 0, 0.0, Dict{T, MCTSNode}())
end
end
end
"""
Arguments\n
-----
Return\n
-----
Example\n
-----
```jldoctest
julia>
```
TODO\n
-----
[] update docstring
[] implement the function
Signature\n
-----
"""
function simulate(state::T, max_depth::Int)
function simulate(state::T, max_depth::Int) where {T<:Any}
total_reward = 0.0
for _ in 1:max_depth
action = select_action(state) # Implement your action selection function
@@ -224,9 +241,6 @@ end
"""
isLeaf(node::MCTSNode)::Bool = isempty(node.children)
# ------------------------------------------------------------------------------------------------ #
# Create a complete example using the defined MCTS functions #
# ------------------------------------------------------------------------------------------------ #
"""
Arguments\n
@@ -244,37 +258,128 @@ isLeaf(node::MCTSNode)::Bool = isempty(node.children)
TODO\n
-----
[] update docstring
[] implement the function
[] implement RAG to pull similar experience
Signature\n
-----
"""
function run_mcts(initial_state, actions, max_iterations::Int, max_depth::Int, w::Float64)
root = MCTSNode(initial_state, 0, 0.0, Dict())
function decisionMaker()
for _ in 1:max_iterations
node = root
while !isLeaf(node)
node = select(node, w)
end
end
expand(node, node.state, actions)
"""
leaf_node = node.children[node.state]
reward = simulate(leaf_node.state, max_depth)
backpropagate(leaf_node, reward)
Arguments\n
-----
Return\n
-----
Example\n
-----
```jldoctest
julia>
```
TODO\n
-----
[] update docstring
[] implement the function
Signature\n
-----
"""
function stateValueEstimator()
end
"""
Arguments\n
-----
Return\n
-----
Example\n
-----
```jldoctest
julia>
```
TODO\n
-----
[] update docstring
[] implement the function
Signature\n
-----
"""
function reflector()
end
# ------------------------------------------------------------------------------------------------ #
# Create a complete example using the defined MCTS functions #
# ------------------------------------------------------------------------------------------------ #
""" Search for best action
Arguments\n
-----
initial state
initial state
decisionMaker::Function
decide what action to take
stateValueEstimator::Function
assess the value of the state
reflector::Function
generate lesson from trajectory and reward
n::Integer
how many times action will be sampled from decisionMaker
w::Float64
exploration weight
Return\n
-----
Example\n
-----
```jldoctest
julia>
```
TODO\n
-----
[] update docstring
Signature\n
-----
"""
function runMCTS(initialState, decisionMaker::Function, stateValueEstimator::Function,
reflector::Function, totalActionSampled::Integer, maxDepth::Integer,
maxIterations::Integer, w::Float64)
root = MCTSNode(initialState, 0, 0.0, Dict())
for _ in 1:maxIterations
node = root
while !isLeaf(node)
node = select(node, w)
end
expand(node, node.state, decisionMaker, stateValueEstimator,
n=n)
leaf_node = node.children[node.state]
reward = simulate(leaf_node.state, maxDepth)
backpropagate(leaf_node, reward)
end
best_child_state = argmax([child.total_reward / child.visits for child in values(root.children)])
return best_child_state
end
# Define your transition function and action selection function here
# Example usage
initial_state = 0
actions = [-1, 0, 1]
best_action = run_mcts(initial_state, actions, 1000, 10, 1.0)
println("Best action to take: ", best_action)