This commit is contained in:
narawat lamaiin
2024-06-01 00:37:20 +07:00
parent 452262d3d6
commit 74a4190440
3 changed files with 71 additions and 7 deletions

View File

@@ -1,7 +1,7 @@
module mcts
export selectBestNextState, selectBestTrajectory, backpropagate, isleaf, isroot, selectChildNode,
expand, mctstransition
expand, simulate, mctstransition
using ..type
@@ -250,7 +250,7 @@ function expand(workDict::T1, node::MCTSNode, decisionMaker::Function, evaluator
while true
nthSample += 1
if nthSample <= totalsample
thoughtDict = decisionMaker(a, node.state)
thoughtDict = decisionMaker(workDict, node.state)
println("---> expand() sample $nthSample")
pprintln(node.state[:thoughtHistory])
pprintln(thoughtDict)
@@ -261,7 +261,7 @@ function expand(workDict::T1, node::MCTSNode, decisionMaker::Function, evaluator
if newstate[:reward] < 0
pprint(newstate[:thoughtHistory])
newstate[:evaluation] = stateevaluation
newstate[:lesson] = reflector(a, newstate)
newstate[:lesson] = reflector(workDict, newstate)
# store new lesson for later use
lessonDict = copy(JSON3.read("lesson.json"))
@@ -288,6 +288,52 @@ function expand(workDict::T1, node::MCTSNode, decisionMaker::Function, evaluator
end
""" Simulate interactions between agent and environment
# Arguments
- `a::T`
one of YiemAgent's agent
- `node::MCTSNode`
node that will be a simulation starting point.
- `decisionMaker::Function`
function that receive state return Thought and Action
# Return
- `simTrajectoryReward::Number`
# Example
```jldoctest
julia>
```
# TODO
- [] update docs
# Signature
"""
function simulate(workDict::T, node::MCTSNode, decisionMaker::Function, evaluator::Function,
reflector::Function; maxDepth::Integer=3, totalsample::Integer=3
)::Union{Tuple{Number, Dict{Symbol, <:Any}}, Tuple{Number, Nothing}} where {T<:AbstractDict}
simTrajectoryReward = 0.0
terminalstate = nothing
for depth in 1:maxDepth
simTrajectoryReward += node.reward
if node.isterminal
terminalstate = node.state
break
else
expand(workDict, node, decisionMaker, evaluator, reflector; totalsample=totalsample)
node = selectChildNode(node)
end
end
return (simTrajectoryReward, terminalstate)
end
""" Get a new state
# Arguments
@@ -329,7 +375,7 @@ julia> thoughtDict = Dict(
"""
function mctstransition(workDict::T1, transition::Function, state::T2, thoughtDict::T2
)::Tuple{String, Dict{Symbol, <:Any}} where {T1<:AbstractDict, T2<:AbstractDict}
error("--> mctstransition")
# actionname = thoughtDict[:action][:name]
# actioninput = thoughtDict[:action][:input]