update
This commit is contained in:
856
src/mcts.jl
856
src/mcts.jl
@@ -1,429 +1,429 @@
|
||||
module mcts
|
||||
|
||||
export selectBestNextNode, selectBestTrajectoryNode, backpropagate, isleaf, isroot, selectChildNode,
|
||||
expand, simulate, makeNewState
|
||||
using Base.Threads
|
||||
using GeneralUtils
|
||||
|
||||
using ..type
|
||||
|
||||
# ---------------------------------------------- 100 --------------------------------------------- #
|
||||
|
||||
|
||||
"""
|
||||
|
||||
# Arguments
|
||||
- `node::MCTSNode`
|
||||
node of a search tree
|
||||
|
||||
# Return
|
||||
- `childNode::MCTSNode`
|
||||
the highest value child node
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function selectBestNextNode(node::MCTSNode)::MCTSNode
|
||||
highestProgressValue = -1
|
||||
nodekey = nothing
|
||||
|
||||
# if all childnode has statevalue == 0, use progressvalue + reward to select the best node
|
||||
stateValueSum = sum([v.statevalue for (k, v) in node.children])
|
||||
|
||||
if stateValueSum != 0
|
||||
for (k, childnode) in node.children
|
||||
potential = childnode.statevalue / childnode.visits
|
||||
|
||||
if potential > highestProgressValue
|
||||
highestProgressValue = potential
|
||||
nodekey = childnode.nodekey
|
||||
end
|
||||
end
|
||||
else
|
||||
for (k, childnode) in node.children
|
||||
potential = childnode.progressvalue + childnode.reward
|
||||
|
||||
if potential > highestProgressValue
|
||||
highestProgressValue = potential
|
||||
nodekey = childnode.nodekey
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return node.children[nodekey]
|
||||
end
|
||||
|
||||
|
||||
"""
|
||||
|
||||
# Arguments
|
||||
- `node::MCTSNode`
|
||||
node of a search tree
|
||||
|
||||
# Return
|
||||
- `childNode::MCTSNode`
|
||||
the highest value child node
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function selectBestTrajectoryNode(node::MCTSNode)::MCTSNode
|
||||
while !isleaf(node)
|
||||
node = selectBestNextNode(node)
|
||||
end
|
||||
|
||||
return node
|
||||
end
|
||||
|
||||
|
||||
""" Backpropagate reward along the simulation chain
|
||||
|
||||
# Arguments
|
||||
- `node::MCTSNode`
|
||||
leaf node of a search tree
|
||||
- `simTrajectoryReward::T`
|
||||
total reward from trajectory simulation
|
||||
- `discountRewardCoeff::AbstractFloat`
|
||||
A discount reward coefficient to reduce future reward. The futher in the future the lower
|
||||
reward it is now.
|
||||
|
||||
# Return
|
||||
- `None`
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function backpropagate(node::MCTSNode, simTrajectoryReward::T;
|
||||
discountRewardCoeff::AbstractFloat=0.9) where {T<:Number}
|
||||
while !isroot(node)
|
||||
# Update the statistics of the current node based on the result of the playout
|
||||
node.visits += 1
|
||||
node.statevalue += ((node.statevalue * (node.visits-1)) + simTrajectoryReward) / node.visits
|
||||
simTrajectoryReward *= discountRewardCoeff # discount because future reward is uncertain
|
||||
node = node.parent
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
""" Determine whether a node is a leaf node of a search tree.
|
||||
|
||||
# Arguments
|
||||
- `node::MCTSNode`
|
||||
a search tree node
|
||||
# Return
|
||||
- `result::Bool`
|
||||
true if it is a leaf node, false otherwise.
|
||||
# Example
|
||||
```jldoctest
|
||||
julia> using Revise
|
||||
julia> using YiemAgent, DataStructures
|
||||
julia> initialState = Dict{Symbol, Any}(
|
||||
:customerinfo=> Dict{Symbol, Any}(),
|
||||
:storeinfo=> Dict{Symbol, Any}(),
|
||||
|
||||
:thoughtHistory=> OrderedDict{Symbol, Any}(
|
||||
:question=> "How are you?",
|
||||
)
|
||||
)
|
||||
julia> statetype = typeof(initialState)
|
||||
julia> root = YiemAgent.MCTSNode(initialState, 0, 0.0, Dict{statetype, YiemAgent.MCTSNode}())
|
||||
julia> YiemAgent.isleaf(root)
|
||||
true
|
||||
```
|
||||
|
||||
# TODO
|
||||
[] update docs
|
||||
|
||||
# Signature
|
||||
"""
|
||||
isleaf(node::MCTSNode)::Bool = isempty(node.children)
|
||||
|
||||
|
||||
""" Determine wheter a given node is a root node
|
||||
|
||||
# Arguments
|
||||
- `node::MCTSNode`
|
||||
node of a search tree
|
||||
|
||||
# Return
|
||||
- `isrootnode::Bool`
|
||||
true if the given node is root node, false otherwise
|
||||
|
||||
# Signature
|
||||
"""
|
||||
isroot(node::MCTSNode)::Bool = node.nodekey == "root" ? true : false
|
||||
|
||||
|
||||
|
||||
""" Select child node based on the highest statevalue
|
||||
|
||||
# Arguments
|
||||
- `node::MCTSNode`
|
||||
node of a search tree
|
||||
|
||||
# Return
|
||||
- `childNode::MCTSNode`
|
||||
the highest value child node
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function selectChildNode(node::MCTSNode)::MCTSNode
|
||||
highestProgressValue = -1
|
||||
nodekey = nothing
|
||||
|
||||
# loop thought node children dictionary to find the highest progress value
|
||||
for (k, childNode) in node.children
|
||||
potential = childNode.progressvalue + childNode.reward
|
||||
if potential > highestProgressValue
|
||||
highestProgressValue = potential
|
||||
nodekey = childNode.nodekey
|
||||
end
|
||||
end
|
||||
|
||||
return node.children[nodekey]
|
||||
end
|
||||
|
||||
|
||||
""" Expand selected node.
|
||||
|
||||
# Arguments
|
||||
- `node::MCTSNode`
|
||||
MCTS node
|
||||
- `transition::Function`
|
||||
A function that handles state transition.
|
||||
- `transitionargs::NamedTuple`
|
||||
Arguments for transition()
|
||||
- `totalsample::Integer`
|
||||
Total number to sample from the current node (i.e. expand new node horizontally)
|
||||
|
||||
# Return
|
||||
- None
|
||||
|
||||
# Signature
|
||||
"""
|
||||
# function expand(node::MCTSNode, transition::Function, transitionargs::NamedTuple;
|
||||
# totalsample::Integer=3)
|
||||
|
||||
# # not use Any[] because I want to preserve result order
|
||||
# results = Vector{Any}(undef, totalsample)
|
||||
|
||||
# @sync for i in 1:totalsample
|
||||
# @spawn begin
|
||||
# result = transition(deepcopy(node.state), deepcopy(transitionargs))
|
||||
# results[i] = result
|
||||
# end
|
||||
# end
|
||||
|
||||
# for result in results
|
||||
# newNodeKey::AbstractString = result[:newNodeKey]
|
||||
# newstate::AbstractDict = result[:newstate]
|
||||
# progressvalue::Integer = result[:progressvalue]
|
||||
|
||||
# """
|
||||
# [] newNodeKey ∉ keys(node.children).
|
||||
# New state may have semantic vector close enought to
|
||||
# one of existing child state. Which can be assume that they are the same state
|
||||
# semantically-wise i.e. De javu. This could be used to recall lessons for this
|
||||
# similar situation to improve decisionMaker and evaluator.
|
||||
# """
|
||||
# if newNodeKey ∉ keys(node.children)
|
||||
# node.children[newNodeKey] =
|
||||
# MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward],
|
||||
# newstate[:isterminal], node, Dict{String, MCTSNode}())
|
||||
# end
|
||||
# end
|
||||
# end
|
||||
function expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple;
|
||||
totalsample::Integer=3)
|
||||
|
||||
nthSample = 0
|
||||
while true
|
||||
nthSample += 1
|
||||
if nthSample <= totalsample
|
||||
result = transition(node.state, transitionargs)
|
||||
newNodeKey::AbstractString = result[:newNodeKey]
|
||||
newstate::AbstractDict = result[:newstate]
|
||||
progressvalue::Integer = result[:progressvalue]
|
||||
|
||||
"""
|
||||
[] newNodeKey ∉ keys(node.children).
|
||||
New state may have semantic vector close enought to
|
||||
one of existing child state. Which can be assume that they are the same state
|
||||
semantically-wise i.e. De javu. This could be used to recall lessons for this
|
||||
similar situation to improve decisionMaker and evaluator.
|
||||
"""
|
||||
if newNodeKey ∉ keys(node.children)
|
||||
node.children[newNodeKey] =
|
||||
MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward],
|
||||
newstate[:isterminal], node, Dict{String, MCTSNode}())
|
||||
end
|
||||
else
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
""" Simulate interactions between agent and environment
|
||||
|
||||
# Arguments
|
||||
- `node::MCTSNode`
|
||||
node that will be a simulation starting point.
|
||||
- `transition::Function`
|
||||
A user function that handles how state transition.
|
||||
- `transitionargs::NamedTuple`
|
||||
Arguments for everything the user will use within transition().
|
||||
- `maxdepth::Integer`
|
||||
maximum depth level MCTS goes vertically.
|
||||
- totalsample::Integer
|
||||
Total number to sample from the current node (i.e. expand new node horizontally)
|
||||
|
||||
# Return
|
||||
- `::NamedTuple{(:simTrajectoryReward, :terminalstate), Tuple{Number, Union{Dict{Symbol, Any}, Nothing}}}`
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function simulate(node::MCTSNode, transition::Function, transitionargs::NamedTuple;
|
||||
maxdepth::Integer=3, totalsample::Integer=3
|
||||
)::NamedTuple{(:simTrajectoryReward, :terminalstate), Tuple{Number, Union{Dict{Symbol, Any}, Nothing}}}
|
||||
|
||||
simTrajectoryReward = 0.0
|
||||
terminalstate = nothing
|
||||
|
||||
for depth in 1:maxdepth
|
||||
simTrajectoryReward += node.reward
|
||||
if node.isterminal
|
||||
terminalstate = node.state
|
||||
break
|
||||
else
|
||||
expand(node, transition, transitionargs;
|
||||
totalsample=totalsample)
|
||||
node = selectChildNode(node)
|
||||
end
|
||||
end
|
||||
|
||||
return (simTrajectoryReward=simTrajectoryReward, terminalstate=terminalstate)
|
||||
end
|
||||
|
||||
|
||||
"""
|
||||
|
||||
# Arguments
|
||||
|
||||
# Return
|
||||
|
||||
# Example
|
||||
```jldoctest
|
||||
julia>
|
||||
```
|
||||
|
||||
# TODO
|
||||
- [] update docstring
|
||||
- [x] implement the function
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function makeNewState(currentstate::T1, thoughtDict::T4, response::T2, select::Union{T3, Nothing},
|
||||
reward::T3, isterminal::Bool
|
||||
)::Tuple{String, Dict{Symbol, <:Any}} where {T1<:AbstractDict, T2<:AbstractString, T3<:Number, T4<:AbstractDict}
|
||||
|
||||
currentstate_latestThoughtKey, currentstate_latestThoughtIndice =
|
||||
GeneralUtils.findHighestIndexKey(currentstate[:thoughtHistory], "thought")
|
||||
currentstate_nextIndice =
|
||||
currentstate_latestThoughtKey == :NA ? 1 : currentstate_latestThoughtIndice + 1
|
||||
currentstate_latestThoughtKey = Symbol("thought_$currentstate_nextIndice")
|
||||
latestActionKey = Symbol("action_$currentstate_nextIndice")
|
||||
|
||||
_, thoughtDict_latestThoughtIndice =
|
||||
GeneralUtils.findHighestIndexKey(thoughtDict, "thought")
|
||||
|
||||
thoughtDict_latestThoughtKey, thoughtDict_latestActionKey =
|
||||
if thoughtDict_latestThoughtIndice == -1
|
||||
(:thought, :action)
|
||||
else
|
||||
(
|
||||
Symbol("thought_$thoughtDict_latestThoughtIndice"),
|
||||
Symbol("action_$thoughtDict_latestThoughtIndice"),
|
||||
)
|
||||
end
|
||||
|
||||
# add Thought, action, observation to thoughtHistory
|
||||
newstate = deepcopy(currentstate)
|
||||
newstate[:thoughtHistory][currentstate_latestThoughtKey] =
|
||||
thoughtDict[thoughtDict_latestThoughtKey]
|
||||
newstate[:thoughtHistory][latestActionKey] = thoughtDict[thoughtDict_latestActionKey]
|
||||
newObservationKey = Symbol("observation_$(currentstate_nextIndice)")
|
||||
newstate[:thoughtHistory][newObservationKey] = response
|
||||
newstate[:reward] = reward
|
||||
newstate[:select] = select
|
||||
newstate[:isterminal] = isterminal
|
||||
|
||||
newNodeKey = GeneralUtils.uuid4snakecase()
|
||||
|
||||
return (newNodeKey, newstate)
|
||||
end
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
module mcts
|
||||
|
||||
export selectBestNextNode, selectBestTrajectoryNode, backpropagate, isleaf, isroot, selectChildNode,
|
||||
expand, simulate, makeNewState
|
||||
using Base.Threads
|
||||
using GeneralUtils
|
||||
|
||||
using ..type
|
||||
|
||||
# ---------------------------------------------- 100 --------------------------------------------- #
|
||||
|
||||
|
||||
"""
|
||||
|
||||
# Arguments
|
||||
- `node::MCTSNode`
|
||||
node of a search tree
|
||||
|
||||
# Return
|
||||
- `childNode::MCTSNode`
|
||||
the highest value child node
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function selectBestNextNode(node::MCTSNode)::MCTSNode
|
||||
highestProgressValue = -1
|
||||
nodekey = nothing
|
||||
|
||||
# if all childnode has statevalue == 0, use progressvalue + reward to select the best node
|
||||
stateValueSum = sum([v.statevalue for (k, v) in node.children])
|
||||
|
||||
if stateValueSum != 0
|
||||
for (k, childnode) in node.children
|
||||
potential = childnode.statevalue / childnode.visits
|
||||
|
||||
if potential > highestProgressValue
|
||||
highestProgressValue = potential
|
||||
nodekey = childnode.nodekey
|
||||
end
|
||||
end
|
||||
else
|
||||
for (k, childnode) in node.children
|
||||
potential = childnode.progressvalue + childnode.reward
|
||||
|
||||
if potential > highestProgressValue
|
||||
highestProgressValue = potential
|
||||
nodekey = childnode.nodekey
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return node.children[nodekey]
|
||||
end
|
||||
|
||||
|
||||
"""
|
||||
|
||||
# Arguments
|
||||
- `node::MCTSNode`
|
||||
node of a search tree
|
||||
|
||||
# Return
|
||||
- `childNode::MCTSNode`
|
||||
the highest value child node
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function selectBestTrajectoryNode(node::MCTSNode)::MCTSNode
|
||||
while !isleaf(node)
|
||||
node = selectBestNextNode(node)
|
||||
end
|
||||
|
||||
return node
|
||||
end
|
||||
|
||||
|
||||
""" Backpropagate reward along the simulation chain
|
||||
|
||||
# Arguments
|
||||
- `node::MCTSNode`
|
||||
leaf node of a search tree
|
||||
- `simTrajectoryReward::T`
|
||||
total reward from trajectory simulation
|
||||
- `discountRewardCoeff::AbstractFloat`
|
||||
A discount reward coefficient to reduce future reward. The futher in the future the lower
|
||||
reward it is now.
|
||||
|
||||
# Return
|
||||
- `None`
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function backpropagate(node::MCTSNode, simTrajectoryReward::T;
|
||||
discountRewardCoeff::AbstractFloat=0.9) where {T<:Number}
|
||||
while !isroot(node)
|
||||
# Update the statistics of the current node based on the result of the playout
|
||||
node.visits += 1
|
||||
node.statevalue += ((node.statevalue * (node.visits-1)) + simTrajectoryReward) / node.visits
|
||||
simTrajectoryReward *= discountRewardCoeff # discount because future reward is uncertain
|
||||
node = node.parent
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
""" Determine whether a node is a leaf node of a search tree.
|
||||
|
||||
# Arguments
|
||||
- `node::MCTSNode`
|
||||
a search tree node
|
||||
# Return
|
||||
- `result::Bool`
|
||||
true if it is a leaf node, false otherwise.
|
||||
# Example
|
||||
```jldoctest
|
||||
julia> using Revise
|
||||
julia> using YiemAgent, DataStructures
|
||||
julia> initialState = Dict{Symbol, Any}(
|
||||
:customerinfo=> Dict{Symbol, Any}(),
|
||||
:storeinfo=> Dict{Symbol, Any}(),
|
||||
|
||||
:thoughtHistory=> OrderedDict{Symbol, Any}(
|
||||
:question=> "How are you?",
|
||||
)
|
||||
)
|
||||
julia> statetype = typeof(initialState)
|
||||
julia> root = YiemAgent.MCTSNode(initialState, 0, 0.0, Dict{statetype, YiemAgent.MCTSNode}())
|
||||
julia> YiemAgent.isleaf(root)
|
||||
true
|
||||
```
|
||||
|
||||
# TODO
|
||||
[] update docs
|
||||
|
||||
# Signature
|
||||
"""
|
||||
isleaf(node::MCTSNode)::Bool = isempty(node.children)
|
||||
|
||||
|
||||
""" Determine wheter a given node is a root node
|
||||
|
||||
# Arguments
|
||||
- `node::MCTSNode`
|
||||
node of a search tree
|
||||
|
||||
# Return
|
||||
- `isrootnode::Bool`
|
||||
true if the given node is root node, false otherwise
|
||||
|
||||
# Signature
|
||||
"""
|
||||
isroot(node::MCTSNode)::Bool = node.nodekey == "root" ? true : false
|
||||
|
||||
|
||||
|
||||
""" Select child node based on the highest statevalue
|
||||
|
||||
# Arguments
|
||||
- `node::MCTSNode`
|
||||
node of a search tree
|
||||
|
||||
# Return
|
||||
- `childNode::MCTSNode`
|
||||
the highest value child node
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function selectChildNode(node::MCTSNode)::MCTSNode
|
||||
highestProgressValue = -1
|
||||
nodekey = nothing
|
||||
|
||||
# loop thought node children dictionary to find the highest progress value
|
||||
for (k, childNode) in node.children
|
||||
potential = childNode.progressvalue + childNode.reward
|
||||
if potential > highestProgressValue
|
||||
highestProgressValue = potential
|
||||
nodekey = childNode.nodekey
|
||||
end
|
||||
end
|
||||
|
||||
return node.children[nodekey]
|
||||
end
|
||||
|
||||
|
||||
""" Expand selected node.
|
||||
|
||||
# Arguments
|
||||
- `node::MCTSNode`
|
||||
MCTS node
|
||||
- `transition::Function`
|
||||
A function that handles state transition.
|
||||
- `transitionargs::NamedTuple`
|
||||
Arguments for transition()
|
||||
- `totalsample::Integer`
|
||||
Total number to sample from the current node (i.e. expand new node horizontally)
|
||||
|
||||
# Return
|
||||
- None
|
||||
|
||||
# Signature
|
||||
"""
|
||||
# function expand(node::MCTSNode, transition::Function, transitionargs::NamedTuple;
|
||||
# totalsample::Integer=3)
|
||||
|
||||
# # not use Any[] because I want to preserve result order
|
||||
# results = Vector{Any}(undef, totalsample)
|
||||
|
||||
# @sync for i in 1:totalsample
|
||||
# @spawn begin
|
||||
# result = transition(deepcopy(node.state), deepcopy(transitionargs))
|
||||
# results[i] = result
|
||||
# end
|
||||
# end
|
||||
|
||||
# for result in results
|
||||
# newNodeKey::AbstractString = result[:newNodeKey]
|
||||
# newstate::AbstractDict = result[:newstate]
|
||||
# progressvalue::Integer = result[:progressvalue]
|
||||
|
||||
# """
|
||||
# [] newNodeKey ∉ keys(node.children).
|
||||
# New state may have semantic vector close enought to
|
||||
# one of existing child state. Which can be assume that they are the same state
|
||||
# semantically-wise i.e. De javu. This could be used to recall lessons for this
|
||||
# similar situation to improve decisionMaker and evaluator.
|
||||
# """
|
||||
# if newNodeKey ∉ keys(node.children)
|
||||
# node.children[newNodeKey] =
|
||||
# MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward],
|
||||
# newstate[:isterminal], node, Dict{String, MCTSNode}())
|
||||
# end
|
||||
# end
|
||||
# end
|
||||
function expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple;
|
||||
totalsample::Integer=3)
|
||||
|
||||
nthSample = 0
|
||||
while true
|
||||
nthSample += 1
|
||||
if nthSample <= totalsample
|
||||
result = transition(node.state, transitionargs)
|
||||
newNodeKey::AbstractString = result[:newNodeKey]
|
||||
newstate::AbstractDict = result[:newstate]
|
||||
progressvalue::Integer = result[:progressvalue]
|
||||
|
||||
"""
|
||||
[] newNodeKey ∉ keys(node.children).
|
||||
New state may have semantic vector close enought to
|
||||
one of existing child state. Which can be assume that they are the same state
|
||||
semantically-wise i.e. De javu. This could be used to recall lessons for this
|
||||
similar situation to improve decisionMaker and evaluator.
|
||||
"""
|
||||
if newNodeKey ∉ keys(node.children)
|
||||
node.children[newNodeKey] =
|
||||
MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward],
|
||||
newstate[:isterminal], node, Dict{String, MCTSNode}())
|
||||
end
|
||||
else
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
""" Simulate interactions between agent and environment
|
||||
|
||||
# Arguments
|
||||
- `node::MCTSNode`
|
||||
node that will be a simulation starting point.
|
||||
- `transition::Function`
|
||||
A user function that handles how state transition.
|
||||
- `transitionargs::NamedTuple`
|
||||
Arguments for everything the user will use within transition().
|
||||
- `maxdepth::Integer`
|
||||
maximum depth level MCTS goes vertically.
|
||||
- totalsample::Integer
|
||||
Total number to sample from the current node (i.e. expand new node horizontally)
|
||||
|
||||
# Return
|
||||
- `::NamedTuple{(:simTrajectoryReward, :terminalstate), Tuple{Number, Union{Dict{Symbol, Any}, Nothing}}}`
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function simulate(node::MCTSNode, transition::Function, transitionargs::NamedTuple;
|
||||
maxdepth::Integer=3, totalsample::Integer=3
|
||||
)::NamedTuple{(:simTrajectoryReward, :terminalstate), Tuple{Number, Union{Dict{Symbol, Any}, Nothing}}}
|
||||
|
||||
simTrajectoryReward = 0.0
|
||||
terminalstate = nothing
|
||||
|
||||
for depth in 1:maxdepth
|
||||
simTrajectoryReward += node.reward
|
||||
if node.isterminal
|
||||
terminalstate = node.state
|
||||
break
|
||||
else
|
||||
expand(node, transition, transitionargs;
|
||||
totalsample=totalsample)
|
||||
node = selectChildNode(node)
|
||||
end
|
||||
end
|
||||
|
||||
return (simTrajectoryReward=simTrajectoryReward, terminalstate=terminalstate)
|
||||
end
|
||||
|
||||
|
||||
"""
|
||||
|
||||
# Arguments
|
||||
|
||||
# Return
|
||||
|
||||
# Example
|
||||
```jldoctest
|
||||
julia>
|
||||
```
|
||||
|
||||
# TODO
|
||||
- [] update docstring
|
||||
- [x] implement the function
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function makeNewState(currentstate::T1, thoughtDict::T4, response::T2, select::Union{T3, Nothing},
|
||||
reward::T3, isterminal::Bool
|
||||
)::Tuple{String, Dict{Symbol, <:Any}} where {T1<:AbstractDict, T2<:AbstractString, T3<:Number, T4<:AbstractDict}
|
||||
|
||||
currentstate_latestThoughtKey, currentstate_latestThoughtIndice =
|
||||
GeneralUtils.findHighestIndexKey(currentstate[:thoughtHistory], "thought")
|
||||
currentstate_nextIndice =
|
||||
currentstate_latestThoughtKey == :NA ? 1 : currentstate_latestThoughtIndice + 1
|
||||
currentstate_latestThoughtKey = Symbol("thought_$currentstate_nextIndice")
|
||||
latestActionKey = Symbol("action_$currentstate_nextIndice")
|
||||
|
||||
_, thoughtDict_latestThoughtIndice =
|
||||
GeneralUtils.findHighestIndexKey(thoughtDict, "thought")
|
||||
|
||||
thoughtDict_latestThoughtKey, thoughtDict_latestActionKey =
|
||||
if thoughtDict_latestThoughtIndice == -1
|
||||
(:thought, :action)
|
||||
else
|
||||
(
|
||||
Symbol("thought_$thoughtDict_latestThoughtIndice"),
|
||||
Symbol("action_$thoughtDict_latestThoughtIndice"),
|
||||
)
|
||||
end
|
||||
|
||||
# add Thought, action, observation to thoughtHistory
|
||||
newstate = deepcopy(currentstate)
|
||||
newstate[:thoughtHistory][currentstate_latestThoughtKey] =
|
||||
thoughtDict[thoughtDict_latestThoughtKey]
|
||||
newstate[:thoughtHistory][latestActionKey] = thoughtDict[thoughtDict_latestActionKey]
|
||||
newObservationKey = Symbol("observation_$(currentstate_nextIndice)")
|
||||
newstate[:thoughtHistory][newObservationKey] = response
|
||||
newstate[:reward] = reward
|
||||
newstate[:select] = select
|
||||
newstate[:isterminal] = isterminal
|
||||
|
||||
newNodeKey = GeneralUtils.uuid4snakecase()
|
||||
|
||||
return (newNodeKey, newstate)
|
||||
end
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
end # module mcts
|
||||
Reference in New Issue
Block a user