diff --git a/src/mcts.jl b/src/mcts.jl index 28810e4..9548e71 100644 --- a/src/mcts.jl +++ b/src/mcts.jl @@ -124,28 +124,35 @@ julia> ``` # TODO - - [] update docstring + [] update docstring + [] try loop should limit to 3 times. if not succeed, skip # Signature """ function expand(a::T1, node::MCTSNode, decisionMaker::Function, progressValueEstimator::Function, isterminal::Function; n::Integer=3) where {T1<:agent} - # sampling action from decisionMaker - for sample in 1:n - thoughtDict = decisionMaker(a, node.state) + nthSample = 0 + while nthSample < n + try + thoughtDict = decisionMaker(a, node.state) - newNodeKey, newstate, isterminalstate, reward = MCTStransition(a, node.state, thoughtDict, - isterminal) - - # add progressValueEstimator - stateevaluation, statevalue = progressValueEstimator(a, newstate) + newNodeKey, newstate, isterminalstate, reward = + MCTStransition(a, node.state, thoughtDict, isterminal) + + # add progressValueEstimator + stateevaluation, statevalue = progressValueEstimator(a, newstate) - if newNodeKey ∉ keys(node.children) - node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, stateevaluation, statevalue, - reward, isterminalstate, node, Dict{String, MCTSNode}()) + if newNodeKey ∉ keys(node.children) + node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, stateevaluation, statevalue, + reward, isterminalstate, node, Dict{String, MCTSNode}()) + end + nthSample += 1 + catch + # skip this child node if error occurs + println("retry node expand") end - end + end end """ @@ -177,14 +184,9 @@ function simulate(a, node::MCTSNode, decisionMaker::Function, progressValueEstim if node.isterminal break else - try - simTrajectoryReward += node.reward - expand(a, node, decisionMaker, progressValueEstimator, isterminal, n=n) - node = selectChildNode(node) - catch - # if error occurs, break and try again later - break - end + simTrajectoryReward += node.reward + expand(a, node, decisionMaker, progressValueEstimator, isterminal, n=n) + node = selectChildNode(node) end end @@ -354,7 +356,7 @@ function selectChildNode(node::MCTSNode)::MCTSNode # loop thought node children dictionary to find the highest progress value for (k, childNode) in node.children thisNodeProgressValue = childNode.statevalue + childNode.reward - if childNode.statevalue > highestProgressValue + if thisNodeProgressValue > highestProgressValue highestProgressValue = thisNodeProgressValue nodekey = childNode.nodekey end