This commit is contained in:
narawat lamaiin
2024-05-06 20:07:41 +07:00
parent 1fae63126f
commit 4608835c95

View File

@@ -124,19 +124,21 @@ julia>
``` ```
# TODO # TODO
- [] update docstring [] update docstring
[] try loop should limit to 3 times. if not succeed, skip
# Signature # Signature
""" """
function expand(a::T1, node::MCTSNode, decisionMaker::Function, function expand(a::T1, node::MCTSNode, decisionMaker::Function,
progressValueEstimator::Function, isterminal::Function; n::Integer=3) where {T1<:agent} progressValueEstimator::Function, isterminal::Function; n::Integer=3) where {T1<:agent}
# sampling action from decisionMaker nthSample = 0
for sample in 1:n while nthSample < n
try
thoughtDict = decisionMaker(a, node.state) thoughtDict = decisionMaker(a, node.state)
newNodeKey, newstate, isterminalstate, reward = MCTStransition(a, node.state, thoughtDict, newNodeKey, newstate, isterminalstate, reward =
isterminal) MCTStransition(a, node.state, thoughtDict, isterminal)
# add progressValueEstimator # add progressValueEstimator
stateevaluation, statevalue = progressValueEstimator(a, newstate) stateevaluation, statevalue = progressValueEstimator(a, newstate)
@@ -145,6 +147,11 @@ function expand(a::T1, node::MCTSNode, decisionMaker::Function,
node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, stateevaluation, statevalue, node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, stateevaluation, statevalue,
reward, isterminalstate, node, Dict{String, MCTSNode}()) reward, isterminalstate, node, Dict{String, MCTSNode}())
end end
nthSample += 1
catch
# skip this child node if error occurs
println("retry node expand")
end
end end
end end
@@ -177,14 +184,9 @@ function simulate(a, node::MCTSNode, decisionMaker::Function, progressValueEstim
if node.isterminal if node.isterminal
break break
else else
try
simTrajectoryReward += node.reward simTrajectoryReward += node.reward
expand(a, node, decisionMaker, progressValueEstimator, isterminal, n=n) expand(a, node, decisionMaker, progressValueEstimator, isterminal, n=n)
node = selectChildNode(node) node = selectChildNode(node)
catch
# if error occurs, break and try again later
break
end
end end
end end
@@ -354,7 +356,7 @@ function selectChildNode(node::MCTSNode)::MCTSNode
# loop thought node children dictionary to find the highest progress value # loop thought node children dictionary to find the highest progress value
for (k, childNode) in node.children for (k, childNode) in node.children
thisNodeProgressValue = childNode.statevalue + childNode.reward thisNodeProgressValue = childNode.statevalue + childNode.reward
if childNode.statevalue > highestProgressValue if thisNodeProgressValue > highestProgressValue
highestProgressValue = thisNodeProgressValue highestProgressValue = thisNodeProgressValue
nodekey = childNode.nodekey nodekey = childNode.nodekey
end end