update
This commit is contained in:
102
src/mcts.jl
102
src/mcts.jl
@@ -47,10 +47,11 @@ julia> state = Dict(
|
||||
|
||||
# Signature
|
||||
"""
|
||||
mutable struct MCTSNode{T<:AbstractDict}
|
||||
nodekey::String
|
||||
state::T
|
||||
mutable struct MCTSNode{T1<:AbstractDict, T2<:AbstractString}
|
||||
nodekey::T2
|
||||
state::T1
|
||||
visits::Integer
|
||||
stateevaluation::T2
|
||||
statevalue::Number
|
||||
reward::Number
|
||||
isterminal::Bool
|
||||
@@ -74,7 +75,7 @@ julia>
|
||||
|
||||
# TODO
|
||||
[] update docstring
|
||||
[TESTING] check childNode.total_reward w/ LATS paper. Which value total_reward representing
|
||||
[x] check childNode.total_reward w/ LATS paper. Which value total_reward representing
|
||||
|
||||
# Signature
|
||||
"""
|
||||
@@ -83,12 +84,18 @@ function UCTselect(node::MCTSNode, w::Float64)
|
||||
selectedNode = nothing
|
||||
|
||||
for (childState, childNode) in node.children
|
||||
uctValue = childNode.statevalue +
|
||||
w * sqrt(log(node.visits) / childNode.visits)
|
||||
if uctValue > max_uct
|
||||
max_uct = uctValue
|
||||
selectedNode = childNode
|
||||
end
|
||||
weightedterm =
|
||||
if node.visits == 0 || childNode.visits == 0
|
||||
0
|
||||
else
|
||||
w * sqrt(log(node.visits) / childNode.visits)
|
||||
end
|
||||
uctValue = childNode.statevalue + weightedterm
|
||||
|
||||
if uctValue > max_uct
|
||||
max_uct = uctValue
|
||||
selectedNode = childNode
|
||||
end
|
||||
end
|
||||
|
||||
return selectedNode
|
||||
@@ -132,11 +139,10 @@ function expand(a::T1, node::MCTSNode, decisionMaker::Function,
|
||||
isterminal)
|
||||
|
||||
# add progressValueEstimator
|
||||
progressRationale, statevalue = progressValueEstimator(a, newstate)
|
||||
statevalue += reward
|
||||
stateevaluation, statevalue = progressValueEstimator(a, newstate)
|
||||
|
||||
if newNodeKey ∉ keys(node.children)
|
||||
node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, statevalue,
|
||||
node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, stateevaluation, statevalue,
|
||||
reward, isterminalstate, node, Dict{String, MCTSNode}())
|
||||
end
|
||||
end
|
||||
@@ -163,18 +169,18 @@ julia>
|
||||
# Signature
|
||||
"""
|
||||
function simulate(a, node::MCTSNode, decisionMaker::Function, progressValueEstimator::Function,
|
||||
isterminal::Function, max_depth::Int; n=3)::Number
|
||||
isterminal::Function, maxDepth::Int; n=3)::Number
|
||||
|
||||
simTrajectoryReward = 0.0
|
||||
|
||||
for _ in 1:max_depth
|
||||
for depth in 1:maxDepth
|
||||
if node.isterminal
|
||||
break
|
||||
else
|
||||
simTrajectoryReward += node.reward
|
||||
expand(a, node, decisionMaker, progressValueEstimator, isterminal, n=n)
|
||||
node = selectChildNode(node)
|
||||
end
|
||||
node = selectChildNode(node)
|
||||
simTrajectoryReward += node.reward
|
||||
end
|
||||
|
||||
return simTrajectoryReward
|
||||
@@ -216,26 +222,14 @@ julia>
|
||||
# Signature
|
||||
"""
|
||||
function backpropagate(node, simTrajectoryReward; discountRewardCoeff=0.9)
|
||||
# Update the statistics of the current node based on the result of the playout
|
||||
node.visits += 1
|
||||
node.statevalue += ((node.statevalue * (node.visits-1)) + simTrajectoryReward) / node.visits
|
||||
|
||||
# Backpropagate the result to the parent node recursively
|
||||
if !isroot(node)
|
||||
simTrajectoryReward *= discountRewardCoeff
|
||||
backpropagate(node.parent, simTrajectoryReward)
|
||||
while !isroot(node)
|
||||
# Update the statistics of the current node based on the result of the playout
|
||||
node.visits += 1
|
||||
node.statevalue += ((node.statevalue * (node.visits-1)) + simTrajectoryReward) / node.visits
|
||||
simTrajectoryReward *= discountRewardCoeff # discount because future reward is uncertain
|
||||
node = node.parent
|
||||
end
|
||||
end
|
||||
# function backpropagate(node::MCTSNode, reward::Float64)
|
||||
# node.visits += 1
|
||||
|
||||
# # [] there is no total_reward in the paper, buy they use stateValue
|
||||
# node.total_reward += reward
|
||||
# if !isempty(node.children)
|
||||
# best_child = argmax([child.total_reward / child.visits for child in values(node.children)])
|
||||
# backpropagate(node.children[best_child], -reward)
|
||||
# end
|
||||
# end
|
||||
|
||||
|
||||
""" Get a new state
|
||||
@@ -256,18 +250,18 @@ end
|
||||
# Example
|
||||
```jldoctest
|
||||
julia> state = Dict{Symbol, Dict{Symbol, Any}}(
|
||||
:thoughtHistory => Dict(:Question => "Hello, I want to buy a bottle of wine."),
|
||||
:thoughtHistory => Dict(:question => "Hello, I want to buy a bottle of wine."),
|
||||
:storeinfo => Dict(),
|
||||
:customerinfo => Dict()
|
||||
)
|
||||
julia> thoughtDict = Dict(
|
||||
:Question=> "I want to buy a bottle of wine.",
|
||||
:Thought_1=> "The customer wants to buy a bottle of wine.",
|
||||
:Action_1=> Dict{Symbol, Any}(
|
||||
:question=> "I want to buy a bottle of wine.",
|
||||
:thought_1=> "The customer wants to buy a bottle of wine.",
|
||||
:action_1=> Dict{Symbol, Any}(
|
||||
:name=>"Chatbox",
|
||||
:input=>"What occasion are you buying the wine for?",
|
||||
),
|
||||
:Observation_1 => ""
|
||||
:observation_1 => ""
|
||||
)
|
||||
```
|
||||
|
||||
@@ -280,8 +274,8 @@ julia> thoughtDict = Dict(
|
||||
function MCTStransition(a::T1, state::T2, thoughtDict::T3, isterminal::Function
|
||||
)::Tuple{String, Dict{Symbol, <:Any}, Bool, <:Number} where {T1<:agent, T2<:AbstractDict, T3<:AbstractDict}
|
||||
|
||||
actionname = thoughtDict[:Action][:name]
|
||||
actioninput = thoughtDict[:Action][:input]
|
||||
actionname = thoughtDict[:action][:name]
|
||||
actioninput = thoughtDict[:action][:input]
|
||||
|
||||
# map action and input() to llm function
|
||||
response =
|
||||
@@ -289,23 +283,23 @@ function MCTStransition(a::T1, state::T2, thoughtDict::T3, isterminal::Function
|
||||
virtualWineCustomerChatbox(a, actioninput) # virtual customer
|
||||
elseif actionname == "winestock"
|
||||
winestock(a, actioninput)
|
||||
elseif actionname == "reccommendbox"
|
||||
elseif actionname == "recommendbox"
|
||||
virtualWineCustomerReccommendbox(a, actioninput)
|
||||
else
|
||||
error("undefined LLM function. Requesting $actionname")
|
||||
end
|
||||
|
||||
latestThoughtKey, latestThoughtIndice = GeneralUtils.findHighestIndexKey(state[:thoughtHistory],
|
||||
"Thought")
|
||||
"thought")
|
||||
nextIndice = latestThoughtKey == :NA ? 1 : latestThoughtIndice + 1
|
||||
latestThoughtKey = Symbol("Thought_$nextIndice")
|
||||
latestActionKey = Symbol("Action_$nextIndice")
|
||||
latestThoughtKey = Symbol("thought_$nextIndice")
|
||||
latestActionKey = Symbol("action_$nextIndice")
|
||||
|
||||
# add Thought, action, observation to thoughtHistory
|
||||
newstate = deepcopy(state)
|
||||
newstate[:thoughtHistory][latestThoughtKey] = thoughtDict[:Thought]
|
||||
newstate[:thoughtHistory][latestActionKey] = thoughtDict[:Action]
|
||||
latestObservationKey = Symbol("Observation_$(nextIndice)")
|
||||
newstate[:thoughtHistory][latestThoughtKey] = thoughtDict[:thought]
|
||||
newstate[:thoughtHistory][latestActionKey] = thoughtDict[:action]
|
||||
latestObservationKey = Symbol("observation_$(nextIndice)")
|
||||
newstate[:thoughtHistory][latestObservationKey] = response
|
||||
|
||||
newNodeKey = GeneralUtils.uuid4snakecase()
|
||||
@@ -332,7 +326,7 @@ julia> initialState = Dict{Symbol, Any}(
|
||||
:storeinfo=> Dict{Symbol, Any}(),
|
||||
|
||||
:thoughtHistory=> OrderedDict{Symbol, Any}(
|
||||
:Question=> "How are you?",
|
||||
:question=> "How are you?",
|
||||
)
|
||||
)
|
||||
julia> statetype = typeof(initialState)
|
||||
@@ -341,6 +335,9 @@ julia> YiemAgent.isleaf(root)
|
||||
true
|
||||
```
|
||||
|
||||
# TODO
|
||||
[] update docs
|
||||
|
||||
# Signature
|
||||
"""
|
||||
isleaf(node::MCTSNode)::Bool = isempty(node.children)
|
||||
@@ -451,9 +448,9 @@ function runMCTS(
|
||||
maxIterations::Integer,
|
||||
w::Float64) where {T1<:agent}
|
||||
|
||||
root = MCTSNode("root", initialState, 0, 0, 0, false, nothing, Dict{String, MCTSNode}())
|
||||
root = MCTSNode("root", initialState, 0, "N/A", 0, 0, false, nothing, Dict{String, MCTSNode}())
|
||||
|
||||
for _ in 1:maxIterations
|
||||
for nth in 1:maxIterations
|
||||
node = root
|
||||
while !isleaf(node)
|
||||
node = UCTselect(node, w)
|
||||
@@ -462,6 +459,7 @@ function runMCTS(
|
||||
expand(a, node, decisionMaker, progressValueEstimator, isterminal, n=n)
|
||||
|
||||
leaf_node = selectChildNode(node)
|
||||
# BUG i didn't assign parent node for this leaf node yet
|
||||
simTrajectoryReward = simulate(a, leaf_node, decisionMaker, progressValueEstimator,
|
||||
isterminal, maxDepth, n=n)
|
||||
backpropagate(leaf_node, simTrajectoryReward)
|
||||
|
||||
Reference in New Issue
Block a user