update
This commit is contained in:
@@ -2,6 +2,7 @@ module interface
|
||||
|
||||
export runMCTS
|
||||
|
||||
using Base.Threads
|
||||
using ..type, ..mcts, ..util
|
||||
|
||||
|
||||
@@ -45,7 +46,8 @@ function runMCTS(
|
||||
transition::Function,
|
||||
transitionargs::NamedTuple,
|
||||
;
|
||||
totalsample::Integer=3,
|
||||
horizontalSampleExpansionPhase::Integer=3,
|
||||
horizontalSampleSimulationPhase::Integer=3,
|
||||
maxdepth::Integer=3,
|
||||
maxiterations::Integer=10,
|
||||
explorationweight::Number=1.0,
|
||||
@@ -69,10 +71,20 @@ function runMCTS(
|
||||
backpropagate(node, node.reward)
|
||||
else
|
||||
_ = expand(node, transition, transitionargs;
|
||||
totalsample=totalsample)
|
||||
horizontalSample=horizontalSampleExpansionPhase)
|
||||
#[WORKING] make simulation parallel, leafNodes must be newly expanded nodes
|
||||
|
||||
leafNode = selectChildNode(node)
|
||||
simTrajectoryReward, terminalstate = simulate(leafNode, transition, transitionargs;
|
||||
maxdepth=maxdepth, totalsample=totalsample)
|
||||
|
||||
|
||||
|
||||
|
||||
# outputch = Channel(8)
|
||||
|
||||
#[WORKING] test whether multiple spawn retain result leafNode's child node
|
||||
|
||||
@spawn simulate(outputch, leafNode, transition, transitionargs;
|
||||
maxdepth=maxdepth, horizontalSample=horizontalSampleSimulationPhase)
|
||||
# if terminalstate !== nothing #XXX not sure why I need this
|
||||
# terminalstate[:totalTrajectoryReward] = simTrajectoryReward
|
||||
# end
|
||||
@@ -82,10 +94,9 @@ function runMCTS(
|
||||
# JSON3.pretty(io, terminalstate)
|
||||
# end
|
||||
|
||||
backpropagate(leafNode, simTrajectoryReward)
|
||||
# result = take!(outputch)
|
||||
|
||||
# delete all child node, no need for child node that was created during simulation
|
||||
leafNode.children = Dict{String,MCTSNode}()
|
||||
backpropagate(leafNode, simTrajectoryReward)
|
||||
end
|
||||
|
||||
# stop if the early stop condition is met
|
||||
|
||||
Reference in New Issue
Block a user