update
This commit is contained in:
@@ -47,17 +47,17 @@ julia>
|
|||||||
# Signature
|
# Signature
|
||||||
"""
|
"""
|
||||||
function runMCTS(
|
function runMCTS(
|
||||||
initialState,
|
initialstate,
|
||||||
transition::Function,
|
transition::Function,
|
||||||
args...,
|
transitionargs::NamedTuple,
|
||||||
;
|
;
|
||||||
totalsample::Integer=3,
|
totalsample::Integer=3,
|
||||||
maxDepth::Integer=3,
|
maxdepth::Integer=3,
|
||||||
maxiterations::Integer=10,
|
maxiterations::Integer=10,
|
||||||
explorationweight::Number=1.0,
|
explorationweight::Number=1.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
root = MCTSNode("root", initialState, 0, 0, 0, 0, false, nothing, Dict{String, MCTSNode}())
|
root = MCTSNode("root", initialstate, 0, 0, 0, 0, false, nothing, Dict{String, MCTSNode}())
|
||||||
|
|
||||||
for nth in 1:maxiterations
|
for nth in 1:maxiterations
|
||||||
node = root
|
node = root
|
||||||
@@ -71,11 +71,11 @@ function runMCTS(
|
|||||||
# do nothing then go directly to backpropagation
|
# do nothing then go directly to backpropagation
|
||||||
backpropagate(leafNode, node.reward)
|
backpropagate(leafNode, node.reward)
|
||||||
else
|
else
|
||||||
expand(node, transition, args...;
|
expand(node, transition, transitionargs;
|
||||||
totalsample=totalsample)
|
totalsample=totalsample)
|
||||||
leafNode = selectChildNode(node)
|
leafNode = selectChildNode(node)
|
||||||
simTrajectoryReward, terminalstate = simulate(leafNode, transition, args...;
|
simTrajectoryReward, terminalstate = simulate(leafNode, transition, transitionargs;
|
||||||
maxDepth=maxDepth, totalsample=totalsample)
|
maxdepth=maxdepth, totalsample=totalsample)
|
||||||
if terminalstate !== nothing #XXX not sure why I need this
|
if terminalstate !== nothing #XXX not sure why I need this
|
||||||
terminalstate[:totalTrajectoryReward] = simTrajectoryReward
|
terminalstate[:totalTrajectoryReward] = simTrajectoryReward
|
||||||
end
|
end
|
||||||
|
|||||||
10
src/mcts.jl
10
src/mcts.jl
@@ -244,13 +244,13 @@ julia>
|
|||||||
|
|
||||||
# Signature
|
# Signature
|
||||||
"""
|
"""
|
||||||
function expand(node::MCTSNode,transition::Function, args...; totalsample::Integer=3)
|
function expand(node::MCTSNode,transition::Function, args::NamedTuple; totalsample::Integer=3)
|
||||||
|
|
||||||
nthSample = 0
|
nthSample = 0
|
||||||
while true
|
while true
|
||||||
nthSample += 1
|
nthSample += 1
|
||||||
if nthSample <= totalsample
|
if nthSample <= totalsample
|
||||||
newNodeKey, newstate, progressvalue = transition(node.state, args...)
|
newNodeKey, newstate, progressvalue = transition(node.state, args)
|
||||||
if newNodeKey ∉ keys(node.children)
|
if newNodeKey ∉ keys(node.children)
|
||||||
node.children[newNodeKey] =
|
node.children[newNodeKey] =
|
||||||
MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward],
|
MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward],
|
||||||
@@ -287,19 +287,19 @@ julia>
|
|||||||
# Signature
|
# Signature
|
||||||
"""
|
"""
|
||||||
function simulate(node::MCTSNode, transition::Function, args...;
|
function simulate(node::MCTSNode, transition::Function, args...;
|
||||||
maxDepth::Integer=3, totalsample::Integer=3
|
maxdepth::Integer=3, totalsample::Integer=3
|
||||||
)::Union{Tuple{Number, Dict{Symbol, <:Any}}, Tuple{Number, Nothing}}
|
)::Union{Tuple{Number, Dict{Symbol, <:Any}}, Tuple{Number, Nothing}}
|
||||||
|
|
||||||
simTrajectoryReward = 0.0
|
simTrajectoryReward = 0.0
|
||||||
terminalstate = nothing
|
terminalstate = nothing
|
||||||
|
|
||||||
for depth in 1:maxDepth
|
for depth in 1:maxdepth
|
||||||
simTrajectoryReward += node.reward
|
simTrajectoryReward += node.reward
|
||||||
if node.isterminal
|
if node.isterminal
|
||||||
terminalstate = node.state
|
terminalstate = node.state
|
||||||
break
|
break
|
||||||
else
|
else
|
||||||
expand(node, transition, args...;
|
expand(node, transition, args;
|
||||||
totalsample=totalsample)
|
totalsample=totalsample)
|
||||||
node = selectChildNode(node)
|
node = selectChildNode(node)
|
||||||
end
|
end
|
||||||
|
|||||||
Reference in New Issue
Block a user