update
This commit is contained in:
@@ -2,7 +2,7 @@ module interface
|
|||||||
|
|
||||||
export runMCTS
|
export runMCTS
|
||||||
|
|
||||||
using ..type, ..mcts
|
using ..type, ..mcts, ..util
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------- 100 --------------------------------------------- #
|
# ---------------------------------------------- 100 --------------------------------------------- #
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ julia>
|
|||||||
# Signature
|
# Signature
|
||||||
"""
|
"""
|
||||||
function selectBestNextState(node::MCTSNode)::MCTSNode
|
function selectBestNextState(node::MCTSNode)::MCTSNode
|
||||||
highestProgressValue = 0
|
highestProgressValue = -1
|
||||||
nodekey = nothing
|
nodekey = nothing
|
||||||
|
|
||||||
# if all childnode has statevalue == 0, use progressvalue + reward to select the best node
|
# if all childnode has statevalue == 0, use progressvalue + reward to select the best node
|
||||||
@@ -251,7 +251,10 @@ function expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple;
|
|||||||
while true
|
while true
|
||||||
nthSample += 1
|
nthSample += 1
|
||||||
if nthSample <= totalsample
|
if nthSample <= totalsample
|
||||||
newNodeKey, newstate, progressvalue = transition(node.state, transitionargs)
|
result = transition(node.state, transitionargs)
|
||||||
|
newNodeKey::AbstractString = result[:newNodeKey]
|
||||||
|
newstate::AbstractDict = result[:newstate]
|
||||||
|
progressvalue::Integer = result[:progressvalue]
|
||||||
if newNodeKey ∉ keys(node.children)
|
if newNodeKey ∉ keys(node.children)
|
||||||
node.children[newNodeKey] =
|
node.children[newNodeKey] =
|
||||||
MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward],
|
MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward],
|
||||||
|
|||||||
Reference in New Issue
Block a user