This commit is contained in:
narawat lamaiin
2025-03-02 17:10:01 +07:00
parent 84d73e742c
commit 6920be2334
3 changed files with 17 additions and 7 deletions

View File

@@ -52,7 +52,8 @@ function runMCTS(
earlystop::Union{Function,Nothing}=nothing earlystop::Union{Function,Nothing}=nothing
)::NamedTuple{(:bestNextState, :bestFinalState),Tuple{T,T}} where {T<:Any} )::NamedTuple{(:bestNextState, :bestFinalState),Tuple{T,T}} where {T<:Any}
root = MCTSNode("root", initialstate, 0, 0, 0, 0, false, nothing, Dict{String,MCTSNode}()) root = MCTSNode("root", initialstate, 0, 0, 0, 0, false, nothing, Dict{String,MCTSNode}(),
Dict{Symbol,Any}())
for nth in 1:maxiterations for nth in 1:maxiterations
node = root node = root
@@ -67,7 +68,7 @@ function runMCTS(
# do nothing then go directly to backpropagation. It means the end of this iteration # do nothing then go directly to backpropagation. It means the end of this iteration
backpropagate(node, node.reward) backpropagate(node, node.reward)
else else
expand(node, transition, transitionargs; _ = expand(node, transition, transitionargs;
totalsample=totalsample) totalsample=totalsample)
leafNode = selectChildNode(node) leafNode = selectChildNode(node)
simTrajectoryReward, terminalstate = simulate(leafNode, transition, transitionargs; simTrajectoryReward, terminalstate = simulate(leafNode, transition, transitionargs;
@@ -82,6 +83,9 @@ function runMCTS(
# end # end
backpropagate(leafNode, simTrajectoryReward) backpropagate(leafNode, simTrajectoryReward)
# delete all child node, no need for child node that was created during simulation
leafNode.children = Dict{String,MCTSNode}()
end end
# stop if the early stop condition is met # stop if the early stop condition is met

View File

@@ -234,6 +234,7 @@ function expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple;
totalsample::Integer=3) totalsample::Integer=3)
nthSample = 0 nthSample = 0
listOfNewNodeId = []
while true while true
nthSample += 1 nthSample += 1
if nthSample <= totalsample if nthSample <= totalsample
@@ -250,12 +251,13 @@ function expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple;
similar situation to improve decisionMaker and evaluator. similar situation to improve decisionMaker and evaluator.
""" """
if newNodeKey keys(node.children) if newNodeKey keys(node.children)
push!(listOfNewNodeId, newNodeKey)
node.children[newNodeKey] = node.children[newNodeKey] =
MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward], MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward],
newstate[:isterminal], node, Dict{String, MCTSNode}()) newstate[:isterminal], node, Dict{String, MCTSNode}(), Dict{Symbol, Any}())
end end
else else
break return listOfNewNodeId
end end
end end
end end
@@ -286,6 +288,7 @@ function simulate(node::MCTSNode, transition::Function, transitionargs::NamedTup
simTrajectoryReward = 0.0 simTrajectoryReward = 0.0
terminalstate = nothing terminalstate = nothing
# listOfSimulatedNodeId = []
for depth in 1:maxdepth for depth in 1:maxdepth
simTrajectoryReward += node.reward simTrajectoryReward += node.reward
@@ -293,10 +296,10 @@ function simulate(node::MCTSNode, transition::Function, transitionargs::NamedTup
terminalstate = node.state terminalstate = node.state
break break
else else
expand(node, transition, transitionargs; _ = expand(node, transition, transitionargs;
totalsample=totalsample) totalsample=totalsample)
node = selectChildNode(node) node = selectChildNode(node)
end end
end end
return (simTrajectoryReward=simTrajectoryReward, terminalstate=terminalstate) return (simTrajectoryReward=simTrajectoryReward, terminalstate=terminalstate)

View File

@@ -2,6 +2,8 @@ module type
export MCTSNode export MCTSNode
using GeneralUtils
# ---------------------------------------------- 100 --------------------------------------------- # # ---------------------------------------------- 100 --------------------------------------------- #
@@ -49,6 +51,7 @@ mutable struct MCTSNode{T1<:AbstractDict, T2<:AbstractString}
isterminal::Bool isterminal::Bool
parent::Union{MCTSNode, Nothing} parent::Union{MCTSNode, Nothing}
children::Dict{String, MCTSNode} children::Dict{String, MCTSNode}
etc::Dict{Symbol, Any} # store anything
end end