update
This commit is contained in:
93
src/mcts.jl
93
src/mcts.jl
@@ -51,7 +51,7 @@ struct MCTSNode{T<:AbstractDict}
|
||||
nodekey::String
|
||||
state::T
|
||||
visits::Integer
|
||||
progressValue::Number
|
||||
statevalue::Number
|
||||
reward::Number
|
||||
isterminal::Bool
|
||||
parent::Union{MCTSNode, Nothing}
|
||||
@@ -83,7 +83,7 @@ function UCTselect(node::MCTSNode, w::Float64)
|
||||
selectedNode = nothing
|
||||
|
||||
for (childState, childNode) in node.children
|
||||
uctValue = childNode.stateValue +
|
||||
uctValue = childNode.statevalue +
|
||||
w * sqrt(log(node.visits) / childNode.visits)
|
||||
if uctValue > max_uct
|
||||
max_uct = uctValue
|
||||
@@ -132,10 +132,11 @@ function expand(a::T1, node::MCTSNode, decisionMaker::Function,
|
||||
isterminal)
|
||||
|
||||
# add progressValueEstimator
|
||||
progressRationale, progressValue = progressValueEstimator(a, newstate)
|
||||
progressRationale, statevalue = progressValueEstimator(a, newstate)
|
||||
statevalue += reward
|
||||
|
||||
if newNodeKey ∉ keys(node.children)
|
||||
node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, progressValue,
|
||||
node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, statevalue,
|
||||
reward, isterminalstate, node, Dict{String, MCTSNode}())
|
||||
end
|
||||
end
|
||||
@@ -144,6 +145,8 @@ end
|
||||
"""
|
||||
|
||||
# Arguments
|
||||
- `node::MCTSNode`
|
||||
node that will be a simulation starting point.
|
||||
|
||||
# Return
|
||||
|
||||
@@ -160,19 +163,40 @@ julia>
|
||||
# Signature
|
||||
"""
|
||||
function simulate(a, node::MCTSNode, decisionMaker::Function, progressValueEstimator::Function,
|
||||
isterminal::Function, max_depth::Int; n=3)
|
||||
isterminal::Function, max_depth::Int; n=3)::Number
|
||||
|
||||
simTrajectoryReward = 0.0
|
||||
|
||||
for _ in 1:max_depth
|
||||
if node.isterminal
|
||||
break
|
||||
else
|
||||
expand(a, node, decisionMaker, progressValueEstimator, isterminal, n=n)
|
||||
end
|
||||
node = selectChildNode(node)
|
||||
expand(a, node, decisionMaker, progressValueEstimator, isterminal, n=n)
|
||||
|
||||
# if isterminal (use for loop over node to look for childNode.reward != 0)
|
||||
|
||||
|
||||
simTrajectoryReward += node.reward
|
||||
end
|
||||
error("--> simulate")
|
||||
return total_reward
|
||||
|
||||
return simTrajectoryReward
|
||||
end
|
||||
# function simulate(a, node::MCTSNode, decisionMaker::Function, progressValueEstimator::Function,
|
||||
# isterminal::Function, max_depth::Int; n=3)::Number
|
||||
|
||||
# simTrajectoryReward = 0.0
|
||||
|
||||
# for _ in 1:max_depth
|
||||
# node = selectChildNode(node)
|
||||
# simTrajectoryReward += node.reward
|
||||
|
||||
# if node.isterminal
|
||||
# break
|
||||
# else
|
||||
# expand(a, node, decisionMaker, progressValueEstimator, isterminal, n=n)
|
||||
# end
|
||||
# end
|
||||
|
||||
# return simTrajectoryReward
|
||||
# end
|
||||
|
||||
"""
|
||||
|
||||
@@ -187,20 +211,32 @@ julia>
|
||||
|
||||
# TODO
|
||||
- [] update docstring
|
||||
- [] implement the function
|
||||
- [WORKING] implement the function
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function backpropagate(node::MCTSNode, reward::Float64)
|
||||
node.visits += 1
|
||||
|
||||
# [] there is no total_reward in the paper, buy they use stateValue
|
||||
node.total_reward += reward
|
||||
if !isempty(node.children)
|
||||
best_child = argmax([child.total_reward / child.visits for child in values(node.children)])
|
||||
backpropagate(node.children[best_child], -reward)
|
||||
end
|
||||
function backpropagate(node, simTrajectoryReward; discountRewardCoeff=0.9)
|
||||
# Update the statistics of the current node based on the result of the playout
|
||||
node.visits += 1
|
||||
node.statevalue += ((node.statevalue * (node.visits-1)) + simTrajectoryReward) / node.visits
|
||||
|
||||
# Backpropagate the result to the parent node recursively
|
||||
if !isroot(node)
|
||||
simTrajectoryReward *= discountRewardCoeff
|
||||
backpropagate(node.parent, simTrajectoryReward)
|
||||
end
|
||||
end
|
||||
# function backpropagate(node::MCTSNode, reward::Float64)
|
||||
# node.visits += 1
|
||||
|
||||
# # [] there is no total_reward in the paper, buy they use stateValue
|
||||
# node.total_reward += reward
|
||||
# if !isempty(node.children)
|
||||
# best_child = argmax([child.total_reward / child.visits for child in values(node.children)])
|
||||
# backpropagate(node.children[best_child], -reward)
|
||||
# end
|
||||
# end
|
||||
|
||||
|
||||
""" Get a new state
|
||||
|
||||
@@ -310,7 +346,7 @@ true
|
||||
isleaf(node::MCTSNode)::Bool = isempty(node.children)
|
||||
|
||||
|
||||
""" Select child node based on the highest progressValue
|
||||
""" Select child node based on the highest statevalue
|
||||
|
||||
# Arguments
|
||||
- `node::MCTSNode`
|
||||
@@ -333,8 +369,8 @@ function selectChildNode(node::MCTSNode)::MCTSNode
|
||||
|
||||
# loop thought node children dictionary to find the highest progress value
|
||||
for (k, childNode) in node.children
|
||||
if childNode.progressValue > highestProgressValue
|
||||
highestProgressValue = childNode.progressValue
|
||||
if childNode.statevalue > highestProgressValue
|
||||
highestProgressValue = childNode.statevalue
|
||||
nodekey = childNode.nodekey
|
||||
end
|
||||
end
|
||||
@@ -402,11 +438,10 @@ function runMCTS(
|
||||
|
||||
expand(a, node, decisionMaker, progressValueEstimator, isterminal, n=n)
|
||||
|
||||
# from paper, just start simulation at this node. Not the node that newly expanded
|
||||
startsim_node = node
|
||||
reward = simulate(a, startsim_node, decisionMaker, progressValueEstimator,
|
||||
leaf_node = selectChildNode(node)
|
||||
simTrajectoryReward = simulate(a, leaf_node, decisionMaker, progressValueEstimator,
|
||||
isterminal, maxDepth, n=n)
|
||||
backpropagate(leaf_node, reward)
|
||||
backpropagate(leaf_node, simTrajectoryReward)
|
||||
end
|
||||
|
||||
best_child_state = argmax([child.total_reward / child.visits for child in values(root.children)])
|
||||
|
||||
Reference in New Issue
Block a user