This commit is contained in:
narawat lamaiin
2024-05-31 11:47:51 +07:00
parent 3613f1d2bd
commit 452262d3d6
6 changed files with 891 additions and 1 deletions

View File

@@ -1,5 +1,28 @@
module LLMMCTS
greet() = print("Hello World!")
# export agent
""" Order by dependencies of each file. The 1st included file must not depend on any other
files and each file can only depend on the file included before it.
"""
include("type.jl")
using .type
include("util.jl")
using .util
include("mcts.jl")
using .mcts
include("interface.jl")
using .interface
# ---------------------------------------------- 100 --------------------------------------------- #
end # module LLMMCTS

179
src/interface.jl Normal file
View File

@@ -0,0 +1,179 @@
module interface
export runMCTS
using ..type, ..mcts
# ---------------------------------------------- 100 --------------------------------------------- #
""" Search the best action to take for a given state and task
# Arguments
- `a::agent`
one of Yiem's agents
- `initial state`
initial state
- `decisionMaker::Function`
decide what action to take
- `evaluator::Function`
assess the value of the state
- `reflector::Function`
generate lesson from trajectory and reward
- `isterminal::Function`
determine whether a given state is a terminal state
- `n::Integer`
how many times action will be sampled from decisionMaker
- `w::Float64`
exploration weight. Value is usually between 1 to 2.
Value 1.0 makes MCTS balance between exploration and exploitation like 50%-50%
Value 2.0 makes MCTS aggressively search the tree
# Return
- `plan::Vector{Dict}`
best plan
# Example
```jldoctest
julia>
```
# TODO
[] update docstring
[] return best action
# Signature
"""
function runMCTS(
workDict::Dict{Symbol, Any},
initialState,
decisionMaker::Function,
evaluator::Function,
reflector::Function,
transition::Function,
;
totalsample::Integer=3,
maxDepth::Integer=3,
maxiterations::Integer=10,
explorationweight::Number=1.0,
)
root = MCTSNode("root", initialState, 0, 0, 0, 0, false, nothing, Dict{String, MCTSNode}())
for nth in 1:maxiterations
node = root
node.visits += 1
while !isleaf(node)
node = UCTselect(node, explorationweight)
end
if node.isterminal
# MCTS arrive at the leaf node that is also a terminal state,
# do nothing then go directly to backpropagation
backpropagate(leafNode, node.reward)
else
expand(workDict, node, decisionMaker, evaluator, reflector; totalsample=totalsample)
leafNode = selectChildNode(node)
simTrajectoryReward, terminalstate = simulate(workDict, leafNode, decisionMaker, evaluator,
reflector; maxDepth=maxDepth, totalsample=totalsample)
if terminalstate !== nothing #XXX not sure why I need this
terminalstate[:totalTrajectoryReward] = simTrajectoryReward
end
#[] write best state to file if it has higher simTrajectoryReward. Use to improve evaluation
# open("trajectory.json", "w") do io
# JSON3.pretty(io, terminalstate)
# end
backpropagate(leafNode, simTrajectoryReward)
end
end
bestNextState = selectBestNextState(root)
besttrajectory = selectBestTrajectory(root)
return (bestNextState.state, besttrajectory.state)
end
end # module interface

432
src/mcts.jl Normal file
View File

@@ -0,0 +1,432 @@
module mcts
export selectBestNextState, selectBestTrajectory, backpropagate, isleaf, isroot, selectChildNode,
expand, mctstransition
using ..type
# ---------------------------------------------- 100 --------------------------------------------- #
"""
# Arguments
- `node::MCTSNode`
node of a search tree
# Return
- `childNode::MCTSNode`
the highest value child node
# Example
```jldoctest
julia>
```
# TODO
- [] update docs
- [x] implement the function
# Signature
"""
function selectBestNextState(node::MCTSNode)::MCTSNode
highestProgressValue = 0
nodekey = nothing
# if all childnode has statevalue == 0, use progressvalue + reward to select the best node
stateValueSum = sum([v.statevalue for (k, v) in node.children])
if stateValueSum != 0
for (k, childnode) in node.children
potential = childnode.statevalue / childnode.visits
if potential > highestProgressValue
highestProgressValue = potential
nodekey = childnode.nodekey
end
end
else
for (k, childnode) in node.children
potential = childnode.progressvalue + childnode.reward
if potential > highestProgressValue
highestProgressValue = potential
nodekey = childnode.nodekey
end
end
end
return node.children[nodekey]
end
"""
# Arguments
- `node::MCTSNode`
node of a search tree
# Return
- `childNode::MCTSNode`
the highest value child node
# Example
```jldoctest
julia>
```
# TODO
- [] update docs
- [x] implement the function
# Signature
"""
function selectBestTrajectory(node::MCTSNode)::MCTSNode
while !isleaf(node)
node = selectBestNextState(node)
end
return node
end
""" Backpropagate reward along the simulation chain
# Arguments
- `node::MCTSNode`
leaf node of a search tree
- `simTrajectoryReward::T`
total reward from trajectory simulation
# Return
- `No return`
# Example
```jldoctest
julia>
```
# Signature
"""
function backpropagate(node::MCTSNode, simTrajectoryReward::T;
discountRewardCoeff::AbstractFloat=0.9) where {T<:Number}
while !isroot(node)
# Update the statistics of the current node based on the result of the playout
node.visits += 1
node.statevalue += ((node.statevalue * (node.visits-1)) + simTrajectoryReward) / node.visits
simTrajectoryReward *= discountRewardCoeff # discount because future reward is uncertain
node = node.parent
end
end
""" Determine whether a node is a leaf node of a search tree.
# Arguments
- `node::MCTSNode`
a search tree node
# Return
- `result::Bool`
true if it is a leaf node, false otherwise.
# Example
```jldoctest
julia> using Revise
julia> using YiemAgent, DataStructures
julia> initialState = Dict{Symbol, Any}(
:customerinfo=> Dict{Symbol, Any}(),
:storeinfo=> Dict{Symbol, Any}(),
:thoughtHistory=> OrderedDict{Symbol, Any}(
:question=> "How are you?",
)
)
julia> statetype = typeof(initialState)
julia> root = YiemAgent.MCTSNode(initialState, 0, 0.0, Dict{statetype, YiemAgent.MCTSNode}())
julia> YiemAgent.isleaf(root)
true
```
# TODO
[] update docs
# Signature
"""
isleaf(node::MCTSNode)::Bool = isempty(node.children)
""" Determine wheter a given node is a root node
# Arguments
- `node::MCTSNode`
node of a search tree
# Return
- `isrootnode::Bool`
true if the given node is root node, false otherwise
# Example
```jldoctest
julia>
```
# Signature
"""
isroot(node::MCTSNode)::Bool = node.nodekey == "root" ? true : false
""" Select child node based on the highest statevalue
# Arguments
- `node::MCTSNode`
node of a search tree
# Return
- `childNode::MCTSNode`
the highest value child node
# Example
```jldoctest
julia>
```
# Signature
"""
function selectChildNode(node::MCTSNode)::MCTSNode
highestProgressValue = 0
nodekey = nothing
# loop thought node children dictionary to find the highest progress value
for (k, childNode) in node.children
potential = childNode.progressvalue + childNode.reward
if childNode.reward > 0 #XXX for testing. remove when done.
println("")
end
if potential > highestProgressValue
highestProgressValue = potential
nodekey = childNode.nodekey
end
end
return node.children[nodekey]
end
""" Expand selected node
# Arguments
- `a::T1`
One of YiemAgent's agent
- `node::MCTSNode`
MCTS node
- `state::T2`
a state of a game. Can be a Dict or something else.
- `decisionMaker::Function`
a function that output Thought and Action
- `evaluator::Function`
a function that output trajectory progress score
# Return
# Example
```jldoctest
julia>
```
# TODO
[] update docstring
[] try loop should limit to 3 times. if not succeed, skip
[] newNodeKey ∉ keys(node.children). New state may have semantic vector close enought to one of existing child state. Which can be assume that they are the same state semantically-wise.
[x] store feedback -> state -> agent.
# Signature
"""
function expand(workDict::T1, node::MCTSNode, decisionMaker::Function, evaluator::Function,
reflector::Function, transition::Function; totalsample::Integer=3
) where {T1<:AbstractDict}
nthSample = 0
while true
nthSample += 1
if nthSample <= totalsample
thoughtDict = decisionMaker(a, node.state)
println("---> expand() sample $nthSample")
pprintln(node.state[:thoughtHistory])
pprintln(thoughtDict)
newNodeKey, newstate = mctstransition(workDict, transition, node.state, thoughtDict)
stateevaluation, progressvalue = evaluator(workDict, newstate)
if newstate[:reward] < 0
pprint(newstate[:thoughtHistory])
newstate[:evaluation] = stateevaluation
newstate[:lesson] = reflector(a, newstate)
# store new lesson for later use
lessonDict = copy(JSON3.read("lesson.json"))
latestLessonKey, latestLessonIndice =
GeneralUtils.findHighestIndexKey(lessonDict, "lesson")
nextIndice = latestLessonKey == :NA ? 1 : latestLessonIndice + 1
newLessonKey = Symbol("lesson_$(nextIndice)")
lessonDict[newLessonKey] = newstate
open("lesson.json", "w") do io
JSON3.pretty(io, lessonDict)
end
print("---> reflector()")
end
if newNodeKey keys(node.children)
node.children[newNodeKey] =
MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward],
newstate[:isterminal], node, Dict{String, MCTSNode}())
end
else
break
end
end
end
""" Get a new state
# Arguments
- `a::T1`
one of YiemAgent's agent
- `state::T2`
current game state
- `thoughtDict::T3`
contain Thought, Action, Observation
- `isterminal::Function`
a function to determine terminal state
# Return
- `(newNodeKey, newstate, isterminalstate, reward)::Tuple{String, Dict{Symbol, <:Any}, Bool, <:Number}`
# Example
```jldoctest
julia> state = Dict{Symbol, Dict{Symbol, Any}}(
:thoughtHistory => Dict(:question => "Hello, I want to buy a bottle of wine."),
:storeinfo => Dict(),
:customerinfo => Dict()
)
julia> thoughtDict = Dict(
:question=> "I want to buy a bottle of wine.",
:thought_1=> "The customer wants to buy a bottle of wine.",
:action_1=> Dict{Symbol, Any}(
:name=>"Chatbox",
:input=>"What occasion are you buying the wine for?",
),
:observation_1 => ""
)
```
# TODO
- [] add other actions
- [WORKING] add embedding of newstate and store in newstate[:embedding]
# Signature
"""
function mctstransition(workDict::T1, transition::Function, state::T2, thoughtDict::T2
)::Tuple{String, Dict{Symbol, <:Any}} where {T1<:AbstractDict, T2<:AbstractDict}
# actionname = thoughtDict[:action][:name]
# actioninput = thoughtDict[:action][:input]
# # map action and input() to llm function
# response, select, reward, isterminal =
# if actionname == "chatbox"
# # deepcopy(state[:virtualCustomerChatHistory]) because I want to keep it clean
# # so that other simulation start from this same node is not contaminated with actioninput
# virtualWineUserChatbox(workDict, actioninput, deepcopy(state[:virtualCustomerChatHistory])) # virtual customer
# elseif actionname == "winestock"
# winestock(a, actioninput)
# elseif actionname == "recommendbox"
# virtualWineUserRecommendbox(workDict, actioninput)
# else
# error("undefined LLM function. Requesting $actionname")
# end
# newNodeKey, newstate = makeNewState(state, thoughtDict, response, select, reward, isterminal)
# if actionname == "chatbox"
# push!(newstate[:virtualCustomerChatHistory], Dict(:name=>"assistant", :text=> actioninput) )
# push!(newstate[:virtualCustomerChatHistory], Dict(:name=>"user", :text=> response))
# end
return (newNodeKey, newstate)
end
end # module mcts

116
src/type.jl Normal file
View File

@@ -0,0 +1,116 @@
module type
export MCTSNode
# ---------------------------------------------- 100 --------------------------------------------- #
""" a node for MCTS search tree
# Arguments
- `state::T`
a state of a game. Can be a Dict or something else.
- `visits::Integer `
number of time the game visits this state
- `stateValue::Float64`
state value
- `children::Dict{T, MCTSNode}`
children node
# Return
- `nothing`
# Example
```jldoctest
julia> state = Dict(
:info=> Dict(), # keyword info
:thoughtHistory=> Dict(
:question=> _,
:thought_1=> _,
:action_1=> _,
:observation_1=> _,
:thought_2=> _,
...
)
)
```
# TODO
[] update docstring
# Signature
"""
mutable struct MCTSNode{T1<:AbstractDict, T2<:AbstractString}
nodekey::T2
state::T1
visits::Integer
progressvalue::Number # estimate value by LLM's reasoning
statevalue::Number # store discounted commulative reward (gather from its child node)
reward::Number # this node's own reward
isterminal::Bool
parent::Union{MCTSNode, Nothing}
children::Dict{String, MCTSNode}
end
end # module type

139
src/util.jl Normal file
View File

@@ -0,0 +1,139 @@
module util
export UCTselect
using ..type
# ---------------------------------------------- 100 --------------------------------------------- #
""" Select a node based on UCT score
# Arguments
- `node::MCTSNode`
mcts node
- `w::T`
exploration weight. Value is usually between 1 to 2.
Value 1.0 makes MCTS balance between exploration and exploitation like 50%-50%.
Value 2.0 makes MCTS aggressively search the tree.
# Return
- `selectedNode::MCTSNode`
# Example
```jldoctest
julia>
```
# Signature
"""
function UCTselect(node::MCTSNode, w::T)::MCTSNode where {T<:AbstractFloat}
maxUCT = -Inf
selectedNode = nothing
for (childState, childNode) in node.children
UCTvalue =
if childNode.visits != 0
weightedterm = w * sqrt(log(node.visits) / childNode.visits) # explore term
childNode.statevalue + weightedterm
else # node.visits == 0 makes sqrt() in explore term error
childNode.progressvalue # exploit term
end
if UCTvalue > maxUCT
maxUCT = UCTvalue
selectedNode = childNode
end
end
return selectedNode
end
end # module util