From 6920be2334034ef64b2569fd2374f7827bbffa7a Mon Sep 17 00:00:00 2001 From: narawat lamaiin Date: Sun, 2 Mar 2025 17:10:01 +0700 Subject: [PATCH] update --- src/interface.jl | 8 ++++++-- src/mcts.jl | 13 ++++++++----- src/type.jl | 3 +++ 3 files changed, 17 insertions(+), 7 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 6fa8b84..27fe8d3 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -52,7 +52,8 @@ function runMCTS( earlystop::Union{Function,Nothing}=nothing )::NamedTuple{(:bestNextState, :bestFinalState),Tuple{T,T}} where {T<:Any} - root = MCTSNode("root", initialstate, 0, 0, 0, 0, false, nothing, Dict{String,MCTSNode}()) + root = MCTSNode("root", initialstate, 0, 0, 0, 0, false, nothing, Dict{String,MCTSNode}(), + Dict{Symbol,Any}()) for nth in 1:maxiterations node = root @@ -67,7 +68,7 @@ function runMCTS( # do nothing then go directly to backpropagation. It means the end of this iteration backpropagate(node, node.reward) else - expand(node, transition, transitionargs; + _ = expand(node, transition, transitionargs; totalsample=totalsample) leafNode = selectChildNode(node) simTrajectoryReward, terminalstate = simulate(leafNode, transition, transitionargs; @@ -82,6 +83,9 @@ function runMCTS( # end backpropagate(leafNode, simTrajectoryReward) + + # delete all child node, no need for child node that was created during simulation + leafNode.children = Dict{String,MCTSNode}() end # stop if the early stop condition is met diff --git a/src/mcts.jl b/src/mcts.jl index 4ca9672..f8d8400 100644 --- a/src/mcts.jl +++ b/src/mcts.jl @@ -234,6 +234,7 @@ function expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple; totalsample::Integer=3) nthSample = 0 + listOfNewNodeId = [] while true nthSample += 1 if nthSample <= totalsample @@ -250,12 +251,13 @@ function expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple; similar situation to improve decisionMaker and evaluator. """ if newNodeKey ∉ keys(node.children) + push!(listOfNewNodeId, newNodeKey) node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward], - newstate[:isterminal], node, Dict{String, MCTSNode}()) + newstate[:isterminal], node, Dict{String, MCTSNode}(), Dict{Symbol, Any}()) end else - break + return listOfNewNodeId end end end @@ -286,6 +288,7 @@ function simulate(node::MCTSNode, transition::Function, transitionargs::NamedTup simTrajectoryReward = 0.0 terminalstate = nothing + # listOfSimulatedNodeId = [] for depth in 1:maxdepth simTrajectoryReward += node.reward @@ -293,10 +296,10 @@ function simulate(node::MCTSNode, transition::Function, transitionargs::NamedTup terminalstate = node.state break else - expand(node, transition, transitionargs; - totalsample=totalsample) + _ = expand(node, transition, transitionargs; + totalsample=totalsample) node = selectChildNode(node) - end + end end return (simTrajectoryReward=simTrajectoryReward, terminalstate=terminalstate) diff --git a/src/type.jl b/src/type.jl index a32846b..cc1d7dc 100644 --- a/src/type.jl +++ b/src/type.jl @@ -2,6 +2,8 @@ module type export MCTSNode +using GeneralUtils + # ---------------------------------------------- 100 --------------------------------------------- # @@ -49,6 +51,7 @@ mutable struct MCTSNode{T1<:AbstractDict, T2<:AbstractString} isterminal::Bool parent::Union{MCTSNode, Nothing} children::Dict{String, MCTSNode} + etc::Dict{Symbol, Any} # store anything end