This commit is contained in:
2025-03-14 12:31:41 +07:00
parent 097484675c
commit 7e160f2031
2 changed files with 36 additions and 113 deletions

View File

@@ -66,7 +66,8 @@ function runMCTS(
maxiterations::Integer=10, maxiterations::Integer=10,
explorationweight::Number=1.0, explorationweight::Number=1.0,
earlystop::Union{Function,Nothing}=nothing, earlystop::Union{Function,Nothing}=nothing,
saveSimulatedNode::Bool=false) where {T<:Any} saveSimulatedNode::Bool=false,
multithread=false) where {T<:Any}
# )::NamedTuple{(:bestNextState, :bestFinalState),Tuple{T,T}} where {T<:Any} # )::NamedTuple{(:bestNextState, :bestFinalState),Tuple{T,T}} where {T<:Any}
root = MCTSNode("root", initialstate, 0, 0, 0, 0, false, nothing, Dict{String,MCTSNode}(), root = MCTSNode("root", initialstate, 0, 0, 0, 0, false, nothing, Dict{String,MCTSNode}(),
@@ -85,26 +86,26 @@ function runMCTS(
# do nothing then go directly to backpropagation. It means the end of this iteration # do nothing then go directly to backpropagation. It means the end of this iteration
backpropagate(node, node.reward) backpropagate(node, node.reward)
else else
println(111)
_ = expand(node, transition, transitionargs; _ = expand(node, transition, transitionargs;
horizontalSample=horizontalSampleExpansionPhase) horizontalSample=horizontalSampleExpansionPhase,
multithread=multithread)
println(666) if multithread
@sync for (leafNodeKey, leafNode) in node.children
@sync for (leafNodeKey, leafNode) in node.children @spawn simulateThenBackpropagate(leafNode, transition, transitionargs;
@spawn simulateThenBackpropagate(leafNode, transition, transitionargs; maxSimulationDepth=maxSimulationDepth,
maxSimulationDepth=maxSimulationDepth, horizontalSampleSimulationPhase=horizontalSampleSimulationPhase,
horizontalSampleSimulationPhase=horizontalSampleSimulationPhase, saveSimulatedNode=saveSimulatedNode,
saveSimulatedNode=saveSimulatedNode) multithread=multithread)
end
else
for (leafNodeKey, leafNode) in node.children
simulateThenBackpropagate(leafNode, transition, transitionargs;
maxSimulationDepth=maxSimulationDepth,
horizontalSampleSimulationPhase=horizontalSampleSimulationPhase,
saveSimulatedNode=saveSimulatedNode,
multithread=multithread)
end
end end
#CHANGE for testing
# for (leafNodeKey, leafNode) in node.children
# simulateThenBackpropagate(leafNode, transition, transitionargs;
# maxSimulationDepth=maxSimulationDepth,
# horizontalSampleSimulationPhase=horizontalSampleSimulationPhase,
# saveSimulatedNode=saveSimulatedNode)
# end
end end
# stop if the early stop condition is met # stop if the early stop condition is met
@@ -123,10 +124,12 @@ end
function simulateThenBackpropagate(node::MCTSNode, transition::Function, transitionargs::NamedTuple; function simulateThenBackpropagate(node::MCTSNode, transition::Function, transitionargs::NamedTuple;
maxSimulationDepth::Integer=3, horizontalSampleSimulationPhase::Integer=3, maxSimulationDepth::Integer=3, horizontalSampleSimulationPhase::Integer=3,
saveSimulatedNode::Bool=false) saveSimulatedNode::Bool=false,
multithread=false)
simTrajectoryReward, terminalstate = simulate(node, transition, transitionargs; simTrajectoryReward, terminalstate = simulate(node, transition, transitionargs;
maxSimulationDepth=maxSimulationDepth, maxSimulationDepth=maxSimulationDepth,
horizontalSample=horizontalSampleSimulationPhase) horizontalSample=horizontalSampleSimulationPhase,
multithread=multithread)
backpropagate(node, simTrajectoryReward) backpropagate(node, simTrajectoryReward)
# check if the user wants to keep the simulated node # check if the user wants to keep the simulated node
@@ -137,58 +140,6 @@ end
# function runMCTS(
# initialstate::T,
# transition::Function,
# transitionargs::NamedTuple,
# ;
# totalsample::Integer=3,
# maxdepth::Integer=3,
# maxiterations::Integer=10,
# explorationweight::Number=1.0,
# )::NamedTuple{(:bestNextState, :bestFinalState),Tuple{T,T}} where {T<:Any}
# root = MCTSNode("root", initialstate, 0, 0, 0, 0, false, nothing, Dict{String,MCTSNode}())
# for nth in 1:maxiterations
# node = root
# node.visits += 1
# while !isleaf(node)
# node = UCTselect(node, explorationweight)
# end
# if node.isterminal
# # MCTS arrive at the leaf node that is also a terminal state,
# # do nothing then go directly to backpropagation. It means the end of this iteration
# backpropagate(leafNode, node.reward)
# else
# expand(node, transition, transitionargs;
# totalsample=totalsample)
# leafNode = selectChildNode(node)
# simTrajectoryReward, terminalstate = simulate(leafNode, transition, transitionargs;
# maxdepth=maxdepth, totalsample=totalsample)
# # if terminalstate !== nothing #XXX not sure why I need this
# # terminalstate[:totalTrajectoryReward] = simTrajectoryReward
# # end
# #[] write best state to file if it has higher simTrajectoryReward. Use to improve evaluation
# # open("trajectory.json", "w") do io
# # JSON3.pretty(io, terminalstate)
# # end
# backpropagate(leafNode, simTrajectoryReward)
# end
# end
# bestNextState = selectBestNextNode(root)
# besttrajectory = selectBestTrajectoryNode(root)
# return (bestNextState=bestNextState.state, bestFinalState=besttrajectory.state)
# end

View File

@@ -199,47 +199,18 @@ end
# Signature # Signature
""" """
function expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple; function expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple;
horizontalSample::Integer=3) horizontalSample::Integer=3, multithread=false)
@sync for i in 1:horizontalSample if multithread
@spawn _expand(node, transition, transitionargs) @sync for i in 1:horizontalSample
@spawn _expand(node, transition, transitionargs)
end
else
for i in 1:horizontalSample
_expand(node, transition, transitionargs)
end
end end
#CHANGE for testing
# for i in 1:horizontalSample
# _expand(node, transition, transitionargs)
# end
end end
# function expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple;
# horizontalSample::Integer=3)
# nthSample = 0
# listOfNewNodeId = []
# while true
# nthSample += 1
# if nthSample <= horizontalSample
# result = transition(node.state, transitionargs)
# newNodeKey::AbstractString = result[:newNodeKey]
# newstate::AbstractDict = result[:newstate]
# progressvalue::Integer = result[:progressvalue]
# """
# [] newNodeKey ∉ keys(node.children).
# New state may have semantic vector close enought to
# one of existing child state. Which can be assume that they are the same state
# semantically-wise i.e. De javu. This could be used to recall lessons for this
# similar situation to improve decisionMaker and evaluator.
# """
# if newNodeKey ∉ keys(node.children)
# push!(listOfNewNodeId, newNodeKey)
# newNode = MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward],
# newstate[:isterminal], node, Dict{String, MCTSNode}(), Dict{Symbol, Any}())
# node.children[newNodeKey] = newNode
# end
# else
# return listOfNewNodeId
# end
# end
# end
function _expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple) function _expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple)
result = transition(node.state, transitionargs) result = transition(node.state, transitionargs)
@@ -282,7 +253,7 @@ end
# Signature # Signature
""" """
function simulate(node::MCTSNode, transition::Function, transitionargs::NamedTuple; function simulate(node::MCTSNode, transition::Function, transitionargs::NamedTuple;
maxSimulationDepth::Integer=3, horizontalSample::Integer=3) maxSimulationDepth::Integer=3, horizontalSample::Integer=3, multithread=false)
# )::NamedTuple{(:simTrajectoryReward, :terminalstate), Tuple{Number, Union{Dict{Symbol, Any}, Nothing}}} # )::NamedTuple{(:simTrajectoryReward, :terminalstate), Tuple{Number, Union{Dict{Symbol, Any}, Nothing}}}
simTrajectoryReward = 0.0 simTrajectoryReward = 0.0
@@ -295,7 +266,8 @@ function simulate(node::MCTSNode, transition::Function, transitionargs::NamedTup
break break
else else
_ = expand(node, transition, transitionargs; _ = expand(node, transition, transitionargs;
horizontalSample=horizontalSample) horizontalSample=horizontalSample,
multithread=multithread)
node = selectChildNode(node) node = selectChildNode(node)
end end
end end