This commit is contained in:
narawat lamaiin
2025-03-07 13:33:38 +07:00
parent 6920be2334
commit 9add88b145
10 changed files with 28 additions and 1401 deletions

View File

@@ -231,13 +231,13 @@ end
# end
# end
function expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple;
totalsample::Integer=3)
horizontalSample::Integer=3)
nthSample = 0
listOfNewNodeId = []
while true
nthSample += 1
if nthSample <= totalsample
if nthSample <= horizontalSample
result = transition(node.state, transitionargs)
newNodeKey::AbstractString = result[:newNodeKey]
newstate::AbstractDict = result[:newstate]
@@ -252,9 +252,9 @@ function expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple;
"""
if newNodeKey keys(node.children)
push!(listOfNewNodeId, newNodeKey)
node.children[newNodeKey] =
MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward],
newNode = MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward],
newstate[:isterminal], node, Dict{String, MCTSNode}(), Dict{Symbol, Any}())
node.children[newNodeKey] = newNode
end
else
return listOfNewNodeId
@@ -274,7 +274,7 @@ end
Arguments for everything the user will use within transition().
- `maxdepth::Integer`
maximum depth level MCTS goes vertically.
- totalsample::Integer
- horizontalSample::Integer
Total number to sample from the current node (i.e. expand new node horizontally)
# Return
@@ -282,8 +282,8 @@ end
# Signature
"""
function simulate(node::MCTSNode, transition::Function, transitionargs::NamedTuple;
maxdepth::Integer=3, totalsample::Integer=3
function simulate(outputchannel::Channel, node::MCTSNode, transition::Function, transitionargs::NamedTuple;
maxdepth::Integer=3, horizontalSample::Integer=3
)::NamedTuple{(:simTrajectoryReward, :terminalstate), Tuple{Number, Union{Dict{Symbol, Any}, Nothing}}}
simTrajectoryReward = 0.0
@@ -297,12 +297,13 @@ function simulate(node::MCTSNode, transition::Function, transitionargs::NamedTup
break
else
_ = expand(node, transition, transitionargs;
totalsample=totalsample)
horizontalSample=horizontalSample)
node = selectChildNode(node)
end
end
return (simTrajectoryReward=simTrajectoryReward, terminalstate=terminalstate)
put!(outputchannel, (simTrajectoryReward=simTrajectoryReward, terminalstate=terminalstate))
# return (simTrajectoryReward=simTrajectoryReward, terminalstate=terminalstate)
end