This commit is contained in:
narawat lamaiin
2024-05-09 09:02:34 +07:00
parent 46bbb31699
commit 6ce4b90d26
3 changed files with 73 additions and 110 deletions

View File

@@ -51,7 +51,6 @@ mutable struct MCTSNode{T1<:AbstractDict, T2<:AbstractString}
nodekey::T2
state::T1
visits::Integer
stateevaluation::T2
statevalue::Number
reward::Number
isterminal::Bool
@@ -134,29 +133,12 @@ julia>
# Signature
"""
function expand(a::T1, node::MCTSNode, decisionMaker::Function,
progressValueEstimator::Function; n::Integer=3) where {T1<:agent}
progressValueEstimator::Function, reflector::Function; n::Integer=3) where {T1<:agent}
nthSample = 0
while nthSample < n
try
thoughtDict = decisionMaker(a, node.state)
newNodeKey, newstate, reward, isterminalstate =
MCTStransition(a, node.state, thoughtDict)
# add progressValueEstimator
stateevaluation, statevalue = progressValueEstimator(a, newstate)
if reward < 0
pprint(newstate[:thoughtHistory])
newstate.feedback = stateevaluation
end
if newNodeKey keys(node.children)
node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, stateevaluation, statevalue,
reward, isterminalstate, node, Dict{String, MCTSNode}())
end
nthSample += 1
catch e
io = IOBuffer()
showerror(io, e)
@@ -166,6 +148,27 @@ function expand(a::T1, node::MCTSNode, decisionMaker::Function,
@warn "Error occurred: $errorMsg\n$st"
println("")
end
thoughtDict = decisionMaker(a, node.state)
newNodeKey, newstate, reward, isterminalstate =
MCTStransition(a, node.state, thoughtDict)
# add progressValueEstimator
stateevaluation, statevalue = progressValueEstimator(a, newstate)
if reward < 0
pprint(newstate[:thoughtHistory])
newstate[:evaluation] = stateevaluation
newstate[:feedback] = reflector(a, newstate)
print("done reflection")
end
if newNodeKey keys(node.children)
node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, statevalue,
reward, isterminalstate, node, Dict{String, MCTSNode}())
end
nthSample += 1
end
end
@@ -192,7 +195,7 @@ julia>
# Signature
"""
function simulate(a::T, node::MCTSNode, decisionMaker::Function, progressValueEstimator::Function,
maxDepth::Int; n=3)::Number where {T<:agent}
reflector::Function; maxDepth::Integer=3, n::Integer=3)::Number where {T<:agent}
simTrajectoryReward = 0.0
@@ -201,7 +204,7 @@ function simulate(a::T, node::MCTSNode, decisionMaker::Function, progressValueEs
if node.isterminal
break
else
expand(a, node, decisionMaker, progressValueEstimator, n=n)
expand(a, node, decisionMaker, progressValueEstimator, reflector; n=n)
node = selectChildNode(node)
end
end
@@ -456,7 +459,7 @@ function runMCTS(
maxIterations::Integer,
w::Float64) where {T1<:agent}
root = MCTSNode("root", initialState, 0, "N/A", 0, 0, false, nothing, Dict{String, MCTSNode}())
root = MCTSNode("root", initialState, 0, 0, 0, false, nothing, Dict{String, MCTSNode}())
for nth in 1:maxIterations
node = root
@@ -468,10 +471,10 @@ function runMCTS(
# do nothing then go directly to backpropagation
backpropagate(leafNode, node.reward)
else
expand(a, node, decisionMaker, progressValueEstimator, n=n)
expand(a, node, decisionMaker, progressValueEstimator, reflector; n=n)
leafNode = UCTselect(node, w)
simTrajectoryReward = simulate(a, leafNode, decisionMaker, progressValueEstimator,
maxDepth, n=n)
reflector; maxDepth=maxDepth, n=n)
backpropagate(leafNode, simTrajectoryReward)
end
end