update
This commit is contained in:
24
src/mcts.jl
24
src/mcts.jl
@@ -124,19 +124,21 @@ julia>
|
|||||||
```
|
```
|
||||||
|
|
||||||
# TODO
|
# TODO
|
||||||
- [] update docstring
|
[] update docstring
|
||||||
|
[] try loop should limit to 3 times. if not succeed, skip
|
||||||
|
|
||||||
# Signature
|
# Signature
|
||||||
"""
|
"""
|
||||||
function expand(a::T1, node::MCTSNode, decisionMaker::Function,
|
function expand(a::T1, node::MCTSNode, decisionMaker::Function,
|
||||||
progressValueEstimator::Function, isterminal::Function; n::Integer=3) where {T1<:agent}
|
progressValueEstimator::Function, isterminal::Function; n::Integer=3) where {T1<:agent}
|
||||||
|
|
||||||
# sampling action from decisionMaker
|
nthSample = 0
|
||||||
for sample in 1:n
|
while nthSample < n
|
||||||
|
try
|
||||||
thoughtDict = decisionMaker(a, node.state)
|
thoughtDict = decisionMaker(a, node.state)
|
||||||
|
|
||||||
newNodeKey, newstate, isterminalstate, reward = MCTStransition(a, node.state, thoughtDict,
|
newNodeKey, newstate, isterminalstate, reward =
|
||||||
isterminal)
|
MCTStransition(a, node.state, thoughtDict, isterminal)
|
||||||
|
|
||||||
# add progressValueEstimator
|
# add progressValueEstimator
|
||||||
stateevaluation, statevalue = progressValueEstimator(a, newstate)
|
stateevaluation, statevalue = progressValueEstimator(a, newstate)
|
||||||
@@ -145,6 +147,11 @@ function expand(a::T1, node::MCTSNode, decisionMaker::Function,
|
|||||||
node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, stateevaluation, statevalue,
|
node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, stateevaluation, statevalue,
|
||||||
reward, isterminalstate, node, Dict{String, MCTSNode}())
|
reward, isterminalstate, node, Dict{String, MCTSNode}())
|
||||||
end
|
end
|
||||||
|
nthSample += 1
|
||||||
|
catch
|
||||||
|
# skip this child node if error occurs
|
||||||
|
println("retry node expand")
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
@@ -177,14 +184,9 @@ function simulate(a, node::MCTSNode, decisionMaker::Function, progressValueEstim
|
|||||||
if node.isterminal
|
if node.isterminal
|
||||||
break
|
break
|
||||||
else
|
else
|
||||||
try
|
|
||||||
simTrajectoryReward += node.reward
|
simTrajectoryReward += node.reward
|
||||||
expand(a, node, decisionMaker, progressValueEstimator, isterminal, n=n)
|
expand(a, node, decisionMaker, progressValueEstimator, isterminal, n=n)
|
||||||
node = selectChildNode(node)
|
node = selectChildNode(node)
|
||||||
catch
|
|
||||||
# if error occurs, break and try again later
|
|
||||||
break
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
@@ -354,7 +356,7 @@ function selectChildNode(node::MCTSNode)::MCTSNode
|
|||||||
# loop thought node children dictionary to find the highest progress value
|
# loop thought node children dictionary to find the highest progress value
|
||||||
for (k, childNode) in node.children
|
for (k, childNode) in node.children
|
||||||
thisNodeProgressValue = childNode.statevalue + childNode.reward
|
thisNodeProgressValue = childNode.statevalue + childNode.reward
|
||||||
if childNode.statevalue > highestProgressValue
|
if thisNodeProgressValue > highestProgressValue
|
||||||
highestProgressValue = thisNodeProgressValue
|
highestProgressValue = thisNodeProgressValue
|
||||||
nodekey = childNode.nodekey
|
nodekey = childNode.nodekey
|
||||||
end
|
end
|
||||||
|
|||||||
Reference in New Issue
Block a user