update
This commit is contained in:
13
src/mcts.jl
13
src/mcts.jl
@@ -244,16 +244,13 @@ julia>
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function expand(config::T1, node::MCTSNode, decisionMaker::Function, evaluator::Function,
|
||||
reflector::Function, transition::Function; totalsample::Integer=3
|
||||
) where {T1<:AbstractDict}
|
||||
function expand(node::MCTSNode,transition::Function, args...; totalsample::Integer=3)
|
||||
|
||||
nthSample = 0
|
||||
while true
|
||||
nthSample += 1
|
||||
if nthSample <= totalsample
|
||||
newNodeKey, newstate, progressvalue = transition(config, node.state, decisionMaker,
|
||||
evaluator, reflector)
|
||||
newNodeKey, newstate, progressvalue = transition(node.state, args...)
|
||||
if newNodeKey ∉ keys(node.children)
|
||||
node.children[newNodeKey] =
|
||||
MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward],
|
||||
@@ -289,8 +286,8 @@ julia>
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function simulate(config::T, node::MCTSNode, decisionMaker::Function, evaluator::Function,
|
||||
reflector::Function, transition::Function; maxDepth::Integer=3, totalsample::Integer=3
|
||||
function simulate(node::MCTSNode, transition::Function, args...;
|
||||
maxDepth::Integer=3, totalsample::Integer=3
|
||||
)::Union{Tuple{Number, Dict{Symbol, <:Any}}, Tuple{Number, Nothing}} where {T<:AbstractDict}
|
||||
|
||||
simTrajectoryReward = 0.0
|
||||
@@ -302,7 +299,7 @@ function simulate(config::T, node::MCTSNode, decisionMaker::Function, evaluator:
|
||||
terminalstate = node.state
|
||||
break
|
||||
else
|
||||
expand(config, node, decisionMaker, evaluator, reflector, transition;
|
||||
expand(node, transition, args...;
|
||||
totalsample=totalsample)
|
||||
node = selectChildNode(node)
|
||||
end
|
||||
|
||||
Reference in New Issue
Block a user