update
This commit is contained in:
@@ -52,7 +52,8 @@ function runMCTS(
|
|||||||
earlystop::Union{Function,Nothing}=nothing
|
earlystop::Union{Function,Nothing}=nothing
|
||||||
)::NamedTuple{(:bestNextState, :bestFinalState),Tuple{T,T}} where {T<:Any}
|
)::NamedTuple{(:bestNextState, :bestFinalState),Tuple{T,T}} where {T<:Any}
|
||||||
|
|
||||||
root = MCTSNode("root", initialstate, 0, 0, 0, 0, false, nothing, Dict{String,MCTSNode}())
|
root = MCTSNode("root", initialstate, 0, 0, 0, 0, false, nothing, Dict{String,MCTSNode}(),
|
||||||
|
Dict{Symbol,Any}())
|
||||||
|
|
||||||
for nth in 1:maxiterations
|
for nth in 1:maxiterations
|
||||||
node = root
|
node = root
|
||||||
@@ -67,7 +68,7 @@ function runMCTS(
|
|||||||
# do nothing then go directly to backpropagation. It means the end of this iteration
|
# do nothing then go directly to backpropagation. It means the end of this iteration
|
||||||
backpropagate(node, node.reward)
|
backpropagate(node, node.reward)
|
||||||
else
|
else
|
||||||
expand(node, transition, transitionargs;
|
_ = expand(node, transition, transitionargs;
|
||||||
totalsample=totalsample)
|
totalsample=totalsample)
|
||||||
leafNode = selectChildNode(node)
|
leafNode = selectChildNode(node)
|
||||||
simTrajectoryReward, terminalstate = simulate(leafNode, transition, transitionargs;
|
simTrajectoryReward, terminalstate = simulate(leafNode, transition, transitionargs;
|
||||||
@@ -82,6 +83,9 @@ function runMCTS(
|
|||||||
# end
|
# end
|
||||||
|
|
||||||
backpropagate(leafNode, simTrajectoryReward)
|
backpropagate(leafNode, simTrajectoryReward)
|
||||||
|
|
||||||
|
# delete all child node, no need for child node that was created during simulation
|
||||||
|
leafNode.children = Dict{String,MCTSNode}()
|
||||||
end
|
end
|
||||||
|
|
||||||
# stop if the early stop condition is met
|
# stop if the early stop condition is met
|
||||||
|
|||||||
11
src/mcts.jl
11
src/mcts.jl
@@ -234,6 +234,7 @@ function expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple;
|
|||||||
totalsample::Integer=3)
|
totalsample::Integer=3)
|
||||||
|
|
||||||
nthSample = 0
|
nthSample = 0
|
||||||
|
listOfNewNodeId = []
|
||||||
while true
|
while true
|
||||||
nthSample += 1
|
nthSample += 1
|
||||||
if nthSample <= totalsample
|
if nthSample <= totalsample
|
||||||
@@ -250,12 +251,13 @@ function expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple;
|
|||||||
similar situation to improve decisionMaker and evaluator.
|
similar situation to improve decisionMaker and evaluator.
|
||||||
"""
|
"""
|
||||||
if newNodeKey ∉ keys(node.children)
|
if newNodeKey ∉ keys(node.children)
|
||||||
|
push!(listOfNewNodeId, newNodeKey)
|
||||||
node.children[newNodeKey] =
|
node.children[newNodeKey] =
|
||||||
MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward],
|
MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward],
|
||||||
newstate[:isterminal], node, Dict{String, MCTSNode}())
|
newstate[:isterminal], node, Dict{String, MCTSNode}(), Dict{Symbol, Any}())
|
||||||
end
|
end
|
||||||
else
|
else
|
||||||
break
|
return listOfNewNodeId
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
@@ -286,6 +288,7 @@ function simulate(node::MCTSNode, transition::Function, transitionargs::NamedTup
|
|||||||
|
|
||||||
simTrajectoryReward = 0.0
|
simTrajectoryReward = 0.0
|
||||||
terminalstate = nothing
|
terminalstate = nothing
|
||||||
|
# listOfSimulatedNodeId = []
|
||||||
|
|
||||||
for depth in 1:maxdepth
|
for depth in 1:maxdepth
|
||||||
simTrajectoryReward += node.reward
|
simTrajectoryReward += node.reward
|
||||||
@@ -293,8 +296,8 @@ function simulate(node::MCTSNode, transition::Function, transitionargs::NamedTup
|
|||||||
terminalstate = node.state
|
terminalstate = node.state
|
||||||
break
|
break
|
||||||
else
|
else
|
||||||
expand(node, transition, transitionargs;
|
_ = expand(node, transition, transitionargs;
|
||||||
totalsample=totalsample)
|
totalsample=totalsample)
|
||||||
node = selectChildNode(node)
|
node = selectChildNode(node)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|||||||
@@ -2,6 +2,8 @@ module type
|
|||||||
|
|
||||||
export MCTSNode
|
export MCTSNode
|
||||||
|
|
||||||
|
using GeneralUtils
|
||||||
|
|
||||||
# ---------------------------------------------- 100 --------------------------------------------- #
|
# ---------------------------------------------- 100 --------------------------------------------- #
|
||||||
|
|
||||||
|
|
||||||
@@ -49,6 +51,7 @@ mutable struct MCTSNode{T1<:AbstractDict, T2<:AbstractString}
|
|||||||
isterminal::Bool
|
isterminal::Bool
|
||||||
parent::Union{MCTSNode, Nothing}
|
parent::Union{MCTSNode, Nothing}
|
||||||
children::Dict{String, MCTSNode}
|
children::Dict{String, MCTSNode}
|
||||||
|
etc::Dict{Symbol, Any} # store anything
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user