This commit is contained in:
narawat lamaiin
2024-06-01 15:03:15 +07:00
parent de51dfd69d
commit 3269b27237
10 changed files with 1394 additions and 15 deletions

View File

@@ -47,12 +47,9 @@ julia>
# Signature
"""
function runMCTS(
config::T1,
initialState,
decisionMaker::Function,
evaluator::Function,
reflector::Function,
transition::Function,
args...,
;
totalsample::Integer=3,
maxDepth::Integer=3,
@@ -74,11 +71,11 @@ function runMCTS(
# do nothing then go directly to backpropagation
backpropagate(leafNode, node.reward)
else
expand(config, node, decisionMaker, evaluator, reflector, transition;
expand(node, transition, args...;
totalsample=totalsample)
leafNode = selectChildNode(node)
simTrajectoryReward, terminalstate = simulate(config, leafNode, decisionMaker, evaluator,
reflector, transition; maxDepth=maxDepth, totalsample=totalsample)
simTrajectoryReward, terminalstate = simulate(leafNode, transition, args...;
maxDepth=maxDepth, totalsample=totalsample)
if terminalstate !== nothing #XXX not sure why I need this
terminalstate[:totalTrajectoryReward] = simTrajectoryReward
end

View File

@@ -244,16 +244,13 @@ julia>
# Signature
"""
function expand(config::T1, node::MCTSNode, decisionMaker::Function, evaluator::Function,
reflector::Function, transition::Function; totalsample::Integer=3
) where {T1<:AbstractDict}
function expand(node::MCTSNode,transition::Function, args...; totalsample::Integer=3)
nthSample = 0
while true
nthSample += 1
if nthSample <= totalsample
newNodeKey, newstate, progressvalue = transition(config, node.state, decisionMaker,
evaluator, reflector)
newNodeKey, newstate, progressvalue = transition(node.state, args...)
if newNodeKey keys(node.children)
node.children[newNodeKey] =
MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward],
@@ -289,8 +286,8 @@ julia>
# Signature
"""
function simulate(config::T, node::MCTSNode, decisionMaker::Function, evaluator::Function,
reflector::Function, transition::Function; maxDepth::Integer=3, totalsample::Integer=3
function simulate(node::MCTSNode, transition::Function, args...;
maxDepth::Integer=3, totalsample::Integer=3
)::Union{Tuple{Number, Dict{Symbol, <:Any}}, Tuple{Number, Nothing}} where {T<:AbstractDict}
simTrajectoryReward = 0.0
@@ -302,7 +299,7 @@ function simulate(config::T, node::MCTSNode, decisionMaker::Function, evaluator:
terminalstate = node.state
break
else
expand(config, node, decisionMaker, evaluator, reflector, transition;
expand(node, transition, args...;
totalsample=totalsample)
node = selectChildNode(node)
end