This commit is contained in:
narawat lamaiin
2024-05-06 20:07:41 +07:00
parent 1fae63126f
commit 4608835c95

View File

@@ -124,19 +124,21 @@ julia>
```
# TODO
- [] update docstring
[] update docstring
[] try loop should limit to 3 times. if not succeed, skip
# Signature
"""
function expand(a::T1, node::MCTSNode, decisionMaker::Function,
progressValueEstimator::Function, isterminal::Function; n::Integer=3) where {T1<:agent}
# sampling action from decisionMaker
for sample in 1:n
nthSample = 0
while nthSample < n
try
thoughtDict = decisionMaker(a, node.state)
newNodeKey, newstate, isterminalstate, reward = MCTStransition(a, node.state, thoughtDict,
isterminal)
newNodeKey, newstate, isterminalstate, reward =
MCTStransition(a, node.state, thoughtDict, isterminal)
# add progressValueEstimator
stateevaluation, statevalue = progressValueEstimator(a, newstate)
@@ -145,6 +147,11 @@ function expand(a::T1, node::MCTSNode, decisionMaker::Function,
node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, stateevaluation, statevalue,
reward, isterminalstate, node, Dict{String, MCTSNode}())
end
nthSample += 1
catch
# skip this child node if error occurs
println("retry node expand")
end
end
end
@@ -177,14 +184,9 @@ function simulate(a, node::MCTSNode, decisionMaker::Function, progressValueEstim
if node.isterminal
break
else
try
simTrajectoryReward += node.reward
expand(a, node, decisionMaker, progressValueEstimator, isterminal, n=n)
node = selectChildNode(node)
catch
# if error occurs, break and try again later
break
end
end
end
@@ -354,7 +356,7 @@ function selectChildNode(node::MCTSNode)::MCTSNode
# loop thought node children dictionary to find the highest progress value
for (k, childNode) in node.children
thisNodeProgressValue = childNode.statevalue + childNode.reward
if childNode.statevalue > highestProgressValue
if thisNodeProgressValue > highestProgressValue
highestProgressValue = thisNodeProgressValue
nodekey = childNode.nodekey
end