This commit is contained in:
narawat lamaiin
2024-05-03 10:32:41 +07:00
parent 8262423317
commit ef940b6ada
3 changed files with 111 additions and 50 deletions

View File

@@ -42,12 +42,16 @@ julia> state = Dict(
)
```
# TODO
[] update docstring
# Signature
"""
struct MCTSNode{T<:AbstractDict}
statekey::String
state::T
visits::Integer
stateValue::AbstractFloat
progressValue::Number
children::Dict{String, MCTSNode}
end
@@ -90,12 +94,16 @@ end
""" Expand selected node
# Arguments
- `a::T1`
One of YiemAgent's agent
- `node::MCTSNode`
MCTS node
- `state::T`
- `state::T2`
a state of a game. Can be a Dict or something else.
- `decisionMaker::Function`
a function that output Thought and Action
- `progressValueEstimator::Function`
a function that output trajectory progress score
# Return
@@ -104,14 +112,10 @@ end
julia>
```
# TODO
- [] update docstring
- [WORKING] implement the function
# Signature
"""
function expand(a::T1, node::MCTSNode, state::T2, decisionMaker::Function,
stateValueEstimator::Function; n::Integer=3) where {T1<:agent, T2<:AbstractDict}
progressValueEstimator::Function; n::Integer=3) where {T1<:agent, T2<:AbstractDict}
# sampling action from decisionMaker
for sample in 1:n
@@ -120,15 +124,12 @@ function expand(a::T1, node::MCTSNode, state::T2, decisionMaker::Function,
@show thoughtDict
newStatekey, newstate = MCTStransition(a, node.state, thoughtDict) #[] Implement your transition function
if newStatekey keys(node.children)# BUG should be "key of the newstate" here not newstate itself
node.children[newStatekey] = MCTSNode(newstate, 0, 0.0, Dict{String, MCTSNode}())
# add progressValueEstimator
_, progressValue = progressValueEstimator(a, newstate)
if newStatekey keys(node.children)
node.children[newStatekey] = MCTSNode(newStatekey, newstate, 0, progressValue, Dict{String, MCTSNode}())
end
# add stateValueEstimator
end
end
@@ -145,23 +146,24 @@ julia>
# TODO
- [] update docstring
- [] implement the function
- [WORKING] implement the function
- [] reward only comes at terminal state
# Signature
"""
function simulate(state::T, max_depth::Int) where {T<:AbstractDict}
total_reward = 0.0
for _ in 1:max_depth
#[] Implement your action selection function based on highest stateValue
action = select_action(state) # current state
state, reward = transition(state, action) # Implement transition function to a new state
error("--> simulate")
total_reward = 0.0
for _ in 1:max_depth
#[] Implement your action selection function based on highest stateValue
action = select_action(state) # current state
state, reward = transition(state, action) # Implement transition function to a new state
#[] check for the terminal state
#[] check for the terminal state
total_reward += reward
end
return total_reward
total_reward += reward
end
return total_reward
end
"""
@@ -332,7 +334,7 @@ end
initial state
- `decisionMaker::Function`
decide what action to take
- `stateValueEstimator::Function`
- `progressValueEstimator::Function`
assess the value of the state
- `reflector::Function`
generate lesson from trajectory and reward
@@ -361,7 +363,7 @@ function runMCTS(
a::T1,
initialState,
decisionMaker::Function,
stateValueEstimator::Function,
progressValueEstimator::Function,
reflector::Function,
isterminal::Function,
n::Integer,
@@ -369,7 +371,7 @@ function runMCTS(
maxIterations::Integer,
w::Float64) where {T1<:agent}
root = MCTSNode(initialState, 0, 0.0, Dict{String, MCTSNode}())
root = MCTSNode("root", initialState, 0, 0.0, Dict{String, MCTSNode}())
for _ in 1:maxIterations
node = root
@@ -377,7 +379,7 @@ function runMCTS(
node = select(node, w)
end
expand(a, node, node.state, decisionMaker, stateValueEstimator, n=n)
expand(a, node, node.state, decisionMaker, progressValueEstimator, n=n)
# from paper, just start simulation at this node. Not the node that newly expanded
leaf_node = node