This commit is contained in:
narawat lamaiin
2024-05-10 11:59:38 +07:00
parent c7fd7bc40d
commit a4ba292fad
4 changed files with 107 additions and 44 deletions

View File

@@ -136,31 +136,84 @@ function expand(a::T1, node::MCTSNode, decisionMaker::Function,
progressValueEstimator::Function, reflector::Function; n::Integer=3) where {T1<:agent}
nthSample = 0
while nthSample < n
thoughtDict = decisionMaker(a, node.state)
newNodeKey, newstate, reward, isterminalstate =
MCTStransition(a, node.state, thoughtDict)
# add progressValueEstimator
stateevaluation, statevalue = progressValueEstimator(a, newstate)
if reward < 0
pprint(newstate[:thoughtHistory])
newstate[:evaluation] = stateevaluation
newstate[:lesson] = reflector(a, newstate)
a.lesson[:lesson_1] = deepcopy(newstate)
print("---> reflector()")
end
if newNodeKey keys(node.children)
node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, statevalue,
reward, isterminalstate, node, Dict{String, MCTSNode}())
end
while true
nthSample += 1
if nthSample <= n
println("---> expand() sample $nthSample")
thoughtDict = decisionMaker(a, node.state)
newNodeKey, newstate, reward, isterminalstate =
MCTStransition(a, node.state, thoughtDict)
# add progressValueEstimator
stateevaluation, statevalue = progressValueEstimator(a, newstate)
if reward < 0
pprint(newstate[:thoughtHistory])
newstate[:evaluation] = stateevaluation
newstate[:lesson] = reflector(a, newstate)
# store new lesson for later use
lessonDict = copy(JSON3.read("lesson.json"))
latestLessonKey, latestLessonIndice =
GeneralUtils.findHighestIndexKey(state[:thoughtHistory], "lesson")
nextIndice = latestLessonKey == :NA ? 1 : latestLessonIndice + 1
newLessonKey = Symbol("lesson_$(nextIndice)")
lessonDict[newLessonKey] = newstate
open("lesson.json", "w") do io
JSON3.pretty(io, lessonDict)
end
print("---> reflector()")
end
if newNodeKey keys(node.children)
node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, statevalue,
reward, isterminalstate, node, Dict{String, MCTSNode}())
end
else
break
end
end
end
# function expand(a::T1, node::MCTSNode, decisionMaker::Function,
# progressValueEstimator::Function, reflector::Function; n::Integer=3) where {T1<:agent}
# nthSample = 0
# while nthSample <= n
# nthSample += 1
# println("---> expand() sample $nthSample")
# thoughtDict = decisionMaker(a, node.state)
# newNodeKey, newstate, reward, isterminalstate =
# MCTStransition(a, node.state, thoughtDict)
# # add progressValueEstimator
# stateevaluation, statevalue = progressValueEstimator(a, newstate)
# if reward < 0
# pprint(newstate[:thoughtHistory])
# newstate[:evaluation] = stateevaluation
# newstate[:lesson] = reflector(a, newstate)
# # store new lesson for later use
# lessonDict = copy(JSON3.read("lesson.json"))
# latestLessonKey, latestLessonIndice =
# GeneralUtils.findHighestIndexKey(state[:thoughtHistory], "lesson")
# nextIndice = latestLessonKey == :NA ? 1 : latestLessonIndice + 1
# newLessonKey = Symbol("lesson_$(nextIndice)")
# lessonDict[newLessonKey] = newstate
# open("lesson.json", "w") do io
# JSON3.pretty(io, lessonDict)
# end
# print("---> reflector()")
# end
# if newNodeKey ∉ keys(node.children)
# node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, statevalue,
# reward, isterminalstate, node, Dict{String, MCTSNode}())
# end
# end
# end
@@ -299,8 +352,8 @@ function MCTStransition(a::T1, state::T2, thoughtDict::T3
newstate = deepcopy(state)
newstate[:thoughtHistory][latestThoughtKey] = thoughtDict[:thought]
newstate[:thoughtHistory][latestActionKey] = thoughtDict[:action]
latestObservationKey = Symbol("observation_$(nextIndice)")
newstate[:thoughtHistory][latestObservationKey] = response
newObservationKey = Symbol("observation_$(nextIndice)")
newstate[:thoughtHistory][newObservationKey] = response
newstate[:reward] = reward
newstate[:select] = select
newstate[:isterminal] = isterminal