update
This commit is contained in:
81
src/mcts.jl
81
src/mcts.jl
@@ -64,22 +64,21 @@ end
|
||||
# Arguments
|
||||
- `node::MCTSNode`
|
||||
mcts node
|
||||
- `w::Float64`
|
||||
exploration weight
|
||||
- `w::T`
|
||||
exploration weight. Value is usually between 1 to 2.
|
||||
Value 1.0 makes MCTS balance between exploration and exploitation like 50%-50%.
|
||||
Value 2.0 makes MCTS aggressively search the tree.
|
||||
# Return
|
||||
- `selectedNode::MCTSNode`
|
||||
|
||||
# Example
|
||||
```jldoctest
|
||||
julia>
|
||||
```
|
||||
|
||||
# TODO
|
||||
[] update docstring
|
||||
[x] check childNode.total_reward w/ LATS paper. Which value total_reward representing
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function UCTselect(node::MCTSNode, w::Float64)
|
||||
function UCTselect(node::MCTSNode, w::T)::MCTSNode where {T<:AbstractFloat}
|
||||
max_uct = -Inf
|
||||
selectedNode = nothing
|
||||
|
||||
@@ -130,7 +129,7 @@ julia>
|
||||
# Signature
|
||||
"""
|
||||
function expand(a::T1, node::MCTSNode, decisionMaker::Function,
|
||||
progressValueEstimator::Function, isterminal::Function; n::Integer=3) where {T1<:agent}
|
||||
progressValueEstimator::Function; n::Integer=3) where {T1<:agent}
|
||||
|
||||
nthSample = 0
|
||||
while nthSample < n
|
||||
@@ -138,7 +137,7 @@ function expand(a::T1, node::MCTSNode, decisionMaker::Function,
|
||||
thoughtDict = decisionMaker(a, node.state)
|
||||
|
||||
newNodeKey, newstate, reward, isterminalstate =
|
||||
MCTStransition(a, node.state, thoughtDict, isterminal)
|
||||
MCTStransition(a, node.state, thoughtDict)
|
||||
|
||||
# add progressValueEstimator
|
||||
stateevaluation, statevalue = progressValueEstimator(a, newstate)
|
||||
@@ -148,69 +147,78 @@ function expand(a::T1, node::MCTSNode, decisionMaker::Function,
|
||||
reward, isterminalstate, node, Dict{String, MCTSNode}())
|
||||
end
|
||||
nthSample += 1
|
||||
catch
|
||||
# skip this child node if error occurs
|
||||
println("retry node expand")
|
||||
catch e
|
||||
io = IOBuffer()
|
||||
showerror(io, e)
|
||||
errorMsg = String(take!(io))
|
||||
st = sprint((io, v) -> show(io, "text/plain", v), stacktrace(catch_backtrace()))
|
||||
println("")
|
||||
@warn "Error occurred: $errorMsg\n$st"
|
||||
println("")
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
"""
|
||||
|
||||
|
||||
""" Simulate interactions between agent and environment
|
||||
|
||||
# Arguments
|
||||
- `a::T`
|
||||
one of YiemAgent's agent
|
||||
- `node::MCTSNode`
|
||||
node that will be a simulation starting point.
|
||||
- `decisionMaker::Function`
|
||||
function that receive state return Thought and Action
|
||||
|
||||
# Return
|
||||
- `simTrajectoryReward::Number`
|
||||
|
||||
# Example
|
||||
```jldoctest
|
||||
julia>
|
||||
```
|
||||
|
||||
# TODO
|
||||
- [] update docstring
|
||||
- [x] implement the function
|
||||
- [] check for the terminal state (node.reward != 0), break if it is terminal state
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function simulate(a, node::MCTSNode, decisionMaker::Function, progressValueEstimator::Function,
|
||||
isterminal::Function, maxDepth::Int; n=3)::Number
|
||||
function simulate(a::T, node::MCTSNode, decisionMaker::Function, progressValueEstimator::Function,
|
||||
maxDepth::Int; n=3)::Number where {T<:agent}
|
||||
|
||||
simTrajectoryReward = 0.0
|
||||
|
||||
for depth in 1:maxDepth
|
||||
simTrajectoryReward += node.reward
|
||||
if node.isterminalrd
|
||||
if node.isterminal
|
||||
break
|
||||
else
|
||||
expand(a, node, decisionMaker, progressValueEstimator, isterminal, n=n)
|
||||
expand(a, node, decisionMaker, progressValueEstimator, n=n)
|
||||
node = selectChildNode(node)
|
||||
end
|
||||
end
|
||||
#BUG new expanded state has reward but it is not included because it is over maxdept by 1 state
|
||||
|
||||
return simTrajectoryReward
|
||||
end
|
||||
|
||||
"""
|
||||
""" Backpropagate reward along the simulation chain
|
||||
|
||||
# Arguments
|
||||
- `node::MCTSNode`
|
||||
node of a search tree
|
||||
- `simTrajectoryReward::T`
|
||||
total reward from all node in simulation trajectory
|
||||
|
||||
# Return
|
||||
- `No return`
|
||||
|
||||
# Example
|
||||
```jldoctest
|
||||
julia>
|
||||
```
|
||||
|
||||
# TODO
|
||||
- [] update docstring
|
||||
- [WORKING] implement the function
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function backpropagate(node, simTrajectoryReward; discountRewardCoeff=0.9)
|
||||
function backpropagate(node::MCTSNode, simTrajectoryReward::T;
|
||||
discountRewardCoeff::AbstractFloat=0.9) where {T<:Number}
|
||||
while !isroot(node)
|
||||
# Update the statistics of the current node based on the result of the playout
|
||||
node.visits += 1
|
||||
@@ -260,8 +268,8 @@ julia> thoughtDict = Dict(
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function MCTStransition(a::T1, state::T2, thoughtDict::T3, isterminal::Function
|
||||
)::Tuple{String, Dict{Symbol, <:Any}, Bool, <:Number} where {T1<:agent, T2<:AbstractDict, T3<:AbstractDict}
|
||||
function MCTStransition(a::T1, state::T2, thoughtDict::T3
|
||||
)::Tuple{String, Dict{Symbol, <:Any}, <:Number, Bool} where {T1<:agent, T2<:AbstractDict, T3<:AbstractDict}
|
||||
|
||||
actionname = thoughtDict[:action][:name]
|
||||
actioninput = thoughtDict[:action][:input]
|
||||
@@ -383,10 +391,6 @@ end
|
||||
julia>
|
||||
```
|
||||
|
||||
# TODO
|
||||
[] update docs
|
||||
[TESTING] implement the function
|
||||
|
||||
# Signature
|
||||
"""
|
||||
isroot(node::MCTSNode)::Bool = node.nodekey == "root" ? true : false
|
||||
@@ -437,7 +441,6 @@ function runMCTS(
|
||||
decisionMaker::Function,
|
||||
progressValueEstimator::Function,
|
||||
reflector::Function,
|
||||
isterminal::Function,
|
||||
n::Integer,
|
||||
maxDepth::Integer,
|
||||
maxIterations::Integer,
|
||||
@@ -455,10 +458,10 @@ function runMCTS(
|
||||
# do nothing then go directly to backpropagation
|
||||
backpropagate(leafNode, node.reward)
|
||||
else
|
||||
expand(a, node, decisionMaker, progressValueEstimator, isterminal, n=n)
|
||||
expand(a, node, decisionMaker, progressValueEstimator, n=n)
|
||||
leafNode = UCTselect(node, w)
|
||||
simTrajectoryReward = simulate(a, leafNode, decisionMaker, progressValueEstimator,
|
||||
isterminal, maxDepth, n=n)
|
||||
simTrajectoryReward = simulate(a, leafNode, decisionMaker, progressValueEstimator,
|
||||
maxDepth, n=n)
|
||||
backpropagate(leafNode, simTrajectoryReward)
|
||||
end
|
||||
end
|
||||
|
||||
Reference in New Issue
Block a user