diff --git a/src/interface.jl b/src/interface.jl index 4600a30..c04e59c 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -85,10 +85,11 @@ function runMCTS( # do nothing then go directly to backpropagation. It means the end of this iteration backpropagate(node, node.reward) else + println(111) _ = expand(node, transition, transitionargs; horizontalSample=horizontalSampleExpansionPhase) - println(555) + println(666) @sync for (leafNodeKey, leafNode) in node.children @spawn simulateThenBackpropagate(leafNode, transition, transitionargs; @@ -96,6 +97,14 @@ function runMCTS( horizontalSampleSimulationPhase=horizontalSampleSimulationPhase, saveSimulatedNode=saveSimulatedNode) end + + #CHANGE for testing + # for (leafNodeKey, leafNode) in node.children + # simulateThenBackpropagate(leafNode, transition, transitionargs; + # maxSimulationDepth=maxSimulationDepth, + # horizontalSampleSimulationPhase=horizontalSampleSimulationPhase, + # saveSimulatedNode=saveSimulatedNode) + # end end # stop if the early stop condition is met diff --git a/src/mcts.jl b/src/mcts.jl index 926140d..8563692 100644 --- a/src/mcts.jl +++ b/src/mcts.jl @@ -198,68 +198,67 @@ end # Signature """ -# function expand(node::MCTSNode, transition::Function, transitionargs::NamedTuple; -# totalsample::Integer=3) +function expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple; + horizontalSample::Integer=3) + @sync for i in 1:horizontalSample + @spawn _expand(node, transition, transitionargs) + end -# # not use Any[] because I want to preserve result order -# results = Vector{Any}(undef, totalsample) + #CHANGE for testing + # for i in 1:horizontalSample + # _expand(node, transition, transitionargs) + # end +end +# function expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple; +# horizontalSample::Integer=3) -# @sync for i in 1:totalsample -# @spawn begin -# result = transition(deepcopy(node.state), deepcopy(transitionargs)) -# results[i] = result +# nthSample = 0 +# listOfNewNodeId = [] +# while true +# nthSample += 1 +# if nthSample <= horizontalSample +# result = transition(node.state, transitionargs) +# newNodeKey::AbstractString = result[:newNodeKey] +# newstate::AbstractDict = result[:newstate] +# progressvalue::Integer = result[:progressvalue] + +# """ +# [] newNodeKey ∉ keys(node.children). +# New state may have semantic vector close enought to +# one of existing child state. Which can be assume that they are the same state +# semantically-wise i.e. De javu. This could be used to recall lessons for this +# similar situation to improve decisionMaker and evaluator. +# """ +# if newNodeKey ∉ keys(node.children) +# push!(listOfNewNodeId, newNodeKey) +# newNode = MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward], +# newstate[:isterminal], node, Dict{String, MCTSNode}(), Dict{Symbol, Any}()) +# node.children[newNodeKey] = newNode # end -# end - -# for result in results -# newNodeKey::AbstractString = result[:newNodeKey] -# newstate::AbstractDict = result[:newstate] -# progressvalue::Integer = result[:progressvalue] - -# """ -# [] newNodeKey ∉ keys(node.children). -# New state may have semantic vector close enought to -# one of existing child state. Which can be assume that they are the same state -# semantically-wise i.e. De javu. This could be used to recall lessons for this -# similar situation to improve decisionMaker and evaluator. -# """ -# if newNodeKey ∉ keys(node.children) -# node.children[newNodeKey] = -# MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward], -# newstate[:isterminal], node, Dict{String, MCTSNode}()) +# else +# return listOfNewNodeId # end # end # end -function expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple; - horizontalSample::Integer=3) - nthSample = 0 - listOfNewNodeId = [] - while true - nthSample += 1 - if nthSample <= horizontalSample - result = transition(node.state, transitionargs) - newNodeKey::AbstractString = result[:newNodeKey] - newstate::AbstractDict = result[:newstate] - progressvalue::Integer = result[:progressvalue] +function _expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple) + result = transition(node.state, transitionargs) + newNodeKey::AbstractString = result[:newNodeKey] + newstate::AbstractDict = result[:newstate] + progressvalue::Integer = result[:progressvalue] - """ - [] newNodeKey ∉ keys(node.children). - New state may have semantic vector close enought to - one of existing child state. Which can be assume that they are the same state - semantically-wise i.e. De javu. This could be used to recall lessons for this - similar situation to improve decisionMaker and evaluator. - """ - if newNodeKey ∉ keys(node.children) - push!(listOfNewNodeId, newNodeKey) - newNode = MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward], - newstate[:isterminal], node, Dict{String, MCTSNode}(), Dict{Symbol, Any}()) - node.children[newNodeKey] = newNode - end - else - return listOfNewNodeId + """ + [] newNodeKey ∉ keys(node.children). + New state may have semantic vector close enought to + one of existing child state. Which can be assume that they are the same state + semantically-wise i.e. De javu. This could be used to recall lessons for this + similar situation to improve decisionMaker and evaluator. + """ + if newNodeKey ∉ keys(node.children) + newNode = MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward], + newstate[:isterminal], node, Dict{String, MCTSNode}(), Dict{Symbol, Any}()) + node.children[newNodeKey] = newNode end - end end