update
This commit is contained in:
@@ -66,7 +66,8 @@ function runMCTS(
|
||||
maxiterations::Integer=10,
|
||||
explorationweight::Number=1.0,
|
||||
earlystop::Union{Function,Nothing}=nothing,
|
||||
saveSimulatedNode::Bool=false) where {T<:Any}
|
||||
saveSimulatedNode::Bool=false,
|
||||
multithread=false) where {T<:Any}
|
||||
# )::NamedTuple{(:bestNextState, :bestFinalState),Tuple{T,T}} where {T<:Any}
|
||||
|
||||
root = MCTSNode("root", initialstate, 0, 0, 0, 0, false, nothing, Dict{String,MCTSNode}(),
|
||||
@@ -85,26 +86,26 @@ function runMCTS(
|
||||
# do nothing then go directly to backpropagation. It means the end of this iteration
|
||||
backpropagate(node, node.reward)
|
||||
else
|
||||
println(111)
|
||||
_ = expand(node, transition, transitionargs;
|
||||
horizontalSample=horizontalSampleExpansionPhase)
|
||||
|
||||
println(666)
|
||||
|
||||
@sync for (leafNodeKey, leafNode) in node.children
|
||||
@spawn simulateThenBackpropagate(leafNode, transition, transitionargs;
|
||||
maxSimulationDepth=maxSimulationDepth,
|
||||
horizontalSampleSimulationPhase=horizontalSampleSimulationPhase,
|
||||
saveSimulatedNode=saveSimulatedNode)
|
||||
horizontalSample=horizontalSampleExpansionPhase,
|
||||
multithread=multithread)
|
||||
if multithread
|
||||
@sync for (leafNodeKey, leafNode) in node.children
|
||||
@spawn simulateThenBackpropagate(leafNode, transition, transitionargs;
|
||||
maxSimulationDepth=maxSimulationDepth,
|
||||
horizontalSampleSimulationPhase=horizontalSampleSimulationPhase,
|
||||
saveSimulatedNode=saveSimulatedNode,
|
||||
multithread=multithread)
|
||||
end
|
||||
else
|
||||
for (leafNodeKey, leafNode) in node.children
|
||||
simulateThenBackpropagate(leafNode, transition, transitionargs;
|
||||
maxSimulationDepth=maxSimulationDepth,
|
||||
horizontalSampleSimulationPhase=horizontalSampleSimulationPhase,
|
||||
saveSimulatedNode=saveSimulatedNode,
|
||||
multithread=multithread)
|
||||
end
|
||||
end
|
||||
|
||||
#CHANGE for testing
|
||||
# for (leafNodeKey, leafNode) in node.children
|
||||
# simulateThenBackpropagate(leafNode, transition, transitionargs;
|
||||
# maxSimulationDepth=maxSimulationDepth,
|
||||
# horizontalSampleSimulationPhase=horizontalSampleSimulationPhase,
|
||||
# saveSimulatedNode=saveSimulatedNode)
|
||||
# end
|
||||
end
|
||||
|
||||
# stop if the early stop condition is met
|
||||
@@ -123,10 +124,12 @@ end
|
||||
|
||||
function simulateThenBackpropagate(node::MCTSNode, transition::Function, transitionargs::NamedTuple;
|
||||
maxSimulationDepth::Integer=3, horizontalSampleSimulationPhase::Integer=3,
|
||||
saveSimulatedNode::Bool=false)
|
||||
saveSimulatedNode::Bool=false,
|
||||
multithread=false)
|
||||
simTrajectoryReward, terminalstate = simulate(node, transition, transitionargs;
|
||||
maxSimulationDepth=maxSimulationDepth,
|
||||
horizontalSample=horizontalSampleSimulationPhase)
|
||||
horizontalSample=horizontalSampleSimulationPhase,
|
||||
multithread=multithread)
|
||||
backpropagate(node, simTrajectoryReward)
|
||||
|
||||
# check if the user wants to keep the simulated node
|
||||
@@ -137,58 +140,6 @@ end
|
||||
|
||||
|
||||
|
||||
# function runMCTS(
|
||||
# initialstate::T,
|
||||
# transition::Function,
|
||||
# transitionargs::NamedTuple,
|
||||
# ;
|
||||
# totalsample::Integer=3,
|
||||
# maxdepth::Integer=3,
|
||||
# maxiterations::Integer=10,
|
||||
# explorationweight::Number=1.0,
|
||||
# )::NamedTuple{(:bestNextState, :bestFinalState),Tuple{T,T}} where {T<:Any}
|
||||
|
||||
# root = MCTSNode("root", initialstate, 0, 0, 0, 0, false, nothing, Dict{String,MCTSNode}())
|
||||
|
||||
# for nth in 1:maxiterations
|
||||
# node = root
|
||||
# node.visits += 1
|
||||
|
||||
# while !isleaf(node)
|
||||
# node = UCTselect(node, explorationweight)
|
||||
# end
|
||||
# if node.isterminal
|
||||
# # MCTS arrive at the leaf node that is also a terminal state,
|
||||
# # do nothing then go directly to backpropagation. It means the end of this iteration
|
||||
# backpropagate(leafNode, node.reward)
|
||||
# else
|
||||
# expand(node, transition, transitionargs;
|
||||
# totalsample=totalsample)
|
||||
# leafNode = selectChildNode(node)
|
||||
# simTrajectoryReward, terminalstate = simulate(leafNode, transition, transitionargs;
|
||||
# maxdepth=maxdepth, totalsample=totalsample)
|
||||
# # if terminalstate !== nothing #XXX not sure why I need this
|
||||
# # terminalstate[:totalTrajectoryReward] = simTrajectoryReward
|
||||
# # end
|
||||
|
||||
# #[] write best state to file if it has higher simTrajectoryReward. Use to improve evaluation
|
||||
# # open("trajectory.json", "w") do io
|
||||
# # JSON3.pretty(io, terminalstate)
|
||||
# # end
|
||||
|
||||
# backpropagate(leafNode, simTrajectoryReward)
|
||||
# end
|
||||
# end
|
||||
|
||||
# bestNextState = selectBestNextNode(root)
|
||||
# besttrajectory = selectBestTrajectoryNode(root)
|
||||
|
||||
# return (bestNextState=bestNextState.state, bestFinalState=besttrajectory.state)
|
||||
# end
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
52
src/mcts.jl
52
src/mcts.jl
@@ -199,47 +199,18 @@ end
|
||||
# Signature
|
||||
"""
|
||||
function expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple;
|
||||
horizontalSample::Integer=3)
|
||||
@sync for i in 1:horizontalSample
|
||||
@spawn _expand(node, transition, transitionargs)
|
||||
horizontalSample::Integer=3, multithread=false)
|
||||
if multithread
|
||||
@sync for i in 1:horizontalSample
|
||||
@spawn _expand(node, transition, transitionargs)
|
||||
end
|
||||
else
|
||||
for i in 1:horizontalSample
|
||||
_expand(node, transition, transitionargs)
|
||||
end
|
||||
end
|
||||
|
||||
#CHANGE for testing
|
||||
# for i in 1:horizontalSample
|
||||
# _expand(node, transition, transitionargs)
|
||||
# end
|
||||
end
|
||||
# function expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple;
|
||||
# horizontalSample::Integer=3)
|
||||
|
||||
# nthSample = 0
|
||||
# listOfNewNodeId = []
|
||||
# while true
|
||||
# nthSample += 1
|
||||
# if nthSample <= horizontalSample
|
||||
# result = transition(node.state, transitionargs)
|
||||
# newNodeKey::AbstractString = result[:newNodeKey]
|
||||
# newstate::AbstractDict = result[:newstate]
|
||||
# progressvalue::Integer = result[:progressvalue]
|
||||
|
||||
# """
|
||||
# [] newNodeKey ∉ keys(node.children).
|
||||
# New state may have semantic vector close enought to
|
||||
# one of existing child state. Which can be assume that they are the same state
|
||||
# semantically-wise i.e. De javu. This could be used to recall lessons for this
|
||||
# similar situation to improve decisionMaker and evaluator.
|
||||
# """
|
||||
# if newNodeKey ∉ keys(node.children)
|
||||
# push!(listOfNewNodeId, newNodeKey)
|
||||
# newNode = MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward],
|
||||
# newstate[:isterminal], node, Dict{String, MCTSNode}(), Dict{Symbol, Any}())
|
||||
# node.children[newNodeKey] = newNode
|
||||
# end
|
||||
# else
|
||||
# return listOfNewNodeId
|
||||
# end
|
||||
# end
|
||||
# end
|
||||
|
||||
function _expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple)
|
||||
result = transition(node.state, transitionargs)
|
||||
@@ -282,7 +253,7 @@ end
|
||||
# Signature
|
||||
"""
|
||||
function simulate(node::MCTSNode, transition::Function, transitionargs::NamedTuple;
|
||||
maxSimulationDepth::Integer=3, horizontalSample::Integer=3)
|
||||
maxSimulationDepth::Integer=3, horizontalSample::Integer=3, multithread=false)
|
||||
# )::NamedTuple{(:simTrajectoryReward, :terminalstate), Tuple{Number, Union{Dict{Symbol, Any}, Nothing}}}
|
||||
|
||||
simTrajectoryReward = 0.0
|
||||
@@ -295,7 +266,8 @@ function simulate(node::MCTSNode, transition::Function, transitionargs::NamedTup
|
||||
break
|
||||
else
|
||||
_ = expand(node, transition, transitionargs;
|
||||
horizontalSample=horizontalSample)
|
||||
horizontalSample=horizontalSample,
|
||||
multithread=multithread)
|
||||
node = selectChildNode(node)
|
||||
end
|
||||
end
|
||||
|
||||
Reference in New Issue
Block a user