This commit is contained in:
narawat lamaiin
2025-03-07 13:33:38 +07:00
parent 6920be2334
commit 9add88b145
10 changed files with 28 additions and 1401 deletions

View File

@@ -2,6 +2,7 @@ module interface
export runMCTS
using Base.Threads
using ..type, ..mcts, ..util
@@ -45,7 +46,8 @@ function runMCTS(
transition::Function,
transitionargs::NamedTuple,
;
totalsample::Integer=3,
horizontalSampleExpansionPhase::Integer=3,
horizontalSampleSimulationPhase::Integer=3,
maxdepth::Integer=3,
maxiterations::Integer=10,
explorationweight::Number=1.0,
@@ -69,10 +71,20 @@ function runMCTS(
backpropagate(node, node.reward)
else
_ = expand(node, transition, transitionargs;
totalsample=totalsample)
horizontalSample=horizontalSampleExpansionPhase)
#[WORKING] make simulation parallel, leafNodes must be newly expanded nodes
leafNode = selectChildNode(node)
simTrajectoryReward, terminalstate = simulate(leafNode, transition, transitionargs;
maxdepth=maxdepth, totalsample=totalsample)
# outputch = Channel(8)
#[WORKING] test whether multiple spawn retain result leafNode's child node
@spawn simulate(outputch, leafNode, transition, transitionargs;
maxdepth=maxdepth, horizontalSample=horizontalSampleSimulationPhase)
# if terminalstate !== nothing #XXX not sure why I need this
# terminalstate[:totalTrajectoryReward] = simTrajectoryReward
# end
@@ -82,10 +94,9 @@ function runMCTS(
# JSON3.pretty(io, terminalstate)
# end
backpropagate(leafNode, simTrajectoryReward)
# result = take!(outputch)
# delete all child node, no need for child node that was created during simulation
leafNode.children = Dict{String,MCTSNode}()
backpropagate(leafNode, simTrajectoryReward)
end
# stop if the early stop condition is met