From 7e160f20313612e372c52b144c127e96d12af345 Mon Sep 17 00:00:00 2001 From: tonaerospace Date: Fri, 14 Mar 2025 12:31:41 +0700 Subject: [PATCH] update --- src/interface.jl | 97 ++++++++++++------------------------------------ src/mcts.jl | 52 ++++++-------------------- 2 files changed, 36 insertions(+), 113 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index c04e59c..d5974b7 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -66,7 +66,8 @@ function runMCTS( maxiterations::Integer=10, explorationweight::Number=1.0, earlystop::Union{Function,Nothing}=nothing, - saveSimulatedNode::Bool=false) where {T<:Any} + saveSimulatedNode::Bool=false, + multithread=false) where {T<:Any} # )::NamedTuple{(:bestNextState, :bestFinalState),Tuple{T,T}} where {T<:Any} root = MCTSNode("root", initialstate, 0, 0, 0, 0, false, nothing, Dict{String,MCTSNode}(), @@ -85,26 +86,26 @@ function runMCTS( # do nothing then go directly to backpropagation. It means the end of this iteration backpropagate(node, node.reward) else - println(111) _ = expand(node, transition, transitionargs; - horizontalSample=horizontalSampleExpansionPhase) - - println(666) - - @sync for (leafNodeKey, leafNode) in node.children - @spawn simulateThenBackpropagate(leafNode, transition, transitionargs; - maxSimulationDepth=maxSimulationDepth, - horizontalSampleSimulationPhase=horizontalSampleSimulationPhase, - saveSimulatedNode=saveSimulatedNode) + horizontalSample=horizontalSampleExpansionPhase, + multithread=multithread) + if multithread + @sync for (leafNodeKey, leafNode) in node.children + @spawn simulateThenBackpropagate(leafNode, transition, transitionargs; + maxSimulationDepth=maxSimulationDepth, + horizontalSampleSimulationPhase=horizontalSampleSimulationPhase, + saveSimulatedNode=saveSimulatedNode, + multithread=multithread) + end + else + for (leafNodeKey, leafNode) in node.children + simulateThenBackpropagate(leafNode, transition, transitionargs; + maxSimulationDepth=maxSimulationDepth, + horizontalSampleSimulationPhase=horizontalSampleSimulationPhase, + saveSimulatedNode=saveSimulatedNode, + multithread=multithread) + end end - - #CHANGE for testing - # for (leafNodeKey, leafNode) in node.children - # simulateThenBackpropagate(leafNode, transition, transitionargs; - # maxSimulationDepth=maxSimulationDepth, - # horizontalSampleSimulationPhase=horizontalSampleSimulationPhase, - # saveSimulatedNode=saveSimulatedNode) - # end end # stop if the early stop condition is met @@ -123,10 +124,12 @@ end function simulateThenBackpropagate(node::MCTSNode, transition::Function, transitionargs::NamedTuple; maxSimulationDepth::Integer=3, horizontalSampleSimulationPhase::Integer=3, - saveSimulatedNode::Bool=false) + saveSimulatedNode::Bool=false, + multithread=false) simTrajectoryReward, terminalstate = simulate(node, transition, transitionargs; maxSimulationDepth=maxSimulationDepth, - horizontalSample=horizontalSampleSimulationPhase) + horizontalSample=horizontalSampleSimulationPhase, + multithread=multithread) backpropagate(node, simTrajectoryReward) # check if the user wants to keep the simulated node @@ -137,58 +140,6 @@ end -# function runMCTS( -# initialstate::T, -# transition::Function, -# transitionargs::NamedTuple, -# ; -# totalsample::Integer=3, -# maxdepth::Integer=3, -# maxiterations::Integer=10, -# explorationweight::Number=1.0, -# )::NamedTuple{(:bestNextState, :bestFinalState),Tuple{T,T}} where {T<:Any} - -# root = MCTSNode("root", initialstate, 0, 0, 0, 0, false, nothing, Dict{String,MCTSNode}()) - -# for nth in 1:maxiterations -# node = root -# node.visits += 1 - -# while !isleaf(node) -# node = UCTselect(node, explorationweight) -# end -# if node.isterminal -# # MCTS arrive at the leaf node that is also a terminal state, -# # do nothing then go directly to backpropagation. It means the end of this iteration -# backpropagate(leafNode, node.reward) -# else -# expand(node, transition, transitionargs; -# totalsample=totalsample) -# leafNode = selectChildNode(node) -# simTrajectoryReward, terminalstate = simulate(leafNode, transition, transitionargs; -# maxdepth=maxdepth, totalsample=totalsample) -# # if terminalstate !== nothing #XXX not sure why I need this -# # terminalstate[:totalTrajectoryReward] = simTrajectoryReward -# # end - -# #[] write best state to file if it has higher simTrajectoryReward. Use to improve evaluation -# # open("trajectory.json", "w") do io -# # JSON3.pretty(io, terminalstate) -# # end - -# backpropagate(leafNode, simTrajectoryReward) -# end -# end - -# bestNextState = selectBestNextNode(root) -# besttrajectory = selectBestTrajectoryNode(root) - -# return (bestNextState=bestNextState.state, bestFinalState=besttrajectory.state) -# end - - - - diff --git a/src/mcts.jl b/src/mcts.jl index 8563692..0df0c1b 100644 --- a/src/mcts.jl +++ b/src/mcts.jl @@ -199,47 +199,18 @@ end # Signature """ function expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple; - horizontalSample::Integer=3) - @sync for i in 1:horizontalSample - @spawn _expand(node, transition, transitionargs) + horizontalSample::Integer=3, multithread=false) + if multithread + @sync for i in 1:horizontalSample + @spawn _expand(node, transition, transitionargs) + end + else + for i in 1:horizontalSample + _expand(node, transition, transitionargs) + end end - - #CHANGE for testing - # for i in 1:horizontalSample - # _expand(node, transition, transitionargs) - # end end -# function expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple; -# horizontalSample::Integer=3) -# nthSample = 0 -# listOfNewNodeId = [] -# while true -# nthSample += 1 -# if nthSample <= horizontalSample -# result = transition(node.state, transitionargs) -# newNodeKey::AbstractString = result[:newNodeKey] -# newstate::AbstractDict = result[:newstate] -# progressvalue::Integer = result[:progressvalue] - -# """ -# [] newNodeKey ∉ keys(node.children). -# New state may have semantic vector close enought to -# one of existing child state. Which can be assume that they are the same state -# semantically-wise i.e. De javu. This could be used to recall lessons for this -# similar situation to improve decisionMaker and evaluator. -# """ -# if newNodeKey ∉ keys(node.children) -# push!(listOfNewNodeId, newNodeKey) -# newNode = MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward], -# newstate[:isterminal], node, Dict{String, MCTSNode}(), Dict{Symbol, Any}()) -# node.children[newNodeKey] = newNode -# end -# else -# return listOfNewNodeId -# end -# end -# end function _expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple) result = transition(node.state, transitionargs) @@ -282,7 +253,7 @@ end # Signature """ function simulate(node::MCTSNode, transition::Function, transitionargs::NamedTuple; - maxSimulationDepth::Integer=3, horizontalSample::Integer=3) + maxSimulationDepth::Integer=3, horizontalSample::Integer=3, multithread=false) # )::NamedTuple{(:simTrajectoryReward, :terminalstate), Tuple{Number, Union{Dict{Symbol, Any}, Nothing}}} simTrajectoryReward = 0.0 @@ -295,7 +266,8 @@ function simulate(node::MCTSNode, transition::Function, transitionargs::NamedTup break else _ = expand(node, transition, transitionargs; - horizontalSample=horizontalSample) + horizontalSample=horizontalSample, + multithread=multithread) node = selectChildNode(node) end end