This commit is contained in:
2025-03-11 00:13:25 +07:00
parent b1d655acff
commit 097484675c
2 changed files with 62 additions and 54 deletions

View File

@@ -85,10 +85,11 @@ function runMCTS(
# do nothing then go directly to backpropagation. It means the end of this iteration # do nothing then go directly to backpropagation. It means the end of this iteration
backpropagate(node, node.reward) backpropagate(node, node.reward)
else else
println(111)
_ = expand(node, transition, transitionargs; _ = expand(node, transition, transitionargs;
horizontalSample=horizontalSampleExpansionPhase) horizontalSample=horizontalSampleExpansionPhase)
println(555) println(666)
@sync for (leafNodeKey, leafNode) in node.children @sync for (leafNodeKey, leafNode) in node.children
@spawn simulateThenBackpropagate(leafNode, transition, transitionargs; @spawn simulateThenBackpropagate(leafNode, transition, transitionargs;
@@ -96,6 +97,14 @@ function runMCTS(
horizontalSampleSimulationPhase=horizontalSampleSimulationPhase, horizontalSampleSimulationPhase=horizontalSampleSimulationPhase,
saveSimulatedNode=saveSimulatedNode) saveSimulatedNode=saveSimulatedNode)
end end
#CHANGE for testing
# for (leafNodeKey, leafNode) in node.children
# simulateThenBackpropagate(leafNode, transition, transitionargs;
# maxSimulationDepth=maxSimulationDepth,
# horizontalSampleSimulationPhase=horizontalSampleSimulationPhase,
# saveSimulatedNode=saveSimulatedNode)
# end
end end
# stop if the early stop condition is met # stop if the early stop condition is met

View File

@@ -198,20 +198,26 @@ end
# Signature # Signature
""" """
function expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple;
horizontalSample::Integer=3)
@sync for i in 1:horizontalSample
@spawn _expand(node, transition, transitionargs)
end
#CHANGE for testing
# for i in 1:horizontalSample
# _expand(node, transition, transitionargs)
# end
end
# function expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple; # function expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple;
# totalsample::Integer=3) # horizontalSample::Integer=3)
# # not use Any[] because I want to preserve result order # nthSample = 0
# results = Vector{Any}(undef, totalsample) # listOfNewNodeId = []
# while true
# @sync for i in 1:totalsample # nthSample += 1
# @spawn begin # if nthSample <= horizontalSample
# result = transition(deepcopy(node.state), deepcopy(transitionargs)) # result = transition(node.state, transitionargs)
# results[i] = result
# end
# end
# for result in results
# newNodeKey::AbstractString = result[:newNodeKey] # newNodeKey::AbstractString = result[:newNodeKey]
# newstate::AbstractDict = result[:newstate] # newstate::AbstractDict = result[:newstate]
# progressvalue::Integer = result[:progressvalue] # progressvalue::Integer = result[:progressvalue]
@@ -224,20 +230,18 @@ end
# similar situation to improve decisionMaker and evaluator. # similar situation to improve decisionMaker and evaluator.
# """ # """
# if newNodeKey ∉ keys(node.children) # if newNodeKey ∉ keys(node.children)
# node.children[newNodeKey] = # push!(listOfNewNodeId, newNodeKey)
# MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward], # newNode = MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward],
# newstate[:isterminal], node, Dict{String, MCTSNode}()) # newstate[:isterminal], node, Dict{String, MCTSNode}(), Dict{Symbol, Any}())
# node.children[newNodeKey] = newNode
# end
# else
# return listOfNewNodeId
# end # end
# end # end
# end # end
function expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple;
horizontalSample::Integer=3)
nthSample = 0 function _expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple)
listOfNewNodeId = []
while true
nthSample += 1
if nthSample <= horizontalSample
result = transition(node.state, transitionargs) result = transition(node.state, transitionargs)
newNodeKey::AbstractString = result[:newNodeKey] newNodeKey::AbstractString = result[:newNodeKey]
newstate::AbstractDict = result[:newstate] newstate::AbstractDict = result[:newstate]
@@ -251,15 +255,10 @@ function expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple;
similar situation to improve decisionMaker and evaluator. similar situation to improve decisionMaker and evaluator.
""" """
if newNodeKey keys(node.children) if newNodeKey keys(node.children)
push!(listOfNewNodeId, newNodeKey)
newNode = MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward], newNode = MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward],
newstate[:isterminal], node, Dict{String, MCTSNode}(), Dict{Symbol, Any}()) newstate[:isterminal], node, Dict{String, MCTSNode}(), Dict{Symbol, Any}())
node.children[newNodeKey] = newNode node.children[newNodeKey] = newNode
end end
else
return listOfNewNodeId
end
end
end end