update
This commit is contained in:
19
src/mcts.jl
19
src/mcts.jl
@@ -231,13 +231,13 @@ end
|
||||
# end
|
||||
# end
|
||||
function expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple;
|
||||
totalsample::Integer=3)
|
||||
horizontalSample::Integer=3)
|
||||
|
||||
nthSample = 0
|
||||
listOfNewNodeId = []
|
||||
while true
|
||||
nthSample += 1
|
||||
if nthSample <= totalsample
|
||||
if nthSample <= horizontalSample
|
||||
result = transition(node.state, transitionargs)
|
||||
newNodeKey::AbstractString = result[:newNodeKey]
|
||||
newstate::AbstractDict = result[:newstate]
|
||||
@@ -252,9 +252,9 @@ function expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple;
|
||||
"""
|
||||
if newNodeKey ∉ keys(node.children)
|
||||
push!(listOfNewNodeId, newNodeKey)
|
||||
node.children[newNodeKey] =
|
||||
MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward],
|
||||
newNode = MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward],
|
||||
newstate[:isterminal], node, Dict{String, MCTSNode}(), Dict{Symbol, Any}())
|
||||
node.children[newNodeKey] = newNode
|
||||
end
|
||||
else
|
||||
return listOfNewNodeId
|
||||
@@ -274,7 +274,7 @@ end
|
||||
Arguments for everything the user will use within transition().
|
||||
- `maxdepth::Integer`
|
||||
maximum depth level MCTS goes vertically.
|
||||
- totalsample::Integer
|
||||
- horizontalSample::Integer
|
||||
Total number to sample from the current node (i.e. expand new node horizontally)
|
||||
|
||||
# Return
|
||||
@@ -282,8 +282,8 @@ end
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function simulate(node::MCTSNode, transition::Function, transitionargs::NamedTuple;
|
||||
maxdepth::Integer=3, totalsample::Integer=3
|
||||
function simulate(outputchannel::Channel, node::MCTSNode, transition::Function, transitionargs::NamedTuple;
|
||||
maxdepth::Integer=3, horizontalSample::Integer=3
|
||||
)::NamedTuple{(:simTrajectoryReward, :terminalstate), Tuple{Number, Union{Dict{Symbol, Any}, Nothing}}}
|
||||
|
||||
simTrajectoryReward = 0.0
|
||||
@@ -297,12 +297,13 @@ function simulate(node::MCTSNode, transition::Function, transitionargs::NamedTup
|
||||
break
|
||||
else
|
||||
_ = expand(node, transition, transitionargs;
|
||||
totalsample=totalsample)
|
||||
horizontalSample=horizontalSample)
|
||||
node = selectChildNode(node)
|
||||
end
|
||||
end
|
||||
|
||||
return (simTrajectoryReward=simTrajectoryReward, terminalstate=terminalstate)
|
||||
put!(outputchannel, (simTrajectoryReward=simTrajectoryReward, terminalstate=terminalstate))
|
||||
# return (simTrajectoryReward=simTrajectoryReward, terminalstate=terminalstate)
|
||||
end
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user