This commit is contained in:
narawat lamaiin
2024-07-10 11:38:59 +07:00
parent 9e39d54c4b
commit 05830f3d9a
3 changed files with 167 additions and 85 deletions

View File

@@ -1,6 +1,6 @@
module mcts
export selectBestNextState, selectBestTrajectory, backpropagate, isleaf, isroot, selectChildNode,
export selectBestNextNode, selectBestTrajectoryNode, backpropagate, isleaf, isroot, selectChildNode,
expand, simulate, makeNewState
using GeneralUtils
@@ -20,18 +20,9 @@ using ..type
- `childNode::MCTSNode`
the highest value child node
# Example
```jldoctest
julia>
```
# TODO
- [] update docs
- [x] implement the function
# Signature
"""
function selectBestNextState(node::MCTSNode)::MCTSNode
function selectBestNextNode(node::MCTSNode)::MCTSNode
highestProgressValue = -1
nodekey = nothing
@@ -72,20 +63,11 @@ end
- `childNode::MCTSNode`
the highest value child node
# Example
```jldoctest
julia>
```
# TODO
- [] update docs
- [x] implement the function
# Signature
"""
function selectBestTrajectory(node::MCTSNode)::MCTSNode
function selectBestTrajectoryNode(node::MCTSNode)::MCTSNode
while !isleaf(node)
node = selectBestNextState(node)
node = selectBestNextNode(node)
end
return node
@@ -99,14 +81,12 @@ end
leaf node of a search tree
- `simTrajectoryReward::T`
total reward from trajectory simulation
- `discountRewardCoeff::AbstractFloat`
A discount reward coefficient to reduce future reward. The futher in the future the lower
reward it is now.
# Return
- `No return`
# Example
```jldoctest
julia>
```
- `None`
# Signature
"""
@@ -166,11 +146,6 @@ isleaf(node::MCTSNode)::Bool = isempty(node.children)
- `isrootnode::Bool`
true if the given node is root node, false otherwise
# Example
```jldoctest
julia>
```
# Signature
"""
isroot(node::MCTSNode)::Bool = node.nodekey == "root" ? true : false
@@ -187,11 +162,6 @@ isroot(node::MCTSNode)::Bool = node.nodekey == "root" ? true : false
- `childNode::MCTSNode`
the highest value child node
# Example
```jldoctest
julia>
```
# Signature
"""
function selectChildNode(node::MCTSNode)::MCTSNode
@@ -201,9 +171,6 @@ function selectChildNode(node::MCTSNode)::MCTSNode
# loop thought node children dictionary to find the highest progress value
for (k, childNode) in node.children
potential = childNode.progressvalue + childNode.reward
if childNode.reward > 0 #XXX for testing. remove when done.
println("")
end
if potential > highestProgressValue
highestProgressValue = potential
nodekey = childNode.nodekey
@@ -214,34 +181,26 @@ function selectChildNode(node::MCTSNode)::MCTSNode
end
""" Expand selected node
""" Expand selected node.
# Arguments
- `a::T1`
One of YiemAgent's agent
- `node::MCTSNode`
MCTS node
- `state::T2`
a state of a game. Can be a Dict or something else.
- `decisionMaker::Function`
a function that output Thought and Action
- `evaluator::Function`
a function that output trajectory progress score
- `transition::Function`
A function that handles state transition.
- `transitionargs::NamedTuple`
Arguments for transition()
- `totalsample::Integer`
Total number to sample from the current node (i.e. expand new node horizontally)
# Return
- None
# Example
```jldoctest
julia>
```
# TODO
[] update docstring
[] try loop should limit to 3 times. if not succeed, skip
[] newNodeKey ∉ keys(node.children). New state may have semantic vector close enought to one of existing child state. Which can be assume that they are the same state semantically-wise.
[x] store feedback -> state -> agent.
# Signature
"""
function expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple;
@@ -255,6 +214,14 @@ function expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple;
newNodeKey::AbstractString = result[:newNodeKey]
newstate::AbstractDict = result[:newstate]
progressvalue::Integer = result[:progressvalue]
"""
[] newNodeKey ∉ keys(node.children).
New state may have semantic vector close enought to
one of existing child state. Which can be assume that they are the same state
semantically-wise i.e. De javu. This could be used to recall lessons for this
similar situation to improve decisionMaker and evaluator.
"""
if newNodeKey keys(node.children)
node.children[newNodeKey] =
MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward],
@@ -270,30 +237,25 @@ end
""" Simulate interactions between agent and environment
# Arguments
- `a::T`
one of YiemAgent's agent
- `node::MCTSNode`
node that will be a simulation starting point.
- `decisionMaker::Function`
function that receive state return Thought and Action
- `transition::Function`
A user function that handles how state transition.
- `transitionargs::NamedTuple`
Arguments for everything the user will use within transition().
- `maxdepth::Integer`
maximum depth level MCTS goes vertically.
- totalsample::Integer
Total number to sample from the current node (i.e. expand new node horizontally)
# Return
- `simTrajectoryReward::Number`
# Example
```jldoctest
julia>
```
# TODO
- [] update docs
- `(simTrajectoryReward, terminalstate)::Union{Tuple{Number, Dict{Symbol, <:Any}}, Tuple{Number, Nothing}}`
# Signature
"""
function simulate(node::MCTSNode, transition::Function, transitionargs::NamedTuple;
maxdepth::Integer=3, totalsample::Integer=3
)::Union{Tuple{Number, Dict{Symbol, <:Any}}, Tuple{Number, Nothing}}
simTrajectoryReward = 0.0
terminalstate = nothing