This commit is contained in:
narawat lamaiin
2024-05-31 11:47:51 +07:00
parent 3613f1d2bd
commit 452262d3d6
6 changed files with 891 additions and 1 deletions

432
src/mcts.jl Normal file
View File

@@ -0,0 +1,432 @@
module mcts
export selectBestNextState, selectBestTrajectory, backpropagate, isleaf, isroot, selectChildNode,
expand, mctstransition
using ..type
# ---------------------------------------------- 100 --------------------------------------------- #
"""
# Arguments
- `node::MCTSNode`
node of a search tree
# Return
- `childNode::MCTSNode`
the highest value child node
# Example
```jldoctest
julia>
```
# TODO
- [] update docs
- [x] implement the function
# Signature
"""
function selectBestNextState(node::MCTSNode)::MCTSNode
highestProgressValue = 0
nodekey = nothing
# if all childnode has statevalue == 0, use progressvalue + reward to select the best node
stateValueSum = sum([v.statevalue for (k, v) in node.children])
if stateValueSum != 0
for (k, childnode) in node.children
potential = childnode.statevalue / childnode.visits
if potential > highestProgressValue
highestProgressValue = potential
nodekey = childnode.nodekey
end
end
else
for (k, childnode) in node.children
potential = childnode.progressvalue + childnode.reward
if potential > highestProgressValue
highestProgressValue = potential
nodekey = childnode.nodekey
end
end
end
return node.children[nodekey]
end
"""
# Arguments
- `node::MCTSNode`
node of a search tree
# Return
- `childNode::MCTSNode`
the highest value child node
# Example
```jldoctest
julia>
```
# TODO
- [] update docs
- [x] implement the function
# Signature
"""
function selectBestTrajectory(node::MCTSNode)::MCTSNode
while !isleaf(node)
node = selectBestNextState(node)
end
return node
end
""" Backpropagate reward along the simulation chain
# Arguments
- `node::MCTSNode`
leaf node of a search tree
- `simTrajectoryReward::T`
total reward from trajectory simulation
# Return
- `No return`
# Example
```jldoctest
julia>
```
# Signature
"""
function backpropagate(node::MCTSNode, simTrajectoryReward::T;
discountRewardCoeff::AbstractFloat=0.9) where {T<:Number}
while !isroot(node)
# Update the statistics of the current node based on the result of the playout
node.visits += 1
node.statevalue += ((node.statevalue * (node.visits-1)) + simTrajectoryReward) / node.visits
simTrajectoryReward *= discountRewardCoeff # discount because future reward is uncertain
node = node.parent
end
end
""" Determine whether a node is a leaf node of a search tree.
# Arguments
- `node::MCTSNode`
a search tree node
# Return
- `result::Bool`
true if it is a leaf node, false otherwise.
# Example
```jldoctest
julia> using Revise
julia> using YiemAgent, DataStructures
julia> initialState = Dict{Symbol, Any}(
:customerinfo=> Dict{Symbol, Any}(),
:storeinfo=> Dict{Symbol, Any}(),
:thoughtHistory=> OrderedDict{Symbol, Any}(
:question=> "How are you?",
)
)
julia> statetype = typeof(initialState)
julia> root = YiemAgent.MCTSNode(initialState, 0, 0.0, Dict{statetype, YiemAgent.MCTSNode}())
julia> YiemAgent.isleaf(root)
true
```
# TODO
[] update docs
# Signature
"""
isleaf(node::MCTSNode)::Bool = isempty(node.children)
""" Determine wheter a given node is a root node
# Arguments
- `node::MCTSNode`
node of a search tree
# Return
- `isrootnode::Bool`
true if the given node is root node, false otherwise
# Example
```jldoctest
julia>
```
# Signature
"""
isroot(node::MCTSNode)::Bool = node.nodekey == "root" ? true : false
""" Select child node based on the highest statevalue
# Arguments
- `node::MCTSNode`
node of a search tree
# Return
- `childNode::MCTSNode`
the highest value child node
# Example
```jldoctest
julia>
```
# Signature
"""
function selectChildNode(node::MCTSNode)::MCTSNode
highestProgressValue = 0
nodekey = nothing
# loop thought node children dictionary to find the highest progress value
for (k, childNode) in node.children
potential = childNode.progressvalue + childNode.reward
if childNode.reward > 0 #XXX for testing. remove when done.
println("")
end
if potential > highestProgressValue
highestProgressValue = potential
nodekey = childNode.nodekey
end
end
return node.children[nodekey]
end
""" Expand selected node
# Arguments
- `a::T1`
One of YiemAgent's agent
- `node::MCTSNode`
MCTS node
- `state::T2`
a state of a game. Can be a Dict or something else.
- `decisionMaker::Function`
a function that output Thought and Action
- `evaluator::Function`
a function that output trajectory progress score
# Return
# Example
```jldoctest
julia>
```
# TODO
[] update docstring
[] try loop should limit to 3 times. if not succeed, skip
[] newNodeKey ∉ keys(node.children). New state may have semantic vector close enought to one of existing child state. Which can be assume that they are the same state semantically-wise.
[x] store feedback -> state -> agent.
# Signature
"""
function expand(workDict::T1, node::MCTSNode, decisionMaker::Function, evaluator::Function,
reflector::Function, transition::Function; totalsample::Integer=3
) where {T1<:AbstractDict}
nthSample = 0
while true
nthSample += 1
if nthSample <= totalsample
thoughtDict = decisionMaker(a, node.state)
println("---> expand() sample $nthSample")
pprintln(node.state[:thoughtHistory])
pprintln(thoughtDict)
newNodeKey, newstate = mctstransition(workDict, transition, node.state, thoughtDict)
stateevaluation, progressvalue = evaluator(workDict, newstate)
if newstate[:reward] < 0
pprint(newstate[:thoughtHistory])
newstate[:evaluation] = stateevaluation
newstate[:lesson] = reflector(a, newstate)
# store new lesson for later use
lessonDict = copy(JSON3.read("lesson.json"))
latestLessonKey, latestLessonIndice =
GeneralUtils.findHighestIndexKey(lessonDict, "lesson")
nextIndice = latestLessonKey == :NA ? 1 : latestLessonIndice + 1
newLessonKey = Symbol("lesson_$(nextIndice)")
lessonDict[newLessonKey] = newstate
open("lesson.json", "w") do io
JSON3.pretty(io, lessonDict)
end
print("---> reflector()")
end
if newNodeKey keys(node.children)
node.children[newNodeKey] =
MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward],
newstate[:isterminal], node, Dict{String, MCTSNode}())
end
else
break
end
end
end
""" Get a new state
# Arguments
- `a::T1`
one of YiemAgent's agent
- `state::T2`
current game state
- `thoughtDict::T3`
contain Thought, Action, Observation
- `isterminal::Function`
a function to determine terminal state
# Return
- `(newNodeKey, newstate, isterminalstate, reward)::Tuple{String, Dict{Symbol, <:Any}, Bool, <:Number}`
# Example
```jldoctest
julia> state = Dict{Symbol, Dict{Symbol, Any}}(
:thoughtHistory => Dict(:question => "Hello, I want to buy a bottle of wine."),
:storeinfo => Dict(),
:customerinfo => Dict()
)
julia> thoughtDict = Dict(
:question=> "I want to buy a bottle of wine.",
:thought_1=> "The customer wants to buy a bottle of wine.",
:action_1=> Dict{Symbol, Any}(
:name=>"Chatbox",
:input=>"What occasion are you buying the wine for?",
),
:observation_1 => ""
)
```
# TODO
- [] add other actions
- [WORKING] add embedding of newstate and store in newstate[:embedding]
# Signature
"""
function mctstransition(workDict::T1, transition::Function, state::T2, thoughtDict::T2
)::Tuple{String, Dict{Symbol, <:Any}} where {T1<:AbstractDict, T2<:AbstractDict}
# actionname = thoughtDict[:action][:name]
# actioninput = thoughtDict[:action][:input]
# # map action and input() to llm function
# response, select, reward, isterminal =
# if actionname == "chatbox"
# # deepcopy(state[:virtualCustomerChatHistory]) because I want to keep it clean
# # so that other simulation start from this same node is not contaminated with actioninput
# virtualWineUserChatbox(workDict, actioninput, deepcopy(state[:virtualCustomerChatHistory])) # virtual customer
# elseif actionname == "winestock"
# winestock(a, actioninput)
# elseif actionname == "recommendbox"
# virtualWineUserRecommendbox(workDict, actioninput)
# else
# error("undefined LLM function. Requesting $actionname")
# end
# newNodeKey, newstate = makeNewState(state, thoughtDict, response, select, reward, isterminal)
# if actionname == "chatbox"
# push!(newstate[:virtualCustomerChatHistory], Dict(:name=>"assistant", :text=> actioninput) )
# push!(newstate[:virtualCustomerChatHistory], Dict(:name=>"user", :text=> response))
# end
return (newNodeKey, newstate)
end
end # module mcts