This commit is contained in:
narawat lamaiin
2024-06-20 16:54:57 +07:00
parent e9268ce500
commit 202981d287
2 changed files with 12 additions and 12 deletions

View File

@@ -47,17 +47,17 @@ julia>
# Signature # Signature
""" """
function runMCTS( function runMCTS(
initialState, initialstate,
transition::Function, transition::Function,
args..., transitionargs::NamedTuple,
; ;
totalsample::Integer=3, totalsample::Integer=3,
maxDepth::Integer=3, maxdepth::Integer=3,
maxiterations::Integer=10, maxiterations::Integer=10,
explorationweight::Number=1.0, explorationweight::Number=1.0,
) )
root = MCTSNode("root", initialState, 0, 0, 0, 0, false, nothing, Dict{String, MCTSNode}()) root = MCTSNode("root", initialstate, 0, 0, 0, 0, false, nothing, Dict{String, MCTSNode}())
for nth in 1:maxiterations for nth in 1:maxiterations
node = root node = root
@@ -71,11 +71,11 @@ function runMCTS(
# do nothing then go directly to backpropagation # do nothing then go directly to backpropagation
backpropagate(leafNode, node.reward) backpropagate(leafNode, node.reward)
else else
expand(node, transition, args...; expand(node, transition, transitionargs;
totalsample=totalsample) totalsample=totalsample)
leafNode = selectChildNode(node) leafNode = selectChildNode(node)
simTrajectoryReward, terminalstate = simulate(leafNode, transition, args...; simTrajectoryReward, terminalstate = simulate(leafNode, transition, transitionargs;
maxDepth=maxDepth, totalsample=totalsample) maxdepth=maxdepth, totalsample=totalsample)
if terminalstate !== nothing #XXX not sure why I need this if terminalstate !== nothing #XXX not sure why I need this
terminalstate[:totalTrajectoryReward] = simTrajectoryReward terminalstate[:totalTrajectoryReward] = simTrajectoryReward
end end

View File

@@ -244,13 +244,13 @@ julia>
# Signature # Signature
""" """
function expand(node::MCTSNode,transition::Function, args...; totalsample::Integer=3) function expand(node::MCTSNode,transition::Function, args::NamedTuple; totalsample::Integer=3)
nthSample = 0 nthSample = 0
while true while true
nthSample += 1 nthSample += 1
if nthSample <= totalsample if nthSample <= totalsample
newNodeKey, newstate, progressvalue = transition(node.state, args...) newNodeKey, newstate, progressvalue = transition(node.state, args)
if newNodeKey keys(node.children) if newNodeKey keys(node.children)
node.children[newNodeKey] = node.children[newNodeKey] =
MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward], MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward],
@@ -287,19 +287,19 @@ julia>
# Signature # Signature
""" """
function simulate(node::MCTSNode, transition::Function, args...; function simulate(node::MCTSNode, transition::Function, args...;
maxDepth::Integer=3, totalsample::Integer=3 maxdepth::Integer=3, totalsample::Integer=3
)::Union{Tuple{Number, Dict{Symbol, <:Any}}, Tuple{Number, Nothing}} )::Union{Tuple{Number, Dict{Symbol, <:Any}}, Tuple{Number, Nothing}}
simTrajectoryReward = 0.0 simTrajectoryReward = 0.0
terminalstate = nothing terminalstate = nothing
for depth in 1:maxDepth for depth in 1:maxdepth
simTrajectoryReward += node.reward simTrajectoryReward += node.reward
if node.isterminal if node.isterminal
terminalstate = node.state terminalstate = node.state
break break
else else
expand(node, transition, args...; expand(node, transition, args;
totalsample=totalsample) totalsample=totalsample)
node = selectChildNode(node) node = selectChildNode(node)
end end