update
This commit is contained in:
62
src/mcts.jl
62
src/mcts.jl
@@ -42,12 +42,16 @@ julia> state = Dict(
|
||||
)
|
||||
```
|
||||
|
||||
# TODO
|
||||
[] update docstring
|
||||
|
||||
# Signature
|
||||
"""
|
||||
struct MCTSNode{T<:AbstractDict}
|
||||
statekey::String
|
||||
state::T
|
||||
visits::Integer
|
||||
stateValue::AbstractFloat
|
||||
progressValue::Number
|
||||
children::Dict{String, MCTSNode}
|
||||
end
|
||||
|
||||
@@ -90,12 +94,16 @@ end
|
||||
""" Expand selected node
|
||||
|
||||
# Arguments
|
||||
- `a::T1`
|
||||
One of YiemAgent's agent
|
||||
- `node::MCTSNode`
|
||||
MCTS node
|
||||
- `state::T`
|
||||
- `state::T2`
|
||||
a state of a game. Can be a Dict or something else.
|
||||
- `decisionMaker::Function`
|
||||
|
||||
a function that output Thought and Action
|
||||
- `progressValueEstimator::Function`
|
||||
a function that output trajectory progress score
|
||||
|
||||
# Return
|
||||
|
||||
@@ -104,14 +112,10 @@ end
|
||||
julia>
|
||||
```
|
||||
|
||||
# TODO
|
||||
- [] update docstring
|
||||
- [WORKING] implement the function
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function expand(a::T1, node::MCTSNode, state::T2, decisionMaker::Function,
|
||||
stateValueEstimator::Function; n::Integer=3) where {T1<:agent, T2<:AbstractDict}
|
||||
progressValueEstimator::Function; n::Integer=3) where {T1<:agent, T2<:AbstractDict}
|
||||
|
||||
# sampling action from decisionMaker
|
||||
for sample in 1:n
|
||||
@@ -120,15 +124,12 @@ function expand(a::T1, node::MCTSNode, state::T2, decisionMaker::Function,
|
||||
@show thoughtDict
|
||||
newStatekey, newstate = MCTStransition(a, node.state, thoughtDict) #[] Implement your transition function
|
||||
|
||||
if newStatekey ∉ keys(node.children)# BUG should be "key of the newstate" here not newstate itself
|
||||
node.children[newStatekey] = MCTSNode(newstate, 0, 0.0, Dict{String, MCTSNode}())
|
||||
# add progressValueEstimator
|
||||
_, progressValue = progressValueEstimator(a, newstate)
|
||||
|
||||
if newStatekey ∉ keys(node.children)
|
||||
node.children[newStatekey] = MCTSNode(newStatekey, newstate, 0, progressValue, Dict{String, MCTSNode}())
|
||||
end
|
||||
|
||||
# add stateValueEstimator
|
||||
|
||||
|
||||
|
||||
|
||||
end
|
||||
end
|
||||
|
||||
@@ -145,23 +146,24 @@ julia>
|
||||
|
||||
# TODO
|
||||
- [] update docstring
|
||||
- [] implement the function
|
||||
- [WORKING] implement the function
|
||||
- [] reward only comes at terminal state
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function simulate(state::T, max_depth::Int) where {T<:AbstractDict}
|
||||
total_reward = 0.0
|
||||
for _ in 1:max_depth
|
||||
#[] Implement your action selection function based on highest stateValue
|
||||
action = select_action(state) # current state
|
||||
state, reward = transition(state, action) # Implement transition function to a new state
|
||||
error("--> simulate")
|
||||
total_reward = 0.0
|
||||
for _ in 1:max_depth
|
||||
#[] Implement your action selection function based on highest stateValue
|
||||
action = select_action(state) # current state
|
||||
state, reward = transition(state, action) # Implement transition function to a new state
|
||||
|
||||
#[] check for the terminal state
|
||||
#[] check for the terminal state
|
||||
|
||||
total_reward += reward
|
||||
end
|
||||
return total_reward
|
||||
total_reward += reward
|
||||
end
|
||||
return total_reward
|
||||
end
|
||||
|
||||
"""
|
||||
@@ -332,7 +334,7 @@ end
|
||||
initial state
|
||||
- `decisionMaker::Function`
|
||||
decide what action to take
|
||||
- `stateValueEstimator::Function`
|
||||
- `progressValueEstimator::Function`
|
||||
assess the value of the state
|
||||
- `reflector::Function`
|
||||
generate lesson from trajectory and reward
|
||||
@@ -361,7 +363,7 @@ function runMCTS(
|
||||
a::T1,
|
||||
initialState,
|
||||
decisionMaker::Function,
|
||||
stateValueEstimator::Function,
|
||||
progressValueEstimator::Function,
|
||||
reflector::Function,
|
||||
isterminal::Function,
|
||||
n::Integer,
|
||||
@@ -369,7 +371,7 @@ function runMCTS(
|
||||
maxIterations::Integer,
|
||||
w::Float64) where {T1<:agent}
|
||||
|
||||
root = MCTSNode(initialState, 0, 0.0, Dict{String, MCTSNode}())
|
||||
root = MCTSNode("root", initialState, 0, 0.0, Dict{String, MCTSNode}())
|
||||
|
||||
for _ in 1:maxIterations
|
||||
node = root
|
||||
@@ -377,7 +379,7 @@ function runMCTS(
|
||||
node = select(node, w)
|
||||
end
|
||||
|
||||
expand(a, node, node.state, decisionMaker, stateValueEstimator, n=n)
|
||||
expand(a, node, node.state, decisionMaker, progressValueEstimator, n=n)
|
||||
|
||||
# from paper, just start simulation at this node. Not the node that newly expanded
|
||||
leaf_node = node
|
||||
|
||||
Reference in New Issue
Block a user