This commit is contained in:
2025-03-11 00:13:25 +07:00
parent b1d655acff
commit 097484675c
2 changed files with 62 additions and 54 deletions

View File

@@ -85,10 +85,11 @@ function runMCTS(
# do nothing then go directly to backpropagation. It means the end of this iteration # do nothing then go directly to backpropagation. It means the end of this iteration
backpropagate(node, node.reward) backpropagate(node, node.reward)
else else
println(111)
_ = expand(node, transition, transitionargs; _ = expand(node, transition, transitionargs;
horizontalSample=horizontalSampleExpansionPhase) horizontalSample=horizontalSampleExpansionPhase)
println(555) println(666)
@sync for (leafNodeKey, leafNode) in node.children @sync for (leafNodeKey, leafNode) in node.children
@spawn simulateThenBackpropagate(leafNode, transition, transitionargs; @spawn simulateThenBackpropagate(leafNode, transition, transitionargs;
@@ -96,6 +97,14 @@ function runMCTS(
horizontalSampleSimulationPhase=horizontalSampleSimulationPhase, horizontalSampleSimulationPhase=horizontalSampleSimulationPhase,
saveSimulatedNode=saveSimulatedNode) saveSimulatedNode=saveSimulatedNode)
end end
#CHANGE for testing
# for (leafNodeKey, leafNode) in node.children
# simulateThenBackpropagate(leafNode, transition, transitionargs;
# maxSimulationDepth=maxSimulationDepth,
# horizontalSampleSimulationPhase=horizontalSampleSimulationPhase,
# saveSimulatedNode=saveSimulatedNode)
# end
end end
# stop if the early stop condition is met # stop if the early stop condition is met

View File

@@ -198,68 +198,67 @@ end
# Signature # Signature
""" """
# function expand(node::MCTSNode, transition::Function, transitionargs::NamedTuple; function expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple;
# totalsample::Integer=3) horizontalSample::Integer=3)
@sync for i in 1:horizontalSample
@spawn _expand(node, transition, transitionargs)
end
# # not use Any[] because I want to preserve result order #CHANGE for testing
# results = Vector{Any}(undef, totalsample) # for i in 1:horizontalSample
# _expand(node, transition, transitionargs)
# end
end
# function expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple;
# horizontalSample::Integer=3)
# @sync for i in 1:totalsample # nthSample = 0
# @spawn begin # listOfNewNodeId = []
# result = transition(deepcopy(node.state), deepcopy(transitionargs)) # while true
# results[i] = result # nthSample += 1
# if nthSample <= horizontalSample
# result = transition(node.state, transitionargs)
# newNodeKey::AbstractString = result[:newNodeKey]
# newstate::AbstractDict = result[:newstate]
# progressvalue::Integer = result[:progressvalue]
# """
# [] newNodeKey ∉ keys(node.children).
# New state may have semantic vector close enought to
# one of existing child state. Which can be assume that they are the same state
# semantically-wise i.e. De javu. This could be used to recall lessons for this
# similar situation to improve decisionMaker and evaluator.
# """
# if newNodeKey ∉ keys(node.children)
# push!(listOfNewNodeId, newNodeKey)
# newNode = MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward],
# newstate[:isterminal], node, Dict{String, MCTSNode}(), Dict{Symbol, Any}())
# node.children[newNodeKey] = newNode
# end # end
# end # else
# return listOfNewNodeId
# for result in results
# newNodeKey::AbstractString = result[:newNodeKey]
# newstate::AbstractDict = result[:newstate]
# progressvalue::Integer = result[:progressvalue]
# """
# [] newNodeKey ∉ keys(node.children).
# New state may have semantic vector close enought to
# one of existing child state. Which can be assume that they are the same state
# semantically-wise i.e. De javu. This could be used to recall lessons for this
# similar situation to improve decisionMaker and evaluator.
# """
# if newNodeKey ∉ keys(node.children)
# node.children[newNodeKey] =
# MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward],
# newstate[:isterminal], node, Dict{String, MCTSNode}())
# end # end
# end # end
# end # end
function expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple;
horizontalSample::Integer=3)
nthSample = 0 function _expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple)
listOfNewNodeId = [] result = transition(node.state, transitionargs)
while true newNodeKey::AbstractString = result[:newNodeKey]
nthSample += 1 newstate::AbstractDict = result[:newstate]
if nthSample <= horizontalSample progressvalue::Integer = result[:progressvalue]
result = transition(node.state, transitionargs)
newNodeKey::AbstractString = result[:newNodeKey]
newstate::AbstractDict = result[:newstate]
progressvalue::Integer = result[:progressvalue]
""" """
[] newNodeKey ∉ keys(node.children). [] newNodeKey ∉ keys(node.children).
New state may have semantic vector close enought to New state may have semantic vector close enought to
one of existing child state. Which can be assume that they are the same state one of existing child state. Which can be assume that they are the same state
semantically-wise i.e. De javu. This could be used to recall lessons for this semantically-wise i.e. De javu. This could be used to recall lessons for this
similar situation to improve decisionMaker and evaluator. similar situation to improve decisionMaker and evaluator.
""" """
if newNodeKey keys(node.children) if newNodeKey keys(node.children)
push!(listOfNewNodeId, newNodeKey) newNode = MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward],
newNode = MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward], newstate[:isterminal], node, Dict{String, MCTSNode}(), Dict{Symbol, Any}())
newstate[:isterminal], node, Dict{String, MCTSNode}(), Dict{Symbol, Any}()) node.children[newNodeKey] = newNode
node.children[newNodeKey] = newNode
end
else
return listOfNewNodeId
end end
end
end end