This commit is contained in:
narawat lamaiin
2024-04-22 17:41:52 +07:00
parent ee1446b1e2
commit 1962035990
3 changed files with 86 additions and 50 deletions

View File

@@ -5,17 +5,22 @@
module mcts
export runMCTS
export MCTSNode, runMCTS, decisionMaker, stateValueEstimator, reflector
using Dates, UUIDs, DataStructures, JSON3, Random
using GeneralUtils
# ---------------------------------------------- 100 --------------------------------------------- #
"""
""" a node for MCTS search tree
Arguments\n
-----
state::T
Represent a state of a game. Can be a Dict or something else.
visits::Integer
number of time the game visits this state
stateValue::Float64
Return\n
-----
@@ -29,15 +34,15 @@ using GeneralUtils
TODO\n
-----
[] update docstring
[] implement the function
[DONE] implement the function
Signature\n
-----
"""
struct MCTSNode{T}
state::T
visits::Int
stateValue::Float64
visits::Integer
stateValue::AbstractFloat
children::Dict{T, MCTSNode}
end
@@ -107,23 +112,12 @@ end
function expand(node::MCTSNode, state::T, decisionMaker::Function, stateValueEstimator::Function;
n::Integer=3) where {T<:Any}
actions = []
# sampling action from decisionMaker
# for nth in 1:n
# end
for action in actions
newState = transition(node.state, action) # Implement your transition function
if newState keys(node.children)
node.children[newState] = MCTSNode(newState, 0, 0.0, Dict{T, MCTSNode}())
end
for sample in 1:n
newState = transition(node.state, action) #[] Implement your transition function
if newState keys(node.children)
node.children[newState] = MCTSNode(newState, 0, 0.0, Dict{T, MCTSNode}())
end
end
end
@@ -145,6 +139,7 @@ end
-----
[] update docstring
[] implement the function
[] reward only comes at terminal state
Signature\n
-----
@@ -152,9 +147,13 @@ end
function simulate(state::T, max_depth::Int) where {T<:Any}
total_reward = 0.0
for _ in 1:max_depth
action = select_action(state) # Implement your action selection function
state, reward = transition(state, action) # Implement your transition function
total_reward += reward
#[] Implement your action selection function based on highest stateValue
action = select_action(state) # current state
state, reward = transition(state, action) # Implement transition function to a new state
#[] check for the terminal state
total_reward += reward
end
return total_reward
end
@@ -183,7 +182,9 @@ end
"""
function backpropagate(node::MCTSNode, reward::Float64)
node.visits += 1
node.total_reward += reward
# [] there is no total_reward in the paper, buy they use stateValue
node.total_reward += reward
if !isempty(node.children)
best_child = argmax([child.total_reward / child.visits for child in values(node.children)])
backpropagate(node.children[best_child], -reward)
@@ -216,25 +217,27 @@ function transition(state, action)
end
""" Check whether a node is a leaf node
""" Check whether a node is a leaf node of a tree
Arguments\n
-----
node::MCTSNode
node of a tree
Return\n
-----
a task represent an agent
result::Bool
true if the node is a leaf node of a tree otherwise false
Example\n
-----
```jldoctest
julia>
julia> using
```
TODO\n
-----
[] update docstring
[DONE] implement isLeaf()
Signature\n
-----
@@ -320,6 +323,34 @@ function reflector()
end
"""
Arguments\n
-----
Return\n
-----
Example\n
-----
```jldoctest
julia>
```
TODO\n
-----
[] update docstring
[] implement the function
[] implement RAG to pull similar experience
Signature\n
-----
"""
function isTerminal()
end
# ------------------------------------------------------------------------------------------------ #
# Create a complete example using the defined MCTS functions #
# ------------------------------------------------------------------------------------------------ #
@@ -342,6 +373,8 @@ end
Return\n
-----
plan::Vector{Dict}
best plan
Example\n
-----
@@ -357,26 +390,26 @@ end
-----
"""
function runMCTS(initialState, decisionMaker::Function, stateValueEstimator::Function,
reflector::Function, totalActionSampled::Integer, maxDepth::Integer,
reflector::Function, n::Integer, maxDepth::Integer,
maxIterations::Integer, w::Float64)
root = MCTSNode(initialState, 0, 0.0, Dict())
for _ in 1:maxIterations
node = root
while !isLeaf(node)
node = select(node, w)
end
expand(node, node.state, decisionMaker, stateValueEstimator,
n=n)
leaf_node = node.children[node.state]
reward = simulate(leaf_node.state, maxDepth)
backpropagate(leaf_node, reward)
statetype = typeof(initialState)
root = MCTSNode(initialState, 0, 0.0, Dict{statetype, MCTSNode}())
error("---> runMCTS")
for _ in 1:maxIterations
node = root
while !isLeaf(node)
node = select(node, w)
end
best_child_state = argmax([child.total_reward / child.visits for child in values(root.children)])
return best_child_state
expand(node, node.state, decisionMaker, stateValueEstimator, n=n)
leaf_node = node.children[node.state] # mark leaf node
reward = simulate(leaf_node.state, maxDepth)
backpropagate(leaf_node, reward)
end
best_child_state = argmax([child.total_reward / child.visits for child in values(root.children)])
return best_child_state
end