This commit is contained in:
narawat lamaiin
2024-12-09 20:28:02 +07:00
parent 3338085567
commit cae94e5690
15 changed files with 2942 additions and 2942 deletions

View File

@@ -1,429 +1,429 @@
module mcts
export selectBestNextNode, selectBestTrajectoryNode, backpropagate, isleaf, isroot, selectChildNode,
expand, simulate, makeNewState
using Base.Threads
using GeneralUtils
using ..type
# ---------------------------------------------- 100 --------------------------------------------- #
"""
# Arguments
- `node::MCTSNode`
node of a search tree
# Return
- `childNode::MCTSNode`
the highest value child node
# Signature
"""
function selectBestNextNode(node::MCTSNode)::MCTSNode
highestProgressValue = -1
nodekey = nothing
# if all childnode has statevalue == 0, use progressvalue + reward to select the best node
stateValueSum = sum([v.statevalue for (k, v) in node.children])
if stateValueSum != 0
for (k, childnode) in node.children
potential = childnode.statevalue / childnode.visits
if potential > highestProgressValue
highestProgressValue = potential
nodekey = childnode.nodekey
end
end
else
for (k, childnode) in node.children
potential = childnode.progressvalue + childnode.reward
if potential > highestProgressValue
highestProgressValue = potential
nodekey = childnode.nodekey
end
end
end
return node.children[nodekey]
end
"""
# Arguments
- `node::MCTSNode`
node of a search tree
# Return
- `childNode::MCTSNode`
the highest value child node
# Signature
"""
function selectBestTrajectoryNode(node::MCTSNode)::MCTSNode
while !isleaf(node)
node = selectBestNextNode(node)
end
return node
end
""" Backpropagate reward along the simulation chain
# Arguments
- `node::MCTSNode`
leaf node of a search tree
- `simTrajectoryReward::T`
total reward from trajectory simulation
- `discountRewardCoeff::AbstractFloat`
A discount reward coefficient to reduce future reward. The futher in the future the lower
reward it is now.
# Return
- `None`
# Signature
"""
function backpropagate(node::MCTSNode, simTrajectoryReward::T;
discountRewardCoeff::AbstractFloat=0.9) where {T<:Number}
while !isroot(node)
# Update the statistics of the current node based on the result of the playout
node.visits += 1
node.statevalue += ((node.statevalue * (node.visits-1)) + simTrajectoryReward) / node.visits
simTrajectoryReward *= discountRewardCoeff # discount because future reward is uncertain
node = node.parent
end
end
""" Determine whether a node is a leaf node of a search tree.
# Arguments
- `node::MCTSNode`
a search tree node
# Return
- `result::Bool`
true if it is a leaf node, false otherwise.
# Example
```jldoctest
julia> using Revise
julia> using YiemAgent, DataStructures
julia> initialState = Dict{Symbol, Any}(
:customerinfo=> Dict{Symbol, Any}(),
:storeinfo=> Dict{Symbol, Any}(),
:thoughtHistory=> OrderedDict{Symbol, Any}(
:question=> "How are you?",
)
)
julia> statetype = typeof(initialState)
julia> root = YiemAgent.MCTSNode(initialState, 0, 0.0, Dict{statetype, YiemAgent.MCTSNode}())
julia> YiemAgent.isleaf(root)
true
```
# TODO
[] update docs
# Signature
"""
isleaf(node::MCTSNode)::Bool = isempty(node.children)
""" Determine wheter a given node is a root node
# Arguments
- `node::MCTSNode`
node of a search tree
# Return
- `isrootnode::Bool`
true if the given node is root node, false otherwise
# Signature
"""
isroot(node::MCTSNode)::Bool = node.nodekey == "root" ? true : false
""" Select child node based on the highest statevalue
# Arguments
- `node::MCTSNode`
node of a search tree
# Return
- `childNode::MCTSNode`
the highest value child node
# Signature
"""
function selectChildNode(node::MCTSNode)::MCTSNode
highestProgressValue = -1
nodekey = nothing
# loop thought node children dictionary to find the highest progress value
for (k, childNode) in node.children
potential = childNode.progressvalue + childNode.reward
if potential > highestProgressValue
highestProgressValue = potential
nodekey = childNode.nodekey
end
end
return node.children[nodekey]
end
""" Expand selected node.
# Arguments
- `node::MCTSNode`
MCTS node
- `transition::Function`
A function that handles state transition.
- `transitionargs::NamedTuple`
Arguments for transition()
- `totalsample::Integer`
Total number to sample from the current node (i.e. expand new node horizontally)
# Return
- None
# Signature
"""
# function expand(node::MCTSNode, transition::Function, transitionargs::NamedTuple;
# totalsample::Integer=3)
# # not use Any[] because I want to preserve result order
# results = Vector{Any}(undef, totalsample)
# @sync for i in 1:totalsample
# @spawn begin
# result = transition(deepcopy(node.state), deepcopy(transitionargs))
# results[i] = result
# end
# end
# for result in results
# newNodeKey::AbstractString = result[:newNodeKey]
# newstate::AbstractDict = result[:newstate]
# progressvalue::Integer = result[:progressvalue]
# """
# [] newNodeKey ∉ keys(node.children).
# New state may have semantic vector close enought to
# one of existing child state. Which can be assume that they are the same state
# semantically-wise i.e. De javu. This could be used to recall lessons for this
# similar situation to improve decisionMaker and evaluator.
# """
# if newNodeKey ∉ keys(node.children)
# node.children[newNodeKey] =
# MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward],
# newstate[:isterminal], node, Dict{String, MCTSNode}())
# end
# end
# end
function expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple;
totalsample::Integer=3)
nthSample = 0
while true
nthSample += 1
if nthSample <= totalsample
result = transition(node.state, transitionargs)
newNodeKey::AbstractString = result[:newNodeKey]
newstate::AbstractDict = result[:newstate]
progressvalue::Integer = result[:progressvalue]
"""
[] newNodeKey ∉ keys(node.children).
New state may have semantic vector close enought to
one of existing child state. Which can be assume that they are the same state
semantically-wise i.e. De javu. This could be used to recall lessons for this
similar situation to improve decisionMaker and evaluator.
"""
if newNodeKey keys(node.children)
node.children[newNodeKey] =
MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward],
newstate[:isterminal], node, Dict{String, MCTSNode}())
end
else
break
end
end
end
""" Simulate interactions between agent and environment
# Arguments
- `node::MCTSNode`
node that will be a simulation starting point.
- `transition::Function`
A user function that handles how state transition.
- `transitionargs::NamedTuple`
Arguments for everything the user will use within transition().
- `maxdepth::Integer`
maximum depth level MCTS goes vertically.
- totalsample::Integer
Total number to sample from the current node (i.e. expand new node horizontally)
# Return
- `::NamedTuple{(:simTrajectoryReward, :terminalstate), Tuple{Number, Union{Dict{Symbol, Any}, Nothing}}}`
# Signature
"""
function simulate(node::MCTSNode, transition::Function, transitionargs::NamedTuple;
maxdepth::Integer=3, totalsample::Integer=3
)::NamedTuple{(:simTrajectoryReward, :terminalstate), Tuple{Number, Union{Dict{Symbol, Any}, Nothing}}}
simTrajectoryReward = 0.0
terminalstate = nothing
for depth in 1:maxdepth
simTrajectoryReward += node.reward
if node.isterminal
terminalstate = node.state
break
else
expand(node, transition, transitionargs;
totalsample=totalsample)
node = selectChildNode(node)
end
end
return (simTrajectoryReward=simTrajectoryReward, terminalstate=terminalstate)
end
"""
# Arguments
# Return
# Example
```jldoctest
julia>
```
# TODO
- [] update docstring
- [x] implement the function
# Signature
"""
function makeNewState(currentstate::T1, thoughtDict::T4, response::T2, select::Union{T3, Nothing},
reward::T3, isterminal::Bool
)::Tuple{String, Dict{Symbol, <:Any}} where {T1<:AbstractDict, T2<:AbstractString, T3<:Number, T4<:AbstractDict}
currentstate_latestThoughtKey, currentstate_latestThoughtIndice =
GeneralUtils.findHighestIndexKey(currentstate[:thoughtHistory], "thought")
currentstate_nextIndice =
currentstate_latestThoughtKey == :NA ? 1 : currentstate_latestThoughtIndice + 1
currentstate_latestThoughtKey = Symbol("thought_$currentstate_nextIndice")
latestActionKey = Symbol("action_$currentstate_nextIndice")
_, thoughtDict_latestThoughtIndice =
GeneralUtils.findHighestIndexKey(thoughtDict, "thought")
thoughtDict_latestThoughtKey, thoughtDict_latestActionKey =
if thoughtDict_latestThoughtIndice == -1
(:thought, :action)
else
(
Symbol("thought_$thoughtDict_latestThoughtIndice"),
Symbol("action_$thoughtDict_latestThoughtIndice"),
)
end
# add Thought, action, observation to thoughtHistory
newstate = deepcopy(currentstate)
newstate[:thoughtHistory][currentstate_latestThoughtKey] =
thoughtDict[thoughtDict_latestThoughtKey]
newstate[:thoughtHistory][latestActionKey] = thoughtDict[thoughtDict_latestActionKey]
newObservationKey = Symbol("observation_$(currentstate_nextIndice)")
newstate[:thoughtHistory][newObservationKey] = response
newstate[:reward] = reward
newstate[:select] = select
newstate[:isterminal] = isterminal
newNodeKey = GeneralUtils.uuid4snakecase()
return (newNodeKey, newstate)
end
module mcts
export selectBestNextNode, selectBestTrajectoryNode, backpropagate, isleaf, isroot, selectChildNode,
expand, simulate, makeNewState
using Base.Threads
using GeneralUtils
using ..type
# ---------------------------------------------- 100 --------------------------------------------- #
"""
# Arguments
- `node::MCTSNode`
node of a search tree
# Return
- `childNode::MCTSNode`
the highest value child node
# Signature
"""
function selectBestNextNode(node::MCTSNode)::MCTSNode
highestProgressValue = -1
nodekey = nothing
# if all childnode has statevalue == 0, use progressvalue + reward to select the best node
stateValueSum = sum([v.statevalue for (k, v) in node.children])
if stateValueSum != 0
for (k, childnode) in node.children
potential = childnode.statevalue / childnode.visits
if potential > highestProgressValue
highestProgressValue = potential
nodekey = childnode.nodekey
end
end
else
for (k, childnode) in node.children
potential = childnode.progressvalue + childnode.reward
if potential > highestProgressValue
highestProgressValue = potential
nodekey = childnode.nodekey
end
end
end
return node.children[nodekey]
end
"""
# Arguments
- `node::MCTSNode`
node of a search tree
# Return
- `childNode::MCTSNode`
the highest value child node
# Signature
"""
function selectBestTrajectoryNode(node::MCTSNode)::MCTSNode
while !isleaf(node)
node = selectBestNextNode(node)
end
return node
end
""" Backpropagate reward along the simulation chain
# Arguments
- `node::MCTSNode`
leaf node of a search tree
- `simTrajectoryReward::T`
total reward from trajectory simulation
- `discountRewardCoeff::AbstractFloat`
A discount reward coefficient to reduce future reward. The futher in the future the lower
reward it is now.
# Return
- `None`
# Signature
"""
function backpropagate(node::MCTSNode, simTrajectoryReward::T;
discountRewardCoeff::AbstractFloat=0.9) where {T<:Number}
while !isroot(node)
# Update the statistics of the current node based on the result of the playout
node.visits += 1
node.statevalue += ((node.statevalue * (node.visits-1)) + simTrajectoryReward) / node.visits
simTrajectoryReward *= discountRewardCoeff # discount because future reward is uncertain
node = node.parent
end
end
""" Determine whether a node is a leaf node of a search tree.
# Arguments
- `node::MCTSNode`
a search tree node
# Return
- `result::Bool`
true if it is a leaf node, false otherwise.
# Example
```jldoctest
julia> using Revise
julia> using YiemAgent, DataStructures
julia> initialState = Dict{Symbol, Any}(
:customerinfo=> Dict{Symbol, Any}(),
:storeinfo=> Dict{Symbol, Any}(),
:thoughtHistory=> OrderedDict{Symbol, Any}(
:question=> "How are you?",
)
)
julia> statetype = typeof(initialState)
julia> root = YiemAgent.MCTSNode(initialState, 0, 0.0, Dict{statetype, YiemAgent.MCTSNode}())
julia> YiemAgent.isleaf(root)
true
```
# TODO
[] update docs
# Signature
"""
isleaf(node::MCTSNode)::Bool = isempty(node.children)
""" Determine wheter a given node is a root node
# Arguments
- `node::MCTSNode`
node of a search tree
# Return
- `isrootnode::Bool`
true if the given node is root node, false otherwise
# Signature
"""
isroot(node::MCTSNode)::Bool = node.nodekey == "root" ? true : false
""" Select child node based on the highest statevalue
# Arguments
- `node::MCTSNode`
node of a search tree
# Return
- `childNode::MCTSNode`
the highest value child node
# Signature
"""
function selectChildNode(node::MCTSNode)::MCTSNode
highestProgressValue = -1
nodekey = nothing
# loop thought node children dictionary to find the highest progress value
for (k, childNode) in node.children
potential = childNode.progressvalue + childNode.reward
if potential > highestProgressValue
highestProgressValue = potential
nodekey = childNode.nodekey
end
end
return node.children[nodekey]
end
""" Expand selected node.
# Arguments
- `node::MCTSNode`
MCTS node
- `transition::Function`
A function that handles state transition.
- `transitionargs::NamedTuple`
Arguments for transition()
- `totalsample::Integer`
Total number to sample from the current node (i.e. expand new node horizontally)
# Return
- None
# Signature
"""
# function expand(node::MCTSNode, transition::Function, transitionargs::NamedTuple;
# totalsample::Integer=3)
# # not use Any[] because I want to preserve result order
# results = Vector{Any}(undef, totalsample)
# @sync for i in 1:totalsample
# @spawn begin
# result = transition(deepcopy(node.state), deepcopy(transitionargs))
# results[i] = result
# end
# end
# for result in results
# newNodeKey::AbstractString = result[:newNodeKey]
# newstate::AbstractDict = result[:newstate]
# progressvalue::Integer = result[:progressvalue]
# """
# [] newNodeKey ∉ keys(node.children).
# New state may have semantic vector close enought to
# one of existing child state. Which can be assume that they are the same state
# semantically-wise i.e. De javu. This could be used to recall lessons for this
# similar situation to improve decisionMaker and evaluator.
# """
# if newNodeKey ∉ keys(node.children)
# node.children[newNodeKey] =
# MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward],
# newstate[:isterminal], node, Dict{String, MCTSNode}())
# end
# end
# end
function expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple;
totalsample::Integer=3)
nthSample = 0
while true
nthSample += 1
if nthSample <= totalsample
result = transition(node.state, transitionargs)
newNodeKey::AbstractString = result[:newNodeKey]
newstate::AbstractDict = result[:newstate]
progressvalue::Integer = result[:progressvalue]
"""
[] newNodeKey ∉ keys(node.children).
New state may have semantic vector close enought to
one of existing child state. Which can be assume that they are the same state
semantically-wise i.e. De javu. This could be used to recall lessons for this
similar situation to improve decisionMaker and evaluator.
"""
if newNodeKey keys(node.children)
node.children[newNodeKey] =
MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward],
newstate[:isterminal], node, Dict{String, MCTSNode}())
end
else
break
end
end
end
""" Simulate interactions between agent and environment
# Arguments
- `node::MCTSNode`
node that will be a simulation starting point.
- `transition::Function`
A user function that handles how state transition.
- `transitionargs::NamedTuple`
Arguments for everything the user will use within transition().
- `maxdepth::Integer`
maximum depth level MCTS goes vertically.
- totalsample::Integer
Total number to sample from the current node (i.e. expand new node horizontally)
# Return
- `::NamedTuple{(:simTrajectoryReward, :terminalstate), Tuple{Number, Union{Dict{Symbol, Any}, Nothing}}}`
# Signature
"""
function simulate(node::MCTSNode, transition::Function, transitionargs::NamedTuple;
maxdepth::Integer=3, totalsample::Integer=3
)::NamedTuple{(:simTrajectoryReward, :terminalstate), Tuple{Number, Union{Dict{Symbol, Any}, Nothing}}}
simTrajectoryReward = 0.0
terminalstate = nothing
for depth in 1:maxdepth
simTrajectoryReward += node.reward
if node.isterminal
terminalstate = node.state
break
else
expand(node, transition, transitionargs;
totalsample=totalsample)
node = selectChildNode(node)
end
end
return (simTrajectoryReward=simTrajectoryReward, terminalstate=terminalstate)
end
"""
# Arguments
# Return
# Example
```jldoctest
julia>
```
# TODO
- [] update docstring
- [x] implement the function
# Signature
"""
function makeNewState(currentstate::T1, thoughtDict::T4, response::T2, select::Union{T3, Nothing},
reward::T3, isterminal::Bool
)::Tuple{String, Dict{Symbol, <:Any}} where {T1<:AbstractDict, T2<:AbstractString, T3<:Number, T4<:AbstractDict}
currentstate_latestThoughtKey, currentstate_latestThoughtIndice =
GeneralUtils.findHighestIndexKey(currentstate[:thoughtHistory], "thought")
currentstate_nextIndice =
currentstate_latestThoughtKey == :NA ? 1 : currentstate_latestThoughtIndice + 1
currentstate_latestThoughtKey = Symbol("thought_$currentstate_nextIndice")
latestActionKey = Symbol("action_$currentstate_nextIndice")
_, thoughtDict_latestThoughtIndice =
GeneralUtils.findHighestIndexKey(thoughtDict, "thought")
thoughtDict_latestThoughtKey, thoughtDict_latestActionKey =
if thoughtDict_latestThoughtIndice == -1
(:thought, :action)
else
(
Symbol("thought_$thoughtDict_latestThoughtIndice"),
Symbol("action_$thoughtDict_latestThoughtIndice"),
)
end
# add Thought, action, observation to thoughtHistory
newstate = deepcopy(currentstate)
newstate[:thoughtHistory][currentstate_latestThoughtKey] =
thoughtDict[thoughtDict_latestThoughtKey]
newstate[:thoughtHistory][latestActionKey] = thoughtDict[thoughtDict_latestActionKey]
newObservationKey = Symbol("observation_$(currentstate_nextIndice)")
newstate[:thoughtHistory][newObservationKey] = response
newstate[:reward] = reward
newstate[:select] = select
newstate[:isterminal] = isterminal
newNodeKey = GeneralUtils.uuid4snakecase()
return (newNodeKey, newstate)
end
end # module mcts