update
This commit is contained in:
116
src/mcts.jl
116
src/mcts.jl
@@ -5,7 +5,7 @@
|
||||
|
||||
module mcts
|
||||
|
||||
export MCTSNode, runMCTS
|
||||
export MCTSNode, runMCTS, isleaf
|
||||
|
||||
using Dates, UUIDs, DataStructures, JSON3, Random
|
||||
using GeneralUtils
|
||||
@@ -18,34 +18,30 @@ using ..type, ..llmfunction
|
||||
# Arguments
|
||||
- `state::T`
|
||||
a state of a game. Can be a Dict or something else.
|
||||
For example:
|
||||
state = Dict(
|
||||
:info=> Dict(), # keyword info
|
||||
:thoughtHistory=> Dict(
|
||||
:question=> _,
|
||||
:thought_1=> _,
|
||||
:action_1=> _,
|
||||
:observation_1=> _,
|
||||
:thought_2=> _,
|
||||
...
|
||||
)
|
||||
)
|
||||
- `visits::Integer `
|
||||
number of time the game visits this state
|
||||
- `stateValue::Float64`
|
||||
state value
|
||||
- `children::Dict{T, MCTSNode}`
|
||||
children node
|
||||
|
||||
# Return
|
||||
|
||||
- `nothing`
|
||||
# Example
|
||||
```jldoctest
|
||||
julia>
|
||||
julia> state = Dict(
|
||||
:info=> Dict(), # keyword info
|
||||
:thoughtHistory=> Dict(
|
||||
:question=> _,
|
||||
:thought_1=> _,
|
||||
:action_1=> _,
|
||||
:observation_1=> _,
|
||||
:thought_2=> _,
|
||||
...
|
||||
)
|
||||
)
|
||||
```
|
||||
|
||||
# TODO
|
||||
[] update docstring
|
||||
[x] implement the function
|
||||
|
||||
# Signature
|
||||
"""
|
||||
struct MCTSNode{T<:AbstractDict}
|
||||
@@ -131,14 +127,15 @@ function expand(a::T1, node::MCTSNode, state::T2, decisionMaker::Function, state
|
||||
)
|
||||
:Observation_1 => ""
|
||||
"""
|
||||
|
||||
@show state
|
||||
@show thoughtDict
|
||||
newstate = MCTStransition(a, node.state, thoughtDict) #[] Implement your transition function
|
||||
|
||||
if newstate ∉ keys(node.children)
|
||||
node.children[newstate] = MCTSNode(newstate, 0, 0.0, Dict{T, MCTSNode}())
|
||||
statetype = typeof(state)
|
||||
node.children[newstate] = MCTSNode(newstate, 0, 0.0, Dict{statetype, MCTSNode}())
|
||||
end
|
||||
end
|
||||
error("--> expand")
|
||||
end
|
||||
end
|
||||
|
||||
"""
|
||||
@@ -208,7 +205,7 @@ end
|
||||
one of YiemAgent's agent
|
||||
- `state::T2`
|
||||
current game state
|
||||
- `thoughtDict::T2`
|
||||
- `thoughtDict::T3`
|
||||
contain Thought, Action, Observation
|
||||
|
||||
# Return
|
||||
@@ -217,26 +214,32 @@ end
|
||||
|
||||
# Example
|
||||
```jldoctest
|
||||
julia> thoughtDict = Dict(
|
||||
:Question=> "I want to buy a bottle of wine."
|
||||
:Thought_1=> "The customer wants to buy a bottle of wine. This is a good start!",
|
||||
:Action_1=> Dict{Symbol, Any}(
|
||||
:name=>"Chatbox",
|
||||
:input=>"What occasion are you buying the wine for?"
|
||||
),
|
||||
:Observation_1 => ""
|
||||
)
|
||||
julia> state = Dict{Symbol, Dict{Symbol, Any}}(
|
||||
:thoughtHistory => Dict(:Question => "Hello, I want to buy a bottle of wine."),
|
||||
:storeinfo => Dict(),
|
||||
:customerinfo => Dict()
|
||||
)
|
||||
julia> thoughtDict = Dict(
|
||||
:Question=> "I want to buy a bottle of wine.",
|
||||
:Thought_1=> "The customer wants to buy a bottle of wine.",
|
||||
:Action_1=> Dict{Symbol, Any}(
|
||||
:name=>"Chatbox",
|
||||
:input=>"What occasion are you buying the wine for?",
|
||||
),
|
||||
:Observation_1 => ""
|
||||
)
|
||||
```
|
||||
|
||||
# TODO
|
||||
- [x] update docstring
|
||||
- [TESTING] implement the function
|
||||
- [] update docstring
|
||||
- [PENDING] add other actions
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function MCTStransition(a::T1, state::T2, thoughtDict::T2)::AbstractDict where {T1<:agent, T2<:AbstractDict}
|
||||
latestThoughtKey, latestindice = GeneralUtils.findHighestIndexKey(thoughtDict, "Thought")
|
||||
latestActionKey = GeneralUtils.findHighestIndexKey(thoughtDict, "Action")
|
||||
function MCTStransition(a::T1, state::T2,
|
||||
thoughtDict::T3)::AbstractDict where {T1<:agent, T2<:AbstractDict, T3<:AbstractDict}
|
||||
latestThoughtKey, _ = GeneralUtils.findHighestIndexKey(thoughtDict, "Thought")
|
||||
latestActionKey, latestActionIndice = GeneralUtils.findHighestIndexKey(thoughtDict, "Action")
|
||||
_action = thoughtDict[latestActionKey]
|
||||
actionname = _action[:name]
|
||||
actioninput = _action[:input]
|
||||
@@ -244,7 +247,7 @@ function MCTStransition(a::T1, state::T2, thoughtDict::T2)::AbstractDict where {
|
||||
# map action and input() to llm function
|
||||
response =
|
||||
if actionname == "chatbox"
|
||||
virtualWineCustomerChatbox(a, actioninput) # user virtu
|
||||
virtualWineCustomerChatbox(a, actioninput) # virtual customer
|
||||
elseif actionname == "winestock"
|
||||
|
||||
elseif actionname == "finish"
|
||||
@@ -257,28 +260,38 @@ function MCTStransition(a::T1, state::T2, thoughtDict::T2)::AbstractDict where {
|
||||
newstate = deepcopy(state)
|
||||
newstate[:thoughtHistory][latestThoughtKey] = thoughtDict[latestThoughtKey]
|
||||
newstate[:thoughtHistory][latestActionKey] = thoughtDict[latestActionKey]
|
||||
latestObservationKey = Symbol("Observation_$(latestindice)")
|
||||
latestObservationKey = Symbol("Observation_$(latestActionIndice)")
|
||||
newstate[:thoughtHistory][latestObservationKey] = response
|
||||
|
||||
|
||||
error("--> transition")
|
||||
return newstate
|
||||
end
|
||||
|
||||
"""
|
||||
|
||||
""" Determine whether a node is a leaf node of a search tree.
|
||||
|
||||
# Arguments
|
||||
|
||||
- `node::MCTSNode`
|
||||
a search tree node
|
||||
# Return
|
||||
|
||||
- `result::Bool`
|
||||
true if it is a leaf node, false otherwise.
|
||||
# Example
|
||||
```jldoctest
|
||||
julia>
|
||||
```
|
||||
julia> using Revise
|
||||
julia> using YiemAgent, DataStructures
|
||||
julia> initialState = Dict{Symbol, Any}(
|
||||
:customerinfo=> Dict{Symbol, Any}(),
|
||||
:storeinfo=> Dict{Symbol, Any}(),
|
||||
|
||||
# TODO
|
||||
- [] update docstring
|
||||
- [x] implement the function
|
||||
:thoughtHistory=> OrderedDict{Symbol, Any}(
|
||||
:Question=> "How are you?",
|
||||
)
|
||||
)
|
||||
julia> statetype = typeof(initialState)
|
||||
julia> root = YiemAgent.MCTSNode(initialState, 0, 0.0, Dict{statetype, YiemAgent.MCTSNode}())
|
||||
julia> YiemAgent.isleaf(root)
|
||||
true
|
||||
```
|
||||
|
||||
# Signature
|
||||
"""
|
||||
@@ -366,13 +379,14 @@ function runMCTS(
|
||||
end
|
||||
|
||||
expand(a, node, node.state, decisionMaker, stateValueEstimator, n=n)
|
||||
error("---> runMCTS")
|
||||
|
||||
leaf_node = node.children[node.state] # mark leaf node
|
||||
reward = simulate(leaf_node.state, maxDepth)
|
||||
backpropagate(leaf_node, reward)
|
||||
end
|
||||
|
||||
best_child_state = argmax([child.total_reward / child.visits for child in values(root.children)])
|
||||
error("---> runMCTS")
|
||||
return best_child_state
|
||||
end
|
||||
|
||||
|
||||
Reference in New Issue
Block a user