update
This commit is contained in:
46
src/mcts.jl
46
src/mcts.jl
@@ -124,28 +124,35 @@ julia>
|
||||
```
|
||||
|
||||
# TODO
|
||||
- [] update docstring
|
||||
[] update docstring
|
||||
[] try loop should limit to 3 times. if not succeed, skip
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function expand(a::T1, node::MCTSNode, decisionMaker::Function,
|
||||
progressValueEstimator::Function, isterminal::Function; n::Integer=3) where {T1<:agent}
|
||||
|
||||
# sampling action from decisionMaker
|
||||
for sample in 1:n
|
||||
thoughtDict = decisionMaker(a, node.state)
|
||||
nthSample = 0
|
||||
while nthSample < n
|
||||
try
|
||||
thoughtDict = decisionMaker(a, node.state)
|
||||
|
||||
newNodeKey, newstate, isterminalstate, reward = MCTStransition(a, node.state, thoughtDict,
|
||||
isterminal)
|
||||
|
||||
# add progressValueEstimator
|
||||
stateevaluation, statevalue = progressValueEstimator(a, newstate)
|
||||
newNodeKey, newstate, isterminalstate, reward =
|
||||
MCTStransition(a, node.state, thoughtDict, isterminal)
|
||||
|
||||
# add progressValueEstimator
|
||||
stateevaluation, statevalue = progressValueEstimator(a, newstate)
|
||||
|
||||
if newNodeKey ∉ keys(node.children)
|
||||
node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, stateevaluation, statevalue,
|
||||
reward, isterminalstate, node, Dict{String, MCTSNode}())
|
||||
if newNodeKey ∉ keys(node.children)
|
||||
node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, stateevaluation, statevalue,
|
||||
reward, isterminalstate, node, Dict{String, MCTSNode}())
|
||||
end
|
||||
nthSample += 1
|
||||
catch
|
||||
# skip this child node if error occurs
|
||||
println("retry node expand")
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
"""
|
||||
@@ -177,14 +184,9 @@ function simulate(a, node::MCTSNode, decisionMaker::Function, progressValueEstim
|
||||
if node.isterminal
|
||||
break
|
||||
else
|
||||
try
|
||||
simTrajectoryReward += node.reward
|
||||
expand(a, node, decisionMaker, progressValueEstimator, isterminal, n=n)
|
||||
node = selectChildNode(node)
|
||||
catch
|
||||
# if error occurs, break and try again later
|
||||
break
|
||||
end
|
||||
simTrajectoryReward += node.reward
|
||||
expand(a, node, decisionMaker, progressValueEstimator, isterminal, n=n)
|
||||
node = selectChildNode(node)
|
||||
end
|
||||
end
|
||||
|
||||
@@ -354,7 +356,7 @@ function selectChildNode(node::MCTSNode)::MCTSNode
|
||||
# loop thought node children dictionary to find the highest progress value
|
||||
for (k, childNode) in node.children
|
||||
thisNodeProgressValue = childNode.statevalue + childNode.reward
|
||||
if childNode.statevalue > highestProgressValue
|
||||
if thisNodeProgressValue > highestProgressValue
|
||||
highestProgressValue = thisNodeProgressValue
|
||||
nodekey = childNode.nodekey
|
||||
end
|
||||
|
||||
Reference in New Issue
Block a user