update
This commit is contained in:
@@ -1,38 +1,38 @@
|
||||
module LLMMCTS
|
||||
|
||||
# export agent
|
||||
|
||||
|
||||
""" Order by dependencies of each file. The 1st included file must not depend on any other
|
||||
files and each file can only depend on the file included before it.
|
||||
"""
|
||||
|
||||
include("type.jl")
|
||||
using .type
|
||||
|
||||
include("util.jl")
|
||||
using .util
|
||||
|
||||
include("mcts.jl")
|
||||
using .mcts
|
||||
|
||||
include("interface.jl")
|
||||
using .interface
|
||||
|
||||
|
||||
# ---------------------------------------------- 100 --------------------------------------------- #
|
||||
|
||||
""" version 0.0.2
|
||||
Todo:
|
||||
- []
|
||||
|
||||
Change from version: 0.0.1
|
||||
-
|
||||
|
||||
All features
|
||||
|
||||
"""
|
||||
|
||||
|
||||
|
||||
end # module LLMMCTS
|
||||
module LLMMCTS
|
||||
|
||||
# export agent
|
||||
|
||||
|
||||
""" Order by dependencies of each file. The 1st included file must not depend on any other
|
||||
files and each file can only depend on the file included before it.
|
||||
"""
|
||||
|
||||
include("type.jl")
|
||||
using .type
|
||||
|
||||
include("util.jl")
|
||||
using .util
|
||||
|
||||
include("mcts.jl")
|
||||
using .mcts
|
||||
|
||||
include("interface.jl")
|
||||
using .interface
|
||||
|
||||
|
||||
# ---------------------------------------------- 100 --------------------------------------------- #
|
||||
|
||||
""" version 0.0.2
|
||||
Todo:
|
||||
- []
|
||||
|
||||
Change from version: 0.0.1
|
||||
-
|
||||
|
||||
All features
|
||||
|
||||
"""
|
||||
|
||||
|
||||
|
||||
end # module LLMMCTS
|
||||
|
||||
454
src/interface.jl
454
src/interface.jl
@@ -1,228 +1,228 @@
|
||||
module interface
|
||||
|
||||
export runMCTS
|
||||
|
||||
using ..type, ..mcts, ..util
|
||||
|
||||
|
||||
# ---------------------------------------------- 100 --------------------------------------------- #
|
||||
|
||||
|
||||
|
||||
""" Search the best action to take for a given state and task
|
||||
|
||||
# Arguments
|
||||
- `initialstate::T`
|
||||
initial state
|
||||
- `transition::Function`
|
||||
a function that define how the state transitions
|
||||
- `transitionargs::NamedTuple`
|
||||
arguments for transition function
|
||||
|
||||
# Keyword Arguments
|
||||
- `totalsample::Integer`
|
||||
a number of child state MCTS sample at each node during expansion phase
|
||||
- `maxdepth::Integer`
|
||||
a number of levels MCTS goes during simulation phase
|
||||
- `maxiterations::Integer`
|
||||
a number of iteration MCTS goes thru expansion -> simulation -> backpropagation cycle
|
||||
- `explorationweight::Number`
|
||||
exploration weight controls how much MCTS should explore new state instead of exploit
|
||||
a known state. 1.0 balance between exploration and exploitation like 50%-50%. 2.0 makes MCTS
|
||||
aggressively explore new state.
|
||||
|
||||
# Return
|
||||
- `NamedTuple{(:bestNextState, :bestFinalState), Tuple{T, T}}`
|
||||
the best next state and the best final state
|
||||
|
||||
# Example
|
||||
Refers to SQLLLM package
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function runMCTS(
|
||||
initialstate::T,
|
||||
transition::Function,
|
||||
transitionargs::NamedTuple,
|
||||
;
|
||||
totalsample::Integer=3,
|
||||
maxdepth::Integer=3,
|
||||
maxiterations::Integer=10,
|
||||
explorationweight::Number=1.0,
|
||||
earlystop::Union{Function,Nothing}=nothing
|
||||
)::NamedTuple{(:bestNextState, :bestFinalState),Tuple{T,T}} where {T<:Any}
|
||||
|
||||
root = MCTSNode("root", initialstate, 0, 0, 0, 0, false, nothing, Dict{String,MCTSNode}())
|
||||
|
||||
for nth in 1:maxiterations
|
||||
node = root
|
||||
node.visits += 1
|
||||
|
||||
while !isleaf(node)
|
||||
node = UCTselect(node, explorationweight)
|
||||
end
|
||||
|
||||
if node.isterminal
|
||||
# MCTS arrive at the leaf node that is also a terminal state,
|
||||
# do nothing then go directly to backpropagation. It means the end of this iteration
|
||||
backpropagate(node, node.reward)
|
||||
else
|
||||
expand(node, transition, transitionargs;
|
||||
totalsample=totalsample)
|
||||
leafNode = selectChildNode(node)
|
||||
simTrajectoryReward, terminalstate = simulate(leafNode, transition, transitionargs;
|
||||
maxdepth=maxdepth, totalsample=totalsample)
|
||||
# if terminalstate !== nothing #XXX not sure why I need this
|
||||
# terminalstate[:totalTrajectoryReward] = simTrajectoryReward
|
||||
# end
|
||||
|
||||
#[] write best state to file if it has higher simTrajectoryReward. Use to improve evaluation
|
||||
# open("trajectory.json", "w") do io
|
||||
# JSON3.pretty(io, terminalstate)
|
||||
# end
|
||||
|
||||
backpropagate(leafNode, simTrajectoryReward)
|
||||
end
|
||||
|
||||
# stop if the early stop condition is met
|
||||
if typeof(earlystop) <: Function && earlystop(node.state)
|
||||
break
|
||||
end
|
||||
end
|
||||
|
||||
bestNextState = selectBestNextNode(root)
|
||||
besttrajectory = selectBestTrajectoryNode(root)
|
||||
|
||||
return (bestNextState=bestNextState.state, bestFinalState=besttrajectory.state)
|
||||
end
|
||||
|
||||
|
||||
# function runMCTS(
|
||||
# initialstate::T,
|
||||
# transition::Function,
|
||||
# transitionargs::NamedTuple,
|
||||
# ;
|
||||
# totalsample::Integer=3,
|
||||
# maxdepth::Integer=3,
|
||||
# maxiterations::Integer=10,
|
||||
# explorationweight::Number=1.0,
|
||||
# )::NamedTuple{(:bestNextState, :bestFinalState),Tuple{T,T}} where {T<:Any}
|
||||
|
||||
# root = MCTSNode("root", initialstate, 0, 0, 0, 0, false, nothing, Dict{String,MCTSNode}())
|
||||
|
||||
# for nth in 1:maxiterations
|
||||
# node = root
|
||||
# node.visits += 1
|
||||
|
||||
# while !isleaf(node)
|
||||
# node = UCTselect(node, explorationweight)
|
||||
# end
|
||||
# if node.isterminal
|
||||
# # MCTS arrive at the leaf node that is also a terminal state,
|
||||
# # do nothing then go directly to backpropagation. It means the end of this iteration
|
||||
# backpropagate(leafNode, node.reward)
|
||||
# else
|
||||
# expand(node, transition, transitionargs;
|
||||
# totalsample=totalsample)
|
||||
# leafNode = selectChildNode(node)
|
||||
# simTrajectoryReward, terminalstate = simulate(leafNode, transition, transitionargs;
|
||||
# maxdepth=maxdepth, totalsample=totalsample)
|
||||
# # if terminalstate !== nothing #XXX not sure why I need this
|
||||
# # terminalstate[:totalTrajectoryReward] = simTrajectoryReward
|
||||
# # end
|
||||
|
||||
# #[] write best state to file if it has higher simTrajectoryReward. Use to improve evaluation
|
||||
# # open("trajectory.json", "w") do io
|
||||
# # JSON3.pretty(io, terminalstate)
|
||||
# # end
|
||||
|
||||
# backpropagate(leafNode, simTrajectoryReward)
|
||||
# end
|
||||
# end
|
||||
|
||||
# bestNextState = selectBestNextNode(root)
|
||||
# besttrajectory = selectBestTrajectoryNode(root)
|
||||
|
||||
# return (bestNextState=bestNextState.state, bestFinalState=besttrajectory.state)
|
||||
# end
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
module interface
|
||||
|
||||
export runMCTS
|
||||
|
||||
using ..type, ..mcts, ..util
|
||||
|
||||
|
||||
# ---------------------------------------------- 100 --------------------------------------------- #
|
||||
|
||||
|
||||
|
||||
""" Search the best action to take for a given state and task
|
||||
|
||||
# Arguments
|
||||
- `initialstate::T`
|
||||
initial state
|
||||
- `transition::Function`
|
||||
a function that define how the state transitions
|
||||
- `transitionargs::NamedTuple`
|
||||
arguments for transition function
|
||||
|
||||
# Keyword Arguments
|
||||
- `totalsample::Integer`
|
||||
a number of child state MCTS sample at each node during expansion phase
|
||||
- `maxdepth::Integer`
|
||||
a number of levels MCTS goes during simulation phase
|
||||
- `maxiterations::Integer`
|
||||
a number of iteration MCTS goes thru expansion -> simulation -> backpropagation cycle
|
||||
- `explorationweight::Number`
|
||||
exploration weight controls how much MCTS should explore new state instead of exploit
|
||||
a known state. 1.0 balance between exploration and exploitation like 50%-50%. 2.0 makes MCTS
|
||||
aggressively explore new state.
|
||||
|
||||
# Return
|
||||
- `NamedTuple{(:bestNextState, :bestFinalState), Tuple{T, T}}`
|
||||
the best next state and the best final state
|
||||
|
||||
# Example
|
||||
Refers to SQLLLM package
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function runMCTS(
|
||||
initialstate::T,
|
||||
transition::Function,
|
||||
transitionargs::NamedTuple,
|
||||
;
|
||||
totalsample::Integer=3,
|
||||
maxdepth::Integer=3,
|
||||
maxiterations::Integer=10,
|
||||
explorationweight::Number=1.0,
|
||||
earlystop::Union{Function,Nothing}=nothing
|
||||
)::NamedTuple{(:bestNextState, :bestFinalState),Tuple{T,T}} where {T<:Any}
|
||||
|
||||
root = MCTSNode("root", initialstate, 0, 0, 0, 0, false, nothing, Dict{String,MCTSNode}())
|
||||
|
||||
for nth in 1:maxiterations
|
||||
node = root
|
||||
node.visits += 1
|
||||
|
||||
while !isleaf(node)
|
||||
node = UCTselect(node, explorationweight)
|
||||
end
|
||||
|
||||
if node.isterminal
|
||||
# MCTS arrive at the leaf node that is also a terminal state,
|
||||
# do nothing then go directly to backpropagation. It means the end of this iteration
|
||||
backpropagate(node, node.reward)
|
||||
else
|
||||
expand(node, transition, transitionargs;
|
||||
totalsample=totalsample)
|
||||
leafNode = selectChildNode(node)
|
||||
simTrajectoryReward, terminalstate = simulate(leafNode, transition, transitionargs;
|
||||
maxdepth=maxdepth, totalsample=totalsample)
|
||||
# if terminalstate !== nothing #XXX not sure why I need this
|
||||
# terminalstate[:totalTrajectoryReward] = simTrajectoryReward
|
||||
# end
|
||||
|
||||
#[] write best state to file if it has higher simTrajectoryReward. Use to improve evaluation
|
||||
# open("trajectory.json", "w") do io
|
||||
# JSON3.pretty(io, terminalstate)
|
||||
# end
|
||||
|
||||
backpropagate(leafNode, simTrajectoryReward)
|
||||
end
|
||||
|
||||
# stop if the early stop condition is met
|
||||
if typeof(earlystop) <: Function && earlystop(node.state)
|
||||
break
|
||||
end
|
||||
end
|
||||
|
||||
bestNextState = selectBestNextNode(root)
|
||||
besttrajectory = selectBestTrajectoryNode(root)
|
||||
|
||||
return (bestNextState=bestNextState.state, bestFinalState=besttrajectory.state)
|
||||
end
|
||||
|
||||
|
||||
# function runMCTS(
|
||||
# initialstate::T,
|
||||
# transition::Function,
|
||||
# transitionargs::NamedTuple,
|
||||
# ;
|
||||
# totalsample::Integer=3,
|
||||
# maxdepth::Integer=3,
|
||||
# maxiterations::Integer=10,
|
||||
# explorationweight::Number=1.0,
|
||||
# )::NamedTuple{(:bestNextState, :bestFinalState),Tuple{T,T}} where {T<:Any}
|
||||
|
||||
# root = MCTSNode("root", initialstate, 0, 0, 0, 0, false, nothing, Dict{String,MCTSNode}())
|
||||
|
||||
# for nth in 1:maxiterations
|
||||
# node = root
|
||||
# node.visits += 1
|
||||
|
||||
# while !isleaf(node)
|
||||
# node = UCTselect(node, explorationweight)
|
||||
# end
|
||||
# if node.isterminal
|
||||
# # MCTS arrive at the leaf node that is also a terminal state,
|
||||
# # do nothing then go directly to backpropagation. It means the end of this iteration
|
||||
# backpropagate(leafNode, node.reward)
|
||||
# else
|
||||
# expand(node, transition, transitionargs;
|
||||
# totalsample=totalsample)
|
||||
# leafNode = selectChildNode(node)
|
||||
# simTrajectoryReward, terminalstate = simulate(leafNode, transition, transitionargs;
|
||||
# maxdepth=maxdepth, totalsample=totalsample)
|
||||
# # if terminalstate !== nothing #XXX not sure why I need this
|
||||
# # terminalstate[:totalTrajectoryReward] = simTrajectoryReward
|
||||
# # end
|
||||
|
||||
# #[] write best state to file if it has higher simTrajectoryReward. Use to improve evaluation
|
||||
# # open("trajectory.json", "w") do io
|
||||
# # JSON3.pretty(io, terminalstate)
|
||||
# # end
|
||||
|
||||
# backpropagate(leafNode, simTrajectoryReward)
|
||||
# end
|
||||
# end
|
||||
|
||||
# bestNextState = selectBestNextNode(root)
|
||||
# besttrajectory = selectBestTrajectoryNode(root)
|
||||
|
||||
# return (bestNextState=bestNextState.state, bestFinalState=besttrajectory.state)
|
||||
# end
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
end # module interface
|
||||
856
src/mcts.jl
856
src/mcts.jl
@@ -1,429 +1,429 @@
|
||||
module mcts
|
||||
|
||||
export selectBestNextNode, selectBestTrajectoryNode, backpropagate, isleaf, isroot, selectChildNode,
|
||||
expand, simulate, makeNewState
|
||||
using Base.Threads
|
||||
using GeneralUtils
|
||||
|
||||
using ..type
|
||||
|
||||
# ---------------------------------------------- 100 --------------------------------------------- #
|
||||
|
||||
|
||||
"""
|
||||
|
||||
# Arguments
|
||||
- `node::MCTSNode`
|
||||
node of a search tree
|
||||
|
||||
# Return
|
||||
- `childNode::MCTSNode`
|
||||
the highest value child node
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function selectBestNextNode(node::MCTSNode)::MCTSNode
|
||||
highestProgressValue = -1
|
||||
nodekey = nothing
|
||||
|
||||
# if all childnode has statevalue == 0, use progressvalue + reward to select the best node
|
||||
stateValueSum = sum([v.statevalue for (k, v) in node.children])
|
||||
|
||||
if stateValueSum != 0
|
||||
for (k, childnode) in node.children
|
||||
potential = childnode.statevalue / childnode.visits
|
||||
|
||||
if potential > highestProgressValue
|
||||
highestProgressValue = potential
|
||||
nodekey = childnode.nodekey
|
||||
end
|
||||
end
|
||||
else
|
||||
for (k, childnode) in node.children
|
||||
potential = childnode.progressvalue + childnode.reward
|
||||
|
||||
if potential > highestProgressValue
|
||||
highestProgressValue = potential
|
||||
nodekey = childnode.nodekey
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return node.children[nodekey]
|
||||
end
|
||||
|
||||
|
||||
"""
|
||||
|
||||
# Arguments
|
||||
- `node::MCTSNode`
|
||||
node of a search tree
|
||||
|
||||
# Return
|
||||
- `childNode::MCTSNode`
|
||||
the highest value child node
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function selectBestTrajectoryNode(node::MCTSNode)::MCTSNode
|
||||
while !isleaf(node)
|
||||
node = selectBestNextNode(node)
|
||||
end
|
||||
|
||||
return node
|
||||
end
|
||||
|
||||
|
||||
""" Backpropagate reward along the simulation chain
|
||||
|
||||
# Arguments
|
||||
- `node::MCTSNode`
|
||||
leaf node of a search tree
|
||||
- `simTrajectoryReward::T`
|
||||
total reward from trajectory simulation
|
||||
- `discountRewardCoeff::AbstractFloat`
|
||||
A discount reward coefficient to reduce future reward. The futher in the future the lower
|
||||
reward it is now.
|
||||
|
||||
# Return
|
||||
- `None`
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function backpropagate(node::MCTSNode, simTrajectoryReward::T;
|
||||
discountRewardCoeff::AbstractFloat=0.9) where {T<:Number}
|
||||
while !isroot(node)
|
||||
# Update the statistics of the current node based on the result of the playout
|
||||
node.visits += 1
|
||||
node.statevalue += ((node.statevalue * (node.visits-1)) + simTrajectoryReward) / node.visits
|
||||
simTrajectoryReward *= discountRewardCoeff # discount because future reward is uncertain
|
||||
node = node.parent
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
""" Determine whether a node is a leaf node of a search tree.
|
||||
|
||||
# Arguments
|
||||
- `node::MCTSNode`
|
||||
a search tree node
|
||||
# Return
|
||||
- `result::Bool`
|
||||
true if it is a leaf node, false otherwise.
|
||||
# Example
|
||||
```jldoctest
|
||||
julia> using Revise
|
||||
julia> using YiemAgent, DataStructures
|
||||
julia> initialState = Dict{Symbol, Any}(
|
||||
:customerinfo=> Dict{Symbol, Any}(),
|
||||
:storeinfo=> Dict{Symbol, Any}(),
|
||||
|
||||
:thoughtHistory=> OrderedDict{Symbol, Any}(
|
||||
:question=> "How are you?",
|
||||
)
|
||||
)
|
||||
julia> statetype = typeof(initialState)
|
||||
julia> root = YiemAgent.MCTSNode(initialState, 0, 0.0, Dict{statetype, YiemAgent.MCTSNode}())
|
||||
julia> YiemAgent.isleaf(root)
|
||||
true
|
||||
```
|
||||
|
||||
# TODO
|
||||
[] update docs
|
||||
|
||||
# Signature
|
||||
"""
|
||||
isleaf(node::MCTSNode)::Bool = isempty(node.children)
|
||||
|
||||
|
||||
""" Determine wheter a given node is a root node
|
||||
|
||||
# Arguments
|
||||
- `node::MCTSNode`
|
||||
node of a search tree
|
||||
|
||||
# Return
|
||||
- `isrootnode::Bool`
|
||||
true if the given node is root node, false otherwise
|
||||
|
||||
# Signature
|
||||
"""
|
||||
isroot(node::MCTSNode)::Bool = node.nodekey == "root" ? true : false
|
||||
|
||||
|
||||
|
||||
""" Select child node based on the highest statevalue
|
||||
|
||||
# Arguments
|
||||
- `node::MCTSNode`
|
||||
node of a search tree
|
||||
|
||||
# Return
|
||||
- `childNode::MCTSNode`
|
||||
the highest value child node
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function selectChildNode(node::MCTSNode)::MCTSNode
|
||||
highestProgressValue = -1
|
||||
nodekey = nothing
|
||||
|
||||
# loop thought node children dictionary to find the highest progress value
|
||||
for (k, childNode) in node.children
|
||||
potential = childNode.progressvalue + childNode.reward
|
||||
if potential > highestProgressValue
|
||||
highestProgressValue = potential
|
||||
nodekey = childNode.nodekey
|
||||
end
|
||||
end
|
||||
|
||||
return node.children[nodekey]
|
||||
end
|
||||
|
||||
|
||||
""" Expand selected node.
|
||||
|
||||
# Arguments
|
||||
- `node::MCTSNode`
|
||||
MCTS node
|
||||
- `transition::Function`
|
||||
A function that handles state transition.
|
||||
- `transitionargs::NamedTuple`
|
||||
Arguments for transition()
|
||||
- `totalsample::Integer`
|
||||
Total number to sample from the current node (i.e. expand new node horizontally)
|
||||
|
||||
# Return
|
||||
- None
|
||||
|
||||
# Signature
|
||||
"""
|
||||
# function expand(node::MCTSNode, transition::Function, transitionargs::NamedTuple;
|
||||
# totalsample::Integer=3)
|
||||
|
||||
# # not use Any[] because I want to preserve result order
|
||||
# results = Vector{Any}(undef, totalsample)
|
||||
|
||||
# @sync for i in 1:totalsample
|
||||
# @spawn begin
|
||||
# result = transition(deepcopy(node.state), deepcopy(transitionargs))
|
||||
# results[i] = result
|
||||
# end
|
||||
# end
|
||||
|
||||
# for result in results
|
||||
# newNodeKey::AbstractString = result[:newNodeKey]
|
||||
# newstate::AbstractDict = result[:newstate]
|
||||
# progressvalue::Integer = result[:progressvalue]
|
||||
|
||||
# """
|
||||
# [] newNodeKey ∉ keys(node.children).
|
||||
# New state may have semantic vector close enought to
|
||||
# one of existing child state. Which can be assume that they are the same state
|
||||
# semantically-wise i.e. De javu. This could be used to recall lessons for this
|
||||
# similar situation to improve decisionMaker and evaluator.
|
||||
# """
|
||||
# if newNodeKey ∉ keys(node.children)
|
||||
# node.children[newNodeKey] =
|
||||
# MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward],
|
||||
# newstate[:isterminal], node, Dict{String, MCTSNode}())
|
||||
# end
|
||||
# end
|
||||
# end
|
||||
function expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple;
|
||||
totalsample::Integer=3)
|
||||
|
||||
nthSample = 0
|
||||
while true
|
||||
nthSample += 1
|
||||
if nthSample <= totalsample
|
||||
result = transition(node.state, transitionargs)
|
||||
newNodeKey::AbstractString = result[:newNodeKey]
|
||||
newstate::AbstractDict = result[:newstate]
|
||||
progressvalue::Integer = result[:progressvalue]
|
||||
|
||||
"""
|
||||
[] newNodeKey ∉ keys(node.children).
|
||||
New state may have semantic vector close enought to
|
||||
one of existing child state. Which can be assume that they are the same state
|
||||
semantically-wise i.e. De javu. This could be used to recall lessons for this
|
||||
similar situation to improve decisionMaker and evaluator.
|
||||
"""
|
||||
if newNodeKey ∉ keys(node.children)
|
||||
node.children[newNodeKey] =
|
||||
MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward],
|
||||
newstate[:isterminal], node, Dict{String, MCTSNode}())
|
||||
end
|
||||
else
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
""" Simulate interactions between agent and environment
|
||||
|
||||
# Arguments
|
||||
- `node::MCTSNode`
|
||||
node that will be a simulation starting point.
|
||||
- `transition::Function`
|
||||
A user function that handles how state transition.
|
||||
- `transitionargs::NamedTuple`
|
||||
Arguments for everything the user will use within transition().
|
||||
- `maxdepth::Integer`
|
||||
maximum depth level MCTS goes vertically.
|
||||
- totalsample::Integer
|
||||
Total number to sample from the current node (i.e. expand new node horizontally)
|
||||
|
||||
# Return
|
||||
- `::NamedTuple{(:simTrajectoryReward, :terminalstate), Tuple{Number, Union{Dict{Symbol, Any}, Nothing}}}`
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function simulate(node::MCTSNode, transition::Function, transitionargs::NamedTuple;
|
||||
maxdepth::Integer=3, totalsample::Integer=3
|
||||
)::NamedTuple{(:simTrajectoryReward, :terminalstate), Tuple{Number, Union{Dict{Symbol, Any}, Nothing}}}
|
||||
|
||||
simTrajectoryReward = 0.0
|
||||
terminalstate = nothing
|
||||
|
||||
for depth in 1:maxdepth
|
||||
simTrajectoryReward += node.reward
|
||||
if node.isterminal
|
||||
terminalstate = node.state
|
||||
break
|
||||
else
|
||||
expand(node, transition, transitionargs;
|
||||
totalsample=totalsample)
|
||||
node = selectChildNode(node)
|
||||
end
|
||||
end
|
||||
|
||||
return (simTrajectoryReward=simTrajectoryReward, terminalstate=terminalstate)
|
||||
end
|
||||
|
||||
|
||||
"""
|
||||
|
||||
# Arguments
|
||||
|
||||
# Return
|
||||
|
||||
# Example
|
||||
```jldoctest
|
||||
julia>
|
||||
```
|
||||
|
||||
# TODO
|
||||
- [] update docstring
|
||||
- [x] implement the function
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function makeNewState(currentstate::T1, thoughtDict::T4, response::T2, select::Union{T3, Nothing},
|
||||
reward::T3, isterminal::Bool
|
||||
)::Tuple{String, Dict{Symbol, <:Any}} where {T1<:AbstractDict, T2<:AbstractString, T3<:Number, T4<:AbstractDict}
|
||||
|
||||
currentstate_latestThoughtKey, currentstate_latestThoughtIndice =
|
||||
GeneralUtils.findHighestIndexKey(currentstate[:thoughtHistory], "thought")
|
||||
currentstate_nextIndice =
|
||||
currentstate_latestThoughtKey == :NA ? 1 : currentstate_latestThoughtIndice + 1
|
||||
currentstate_latestThoughtKey = Symbol("thought_$currentstate_nextIndice")
|
||||
latestActionKey = Symbol("action_$currentstate_nextIndice")
|
||||
|
||||
_, thoughtDict_latestThoughtIndice =
|
||||
GeneralUtils.findHighestIndexKey(thoughtDict, "thought")
|
||||
|
||||
thoughtDict_latestThoughtKey, thoughtDict_latestActionKey =
|
||||
if thoughtDict_latestThoughtIndice == -1
|
||||
(:thought, :action)
|
||||
else
|
||||
(
|
||||
Symbol("thought_$thoughtDict_latestThoughtIndice"),
|
||||
Symbol("action_$thoughtDict_latestThoughtIndice"),
|
||||
)
|
||||
end
|
||||
|
||||
# add Thought, action, observation to thoughtHistory
|
||||
newstate = deepcopy(currentstate)
|
||||
newstate[:thoughtHistory][currentstate_latestThoughtKey] =
|
||||
thoughtDict[thoughtDict_latestThoughtKey]
|
||||
newstate[:thoughtHistory][latestActionKey] = thoughtDict[thoughtDict_latestActionKey]
|
||||
newObservationKey = Symbol("observation_$(currentstate_nextIndice)")
|
||||
newstate[:thoughtHistory][newObservationKey] = response
|
||||
newstate[:reward] = reward
|
||||
newstate[:select] = select
|
||||
newstate[:isterminal] = isterminal
|
||||
|
||||
newNodeKey = GeneralUtils.uuid4snakecase()
|
||||
|
||||
return (newNodeKey, newstate)
|
||||
end
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
module mcts
|
||||
|
||||
export selectBestNextNode, selectBestTrajectoryNode, backpropagate, isleaf, isroot, selectChildNode,
|
||||
expand, simulate, makeNewState
|
||||
using Base.Threads
|
||||
using GeneralUtils
|
||||
|
||||
using ..type
|
||||
|
||||
# ---------------------------------------------- 100 --------------------------------------------- #
|
||||
|
||||
|
||||
"""
|
||||
|
||||
# Arguments
|
||||
- `node::MCTSNode`
|
||||
node of a search tree
|
||||
|
||||
# Return
|
||||
- `childNode::MCTSNode`
|
||||
the highest value child node
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function selectBestNextNode(node::MCTSNode)::MCTSNode
|
||||
highestProgressValue = -1
|
||||
nodekey = nothing
|
||||
|
||||
# if all childnode has statevalue == 0, use progressvalue + reward to select the best node
|
||||
stateValueSum = sum([v.statevalue for (k, v) in node.children])
|
||||
|
||||
if stateValueSum != 0
|
||||
for (k, childnode) in node.children
|
||||
potential = childnode.statevalue / childnode.visits
|
||||
|
||||
if potential > highestProgressValue
|
||||
highestProgressValue = potential
|
||||
nodekey = childnode.nodekey
|
||||
end
|
||||
end
|
||||
else
|
||||
for (k, childnode) in node.children
|
||||
potential = childnode.progressvalue + childnode.reward
|
||||
|
||||
if potential > highestProgressValue
|
||||
highestProgressValue = potential
|
||||
nodekey = childnode.nodekey
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return node.children[nodekey]
|
||||
end
|
||||
|
||||
|
||||
"""
|
||||
|
||||
# Arguments
|
||||
- `node::MCTSNode`
|
||||
node of a search tree
|
||||
|
||||
# Return
|
||||
- `childNode::MCTSNode`
|
||||
the highest value child node
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function selectBestTrajectoryNode(node::MCTSNode)::MCTSNode
|
||||
while !isleaf(node)
|
||||
node = selectBestNextNode(node)
|
||||
end
|
||||
|
||||
return node
|
||||
end
|
||||
|
||||
|
||||
""" Backpropagate reward along the simulation chain
|
||||
|
||||
# Arguments
|
||||
- `node::MCTSNode`
|
||||
leaf node of a search tree
|
||||
- `simTrajectoryReward::T`
|
||||
total reward from trajectory simulation
|
||||
- `discountRewardCoeff::AbstractFloat`
|
||||
A discount reward coefficient to reduce future reward. The futher in the future the lower
|
||||
reward it is now.
|
||||
|
||||
# Return
|
||||
- `None`
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function backpropagate(node::MCTSNode, simTrajectoryReward::T;
|
||||
discountRewardCoeff::AbstractFloat=0.9) where {T<:Number}
|
||||
while !isroot(node)
|
||||
# Update the statistics of the current node based on the result of the playout
|
||||
node.visits += 1
|
||||
node.statevalue += ((node.statevalue * (node.visits-1)) + simTrajectoryReward) / node.visits
|
||||
simTrajectoryReward *= discountRewardCoeff # discount because future reward is uncertain
|
||||
node = node.parent
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
""" Determine whether a node is a leaf node of a search tree.
|
||||
|
||||
# Arguments
|
||||
- `node::MCTSNode`
|
||||
a search tree node
|
||||
# Return
|
||||
- `result::Bool`
|
||||
true if it is a leaf node, false otherwise.
|
||||
# Example
|
||||
```jldoctest
|
||||
julia> using Revise
|
||||
julia> using YiemAgent, DataStructures
|
||||
julia> initialState = Dict{Symbol, Any}(
|
||||
:customerinfo=> Dict{Symbol, Any}(),
|
||||
:storeinfo=> Dict{Symbol, Any}(),
|
||||
|
||||
:thoughtHistory=> OrderedDict{Symbol, Any}(
|
||||
:question=> "How are you?",
|
||||
)
|
||||
)
|
||||
julia> statetype = typeof(initialState)
|
||||
julia> root = YiemAgent.MCTSNode(initialState, 0, 0.0, Dict{statetype, YiemAgent.MCTSNode}())
|
||||
julia> YiemAgent.isleaf(root)
|
||||
true
|
||||
```
|
||||
|
||||
# TODO
|
||||
[] update docs
|
||||
|
||||
# Signature
|
||||
"""
|
||||
isleaf(node::MCTSNode)::Bool = isempty(node.children)
|
||||
|
||||
|
||||
""" Determine wheter a given node is a root node
|
||||
|
||||
# Arguments
|
||||
- `node::MCTSNode`
|
||||
node of a search tree
|
||||
|
||||
# Return
|
||||
- `isrootnode::Bool`
|
||||
true if the given node is root node, false otherwise
|
||||
|
||||
# Signature
|
||||
"""
|
||||
isroot(node::MCTSNode)::Bool = node.nodekey == "root" ? true : false
|
||||
|
||||
|
||||
|
||||
""" Select child node based on the highest statevalue
|
||||
|
||||
# Arguments
|
||||
- `node::MCTSNode`
|
||||
node of a search tree
|
||||
|
||||
# Return
|
||||
- `childNode::MCTSNode`
|
||||
the highest value child node
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function selectChildNode(node::MCTSNode)::MCTSNode
|
||||
highestProgressValue = -1
|
||||
nodekey = nothing
|
||||
|
||||
# loop thought node children dictionary to find the highest progress value
|
||||
for (k, childNode) in node.children
|
||||
potential = childNode.progressvalue + childNode.reward
|
||||
if potential > highestProgressValue
|
||||
highestProgressValue = potential
|
||||
nodekey = childNode.nodekey
|
||||
end
|
||||
end
|
||||
|
||||
return node.children[nodekey]
|
||||
end
|
||||
|
||||
|
||||
""" Expand selected node.
|
||||
|
||||
# Arguments
|
||||
- `node::MCTSNode`
|
||||
MCTS node
|
||||
- `transition::Function`
|
||||
A function that handles state transition.
|
||||
- `transitionargs::NamedTuple`
|
||||
Arguments for transition()
|
||||
- `totalsample::Integer`
|
||||
Total number to sample from the current node (i.e. expand new node horizontally)
|
||||
|
||||
# Return
|
||||
- None
|
||||
|
||||
# Signature
|
||||
"""
|
||||
# function expand(node::MCTSNode, transition::Function, transitionargs::NamedTuple;
|
||||
# totalsample::Integer=3)
|
||||
|
||||
# # not use Any[] because I want to preserve result order
|
||||
# results = Vector{Any}(undef, totalsample)
|
||||
|
||||
# @sync for i in 1:totalsample
|
||||
# @spawn begin
|
||||
# result = transition(deepcopy(node.state), deepcopy(transitionargs))
|
||||
# results[i] = result
|
||||
# end
|
||||
# end
|
||||
|
||||
# for result in results
|
||||
# newNodeKey::AbstractString = result[:newNodeKey]
|
||||
# newstate::AbstractDict = result[:newstate]
|
||||
# progressvalue::Integer = result[:progressvalue]
|
||||
|
||||
# """
|
||||
# [] newNodeKey ∉ keys(node.children).
|
||||
# New state may have semantic vector close enought to
|
||||
# one of existing child state. Which can be assume that they are the same state
|
||||
# semantically-wise i.e. De javu. This could be used to recall lessons for this
|
||||
# similar situation to improve decisionMaker and evaluator.
|
||||
# """
|
||||
# if newNodeKey ∉ keys(node.children)
|
||||
# node.children[newNodeKey] =
|
||||
# MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward],
|
||||
# newstate[:isterminal], node, Dict{String, MCTSNode}())
|
||||
# end
|
||||
# end
|
||||
# end
|
||||
function expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple;
|
||||
totalsample::Integer=3)
|
||||
|
||||
nthSample = 0
|
||||
while true
|
||||
nthSample += 1
|
||||
if nthSample <= totalsample
|
||||
result = transition(node.state, transitionargs)
|
||||
newNodeKey::AbstractString = result[:newNodeKey]
|
||||
newstate::AbstractDict = result[:newstate]
|
||||
progressvalue::Integer = result[:progressvalue]
|
||||
|
||||
"""
|
||||
[] newNodeKey ∉ keys(node.children).
|
||||
New state may have semantic vector close enought to
|
||||
one of existing child state. Which can be assume that they are the same state
|
||||
semantically-wise i.e. De javu. This could be used to recall lessons for this
|
||||
similar situation to improve decisionMaker and evaluator.
|
||||
"""
|
||||
if newNodeKey ∉ keys(node.children)
|
||||
node.children[newNodeKey] =
|
||||
MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward],
|
||||
newstate[:isterminal], node, Dict{String, MCTSNode}())
|
||||
end
|
||||
else
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
""" Simulate interactions between agent and environment
|
||||
|
||||
# Arguments
|
||||
- `node::MCTSNode`
|
||||
node that will be a simulation starting point.
|
||||
- `transition::Function`
|
||||
A user function that handles how state transition.
|
||||
- `transitionargs::NamedTuple`
|
||||
Arguments for everything the user will use within transition().
|
||||
- `maxdepth::Integer`
|
||||
maximum depth level MCTS goes vertically.
|
||||
- totalsample::Integer
|
||||
Total number to sample from the current node (i.e. expand new node horizontally)
|
||||
|
||||
# Return
|
||||
- `::NamedTuple{(:simTrajectoryReward, :terminalstate), Tuple{Number, Union{Dict{Symbol, Any}, Nothing}}}`
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function simulate(node::MCTSNode, transition::Function, transitionargs::NamedTuple;
|
||||
maxdepth::Integer=3, totalsample::Integer=3
|
||||
)::NamedTuple{(:simTrajectoryReward, :terminalstate), Tuple{Number, Union{Dict{Symbol, Any}, Nothing}}}
|
||||
|
||||
simTrajectoryReward = 0.0
|
||||
terminalstate = nothing
|
||||
|
||||
for depth in 1:maxdepth
|
||||
simTrajectoryReward += node.reward
|
||||
if node.isterminal
|
||||
terminalstate = node.state
|
||||
break
|
||||
else
|
||||
expand(node, transition, transitionargs;
|
||||
totalsample=totalsample)
|
||||
node = selectChildNode(node)
|
||||
end
|
||||
end
|
||||
|
||||
return (simTrajectoryReward=simTrajectoryReward, terminalstate=terminalstate)
|
||||
end
|
||||
|
||||
|
||||
"""
|
||||
|
||||
# Arguments
|
||||
|
||||
# Return
|
||||
|
||||
# Example
|
||||
```jldoctest
|
||||
julia>
|
||||
```
|
||||
|
||||
# TODO
|
||||
- [] update docstring
|
||||
- [x] implement the function
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function makeNewState(currentstate::T1, thoughtDict::T4, response::T2, select::Union{T3, Nothing},
|
||||
reward::T3, isterminal::Bool
|
||||
)::Tuple{String, Dict{Symbol, <:Any}} where {T1<:AbstractDict, T2<:AbstractString, T3<:Number, T4<:AbstractDict}
|
||||
|
||||
currentstate_latestThoughtKey, currentstate_latestThoughtIndice =
|
||||
GeneralUtils.findHighestIndexKey(currentstate[:thoughtHistory], "thought")
|
||||
currentstate_nextIndice =
|
||||
currentstate_latestThoughtKey == :NA ? 1 : currentstate_latestThoughtIndice + 1
|
||||
currentstate_latestThoughtKey = Symbol("thought_$currentstate_nextIndice")
|
||||
latestActionKey = Symbol("action_$currentstate_nextIndice")
|
||||
|
||||
_, thoughtDict_latestThoughtIndice =
|
||||
GeneralUtils.findHighestIndexKey(thoughtDict, "thought")
|
||||
|
||||
thoughtDict_latestThoughtKey, thoughtDict_latestActionKey =
|
||||
if thoughtDict_latestThoughtIndice == -1
|
||||
(:thought, :action)
|
||||
else
|
||||
(
|
||||
Symbol("thought_$thoughtDict_latestThoughtIndice"),
|
||||
Symbol("action_$thoughtDict_latestThoughtIndice"),
|
||||
)
|
||||
end
|
||||
|
||||
# add Thought, action, observation to thoughtHistory
|
||||
newstate = deepcopy(currentstate)
|
||||
newstate[:thoughtHistory][currentstate_latestThoughtKey] =
|
||||
thoughtDict[thoughtDict_latestThoughtKey]
|
||||
newstate[:thoughtHistory][latestActionKey] = thoughtDict[thoughtDict_latestActionKey]
|
||||
newObservationKey = Symbol("observation_$(currentstate_nextIndice)")
|
||||
newstate[:thoughtHistory][newObservationKey] = response
|
||||
newstate[:reward] = reward
|
||||
newstate[:select] = select
|
||||
newstate[:isterminal] = isterminal
|
||||
|
||||
newNodeKey = GeneralUtils.uuid4snakecase()
|
||||
|
||||
return (newNodeKey, newstate)
|
||||
end
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
end # module mcts
|
||||
230
src/type.jl
230
src/type.jl
@@ -1,116 +1,116 @@
|
||||
module type
|
||||
|
||||
export MCTSNode
|
||||
|
||||
# ---------------------------------------------- 100 --------------------------------------------- #
|
||||
|
||||
|
||||
""" a node for MCTS search tree
|
||||
|
||||
# Arguments
|
||||
- `state::T`
|
||||
a state of a game. Can be a Dict or something else.
|
||||
- `visits::Integer `
|
||||
number of time the game visits this state
|
||||
- `stateValue::Float64`
|
||||
state value
|
||||
- `children::Dict{T, MCTSNode}`
|
||||
children node
|
||||
|
||||
# Return
|
||||
- `nothing`
|
||||
# Example
|
||||
```jldoctest
|
||||
julia> state = Dict(
|
||||
:info=> Dict(), # keyword info
|
||||
:thoughtHistory=> Dict(
|
||||
:question=> _,
|
||||
:thought_1=> _,
|
||||
:action_1=> _,
|
||||
:observation_1=> _,
|
||||
:thought_2=> _,
|
||||
...
|
||||
)
|
||||
)
|
||||
```
|
||||
|
||||
# TODO
|
||||
[] update docstring
|
||||
|
||||
# Signature
|
||||
"""
|
||||
mutable struct MCTSNode{T1<:AbstractDict, T2<:AbstractString}
|
||||
nodekey::T2
|
||||
state::T1
|
||||
visits::Integer
|
||||
progressvalue::Number # estimate value by LLM's reasoning
|
||||
statevalue::Number # current state value. store the node's immediate reward and all future discounted rewards (gather from its child node)
|
||||
reward::Number # this node's immediate reward
|
||||
isterminal::Bool
|
||||
parent::Union{MCTSNode, Nothing}
|
||||
children::Dict{String, MCTSNode}
|
||||
end
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
module type
|
||||
|
||||
export MCTSNode
|
||||
|
||||
# ---------------------------------------------- 100 --------------------------------------------- #
|
||||
|
||||
|
||||
""" a node for MCTS search tree
|
||||
|
||||
# Arguments
|
||||
- `state::T`
|
||||
a state of a game. Can be a Dict or something else.
|
||||
- `visits::Integer `
|
||||
number of time the game visits this state
|
||||
- `stateValue::Float64`
|
||||
state value
|
||||
- `children::Dict{T, MCTSNode}`
|
||||
children node
|
||||
|
||||
# Return
|
||||
- `nothing`
|
||||
# Example
|
||||
```jldoctest
|
||||
julia> state = Dict(
|
||||
:info=> Dict(), # keyword info
|
||||
:thoughtHistory=> Dict(
|
||||
:question=> _,
|
||||
:thought_1=> _,
|
||||
:action_1=> _,
|
||||
:observation_1=> _,
|
||||
:thought_2=> _,
|
||||
...
|
||||
)
|
||||
)
|
||||
```
|
||||
|
||||
# TODO
|
||||
[] update docstring
|
||||
|
||||
# Signature
|
||||
"""
|
||||
mutable struct MCTSNode{T1<:AbstractDict, T2<:AbstractString}
|
||||
nodekey::T2
|
||||
state::T1
|
||||
visits::Integer
|
||||
progressvalue::Number # estimate value by LLM's reasoning
|
||||
statevalue::Number # current state value. store the node's immediate reward and all future discounted rewards (gather from its child node)
|
||||
reward::Number # this node's immediate reward
|
||||
isterminal::Bool
|
||||
parent::Union{MCTSNode, Nothing}
|
||||
children::Dict{String, MCTSNode}
|
||||
end
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
end # module type
|
||||
276
src/util.jl
276
src/util.jl
@@ -1,139 +1,139 @@
|
||||
module util
|
||||
|
||||
export UCTselect
|
||||
|
||||
using ..type
|
||||
|
||||
# ---------------------------------------------- 100 --------------------------------------------- #
|
||||
|
||||
""" Select a node based on UCT score
|
||||
|
||||
# Arguments
|
||||
- `node::MCTSNode`
|
||||
mcts node
|
||||
- `w::T`
|
||||
exploration weight. Value is usually between 1 to 2.
|
||||
Value 1.0 makes MCTS balance between exploration and exploitation like 50%-50%.
|
||||
Value 2.0 makes MCTS aggressively search the tree.
|
||||
# Return
|
||||
- `selectedNode::MCTSNode`
|
||||
|
||||
# Example
|
||||
```jldoctest
|
||||
julia>
|
||||
```
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function UCTselect(node::MCTSNode, w::T)::MCTSNode where {T<:AbstractFloat}
|
||||
maxUCT = -Inf
|
||||
selectedNode = nothing
|
||||
|
||||
for (childState, childNode) in node.children
|
||||
UCTvalue =
|
||||
if childNode.visits != 0
|
||||
weightedterm = w * sqrt(log(node.visits) / childNode.visits) # explore term
|
||||
childNode.statevalue + weightedterm
|
||||
else # node.visits == 0 makes sqrt() in explore term error
|
||||
childNode.progressvalue # exploit term
|
||||
end
|
||||
|
||||
if UCTvalue > maxUCT
|
||||
maxUCT = UCTvalue
|
||||
selectedNode = childNode
|
||||
end
|
||||
end
|
||||
|
||||
return selectedNode
|
||||
end
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
module util
|
||||
|
||||
export UCTselect
|
||||
|
||||
using ..type
|
||||
|
||||
# ---------------------------------------------- 100 --------------------------------------------- #
|
||||
|
||||
""" Select a node based on UCT score
|
||||
|
||||
# Arguments
|
||||
- `node::MCTSNode`
|
||||
mcts node
|
||||
- `w::T`
|
||||
exploration weight. Value is usually between 1 to 2.
|
||||
Value 1.0 makes MCTS balance between exploration and exploitation like 50%-50%.
|
||||
Value 2.0 makes MCTS aggressively search the tree.
|
||||
# Return
|
||||
- `selectedNode::MCTSNode`
|
||||
|
||||
# Example
|
||||
```jldoctest
|
||||
julia>
|
||||
```
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function UCTselect(node::MCTSNode, w::T)::MCTSNode where {T<:AbstractFloat}
|
||||
maxUCT = -Inf
|
||||
selectedNode = nothing
|
||||
|
||||
for (childState, childNode) in node.children
|
||||
UCTvalue =
|
||||
if childNode.visits != 0
|
||||
weightedterm = w * sqrt(log(node.visits) / childNode.visits) # explore term
|
||||
childNode.statevalue + weightedterm
|
||||
else # node.visits == 0 makes sqrt() in explore term error
|
||||
childNode.progressvalue # exploit term
|
||||
end
|
||||
|
||||
if UCTvalue > maxUCT
|
||||
maxUCT = UCTvalue
|
||||
selectedNode = childNode
|
||||
end
|
||||
end
|
||||
|
||||
return selectedNode
|
||||
end
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
end # module util
|
||||
Reference in New Issue
Block a user