update
This commit is contained in:
51
src/mcts.jl
51
src/mcts.jl
@@ -51,7 +51,6 @@ mutable struct MCTSNode{T1<:AbstractDict, T2<:AbstractString}
|
||||
nodekey::T2
|
||||
state::T1
|
||||
visits::Integer
|
||||
stateevaluation::T2
|
||||
statevalue::Number
|
||||
reward::Number
|
||||
isterminal::Bool
|
||||
@@ -134,29 +133,12 @@ julia>
|
||||
# Signature
|
||||
"""
|
||||
function expand(a::T1, node::MCTSNode, decisionMaker::Function,
|
||||
progressValueEstimator::Function; n::Integer=3) where {T1<:agent}
|
||||
progressValueEstimator::Function, reflector::Function; n::Integer=3) where {T1<:agent}
|
||||
|
||||
nthSample = 0
|
||||
while nthSample < n
|
||||
try
|
||||
thoughtDict = decisionMaker(a, node.state)
|
||||
|
||||
newNodeKey, newstate, reward, isterminalstate =
|
||||
MCTStransition(a, node.state, thoughtDict)
|
||||
|
||||
# add progressValueEstimator
|
||||
stateevaluation, statevalue = progressValueEstimator(a, newstate)
|
||||
|
||||
if reward < 0
|
||||
pprint(newstate[:thoughtHistory])
|
||||
newstate.feedback = stateevaluation
|
||||
end
|
||||
|
||||
if newNodeKey ∉ keys(node.children)
|
||||
node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, stateevaluation, statevalue,
|
||||
reward, isterminalstate, node, Dict{String, MCTSNode}())
|
||||
end
|
||||
nthSample += 1
|
||||
catch e
|
||||
io = IOBuffer()
|
||||
showerror(io, e)
|
||||
@@ -166,6 +148,27 @@ function expand(a::T1, node::MCTSNode, decisionMaker::Function,
|
||||
@warn "Error occurred: $errorMsg\n$st"
|
||||
println("")
|
||||
end
|
||||
|
||||
thoughtDict = decisionMaker(a, node.state)
|
||||
|
||||
newNodeKey, newstate, reward, isterminalstate =
|
||||
MCTStransition(a, node.state, thoughtDict)
|
||||
|
||||
# add progressValueEstimator
|
||||
stateevaluation, statevalue = progressValueEstimator(a, newstate)
|
||||
|
||||
if reward < 0
|
||||
pprint(newstate[:thoughtHistory])
|
||||
newstate[:evaluation] = stateevaluation
|
||||
newstate[:feedback] = reflector(a, newstate)
|
||||
print("done reflection")
|
||||
end
|
||||
|
||||
if newNodeKey ∉ keys(node.children)
|
||||
node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, statevalue,
|
||||
reward, isterminalstate, node, Dict{String, MCTSNode}())
|
||||
end
|
||||
nthSample += 1
|
||||
end
|
||||
end
|
||||
|
||||
@@ -192,7 +195,7 @@ julia>
|
||||
# Signature
|
||||
"""
|
||||
function simulate(a::T, node::MCTSNode, decisionMaker::Function, progressValueEstimator::Function,
|
||||
maxDepth::Int; n=3)::Number where {T<:agent}
|
||||
reflector::Function; maxDepth::Integer=3, n::Integer=3)::Number where {T<:agent}
|
||||
|
||||
simTrajectoryReward = 0.0
|
||||
|
||||
@@ -201,7 +204,7 @@ function simulate(a::T, node::MCTSNode, decisionMaker::Function, progressValueEs
|
||||
if node.isterminal
|
||||
break
|
||||
else
|
||||
expand(a, node, decisionMaker, progressValueEstimator, n=n)
|
||||
expand(a, node, decisionMaker, progressValueEstimator, reflector; n=n)
|
||||
node = selectChildNode(node)
|
||||
end
|
||||
end
|
||||
@@ -456,7 +459,7 @@ function runMCTS(
|
||||
maxIterations::Integer,
|
||||
w::Float64) where {T1<:agent}
|
||||
|
||||
root = MCTSNode("root", initialState, 0, "N/A", 0, 0, false, nothing, Dict{String, MCTSNode}())
|
||||
root = MCTSNode("root", initialState, 0, 0, 0, false, nothing, Dict{String, MCTSNode}())
|
||||
|
||||
for nth in 1:maxIterations
|
||||
node = root
|
||||
@@ -468,10 +471,10 @@ function runMCTS(
|
||||
# do nothing then go directly to backpropagation
|
||||
backpropagate(leafNode, node.reward)
|
||||
else
|
||||
expand(a, node, decisionMaker, progressValueEstimator, n=n)
|
||||
expand(a, node, decisionMaker, progressValueEstimator, reflector; n=n)
|
||||
leafNode = UCTselect(node, w)
|
||||
simTrajectoryReward = simulate(a, leafNode, decisionMaker, progressValueEstimator,
|
||||
maxDepth, n=n)
|
||||
reflector; maxDepth=maxDepth, n=n)
|
||||
backpropagate(leafNode, simTrajectoryReward)
|
||||
end
|
||||
end
|
||||
|
||||
Reference in New Issue
Block a user