This commit is contained in:
narawat lamaiin
2024-05-05 07:03:52 +07:00
parent 8907156522
commit 8f3132c7cd
2 changed files with 30 additions and 18 deletions

View File

@@ -122,20 +122,21 @@ julia>
# Signature
"""
function expand(a::T1, node::MCTSNode, decisionMaker::Function,
progressValueEstimator::Function; n::Integer=3) where {T1<:agent}
progressValueEstimator::Function, isterminal::Function; n::Integer=3) where {T1<:agent}
# sampling action from decisionMaker
for sample in 1:n
thoughtDict = decisionMaker(a, node.state)
newNodeKey, newstate, isterminal, reward = MCTStransition(a, node.state, thoughtDict)
newNodeKey, newstate, isterminalstate, reward = MCTStransition(a, node.state, thoughtDict,
isterminal)
# add progressValueEstimator
progressRationale, progressValue = progressValueEstimator(a, newstate)
if newNodeKey keys(node.children)
node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, progressValue,
reward, isterminal, node, Dict{String, MCTSNode}())
reward, isterminalstate, node, Dict{String, MCTSNode}())
end
end
end
@@ -158,11 +159,12 @@ julia>
# Signature
"""
function simulate(a, node::MCTSNode, decisionMaker, progressValueEstimator, max_depth::Int; n=3)
function simulate(a, node::MCTSNode, decisionMaker::Function, progressValueEstimator::Function,
isterminal::Function, max_depth::Int; n=3)
for _ in 1:max_depth
node = selectChildNode(node)
expand(a, node, decisionMaker, progressValueEstimator, n=n)
expand(a, node, decisionMaker, progressValueEstimator, isterminal, n=n)
# if isterminal (use for loop over node to look for childNode.reward != 0)
@@ -243,8 +245,9 @@ julia> thoughtDict = Dict(
# Signature
"""
function MCTStransition(a::T1, state::T2,
thoughtDict::T3)::Tuple{String, Dict{Symbol, Any}} where {T1<:agent, T2<:AbstractDict, T3<:AbstractDict}
function MCTStransition(a::T1, state::T2, thoughtDict::T3, isterminal::Function
)::Tuple{String, Dict{Symbol, <:Any}, Bool, <:Number} where {T1<:agent, T2<:AbstractDict, T3<:AbstractDict}
actionname = thoughtDict[:Action][:name]
actioninput = thoughtDict[:Action][:input]
@@ -401,11 +404,12 @@ function runMCTS(
node = UCTselect(node, w)
end
expand(a, node, decisionMaker, progressValueEstimator, n=n)
expand(a, node, decisionMaker, progressValueEstimator, isterminal, n=n)
# from paper, just start simulation at this node. Not the node that newly expanded
startsim_node = node
reward = simulate(a, startsim_node, decisionMaker, progressValueEstimator, maxDepth, n=n)
reward = simulate(a, startsim_node, decisionMaker, progressValueEstimator,
isterminal, maxDepth, n=n)
backpropagate(leaf_node, reward)
end