This commit is contained in:
narawat lamaiin
2024-12-09 20:28:02 +07:00
parent 3338085567
commit cae94e5690
15 changed files with 2942 additions and 2942 deletions

View File

@@ -1,38 +1,38 @@
module LLMMCTS
# export agent
""" Order by dependencies of each file. The 1st included file must not depend on any other
files and each file can only depend on the file included before it.
"""
include("type.jl")
using .type
include("util.jl")
using .util
include("mcts.jl")
using .mcts
include("interface.jl")
using .interface
# ---------------------------------------------- 100 --------------------------------------------- #
""" version 0.0.2
Todo:
- []
Change from version: 0.0.1
-
All features
"""
end # module LLMMCTS
module LLMMCTS
# export agent
""" Order by dependencies of each file. The 1st included file must not depend on any other
files and each file can only depend on the file included before it.
"""
include("type.jl")
using .type
include("util.jl")
using .util
include("mcts.jl")
using .mcts
include("interface.jl")
using .interface
# ---------------------------------------------- 100 --------------------------------------------- #
""" version 0.0.2
Todo:
- []
Change from version: 0.0.1
-
All features
"""
end # module LLMMCTS

View File

@@ -1,228 +1,228 @@
module interface
export runMCTS
using ..type, ..mcts, ..util
# ---------------------------------------------- 100 --------------------------------------------- #
""" Search the best action to take for a given state and task
# Arguments
- `initialstate::T`
initial state
- `transition::Function`
a function that define how the state transitions
- `transitionargs::NamedTuple`
arguments for transition function
# Keyword Arguments
- `totalsample::Integer`
a number of child state MCTS sample at each node during expansion phase
- `maxdepth::Integer`
a number of levels MCTS goes during simulation phase
- `maxiterations::Integer`
a number of iteration MCTS goes thru expansion -> simulation -> backpropagation cycle
- `explorationweight::Number`
exploration weight controls how much MCTS should explore new state instead of exploit
a known state. 1.0 balance between exploration and exploitation like 50%-50%. 2.0 makes MCTS
aggressively explore new state.
# Return
- `NamedTuple{(:bestNextState, :bestFinalState), Tuple{T, T}}`
the best next state and the best final state
# Example
Refers to SQLLLM package
# Signature
"""
function runMCTS(
initialstate::T,
transition::Function,
transitionargs::NamedTuple,
;
totalsample::Integer=3,
maxdepth::Integer=3,
maxiterations::Integer=10,
explorationweight::Number=1.0,
earlystop::Union{Function,Nothing}=nothing
)::NamedTuple{(:bestNextState, :bestFinalState),Tuple{T,T}} where {T<:Any}
root = MCTSNode("root", initialstate, 0, 0, 0, 0, false, nothing, Dict{String,MCTSNode}())
for nth in 1:maxiterations
node = root
node.visits += 1
while !isleaf(node)
node = UCTselect(node, explorationweight)
end
if node.isterminal
# MCTS arrive at the leaf node that is also a terminal state,
# do nothing then go directly to backpropagation. It means the end of this iteration
backpropagate(node, node.reward)
else
expand(node, transition, transitionargs;
totalsample=totalsample)
leafNode = selectChildNode(node)
simTrajectoryReward, terminalstate = simulate(leafNode, transition, transitionargs;
maxdepth=maxdepth, totalsample=totalsample)
# if terminalstate !== nothing #XXX not sure why I need this
# terminalstate[:totalTrajectoryReward] = simTrajectoryReward
# end
#[] write best state to file if it has higher simTrajectoryReward. Use to improve evaluation
# open("trajectory.json", "w") do io
# JSON3.pretty(io, terminalstate)
# end
backpropagate(leafNode, simTrajectoryReward)
end
# stop if the early stop condition is met
if typeof(earlystop) <: Function && earlystop(node.state)
break
end
end
bestNextState = selectBestNextNode(root)
besttrajectory = selectBestTrajectoryNode(root)
return (bestNextState=bestNextState.state, bestFinalState=besttrajectory.state)
end
# function runMCTS(
# initialstate::T,
# transition::Function,
# transitionargs::NamedTuple,
# ;
# totalsample::Integer=3,
# maxdepth::Integer=3,
# maxiterations::Integer=10,
# explorationweight::Number=1.0,
# )::NamedTuple{(:bestNextState, :bestFinalState),Tuple{T,T}} where {T<:Any}
# root = MCTSNode("root", initialstate, 0, 0, 0, 0, false, nothing, Dict{String,MCTSNode}())
# for nth in 1:maxiterations
# node = root
# node.visits += 1
# while !isleaf(node)
# node = UCTselect(node, explorationweight)
# end
# if node.isterminal
# # MCTS arrive at the leaf node that is also a terminal state,
# # do nothing then go directly to backpropagation. It means the end of this iteration
# backpropagate(leafNode, node.reward)
# else
# expand(node, transition, transitionargs;
# totalsample=totalsample)
# leafNode = selectChildNode(node)
# simTrajectoryReward, terminalstate = simulate(leafNode, transition, transitionargs;
# maxdepth=maxdepth, totalsample=totalsample)
# # if terminalstate !== nothing #XXX not sure why I need this
# # terminalstate[:totalTrajectoryReward] = simTrajectoryReward
# # end
# #[] write best state to file if it has higher simTrajectoryReward. Use to improve evaluation
# # open("trajectory.json", "w") do io
# # JSON3.pretty(io, terminalstate)
# # end
# backpropagate(leafNode, simTrajectoryReward)
# end
# end
# bestNextState = selectBestNextNode(root)
# besttrajectory = selectBestTrajectoryNode(root)
# return (bestNextState=bestNextState.state, bestFinalState=besttrajectory.state)
# end
module interface
export runMCTS
using ..type, ..mcts, ..util
# ---------------------------------------------- 100 --------------------------------------------- #
""" Search the best action to take for a given state and task
# Arguments
- `initialstate::T`
initial state
- `transition::Function`
a function that define how the state transitions
- `transitionargs::NamedTuple`
arguments for transition function
# Keyword Arguments
- `totalsample::Integer`
a number of child state MCTS sample at each node during expansion phase
- `maxdepth::Integer`
a number of levels MCTS goes during simulation phase
- `maxiterations::Integer`
a number of iteration MCTS goes thru expansion -> simulation -> backpropagation cycle
- `explorationweight::Number`
exploration weight controls how much MCTS should explore new state instead of exploit
a known state. 1.0 balance between exploration and exploitation like 50%-50%. 2.0 makes MCTS
aggressively explore new state.
# Return
- `NamedTuple{(:bestNextState, :bestFinalState), Tuple{T, T}}`
the best next state and the best final state
# Example
Refers to SQLLLM package
# Signature
"""
function runMCTS(
initialstate::T,
transition::Function,
transitionargs::NamedTuple,
;
totalsample::Integer=3,
maxdepth::Integer=3,
maxiterations::Integer=10,
explorationweight::Number=1.0,
earlystop::Union{Function,Nothing}=nothing
)::NamedTuple{(:bestNextState, :bestFinalState),Tuple{T,T}} where {T<:Any}
root = MCTSNode("root", initialstate, 0, 0, 0, 0, false, nothing, Dict{String,MCTSNode}())
for nth in 1:maxiterations
node = root
node.visits += 1
while !isleaf(node)
node = UCTselect(node, explorationweight)
end
if node.isterminal
# MCTS arrive at the leaf node that is also a terminal state,
# do nothing then go directly to backpropagation. It means the end of this iteration
backpropagate(node, node.reward)
else
expand(node, transition, transitionargs;
totalsample=totalsample)
leafNode = selectChildNode(node)
simTrajectoryReward, terminalstate = simulate(leafNode, transition, transitionargs;
maxdepth=maxdepth, totalsample=totalsample)
# if terminalstate !== nothing #XXX not sure why I need this
# terminalstate[:totalTrajectoryReward] = simTrajectoryReward
# end
#[] write best state to file if it has higher simTrajectoryReward. Use to improve evaluation
# open("trajectory.json", "w") do io
# JSON3.pretty(io, terminalstate)
# end
backpropagate(leafNode, simTrajectoryReward)
end
# stop if the early stop condition is met
if typeof(earlystop) <: Function && earlystop(node.state)
break
end
end
bestNextState = selectBestNextNode(root)
besttrajectory = selectBestTrajectoryNode(root)
return (bestNextState=bestNextState.state, bestFinalState=besttrajectory.state)
end
# function runMCTS(
# initialstate::T,
# transition::Function,
# transitionargs::NamedTuple,
# ;
# totalsample::Integer=3,
# maxdepth::Integer=3,
# maxiterations::Integer=10,
# explorationweight::Number=1.0,
# )::NamedTuple{(:bestNextState, :bestFinalState),Tuple{T,T}} where {T<:Any}
# root = MCTSNode("root", initialstate, 0, 0, 0, 0, false, nothing, Dict{String,MCTSNode}())
# for nth in 1:maxiterations
# node = root
# node.visits += 1
# while !isleaf(node)
# node = UCTselect(node, explorationweight)
# end
# if node.isterminal
# # MCTS arrive at the leaf node that is also a terminal state,
# # do nothing then go directly to backpropagation. It means the end of this iteration
# backpropagate(leafNode, node.reward)
# else
# expand(node, transition, transitionargs;
# totalsample=totalsample)
# leafNode = selectChildNode(node)
# simTrajectoryReward, terminalstate = simulate(leafNode, transition, transitionargs;
# maxdepth=maxdepth, totalsample=totalsample)
# # if terminalstate !== nothing #XXX not sure why I need this
# # terminalstate[:totalTrajectoryReward] = simTrajectoryReward
# # end
# #[] write best state to file if it has higher simTrajectoryReward. Use to improve evaluation
# # open("trajectory.json", "w") do io
# # JSON3.pretty(io, terminalstate)
# # end
# backpropagate(leafNode, simTrajectoryReward)
# end
# end
# bestNextState = selectBestNextNode(root)
# besttrajectory = selectBestTrajectoryNode(root)
# return (bestNextState=bestNextState.state, bestFinalState=besttrajectory.state)
# end
end # module interface

View File

@@ -1,429 +1,429 @@
module mcts
export selectBestNextNode, selectBestTrajectoryNode, backpropagate, isleaf, isroot, selectChildNode,
expand, simulate, makeNewState
using Base.Threads
using GeneralUtils
using ..type
# ---------------------------------------------- 100 --------------------------------------------- #
"""
# Arguments
- `node::MCTSNode`
node of a search tree
# Return
- `childNode::MCTSNode`
the highest value child node
# Signature
"""
function selectBestNextNode(node::MCTSNode)::MCTSNode
highestProgressValue = -1
nodekey = nothing
# if all childnode has statevalue == 0, use progressvalue + reward to select the best node
stateValueSum = sum([v.statevalue for (k, v) in node.children])
if stateValueSum != 0
for (k, childnode) in node.children
potential = childnode.statevalue / childnode.visits
if potential > highestProgressValue
highestProgressValue = potential
nodekey = childnode.nodekey
end
end
else
for (k, childnode) in node.children
potential = childnode.progressvalue + childnode.reward
if potential > highestProgressValue
highestProgressValue = potential
nodekey = childnode.nodekey
end
end
end
return node.children[nodekey]
end
"""
# Arguments
- `node::MCTSNode`
node of a search tree
# Return
- `childNode::MCTSNode`
the highest value child node
# Signature
"""
function selectBestTrajectoryNode(node::MCTSNode)::MCTSNode
while !isleaf(node)
node = selectBestNextNode(node)
end
return node
end
""" Backpropagate reward along the simulation chain
# Arguments
- `node::MCTSNode`
leaf node of a search tree
- `simTrajectoryReward::T`
total reward from trajectory simulation
- `discountRewardCoeff::AbstractFloat`
A discount reward coefficient to reduce future reward. The futher in the future the lower
reward it is now.
# Return
- `None`
# Signature
"""
function backpropagate(node::MCTSNode, simTrajectoryReward::T;
discountRewardCoeff::AbstractFloat=0.9) where {T<:Number}
while !isroot(node)
# Update the statistics of the current node based on the result of the playout
node.visits += 1
node.statevalue += ((node.statevalue * (node.visits-1)) + simTrajectoryReward) / node.visits
simTrajectoryReward *= discountRewardCoeff # discount because future reward is uncertain
node = node.parent
end
end
""" Determine whether a node is a leaf node of a search tree.
# Arguments
- `node::MCTSNode`
a search tree node
# Return
- `result::Bool`
true if it is a leaf node, false otherwise.
# Example
```jldoctest
julia> using Revise
julia> using YiemAgent, DataStructures
julia> initialState = Dict{Symbol, Any}(
:customerinfo=> Dict{Symbol, Any}(),
:storeinfo=> Dict{Symbol, Any}(),
:thoughtHistory=> OrderedDict{Symbol, Any}(
:question=> "How are you?",
)
)
julia> statetype = typeof(initialState)
julia> root = YiemAgent.MCTSNode(initialState, 0, 0.0, Dict{statetype, YiemAgent.MCTSNode}())
julia> YiemAgent.isleaf(root)
true
```
# TODO
[] update docs
# Signature
"""
isleaf(node::MCTSNode)::Bool = isempty(node.children)
""" Determine wheter a given node is a root node
# Arguments
- `node::MCTSNode`
node of a search tree
# Return
- `isrootnode::Bool`
true if the given node is root node, false otherwise
# Signature
"""
isroot(node::MCTSNode)::Bool = node.nodekey == "root" ? true : false
""" Select child node based on the highest statevalue
# Arguments
- `node::MCTSNode`
node of a search tree
# Return
- `childNode::MCTSNode`
the highest value child node
# Signature
"""
function selectChildNode(node::MCTSNode)::MCTSNode
highestProgressValue = -1
nodekey = nothing
# loop thought node children dictionary to find the highest progress value
for (k, childNode) in node.children
potential = childNode.progressvalue + childNode.reward
if potential > highestProgressValue
highestProgressValue = potential
nodekey = childNode.nodekey
end
end
return node.children[nodekey]
end
""" Expand selected node.
# Arguments
- `node::MCTSNode`
MCTS node
- `transition::Function`
A function that handles state transition.
- `transitionargs::NamedTuple`
Arguments for transition()
- `totalsample::Integer`
Total number to sample from the current node (i.e. expand new node horizontally)
# Return
- None
# Signature
"""
# function expand(node::MCTSNode, transition::Function, transitionargs::NamedTuple;
# totalsample::Integer=3)
# # not use Any[] because I want to preserve result order
# results = Vector{Any}(undef, totalsample)
# @sync for i in 1:totalsample
# @spawn begin
# result = transition(deepcopy(node.state), deepcopy(transitionargs))
# results[i] = result
# end
# end
# for result in results
# newNodeKey::AbstractString = result[:newNodeKey]
# newstate::AbstractDict = result[:newstate]
# progressvalue::Integer = result[:progressvalue]
# """
# [] newNodeKey ∉ keys(node.children).
# New state may have semantic vector close enought to
# one of existing child state. Which can be assume that they are the same state
# semantically-wise i.e. De javu. This could be used to recall lessons for this
# similar situation to improve decisionMaker and evaluator.
# """
# if newNodeKey ∉ keys(node.children)
# node.children[newNodeKey] =
# MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward],
# newstate[:isterminal], node, Dict{String, MCTSNode}())
# end
# end
# end
function expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple;
totalsample::Integer=3)
nthSample = 0
while true
nthSample += 1
if nthSample <= totalsample
result = transition(node.state, transitionargs)
newNodeKey::AbstractString = result[:newNodeKey]
newstate::AbstractDict = result[:newstate]
progressvalue::Integer = result[:progressvalue]
"""
[] newNodeKey ∉ keys(node.children).
New state may have semantic vector close enought to
one of existing child state. Which can be assume that they are the same state
semantically-wise i.e. De javu. This could be used to recall lessons for this
similar situation to improve decisionMaker and evaluator.
"""
if newNodeKey keys(node.children)
node.children[newNodeKey] =
MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward],
newstate[:isterminal], node, Dict{String, MCTSNode}())
end
else
break
end
end
end
""" Simulate interactions between agent and environment
# Arguments
- `node::MCTSNode`
node that will be a simulation starting point.
- `transition::Function`
A user function that handles how state transition.
- `transitionargs::NamedTuple`
Arguments for everything the user will use within transition().
- `maxdepth::Integer`
maximum depth level MCTS goes vertically.
- totalsample::Integer
Total number to sample from the current node (i.e. expand new node horizontally)
# Return
- `::NamedTuple{(:simTrajectoryReward, :terminalstate), Tuple{Number, Union{Dict{Symbol, Any}, Nothing}}}`
# Signature
"""
function simulate(node::MCTSNode, transition::Function, transitionargs::NamedTuple;
maxdepth::Integer=3, totalsample::Integer=3
)::NamedTuple{(:simTrajectoryReward, :terminalstate), Tuple{Number, Union{Dict{Symbol, Any}, Nothing}}}
simTrajectoryReward = 0.0
terminalstate = nothing
for depth in 1:maxdepth
simTrajectoryReward += node.reward
if node.isterminal
terminalstate = node.state
break
else
expand(node, transition, transitionargs;
totalsample=totalsample)
node = selectChildNode(node)
end
end
return (simTrajectoryReward=simTrajectoryReward, terminalstate=terminalstate)
end
"""
# Arguments
# Return
# Example
```jldoctest
julia>
```
# TODO
- [] update docstring
- [x] implement the function
# Signature
"""
function makeNewState(currentstate::T1, thoughtDict::T4, response::T2, select::Union{T3, Nothing},
reward::T3, isterminal::Bool
)::Tuple{String, Dict{Symbol, <:Any}} where {T1<:AbstractDict, T2<:AbstractString, T3<:Number, T4<:AbstractDict}
currentstate_latestThoughtKey, currentstate_latestThoughtIndice =
GeneralUtils.findHighestIndexKey(currentstate[:thoughtHistory], "thought")
currentstate_nextIndice =
currentstate_latestThoughtKey == :NA ? 1 : currentstate_latestThoughtIndice + 1
currentstate_latestThoughtKey = Symbol("thought_$currentstate_nextIndice")
latestActionKey = Symbol("action_$currentstate_nextIndice")
_, thoughtDict_latestThoughtIndice =
GeneralUtils.findHighestIndexKey(thoughtDict, "thought")
thoughtDict_latestThoughtKey, thoughtDict_latestActionKey =
if thoughtDict_latestThoughtIndice == -1
(:thought, :action)
else
(
Symbol("thought_$thoughtDict_latestThoughtIndice"),
Symbol("action_$thoughtDict_latestThoughtIndice"),
)
end
# add Thought, action, observation to thoughtHistory
newstate = deepcopy(currentstate)
newstate[:thoughtHistory][currentstate_latestThoughtKey] =
thoughtDict[thoughtDict_latestThoughtKey]
newstate[:thoughtHistory][latestActionKey] = thoughtDict[thoughtDict_latestActionKey]
newObservationKey = Symbol("observation_$(currentstate_nextIndice)")
newstate[:thoughtHistory][newObservationKey] = response
newstate[:reward] = reward
newstate[:select] = select
newstate[:isterminal] = isterminal
newNodeKey = GeneralUtils.uuid4snakecase()
return (newNodeKey, newstate)
end
module mcts
export selectBestNextNode, selectBestTrajectoryNode, backpropagate, isleaf, isroot, selectChildNode,
expand, simulate, makeNewState
using Base.Threads
using GeneralUtils
using ..type
# ---------------------------------------------- 100 --------------------------------------------- #
"""
# Arguments
- `node::MCTSNode`
node of a search tree
# Return
- `childNode::MCTSNode`
the highest value child node
# Signature
"""
function selectBestNextNode(node::MCTSNode)::MCTSNode
highestProgressValue = -1
nodekey = nothing
# if all childnode has statevalue == 0, use progressvalue + reward to select the best node
stateValueSum = sum([v.statevalue for (k, v) in node.children])
if stateValueSum != 0
for (k, childnode) in node.children
potential = childnode.statevalue / childnode.visits
if potential > highestProgressValue
highestProgressValue = potential
nodekey = childnode.nodekey
end
end
else
for (k, childnode) in node.children
potential = childnode.progressvalue + childnode.reward
if potential > highestProgressValue
highestProgressValue = potential
nodekey = childnode.nodekey
end
end
end
return node.children[nodekey]
end
"""
# Arguments
- `node::MCTSNode`
node of a search tree
# Return
- `childNode::MCTSNode`
the highest value child node
# Signature
"""
function selectBestTrajectoryNode(node::MCTSNode)::MCTSNode
while !isleaf(node)
node = selectBestNextNode(node)
end
return node
end
""" Backpropagate reward along the simulation chain
# Arguments
- `node::MCTSNode`
leaf node of a search tree
- `simTrajectoryReward::T`
total reward from trajectory simulation
- `discountRewardCoeff::AbstractFloat`
A discount reward coefficient to reduce future reward. The futher in the future the lower
reward it is now.
# Return
- `None`
# Signature
"""
function backpropagate(node::MCTSNode, simTrajectoryReward::T;
discountRewardCoeff::AbstractFloat=0.9) where {T<:Number}
while !isroot(node)
# Update the statistics of the current node based on the result of the playout
node.visits += 1
node.statevalue += ((node.statevalue * (node.visits-1)) + simTrajectoryReward) / node.visits
simTrajectoryReward *= discountRewardCoeff # discount because future reward is uncertain
node = node.parent
end
end
""" Determine whether a node is a leaf node of a search tree.
# Arguments
- `node::MCTSNode`
a search tree node
# Return
- `result::Bool`
true if it is a leaf node, false otherwise.
# Example
```jldoctest
julia> using Revise
julia> using YiemAgent, DataStructures
julia> initialState = Dict{Symbol, Any}(
:customerinfo=> Dict{Symbol, Any}(),
:storeinfo=> Dict{Symbol, Any}(),
:thoughtHistory=> OrderedDict{Symbol, Any}(
:question=> "How are you?",
)
)
julia> statetype = typeof(initialState)
julia> root = YiemAgent.MCTSNode(initialState, 0, 0.0, Dict{statetype, YiemAgent.MCTSNode}())
julia> YiemAgent.isleaf(root)
true
```
# TODO
[] update docs
# Signature
"""
isleaf(node::MCTSNode)::Bool = isempty(node.children)
""" Determine wheter a given node is a root node
# Arguments
- `node::MCTSNode`
node of a search tree
# Return
- `isrootnode::Bool`
true if the given node is root node, false otherwise
# Signature
"""
isroot(node::MCTSNode)::Bool = node.nodekey == "root" ? true : false
""" Select child node based on the highest statevalue
# Arguments
- `node::MCTSNode`
node of a search tree
# Return
- `childNode::MCTSNode`
the highest value child node
# Signature
"""
function selectChildNode(node::MCTSNode)::MCTSNode
highestProgressValue = -1
nodekey = nothing
# loop thought node children dictionary to find the highest progress value
for (k, childNode) in node.children
potential = childNode.progressvalue + childNode.reward
if potential > highestProgressValue
highestProgressValue = potential
nodekey = childNode.nodekey
end
end
return node.children[nodekey]
end
""" Expand selected node.
# Arguments
- `node::MCTSNode`
MCTS node
- `transition::Function`
A function that handles state transition.
- `transitionargs::NamedTuple`
Arguments for transition()
- `totalsample::Integer`
Total number to sample from the current node (i.e. expand new node horizontally)
# Return
- None
# Signature
"""
# function expand(node::MCTSNode, transition::Function, transitionargs::NamedTuple;
# totalsample::Integer=3)
# # not use Any[] because I want to preserve result order
# results = Vector{Any}(undef, totalsample)
# @sync for i in 1:totalsample
# @spawn begin
# result = transition(deepcopy(node.state), deepcopy(transitionargs))
# results[i] = result
# end
# end
# for result in results
# newNodeKey::AbstractString = result[:newNodeKey]
# newstate::AbstractDict = result[:newstate]
# progressvalue::Integer = result[:progressvalue]
# """
# [] newNodeKey ∉ keys(node.children).
# New state may have semantic vector close enought to
# one of existing child state. Which can be assume that they are the same state
# semantically-wise i.e. De javu. This could be used to recall lessons for this
# similar situation to improve decisionMaker and evaluator.
# """
# if newNodeKey ∉ keys(node.children)
# node.children[newNodeKey] =
# MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward],
# newstate[:isterminal], node, Dict{String, MCTSNode}())
# end
# end
# end
function expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple;
totalsample::Integer=3)
nthSample = 0
while true
nthSample += 1
if nthSample <= totalsample
result = transition(node.state, transitionargs)
newNodeKey::AbstractString = result[:newNodeKey]
newstate::AbstractDict = result[:newstate]
progressvalue::Integer = result[:progressvalue]
"""
[] newNodeKey ∉ keys(node.children).
New state may have semantic vector close enought to
one of existing child state. Which can be assume that they are the same state
semantically-wise i.e. De javu. This could be used to recall lessons for this
similar situation to improve decisionMaker and evaluator.
"""
if newNodeKey keys(node.children)
node.children[newNodeKey] =
MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward],
newstate[:isterminal], node, Dict{String, MCTSNode}())
end
else
break
end
end
end
""" Simulate interactions between agent and environment
# Arguments
- `node::MCTSNode`
node that will be a simulation starting point.
- `transition::Function`
A user function that handles how state transition.
- `transitionargs::NamedTuple`
Arguments for everything the user will use within transition().
- `maxdepth::Integer`
maximum depth level MCTS goes vertically.
- totalsample::Integer
Total number to sample from the current node (i.e. expand new node horizontally)
# Return
- `::NamedTuple{(:simTrajectoryReward, :terminalstate), Tuple{Number, Union{Dict{Symbol, Any}, Nothing}}}`
# Signature
"""
function simulate(node::MCTSNode, transition::Function, transitionargs::NamedTuple;
maxdepth::Integer=3, totalsample::Integer=3
)::NamedTuple{(:simTrajectoryReward, :terminalstate), Tuple{Number, Union{Dict{Symbol, Any}, Nothing}}}
simTrajectoryReward = 0.0
terminalstate = nothing
for depth in 1:maxdepth
simTrajectoryReward += node.reward
if node.isterminal
terminalstate = node.state
break
else
expand(node, transition, transitionargs;
totalsample=totalsample)
node = selectChildNode(node)
end
end
return (simTrajectoryReward=simTrajectoryReward, terminalstate=terminalstate)
end
"""
# Arguments
# Return
# Example
```jldoctest
julia>
```
# TODO
- [] update docstring
- [x] implement the function
# Signature
"""
function makeNewState(currentstate::T1, thoughtDict::T4, response::T2, select::Union{T3, Nothing},
reward::T3, isterminal::Bool
)::Tuple{String, Dict{Symbol, <:Any}} where {T1<:AbstractDict, T2<:AbstractString, T3<:Number, T4<:AbstractDict}
currentstate_latestThoughtKey, currentstate_latestThoughtIndice =
GeneralUtils.findHighestIndexKey(currentstate[:thoughtHistory], "thought")
currentstate_nextIndice =
currentstate_latestThoughtKey == :NA ? 1 : currentstate_latestThoughtIndice + 1
currentstate_latestThoughtKey = Symbol("thought_$currentstate_nextIndice")
latestActionKey = Symbol("action_$currentstate_nextIndice")
_, thoughtDict_latestThoughtIndice =
GeneralUtils.findHighestIndexKey(thoughtDict, "thought")
thoughtDict_latestThoughtKey, thoughtDict_latestActionKey =
if thoughtDict_latestThoughtIndice == -1
(:thought, :action)
else
(
Symbol("thought_$thoughtDict_latestThoughtIndice"),
Symbol("action_$thoughtDict_latestThoughtIndice"),
)
end
# add Thought, action, observation to thoughtHistory
newstate = deepcopy(currentstate)
newstate[:thoughtHistory][currentstate_latestThoughtKey] =
thoughtDict[thoughtDict_latestThoughtKey]
newstate[:thoughtHistory][latestActionKey] = thoughtDict[thoughtDict_latestActionKey]
newObservationKey = Symbol("observation_$(currentstate_nextIndice)")
newstate[:thoughtHistory][newObservationKey] = response
newstate[:reward] = reward
newstate[:select] = select
newstate[:isterminal] = isterminal
newNodeKey = GeneralUtils.uuid4snakecase()
return (newNodeKey, newstate)
end
end # module mcts

View File

@@ -1,116 +1,116 @@
module type
export MCTSNode
# ---------------------------------------------- 100 --------------------------------------------- #
""" a node for MCTS search tree
# Arguments
- `state::T`
a state of a game. Can be a Dict or something else.
- `visits::Integer `
number of time the game visits this state
- `stateValue::Float64`
state value
- `children::Dict{T, MCTSNode}`
children node
# Return
- `nothing`
# Example
```jldoctest
julia> state = Dict(
:info=> Dict(), # keyword info
:thoughtHistory=> Dict(
:question=> _,
:thought_1=> _,
:action_1=> _,
:observation_1=> _,
:thought_2=> _,
...
)
)
```
# TODO
[] update docstring
# Signature
"""
mutable struct MCTSNode{T1<:AbstractDict, T2<:AbstractString}
nodekey::T2
state::T1
visits::Integer
progressvalue::Number # estimate value by LLM's reasoning
statevalue::Number # current state value. store the node's immediate reward and all future discounted rewards (gather from its child node)
reward::Number # this node's immediate reward
isterminal::Bool
parent::Union{MCTSNode, Nothing}
children::Dict{String, MCTSNode}
end
module type
export MCTSNode
# ---------------------------------------------- 100 --------------------------------------------- #
""" a node for MCTS search tree
# Arguments
- `state::T`
a state of a game. Can be a Dict or something else.
- `visits::Integer `
number of time the game visits this state
- `stateValue::Float64`
state value
- `children::Dict{T, MCTSNode}`
children node
# Return
- `nothing`
# Example
```jldoctest
julia> state = Dict(
:info=> Dict(), # keyword info
:thoughtHistory=> Dict(
:question=> _,
:thought_1=> _,
:action_1=> _,
:observation_1=> _,
:thought_2=> _,
...
)
)
```
# TODO
[] update docstring
# Signature
"""
mutable struct MCTSNode{T1<:AbstractDict, T2<:AbstractString}
nodekey::T2
state::T1
visits::Integer
progressvalue::Number # estimate value by LLM's reasoning
statevalue::Number # current state value. store the node's immediate reward and all future discounted rewards (gather from its child node)
reward::Number # this node's immediate reward
isterminal::Bool
parent::Union{MCTSNode, Nothing}
children::Dict{String, MCTSNode}
end
end # module type

View File

@@ -1,139 +1,139 @@
module util
export UCTselect
using ..type
# ---------------------------------------------- 100 --------------------------------------------- #
""" Select a node based on UCT score
# Arguments
- `node::MCTSNode`
mcts node
- `w::T`
exploration weight. Value is usually between 1 to 2.
Value 1.0 makes MCTS balance between exploration and exploitation like 50%-50%.
Value 2.0 makes MCTS aggressively search the tree.
# Return
- `selectedNode::MCTSNode`
# Example
```jldoctest
julia>
```
# Signature
"""
function UCTselect(node::MCTSNode, w::T)::MCTSNode where {T<:AbstractFloat}
maxUCT = -Inf
selectedNode = nothing
for (childState, childNode) in node.children
UCTvalue =
if childNode.visits != 0
weightedterm = w * sqrt(log(node.visits) / childNode.visits) # explore term
childNode.statevalue + weightedterm
else # node.visits == 0 makes sqrt() in explore term error
childNode.progressvalue # exploit term
end
if UCTvalue > maxUCT
maxUCT = UCTvalue
selectedNode = childNode
end
end
return selectedNode
end
module util
export UCTselect
using ..type
# ---------------------------------------------- 100 --------------------------------------------- #
""" Select a node based on UCT score
# Arguments
- `node::MCTSNode`
mcts node
- `w::T`
exploration weight. Value is usually between 1 to 2.
Value 1.0 makes MCTS balance between exploration and exploitation like 50%-50%.
Value 2.0 makes MCTS aggressively search the tree.
# Return
- `selectedNode::MCTSNode`
# Example
```jldoctest
julia>
```
# Signature
"""
function UCTselect(node::MCTSNode, w::T)::MCTSNode where {T<:AbstractFloat}
maxUCT = -Inf
selectedNode = nothing
for (childState, childNode) in node.children
UCTvalue =
if childNode.visits != 0
weightedterm = w * sqrt(log(node.visits) / childNode.visits) # explore term
childNode.statevalue + weightedterm
else # node.visits == 0 makes sqrt() in explore term error
childNode.progressvalue # exploit term
end
if UCTvalue > maxUCT
maxUCT = UCTvalue
selectedNode = childNode
end
end
return selectedNode
end
end # module util