This commit is contained in:
narawat lamaiin
2025-03-02 17:10:01 +07:00
parent 84d73e742c
commit 6920be2334
3 changed files with 17 additions and 7 deletions

View File

@@ -234,6 +234,7 @@ function expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple;
totalsample::Integer=3)
nthSample = 0
listOfNewNodeId = []
while true
nthSample += 1
if nthSample <= totalsample
@@ -250,12 +251,13 @@ function expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple;
similar situation to improve decisionMaker and evaluator.
"""
if newNodeKey keys(node.children)
push!(listOfNewNodeId, newNodeKey)
node.children[newNodeKey] =
MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward],
newstate[:isterminal], node, Dict{String, MCTSNode}())
newstate[:isterminal], node, Dict{String, MCTSNode}(), Dict{Symbol, Any}())
end
else
break
return listOfNewNodeId
end
end
end
@@ -286,6 +288,7 @@ function simulate(node::MCTSNode, transition::Function, transitionargs::NamedTup
simTrajectoryReward = 0.0
terminalstate = nothing
# listOfSimulatedNodeId = []
for depth in 1:maxdepth
simTrajectoryReward += node.reward
@@ -293,10 +296,10 @@ function simulate(node::MCTSNode, transition::Function, transitionargs::NamedTup
terminalstate = node.state
break
else
expand(node, transition, transitionargs;
totalsample=totalsample)
_ = expand(node, transition, transitionargs;
totalsample=totalsample)
node = selectChildNode(node)
end
end
end
return (simTrajectoryReward=simTrajectoryReward, terminalstate=terminalstate)