433 lines
10 KiB
Julia
433 lines
10 KiB
Julia
module mcts
|
|
|
|
export selectBestNextNode, selectBestTrajectoryNode, backpropagate, isleaf, isroot, selectChildNode,
|
|
expand, simulate, makeNewState
|
|
using Base.Threads
|
|
using GeneralUtils
|
|
|
|
using ..type
|
|
|
|
# ---------------------------------------------- 100 --------------------------------------------- #
|
|
|
|
|
|
"""
|
|
|
|
# Arguments
|
|
- `node::MCTSNode`
|
|
node of a search tree
|
|
|
|
# Return
|
|
- `childNode::MCTSNode`
|
|
the highest value child node
|
|
|
|
# Signature
|
|
"""
|
|
function selectBestNextNode(node::MCTSNode)::MCTSNode
|
|
highestProgressValue = -1
|
|
nodekey = nothing
|
|
|
|
# if all childnode has statevalue == 0, use progressvalue + reward to select the best node
|
|
stateValueSum = sum([v.statevalue for (k, v) in node.children])
|
|
|
|
if stateValueSum != 0
|
|
for (k, childnode) in node.children
|
|
potential = childnode.statevalue / childnode.visits
|
|
|
|
if potential > highestProgressValue
|
|
highestProgressValue = potential
|
|
nodekey = childnode.nodekey
|
|
end
|
|
end
|
|
else
|
|
for (k, childnode) in node.children
|
|
potential = childnode.progressvalue + childnode.reward
|
|
|
|
if potential > highestProgressValue
|
|
highestProgressValue = potential
|
|
nodekey = childnode.nodekey
|
|
end
|
|
end
|
|
end
|
|
|
|
return node.children[nodekey]
|
|
end
|
|
|
|
|
|
"""
|
|
|
|
# Arguments
|
|
- `node::MCTSNode`
|
|
node of a search tree
|
|
|
|
# Return
|
|
- `childNode::MCTSNode`
|
|
the highest value child node
|
|
|
|
# Signature
|
|
"""
|
|
function selectBestTrajectoryNode(node::MCTSNode)::MCTSNode
|
|
while !isleaf(node)
|
|
node = selectBestNextNode(node)
|
|
end
|
|
|
|
return node
|
|
end
|
|
|
|
|
|
""" Backpropagate reward along the simulation chain
|
|
|
|
# Arguments
|
|
- `node::MCTSNode`
|
|
leaf node of a search tree
|
|
- `simTrajectoryReward::T`
|
|
total reward from trajectory simulation
|
|
- `discountRewardCoeff::AbstractFloat`
|
|
A discount reward coefficient to reduce future reward. The futher in the future the lower
|
|
reward it is now.
|
|
|
|
# Return
|
|
- `None`
|
|
|
|
# Signature
|
|
"""
|
|
function backpropagate(node::MCTSNode, simTrajectoryReward::T;
|
|
discountRewardCoeff::AbstractFloat=0.9) where {T<:Number}
|
|
while !isroot(node)
|
|
# Update the statistics of the current node based on the result of the playout
|
|
node.visits += 1
|
|
node.statevalue += ((node.statevalue * (node.visits-1)) + simTrajectoryReward) / node.visits
|
|
simTrajectoryReward *= discountRewardCoeff # discount because future reward is uncertain
|
|
node = node.parent
|
|
end
|
|
end
|
|
|
|
|
|
""" Determine whether a node is a leaf node of a search tree.
|
|
|
|
# Arguments
|
|
- `node::MCTSNode`
|
|
a search tree node
|
|
# Return
|
|
- `result::Bool`
|
|
true if it is a leaf node, false otherwise.
|
|
# Example
|
|
```jldoctest
|
|
julia> using Revise
|
|
julia> using YiemAgent, DataStructures
|
|
julia> initialState = Dict{Symbol, Any}(
|
|
:customerinfo=> Dict{Symbol, Any}(),
|
|
:storeinfo=> Dict{Symbol, Any}(),
|
|
|
|
:thoughtHistory=> OrderedDict{Symbol, Any}(
|
|
:question=> "How are you?",
|
|
)
|
|
)
|
|
julia> statetype = typeof(initialState)
|
|
julia> root = YiemAgent.MCTSNode(initialState, 0, 0.0, Dict{statetype, YiemAgent.MCTSNode}())
|
|
julia> YiemAgent.isleaf(root)
|
|
true
|
|
```
|
|
|
|
# TODO
|
|
[] update docs
|
|
|
|
# Signature
|
|
"""
|
|
isleaf(node::MCTSNode)::Bool = isempty(node.children)
|
|
|
|
|
|
""" Determine wheter a given node is a root node
|
|
|
|
# Arguments
|
|
- `node::MCTSNode`
|
|
node of a search tree
|
|
|
|
# Return
|
|
- `isrootnode::Bool`
|
|
true if the given node is root node, false otherwise
|
|
|
|
# Signature
|
|
"""
|
|
isroot(node::MCTSNode)::Bool = node.nodekey == "root" ? true : false
|
|
|
|
|
|
|
|
""" Select child node based on the highest statevalue
|
|
|
|
# Arguments
|
|
- `node::MCTSNode`
|
|
node of a search tree
|
|
|
|
# Return
|
|
- `childNode::MCTSNode`
|
|
the highest value child node
|
|
|
|
# Signature
|
|
"""
|
|
function selectChildNode(node::MCTSNode)::MCTSNode
|
|
highestProgressValue = -1
|
|
nodekey = nothing
|
|
|
|
# loop thought node children dictionary to find the highest progress value
|
|
for (k, childNode) in node.children
|
|
potential = childNode.progressvalue + childNode.reward
|
|
if potential > highestProgressValue
|
|
highestProgressValue = potential
|
|
nodekey = childNode.nodekey
|
|
end
|
|
end
|
|
|
|
return node.children[nodekey]
|
|
end
|
|
|
|
|
|
""" Expand selected node.
|
|
|
|
# Arguments
|
|
- `node::MCTSNode`
|
|
MCTS node
|
|
- `transition::Function`
|
|
A function that handles state transition.
|
|
- `transitionargs::NamedTuple`
|
|
Arguments for transition()
|
|
- `totalsample::Integer`
|
|
Total number to sample from the current node (i.e. expand new node horizontally)
|
|
|
|
# Return
|
|
- None
|
|
|
|
# Signature
|
|
"""
|
|
# function expand(node::MCTSNode, transition::Function, transitionargs::NamedTuple;
|
|
# totalsample::Integer=3)
|
|
|
|
# # not use Any[] because I want to preserve result order
|
|
# results = Vector{Any}(undef, totalsample)
|
|
|
|
# @sync for i in 1:totalsample
|
|
# @spawn begin
|
|
# result = transition(deepcopy(node.state), deepcopy(transitionargs))
|
|
# results[i] = result
|
|
# end
|
|
# end
|
|
|
|
# for result in results
|
|
# newNodeKey::AbstractString = result[:newNodeKey]
|
|
# newstate::AbstractDict = result[:newstate]
|
|
# progressvalue::Integer = result[:progressvalue]
|
|
|
|
# """
|
|
# [] newNodeKey ∉ keys(node.children).
|
|
# New state may have semantic vector close enought to
|
|
# one of existing child state. Which can be assume that they are the same state
|
|
# semantically-wise i.e. De javu. This could be used to recall lessons for this
|
|
# similar situation to improve decisionMaker and evaluator.
|
|
# """
|
|
# if newNodeKey ∉ keys(node.children)
|
|
# node.children[newNodeKey] =
|
|
# MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward],
|
|
# newstate[:isterminal], node, Dict{String, MCTSNode}())
|
|
# end
|
|
# end
|
|
# end
|
|
function expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple;
|
|
horizontalSample::Integer=3)
|
|
|
|
nthSample = 0
|
|
listOfNewNodeId = []
|
|
while true
|
|
nthSample += 1
|
|
if nthSample <= horizontalSample
|
|
result = transition(node.state, transitionargs)
|
|
newNodeKey::AbstractString = result[:newNodeKey]
|
|
newstate::AbstractDict = result[:newstate]
|
|
progressvalue::Integer = result[:progressvalue]
|
|
|
|
"""
|
|
[] newNodeKey ∉ keys(node.children).
|
|
New state may have semantic vector close enought to
|
|
one of existing child state. Which can be assume that they are the same state
|
|
semantically-wise i.e. De javu. This could be used to recall lessons for this
|
|
similar situation to improve decisionMaker and evaluator.
|
|
"""
|
|
if newNodeKey ∉ keys(node.children)
|
|
push!(listOfNewNodeId, newNodeKey)
|
|
newNode = MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward],
|
|
newstate[:isterminal], node, Dict{String, MCTSNode}(), Dict{Symbol, Any}())
|
|
node.children[newNodeKey] = newNode
|
|
end
|
|
else
|
|
return listOfNewNodeId
|
|
end
|
|
end
|
|
end
|
|
|
|
|
|
""" Simulate interactions between agent and environment
|
|
|
|
# Arguments
|
|
- `node::MCTSNode`
|
|
node that will be a simulation starting point.
|
|
- `transition::Function`
|
|
A user function that handles how state transition.
|
|
- `transitionargs::NamedTuple`
|
|
Arguments for everything the user will use within transition().
|
|
- `maxdepth::Integer`
|
|
maximum depth level MCTS goes vertically.
|
|
- horizontalSample::Integer
|
|
Total number to sample from the current node (i.e. expand new node horizontally)
|
|
|
|
# Return
|
|
- `::NamedTuple{(:simTrajectoryReward, :terminalstate), Tuple{Number, Union{Dict{Symbol, Any}, Nothing}}}`
|
|
|
|
# Signature
|
|
"""
|
|
function simulate(outputchannel::Channel, node::MCTSNode, transition::Function, transitionargs::NamedTuple;
|
|
maxdepth::Integer=3, horizontalSample::Integer=3
|
|
)::NamedTuple{(:simTrajectoryReward, :terminalstate), Tuple{Number, Union{Dict{Symbol, Any}, Nothing}}}
|
|
|
|
simTrajectoryReward = 0.0
|
|
terminalstate = nothing
|
|
# listOfSimulatedNodeId = []
|
|
|
|
for depth in 1:maxdepth
|
|
simTrajectoryReward += node.reward
|
|
if node.isterminal
|
|
terminalstate = node.state
|
|
break
|
|
else
|
|
_ = expand(node, transition, transitionargs;
|
|
horizontalSample=horizontalSample)
|
|
node = selectChildNode(node)
|
|
end
|
|
end
|
|
|
|
put!(outputchannel, (simTrajectoryReward=simTrajectoryReward, terminalstate=terminalstate))
|
|
# return (simTrajectoryReward=simTrajectoryReward, terminalstate=terminalstate)
|
|
end
|
|
|
|
|
|
"""
|
|
|
|
# Arguments
|
|
|
|
# Return
|
|
|
|
# Example
|
|
```jldoctest
|
|
julia>
|
|
```
|
|
|
|
# TODO
|
|
- [] update docstring
|
|
- [x] implement the function
|
|
|
|
# Signature
|
|
"""
|
|
function makeNewState(currentstate::T1, thoughtDict::T4, response::T2, select::Union{T3, Nothing},
|
|
reward::T3, isterminal::Bool
|
|
)::Tuple{String, Dict{Symbol, <:Any}} where {T1<:AbstractDict, T2<:AbstractString, T3<:Number, T4<:AbstractDict}
|
|
|
|
currentstate_latestThoughtKey, currentstate_latestThoughtIndice =
|
|
GeneralUtils.findHighestIndexKey(currentstate[:thoughtHistory], "thought")
|
|
currentstate_nextIndice =
|
|
currentstate_latestThoughtKey == :NA ? 1 : currentstate_latestThoughtIndice + 1
|
|
currentstate_latestThoughtKey = Symbol("thought_$currentstate_nextIndice")
|
|
latestActionKey = Symbol("action_$currentstate_nextIndice")
|
|
|
|
_, thoughtDict_latestThoughtIndice =
|
|
GeneralUtils.findHighestIndexKey(thoughtDict, "thought")
|
|
|
|
thoughtDict_latestThoughtKey, thoughtDict_latestActionKey =
|
|
if thoughtDict_latestThoughtIndice == -1
|
|
(:thought, :action)
|
|
else
|
|
(
|
|
Symbol("thought_$thoughtDict_latestThoughtIndice"),
|
|
Symbol("action_$thoughtDict_latestThoughtIndice"),
|
|
)
|
|
end
|
|
|
|
# add Thought, action, observation to thoughtHistory
|
|
newstate = deepcopy(currentstate)
|
|
newstate[:thoughtHistory][currentstate_latestThoughtKey] =
|
|
thoughtDict[thoughtDict_latestThoughtKey]
|
|
newstate[:thoughtHistory][latestActionKey] = thoughtDict[thoughtDict_latestActionKey]
|
|
newObservationKey = Symbol("observation_$(currentstate_nextIndice)")
|
|
newstate[:thoughtHistory][newObservationKey] = response
|
|
newstate[:reward] = reward
|
|
newstate[:select] = select
|
|
newstate[:isterminal] = isterminal
|
|
|
|
newNodeKey = GeneralUtils.uuid4snakecase()
|
|
|
|
return (newNodeKey, newstate)
|
|
end
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
end # module mcts |