update
This commit is contained in:
1
.vscode/settings.json
vendored
Normal file
1
.vscode/settings.json
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
{}
|
||||||
@@ -1,5 +1,28 @@
|
|||||||
module LLMMCTS
|
module LLMMCTS
|
||||||
|
|
||||||
greet() = print("Hello World!")
|
# export agent
|
||||||
|
|
||||||
|
|
||||||
|
""" Order by dependencies of each file. The 1st included file must not depend on any other
|
||||||
|
files and each file can only depend on the file included before it.
|
||||||
|
"""
|
||||||
|
|
||||||
|
include("type.jl")
|
||||||
|
using .type
|
||||||
|
|
||||||
|
include("util.jl")
|
||||||
|
using .util
|
||||||
|
|
||||||
|
include("mcts.jl")
|
||||||
|
using .mcts
|
||||||
|
|
||||||
|
include("interface.jl")
|
||||||
|
using .interface
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------- 100 --------------------------------------------- #
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
end # module LLMMCTS
|
end # module LLMMCTS
|
||||||
|
|||||||
179
src/interface.jl
Normal file
179
src/interface.jl
Normal file
@@ -0,0 +1,179 @@
|
|||||||
|
module interface
|
||||||
|
|
||||||
|
export runMCTS
|
||||||
|
|
||||||
|
using ..type, ..mcts
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------- 100 --------------------------------------------- #
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
""" Search the best action to take for a given state and task
|
||||||
|
|
||||||
|
# Arguments
|
||||||
|
- `a::agent`
|
||||||
|
one of Yiem's agents
|
||||||
|
- `initial state`
|
||||||
|
initial state
|
||||||
|
- `decisionMaker::Function`
|
||||||
|
decide what action to take
|
||||||
|
- `evaluator::Function`
|
||||||
|
assess the value of the state
|
||||||
|
- `reflector::Function`
|
||||||
|
generate lesson from trajectory and reward
|
||||||
|
- `isterminal::Function`
|
||||||
|
determine whether a given state is a terminal state
|
||||||
|
- `n::Integer`
|
||||||
|
how many times action will be sampled from decisionMaker
|
||||||
|
- `w::Float64`
|
||||||
|
exploration weight. Value is usually between 1 to 2.
|
||||||
|
Value 1.0 makes MCTS balance between exploration and exploitation like 50%-50%
|
||||||
|
Value 2.0 makes MCTS aggressively search the tree
|
||||||
|
|
||||||
|
# Return
|
||||||
|
- `plan::Vector{Dict}`
|
||||||
|
best plan
|
||||||
|
|
||||||
|
# Example
|
||||||
|
```jldoctest
|
||||||
|
julia>
|
||||||
|
```
|
||||||
|
|
||||||
|
# TODO
|
||||||
|
[] update docstring
|
||||||
|
[] return best action
|
||||||
|
|
||||||
|
# Signature
|
||||||
|
"""
|
||||||
|
function runMCTS(
|
||||||
|
workDict::Dict{Symbol, Any},
|
||||||
|
initialState,
|
||||||
|
decisionMaker::Function,
|
||||||
|
evaluator::Function,
|
||||||
|
reflector::Function,
|
||||||
|
transition::Function,
|
||||||
|
;
|
||||||
|
totalsample::Integer=3,
|
||||||
|
maxDepth::Integer=3,
|
||||||
|
maxiterations::Integer=10,
|
||||||
|
explorationweight::Number=1.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
root = MCTSNode("root", initialState, 0, 0, 0, 0, false, nothing, Dict{String, MCTSNode}())
|
||||||
|
|
||||||
|
for nth in 1:maxiterations
|
||||||
|
node = root
|
||||||
|
node.visits += 1
|
||||||
|
|
||||||
|
while !isleaf(node)
|
||||||
|
node = UCTselect(node, explorationweight)
|
||||||
|
end
|
||||||
|
if node.isterminal
|
||||||
|
# MCTS arrive at the leaf node that is also a terminal state,
|
||||||
|
# do nothing then go directly to backpropagation
|
||||||
|
backpropagate(leafNode, node.reward)
|
||||||
|
else
|
||||||
|
expand(workDict, node, decisionMaker, evaluator, reflector; totalsample=totalsample)
|
||||||
|
leafNode = selectChildNode(node)
|
||||||
|
simTrajectoryReward, terminalstate = simulate(workDict, leafNode, decisionMaker, evaluator,
|
||||||
|
reflector; maxDepth=maxDepth, totalsample=totalsample)
|
||||||
|
if terminalstate !== nothing #XXX not sure why I need this
|
||||||
|
terminalstate[:totalTrajectoryReward] = simTrajectoryReward
|
||||||
|
end
|
||||||
|
|
||||||
|
#[] write best state to file if it has higher simTrajectoryReward. Use to improve evaluation
|
||||||
|
# open("trajectory.json", "w") do io
|
||||||
|
# JSON3.pretty(io, terminalstate)
|
||||||
|
# end
|
||||||
|
|
||||||
|
backpropagate(leafNode, simTrajectoryReward)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
bestNextState = selectBestNextState(root)
|
||||||
|
besttrajectory = selectBestTrajectory(root)
|
||||||
|
|
||||||
|
return (bestNextState.state, besttrajectory.state)
|
||||||
|
end
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
end # module interface
|
||||||
432
src/mcts.jl
Normal file
432
src/mcts.jl
Normal file
@@ -0,0 +1,432 @@
|
|||||||
|
module mcts
|
||||||
|
|
||||||
|
export selectBestNextState, selectBestTrajectory, backpropagate, isleaf, isroot, selectChildNode,
|
||||||
|
expand, mctstransition
|
||||||
|
|
||||||
|
using ..type
|
||||||
|
|
||||||
|
# ---------------------------------------------- 100 --------------------------------------------- #
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Arguments
|
||||||
|
- `node::MCTSNode`
|
||||||
|
node of a search tree
|
||||||
|
|
||||||
|
# Return
|
||||||
|
- `childNode::MCTSNode`
|
||||||
|
the highest value child node
|
||||||
|
|
||||||
|
# Example
|
||||||
|
```jldoctest
|
||||||
|
julia>
|
||||||
|
```
|
||||||
|
|
||||||
|
# TODO
|
||||||
|
- [] update docs
|
||||||
|
- [x] implement the function
|
||||||
|
|
||||||
|
# Signature
|
||||||
|
"""
|
||||||
|
function selectBestNextState(node::MCTSNode)::MCTSNode
|
||||||
|
highestProgressValue = 0
|
||||||
|
nodekey = nothing
|
||||||
|
|
||||||
|
# if all childnode has statevalue == 0, use progressvalue + reward to select the best node
|
||||||
|
stateValueSum = sum([v.statevalue for (k, v) in node.children])
|
||||||
|
|
||||||
|
if stateValueSum != 0
|
||||||
|
for (k, childnode) in node.children
|
||||||
|
potential = childnode.statevalue / childnode.visits
|
||||||
|
|
||||||
|
if potential > highestProgressValue
|
||||||
|
highestProgressValue = potential
|
||||||
|
nodekey = childnode.nodekey
|
||||||
|
end
|
||||||
|
end
|
||||||
|
else
|
||||||
|
for (k, childnode) in node.children
|
||||||
|
potential = childnode.progressvalue + childnode.reward
|
||||||
|
|
||||||
|
if potential > highestProgressValue
|
||||||
|
highestProgressValue = potential
|
||||||
|
nodekey = childnode.nodekey
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
return node.children[nodekey]
|
||||||
|
end
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Arguments
|
||||||
|
- `node::MCTSNode`
|
||||||
|
node of a search tree
|
||||||
|
|
||||||
|
# Return
|
||||||
|
- `childNode::MCTSNode`
|
||||||
|
the highest value child node
|
||||||
|
|
||||||
|
# Example
|
||||||
|
```jldoctest
|
||||||
|
julia>
|
||||||
|
```
|
||||||
|
|
||||||
|
# TODO
|
||||||
|
- [] update docs
|
||||||
|
- [x] implement the function
|
||||||
|
|
||||||
|
# Signature
|
||||||
|
"""
|
||||||
|
function selectBestTrajectory(node::MCTSNode)::MCTSNode
|
||||||
|
while !isleaf(node)
|
||||||
|
node = selectBestNextState(node)
|
||||||
|
end
|
||||||
|
|
||||||
|
return node
|
||||||
|
end
|
||||||
|
|
||||||
|
|
||||||
|
""" Backpropagate reward along the simulation chain
|
||||||
|
|
||||||
|
# Arguments
|
||||||
|
- `node::MCTSNode`
|
||||||
|
leaf node of a search tree
|
||||||
|
- `simTrajectoryReward::T`
|
||||||
|
total reward from trajectory simulation
|
||||||
|
|
||||||
|
# Return
|
||||||
|
- `No return`
|
||||||
|
|
||||||
|
# Example
|
||||||
|
```jldoctest
|
||||||
|
julia>
|
||||||
|
```
|
||||||
|
|
||||||
|
# Signature
|
||||||
|
"""
|
||||||
|
function backpropagate(node::MCTSNode, simTrajectoryReward::T;
|
||||||
|
discountRewardCoeff::AbstractFloat=0.9) where {T<:Number}
|
||||||
|
while !isroot(node)
|
||||||
|
# Update the statistics of the current node based on the result of the playout
|
||||||
|
node.visits += 1
|
||||||
|
node.statevalue += ((node.statevalue * (node.visits-1)) + simTrajectoryReward) / node.visits
|
||||||
|
simTrajectoryReward *= discountRewardCoeff # discount because future reward is uncertain
|
||||||
|
node = node.parent
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
|
||||||
|
""" Determine whether a node is a leaf node of a search tree.
|
||||||
|
|
||||||
|
# Arguments
|
||||||
|
- `node::MCTSNode`
|
||||||
|
a search tree node
|
||||||
|
# Return
|
||||||
|
- `result::Bool`
|
||||||
|
true if it is a leaf node, false otherwise.
|
||||||
|
# Example
|
||||||
|
```jldoctest
|
||||||
|
julia> using Revise
|
||||||
|
julia> using YiemAgent, DataStructures
|
||||||
|
julia> initialState = Dict{Symbol, Any}(
|
||||||
|
:customerinfo=> Dict{Symbol, Any}(),
|
||||||
|
:storeinfo=> Dict{Symbol, Any}(),
|
||||||
|
|
||||||
|
:thoughtHistory=> OrderedDict{Symbol, Any}(
|
||||||
|
:question=> "How are you?",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
julia> statetype = typeof(initialState)
|
||||||
|
julia> root = YiemAgent.MCTSNode(initialState, 0, 0.0, Dict{statetype, YiemAgent.MCTSNode}())
|
||||||
|
julia> YiemAgent.isleaf(root)
|
||||||
|
true
|
||||||
|
```
|
||||||
|
|
||||||
|
# TODO
|
||||||
|
[] update docs
|
||||||
|
|
||||||
|
# Signature
|
||||||
|
"""
|
||||||
|
isleaf(node::MCTSNode)::Bool = isempty(node.children)
|
||||||
|
|
||||||
|
|
||||||
|
""" Determine wheter a given node is a root node
|
||||||
|
|
||||||
|
# Arguments
|
||||||
|
- `node::MCTSNode`
|
||||||
|
node of a search tree
|
||||||
|
|
||||||
|
# Return
|
||||||
|
- `isrootnode::Bool`
|
||||||
|
true if the given node is root node, false otherwise
|
||||||
|
|
||||||
|
# Example
|
||||||
|
```jldoctest
|
||||||
|
julia>
|
||||||
|
```
|
||||||
|
|
||||||
|
# Signature
|
||||||
|
"""
|
||||||
|
isroot(node::MCTSNode)::Bool = node.nodekey == "root" ? true : false
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
""" Select child node based on the highest statevalue
|
||||||
|
|
||||||
|
# Arguments
|
||||||
|
- `node::MCTSNode`
|
||||||
|
node of a search tree
|
||||||
|
|
||||||
|
# Return
|
||||||
|
- `childNode::MCTSNode`
|
||||||
|
the highest value child node
|
||||||
|
|
||||||
|
# Example
|
||||||
|
```jldoctest
|
||||||
|
julia>
|
||||||
|
```
|
||||||
|
|
||||||
|
# Signature
|
||||||
|
"""
|
||||||
|
function selectChildNode(node::MCTSNode)::MCTSNode
|
||||||
|
highestProgressValue = 0
|
||||||
|
nodekey = nothing
|
||||||
|
|
||||||
|
# loop thought node children dictionary to find the highest progress value
|
||||||
|
for (k, childNode) in node.children
|
||||||
|
potential = childNode.progressvalue + childNode.reward
|
||||||
|
if childNode.reward > 0 #XXX for testing. remove when done.
|
||||||
|
println("")
|
||||||
|
end
|
||||||
|
if potential > highestProgressValue
|
||||||
|
highestProgressValue = potential
|
||||||
|
nodekey = childNode.nodekey
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
return node.children[nodekey]
|
||||||
|
end
|
||||||
|
|
||||||
|
|
||||||
|
""" Expand selected node
|
||||||
|
|
||||||
|
# Arguments
|
||||||
|
- `a::T1`
|
||||||
|
One of YiemAgent's agent
|
||||||
|
- `node::MCTSNode`
|
||||||
|
MCTS node
|
||||||
|
- `state::T2`
|
||||||
|
a state of a game. Can be a Dict or something else.
|
||||||
|
- `decisionMaker::Function`
|
||||||
|
a function that output Thought and Action
|
||||||
|
- `evaluator::Function`
|
||||||
|
a function that output trajectory progress score
|
||||||
|
|
||||||
|
# Return
|
||||||
|
|
||||||
|
# Example
|
||||||
|
```jldoctest
|
||||||
|
julia>
|
||||||
|
```
|
||||||
|
|
||||||
|
# TODO
|
||||||
|
[] update docstring
|
||||||
|
[] try loop should limit to 3 times. if not succeed, skip
|
||||||
|
[] newNodeKey ∉ keys(node.children). New state may have semantic vector close enought to one of existing child state. Which can be assume that they are the same state semantically-wise.
|
||||||
|
[x] store feedback -> state -> agent.
|
||||||
|
|
||||||
|
|
||||||
|
# Signature
|
||||||
|
"""
|
||||||
|
function expand(workDict::T1, node::MCTSNode, decisionMaker::Function, evaluator::Function,
|
||||||
|
reflector::Function, transition::Function; totalsample::Integer=3
|
||||||
|
) where {T1<:AbstractDict}
|
||||||
|
|
||||||
|
nthSample = 0
|
||||||
|
while true
|
||||||
|
nthSample += 1
|
||||||
|
if nthSample <= totalsample
|
||||||
|
thoughtDict = decisionMaker(a, node.state)
|
||||||
|
println("---> expand() sample $nthSample")
|
||||||
|
pprintln(node.state[:thoughtHistory])
|
||||||
|
pprintln(thoughtDict)
|
||||||
|
newNodeKey, newstate = mctstransition(workDict, transition, node.state, thoughtDict)
|
||||||
|
|
||||||
|
stateevaluation, progressvalue = evaluator(workDict, newstate)
|
||||||
|
|
||||||
|
if newstate[:reward] < 0
|
||||||
|
pprint(newstate[:thoughtHistory])
|
||||||
|
newstate[:evaluation] = stateevaluation
|
||||||
|
newstate[:lesson] = reflector(a, newstate)
|
||||||
|
|
||||||
|
# store new lesson for later use
|
||||||
|
lessonDict = copy(JSON3.read("lesson.json"))
|
||||||
|
latestLessonKey, latestLessonIndice =
|
||||||
|
GeneralUtils.findHighestIndexKey(lessonDict, "lesson")
|
||||||
|
nextIndice = latestLessonKey == :NA ? 1 : latestLessonIndice + 1
|
||||||
|
newLessonKey = Symbol("lesson_$(nextIndice)")
|
||||||
|
lessonDict[newLessonKey] = newstate
|
||||||
|
open("lesson.json", "w") do io
|
||||||
|
JSON3.pretty(io, lessonDict)
|
||||||
|
end
|
||||||
|
print("---> reflector()")
|
||||||
|
end
|
||||||
|
|
||||||
|
if newNodeKey ∉ keys(node.children)
|
||||||
|
node.children[newNodeKey] =
|
||||||
|
MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward],
|
||||||
|
newstate[:isterminal], node, Dict{String, MCTSNode}())
|
||||||
|
end
|
||||||
|
else
|
||||||
|
break
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
|
||||||
|
""" Get a new state
|
||||||
|
|
||||||
|
# Arguments
|
||||||
|
- `a::T1`
|
||||||
|
one of YiemAgent's agent
|
||||||
|
- `state::T2`
|
||||||
|
current game state
|
||||||
|
- `thoughtDict::T3`
|
||||||
|
contain Thought, Action, Observation
|
||||||
|
- `isterminal::Function`
|
||||||
|
a function to determine terminal state
|
||||||
|
|
||||||
|
# Return
|
||||||
|
- `(newNodeKey, newstate, isterminalstate, reward)::Tuple{String, Dict{Symbol, <:Any}, Bool, <:Number}`
|
||||||
|
|
||||||
|
# Example
|
||||||
|
```jldoctest
|
||||||
|
julia> state = Dict{Symbol, Dict{Symbol, Any}}(
|
||||||
|
:thoughtHistory => Dict(:question => "Hello, I want to buy a bottle of wine."),
|
||||||
|
:storeinfo => Dict(),
|
||||||
|
:customerinfo => Dict()
|
||||||
|
)
|
||||||
|
julia> thoughtDict = Dict(
|
||||||
|
:question=> "I want to buy a bottle of wine.",
|
||||||
|
:thought_1=> "The customer wants to buy a bottle of wine.",
|
||||||
|
:action_1=> Dict{Symbol, Any}(
|
||||||
|
:name=>"Chatbox",
|
||||||
|
:input=>"What occasion are you buying the wine for?",
|
||||||
|
),
|
||||||
|
:observation_1 => ""
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
# TODO
|
||||||
|
- [] add other actions
|
||||||
|
- [WORKING] add embedding of newstate and store in newstate[:embedding]
|
||||||
|
|
||||||
|
# Signature
|
||||||
|
"""
|
||||||
|
function mctstransition(workDict::T1, transition::Function, state::T2, thoughtDict::T2
|
||||||
|
)::Tuple{String, Dict{Symbol, <:Any}} where {T1<:AbstractDict, T2<:AbstractDict}
|
||||||
|
|
||||||
|
# actionname = thoughtDict[:action][:name]
|
||||||
|
# actioninput = thoughtDict[:action][:input]
|
||||||
|
|
||||||
|
# # map action and input() to llm function
|
||||||
|
# response, select, reward, isterminal =
|
||||||
|
# if actionname == "chatbox"
|
||||||
|
# # deepcopy(state[:virtualCustomerChatHistory]) because I want to keep it clean
|
||||||
|
# # so that other simulation start from this same node is not contaminated with actioninput
|
||||||
|
# virtualWineUserChatbox(workDict, actioninput, deepcopy(state[:virtualCustomerChatHistory])) # virtual customer
|
||||||
|
# elseif actionname == "winestock"
|
||||||
|
# winestock(a, actioninput)
|
||||||
|
# elseif actionname == "recommendbox"
|
||||||
|
# virtualWineUserRecommendbox(workDict, actioninput)
|
||||||
|
# else
|
||||||
|
# error("undefined LLM function. Requesting $actionname")
|
||||||
|
# end
|
||||||
|
|
||||||
|
# newNodeKey, newstate = makeNewState(state, thoughtDict, response, select, reward, isterminal)
|
||||||
|
# if actionname == "chatbox"
|
||||||
|
# push!(newstate[:virtualCustomerChatHistory], Dict(:name=>"assistant", :text=> actioninput) )
|
||||||
|
# push!(newstate[:virtualCustomerChatHistory], Dict(:name=>"user", :text=> response))
|
||||||
|
# end
|
||||||
|
|
||||||
|
return (newNodeKey, newstate)
|
||||||
|
end
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
end # module mcts
|
||||||
116
src/type.jl
Normal file
116
src/type.jl
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
module type
|
||||||
|
|
||||||
|
export MCTSNode
|
||||||
|
|
||||||
|
# ---------------------------------------------- 100 --------------------------------------------- #
|
||||||
|
|
||||||
|
|
||||||
|
""" a node for MCTS search tree
|
||||||
|
|
||||||
|
# Arguments
|
||||||
|
- `state::T`
|
||||||
|
a state of a game. Can be a Dict or something else.
|
||||||
|
- `visits::Integer `
|
||||||
|
number of time the game visits this state
|
||||||
|
- `stateValue::Float64`
|
||||||
|
state value
|
||||||
|
- `children::Dict{T, MCTSNode}`
|
||||||
|
children node
|
||||||
|
|
||||||
|
# Return
|
||||||
|
- `nothing`
|
||||||
|
# Example
|
||||||
|
```jldoctest
|
||||||
|
julia> state = Dict(
|
||||||
|
:info=> Dict(), # keyword info
|
||||||
|
:thoughtHistory=> Dict(
|
||||||
|
:question=> _,
|
||||||
|
:thought_1=> _,
|
||||||
|
:action_1=> _,
|
||||||
|
:observation_1=> _,
|
||||||
|
:thought_2=> _,
|
||||||
|
...
|
||||||
|
)
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
# TODO
|
||||||
|
[] update docstring
|
||||||
|
|
||||||
|
# Signature
|
||||||
|
"""
|
||||||
|
mutable struct MCTSNode{T1<:AbstractDict, T2<:AbstractString}
|
||||||
|
nodekey::T2
|
||||||
|
state::T1
|
||||||
|
visits::Integer
|
||||||
|
progressvalue::Number # estimate value by LLM's reasoning
|
||||||
|
statevalue::Number # store discounted commulative reward (gather from its child node)
|
||||||
|
reward::Number # this node's own reward
|
||||||
|
isterminal::Bool
|
||||||
|
parent::Union{MCTSNode, Nothing}
|
||||||
|
children::Dict{String, MCTSNode}
|
||||||
|
end
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
end # module type
|
||||||
139
src/util.jl
Normal file
139
src/util.jl
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
module util
|
||||||
|
|
||||||
|
export UCTselect
|
||||||
|
|
||||||
|
using ..type
|
||||||
|
|
||||||
|
# ---------------------------------------------- 100 --------------------------------------------- #
|
||||||
|
|
||||||
|
""" Select a node based on UCT score
|
||||||
|
|
||||||
|
# Arguments
|
||||||
|
- `node::MCTSNode`
|
||||||
|
mcts node
|
||||||
|
- `w::T`
|
||||||
|
exploration weight. Value is usually between 1 to 2.
|
||||||
|
Value 1.0 makes MCTS balance between exploration and exploitation like 50%-50%.
|
||||||
|
Value 2.0 makes MCTS aggressively search the tree.
|
||||||
|
# Return
|
||||||
|
- `selectedNode::MCTSNode`
|
||||||
|
|
||||||
|
# Example
|
||||||
|
```jldoctest
|
||||||
|
julia>
|
||||||
|
```
|
||||||
|
|
||||||
|
# Signature
|
||||||
|
"""
|
||||||
|
function UCTselect(node::MCTSNode, w::T)::MCTSNode where {T<:AbstractFloat}
|
||||||
|
maxUCT = -Inf
|
||||||
|
selectedNode = nothing
|
||||||
|
|
||||||
|
for (childState, childNode) in node.children
|
||||||
|
UCTvalue =
|
||||||
|
if childNode.visits != 0
|
||||||
|
weightedterm = w * sqrt(log(node.visits) / childNode.visits) # explore term
|
||||||
|
childNode.statevalue + weightedterm
|
||||||
|
else # node.visits == 0 makes sqrt() in explore term error
|
||||||
|
childNode.progressvalue # exploit term
|
||||||
|
end
|
||||||
|
|
||||||
|
if UCTvalue > maxUCT
|
||||||
|
maxUCT = UCTvalue
|
||||||
|
selectedNode = childNode
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
return selectedNode
|
||||||
|
end
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
end # module util
|
||||||
Reference in New Issue
Block a user