update
This commit is contained in:
@@ -12,7 +12,7 @@ using ..type, ..mcts, ..util
|
||||
""" Search the best action to take for a given state and task
|
||||
|
||||
# Arguments
|
||||
- `initial state`
|
||||
- `initialstate::T`
|
||||
initial state
|
||||
- `transition::Function`
|
||||
a function that define how the state transitions
|
||||
@@ -32,21 +32,16 @@ using ..type, ..mcts, ..util
|
||||
aggressively explore new state.
|
||||
|
||||
# Return
|
||||
- `(bestNextState, BestFinalState)::Tuple`
|
||||
- `(bestNextState, BestFinalState)::@NamedTuple{bestNextState::T, bestFinalState::T}`
|
||||
the best next state and the best final state
|
||||
|
||||
# Example
|
||||
```jldoctest
|
||||
julia>
|
||||
```
|
||||
|
||||
# TODO
|
||||
[] update example
|
||||
Refers to SQLLLM package
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function runMCTS(
|
||||
initialstate,
|
||||
initialstate::T,
|
||||
transition::Function,
|
||||
transitionargs::NamedTuple,
|
||||
;
|
||||
@@ -54,7 +49,7 @@ function runMCTS(
|
||||
maxdepth::Integer=3,
|
||||
maxiterations::Integer=10,
|
||||
explorationweight::Number=1.0,
|
||||
)::NamedTuple
|
||||
)::@NamedTuple{bestNextState::T, bestFinalState::T} where {T<:Any}
|
||||
|
||||
root = MCTSNode("root", initialstate, 0, 0, 0, 0, false, nothing, Dict{String, MCTSNode}())
|
||||
|
||||
@@ -88,8 +83,8 @@ function runMCTS(
|
||||
end
|
||||
end
|
||||
|
||||
bestNextState = selectBestNextState(root)
|
||||
besttrajectory = selectBestTrajectory(root)
|
||||
bestNextState = selectBestNextNode(root)
|
||||
besttrajectory = selectBestTrajectoryNode(root)
|
||||
|
||||
return (bestNextState=bestNextState.state, bestFinalState=besttrajectory.state)
|
||||
end
|
||||
|
||||
106
src/mcts.jl
106
src/mcts.jl
@@ -1,6 +1,6 @@
|
||||
module mcts
|
||||
|
||||
export selectBestNextState, selectBestTrajectory, backpropagate, isleaf, isroot, selectChildNode,
|
||||
export selectBestNextNode, selectBestTrajectoryNode, backpropagate, isleaf, isroot, selectChildNode,
|
||||
expand, simulate, makeNewState
|
||||
|
||||
using GeneralUtils
|
||||
@@ -20,18 +20,9 @@ using ..type
|
||||
- `childNode::MCTSNode`
|
||||
the highest value child node
|
||||
|
||||
# Example
|
||||
```jldoctest
|
||||
julia>
|
||||
```
|
||||
|
||||
# TODO
|
||||
- [] update docs
|
||||
- [x] implement the function
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function selectBestNextState(node::MCTSNode)::MCTSNode
|
||||
function selectBestNextNode(node::MCTSNode)::MCTSNode
|
||||
highestProgressValue = -1
|
||||
nodekey = nothing
|
||||
|
||||
@@ -72,20 +63,11 @@ end
|
||||
- `childNode::MCTSNode`
|
||||
the highest value child node
|
||||
|
||||
# Example
|
||||
```jldoctest
|
||||
julia>
|
||||
```
|
||||
|
||||
# TODO
|
||||
- [] update docs
|
||||
- [x] implement the function
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function selectBestTrajectory(node::MCTSNode)::MCTSNode
|
||||
function selectBestTrajectoryNode(node::MCTSNode)::MCTSNode
|
||||
while !isleaf(node)
|
||||
node = selectBestNextState(node)
|
||||
node = selectBestNextNode(node)
|
||||
end
|
||||
|
||||
return node
|
||||
@@ -99,14 +81,12 @@ end
|
||||
leaf node of a search tree
|
||||
- `simTrajectoryReward::T`
|
||||
total reward from trajectory simulation
|
||||
- `discountRewardCoeff::AbstractFloat`
|
||||
A discount reward coefficient to reduce future reward. The futher in the future the lower
|
||||
reward it is now.
|
||||
|
||||
# Return
|
||||
- `No return`
|
||||
|
||||
# Example
|
||||
```jldoctest
|
||||
julia>
|
||||
```
|
||||
- `None`
|
||||
|
||||
# Signature
|
||||
"""
|
||||
@@ -166,11 +146,6 @@ isleaf(node::MCTSNode)::Bool = isempty(node.children)
|
||||
- `isrootnode::Bool`
|
||||
true if the given node is root node, false otherwise
|
||||
|
||||
# Example
|
||||
```jldoctest
|
||||
julia>
|
||||
```
|
||||
|
||||
# Signature
|
||||
"""
|
||||
isroot(node::MCTSNode)::Bool = node.nodekey == "root" ? true : false
|
||||
@@ -187,11 +162,6 @@ isroot(node::MCTSNode)::Bool = node.nodekey == "root" ? true : false
|
||||
- `childNode::MCTSNode`
|
||||
the highest value child node
|
||||
|
||||
# Example
|
||||
```jldoctest
|
||||
julia>
|
||||
```
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function selectChildNode(node::MCTSNode)::MCTSNode
|
||||
@@ -201,9 +171,6 @@ function selectChildNode(node::MCTSNode)::MCTSNode
|
||||
# loop thought node children dictionary to find the highest progress value
|
||||
for (k, childNode) in node.children
|
||||
potential = childNode.progressvalue + childNode.reward
|
||||
if childNode.reward > 0 #XXX for testing. remove when done.
|
||||
println("")
|
||||
end
|
||||
if potential > highestProgressValue
|
||||
highestProgressValue = potential
|
||||
nodekey = childNode.nodekey
|
||||
@@ -214,34 +181,26 @@ function selectChildNode(node::MCTSNode)::MCTSNode
|
||||
end
|
||||
|
||||
|
||||
""" Expand selected node
|
||||
""" Expand selected node.
|
||||
|
||||
# Arguments
|
||||
- `a::T1`
|
||||
One of YiemAgent's agent
|
||||
- `node::MCTSNode`
|
||||
MCTS node
|
||||
- `state::T2`
|
||||
a state of a game. Can be a Dict or something else.
|
||||
- `decisionMaker::Function`
|
||||
a function that output Thought and Action
|
||||
- `evaluator::Function`
|
||||
a function that output trajectory progress score
|
||||
- `transition::Function`
|
||||
A function that handles state transition.
|
||||
- `transitionargs::NamedTuple`
|
||||
Arguments for transition()
|
||||
- `totalsample::Integer`
|
||||
Total number to sample from the current node (i.e. expand new node horizontally)
|
||||
|
||||
# Return
|
||||
- None
|
||||
|
||||
# Example
|
||||
```jldoctest
|
||||
julia>
|
||||
```
|
||||
|
||||
# TODO
|
||||
[] update docstring
|
||||
[] try loop should limit to 3 times. if not succeed, skip
|
||||
[] newNodeKey ∉ keys(node.children). New state may have semantic vector close enought to one of existing child state. Which can be assume that they are the same state semantically-wise.
|
||||
[x] store feedback -> state -> agent.
|
||||
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple;
|
||||
@@ -255,6 +214,14 @@ function expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple;
|
||||
newNodeKey::AbstractString = result[:newNodeKey]
|
||||
newstate::AbstractDict = result[:newstate]
|
||||
progressvalue::Integer = result[:progressvalue]
|
||||
|
||||
"""
|
||||
[] newNodeKey ∉ keys(node.children).
|
||||
New state may have semantic vector close enought to
|
||||
one of existing child state. Which can be assume that they are the same state
|
||||
semantically-wise i.e. De javu. This could be used to recall lessons for this
|
||||
similar situation to improve decisionMaker and evaluator.
|
||||
"""
|
||||
if newNodeKey ∉ keys(node.children)
|
||||
node.children[newNodeKey] =
|
||||
MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward],
|
||||
@@ -270,30 +237,25 @@ end
|
||||
""" Simulate interactions between agent and environment
|
||||
|
||||
# Arguments
|
||||
- `a::T`
|
||||
one of YiemAgent's agent
|
||||
- `node::MCTSNode`
|
||||
node that will be a simulation starting point.
|
||||
- `decisionMaker::Function`
|
||||
function that receive state return Thought and Action
|
||||
|
||||
- `transition::Function`
|
||||
A user function that handles how state transition.
|
||||
- `transitionargs::NamedTuple`
|
||||
Arguments for everything the user will use within transition().
|
||||
- `maxdepth::Integer`
|
||||
maximum depth level MCTS goes vertically.
|
||||
- totalsample::Integer
|
||||
Total number to sample from the current node (i.e. expand new node horizontally)
|
||||
|
||||
# Return
|
||||
- `simTrajectoryReward::Number`
|
||||
|
||||
# Example
|
||||
```jldoctest
|
||||
julia>
|
||||
```
|
||||
|
||||
# TODO
|
||||
- [] update docs
|
||||
- `(simTrajectoryReward, terminalstate)::Union{Tuple{Number, Dict{Symbol, <:Any}}, Tuple{Number, Nothing}}`
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function simulate(node::MCTSNode, transition::Function, transitionargs::NamedTuple;
|
||||
maxdepth::Integer=3, totalsample::Integer=3
|
||||
)::Union{Tuple{Number, Dict{Symbol, <:Any}}, Tuple{Number, Nothing}}
|
||||
|
||||
simTrajectoryReward = 0.0
|
||||
terminalstate = nothing
|
||||
|
||||
|
||||
Reference in New Issue
Block a user