update
This commit is contained in:
@@ -47,7 +47,7 @@ julia>
|
||||
# Signature
|
||||
"""
|
||||
function runMCTS(
|
||||
workDict::T1,
|
||||
config::T1,
|
||||
initialState,
|
||||
decisionMaker::Function,
|
||||
evaluator::Function,
|
||||
@@ -74,11 +74,11 @@ function runMCTS(
|
||||
# do nothing then go directly to backpropagation
|
||||
backpropagate(leafNode, node.reward)
|
||||
else
|
||||
expand(workDict, node, decisionMaker, evaluator, reflector, transition;
|
||||
expand(config, node, decisionMaker, evaluator, reflector, transition;
|
||||
totalsample=totalsample)
|
||||
leafNode = selectChildNode(node)
|
||||
simTrajectoryReward, terminalstate = simulate(workDict, leafNode, decisionMaker, evaluator,
|
||||
reflector; maxDepth=maxDepth, totalsample=totalsample)
|
||||
simTrajectoryReward, terminalstate = simulate(config, leafNode, decisionMaker, evaluator,
|
||||
reflector, transition; maxDepth=maxDepth, totalsample=totalsample)
|
||||
if terminalstate !== nothing #XXX not sure why I need this
|
||||
terminalstate[:totalTrajectoryReward] = simTrajectoryReward
|
||||
end
|
||||
|
||||
137
src/mcts.jl
137
src/mcts.jl
@@ -1,7 +1,9 @@
|
||||
module mcts
|
||||
|
||||
export selectBestNextState, selectBestTrajectory, backpropagate, isleaf, isroot, selectChildNode,
|
||||
expand, simulate, mctstransition
|
||||
expand, simulate, makeNewState
|
||||
|
||||
using GeneralUtils
|
||||
|
||||
using ..type
|
||||
|
||||
@@ -242,7 +244,7 @@ julia>
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function expand(workDict::T1, node::MCTSNode, decisionMaker::Function, evaluator::Function,
|
||||
function expand(config::T1, node::MCTSNode, decisionMaker::Function, evaluator::Function,
|
||||
reflector::Function, transition::Function; totalsample::Integer=3
|
||||
) where {T1<:AbstractDict}
|
||||
|
||||
@@ -250,32 +252,8 @@ function expand(workDict::T1, node::MCTSNode, decisionMaker::Function, evaluator
|
||||
while true
|
||||
nthSample += 1
|
||||
if nthSample <= totalsample
|
||||
thoughtDict = decisionMaker(workDict, node.state)
|
||||
println("---> expand() sample $nthSample")
|
||||
pprintln(node.state[:thoughtHistory])
|
||||
pprintln(thoughtDict)
|
||||
newNodeKey, newstate = mctstransition(workDict, transition, node.state, thoughtDict)
|
||||
|
||||
stateevaluation, progressvalue = evaluator(workDict, newstate)
|
||||
|
||||
if newstate[:reward] < 0
|
||||
pprint(newstate[:thoughtHistory])
|
||||
newstate[:evaluation] = stateevaluation
|
||||
newstate[:lesson] = reflector(workDict, newstate)
|
||||
|
||||
# store new lesson for later use
|
||||
lessonDict = copy(JSON3.read("lesson.json"))
|
||||
latestLessonKey, latestLessonIndice =
|
||||
GeneralUtils.findHighestIndexKey(lessonDict, "lesson")
|
||||
nextIndice = latestLessonKey == :NA ? 1 : latestLessonIndice + 1
|
||||
newLessonKey = Symbol("lesson_$(nextIndice)")
|
||||
lessonDict[newLessonKey] = newstate
|
||||
open("lesson.json", "w") do io
|
||||
JSON3.pretty(io, lessonDict)
|
||||
end
|
||||
print("---> reflector()")
|
||||
end
|
||||
|
||||
newNodeKey, newstate, progressvalue = transition(config, node.state, decisionMaker,
|
||||
evaluator, reflector)
|
||||
if newNodeKey ∉ keys(node.children)
|
||||
node.children[newNodeKey] =
|
||||
MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward],
|
||||
@@ -311,8 +289,8 @@ julia>
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function simulate(workDict::T, node::MCTSNode, decisionMaker::Function, evaluator::Function,
|
||||
reflector::Function; maxDepth::Integer=3, totalsample::Integer=3
|
||||
function simulate(config::T, node::MCTSNode, decisionMaker::Function, evaluator::Function,
|
||||
reflector::Function, transition::Function; maxDepth::Integer=3, totalsample::Integer=3
|
||||
)::Union{Tuple{Number, Dict{Symbol, <:Any}}, Tuple{Number, Nothing}} where {T<:AbstractDict}
|
||||
|
||||
simTrajectoryReward = 0.0
|
||||
@@ -324,7 +302,8 @@ function simulate(workDict::T, node::MCTSNode, decisionMaker::Function, evaluato
|
||||
terminalstate = node.state
|
||||
break
|
||||
else
|
||||
expand(workDict, node, decisionMaker, evaluator, reflector; totalsample=totalsample)
|
||||
expand(config, node, decisionMaker, evaluator, reflector, transition;
|
||||
totalsample=totalsample)
|
||||
node = selectChildNode(node)
|
||||
end
|
||||
end
|
||||
@@ -333,71 +312,58 @@ function simulate(workDict::T, node::MCTSNode, decisionMaker::Function, evaluato
|
||||
end
|
||||
|
||||
|
||||
|
||||
""" Get a new state
|
||||
"""
|
||||
|
||||
# Arguments
|
||||
- `a::T1`
|
||||
one of YiemAgent's agent
|
||||
- `state::T2`
|
||||
current game state
|
||||
- `thoughtDict::T3`
|
||||
contain Thought, Action, Observation
|
||||
- `isterminal::Function`
|
||||
a function to determine terminal state
|
||||
|
||||
|
||||
# Return
|
||||
- `(newNodeKey, newstate, isterminalstate, reward)::Tuple{String, Dict{Symbol, <:Any}, Bool, <:Number}`
|
||||
|
||||
# Example
|
||||
```jldoctest
|
||||
julia> state = Dict{Symbol, Dict{Symbol, Any}}(
|
||||
:thoughtHistory => Dict(:question => "Hello, I want to buy a bottle of wine."),
|
||||
:storeinfo => Dict(),
|
||||
:customerinfo => Dict()
|
||||
)
|
||||
julia> thoughtDict = Dict(
|
||||
:question=> "I want to buy a bottle of wine.",
|
||||
:thought_1=> "The customer wants to buy a bottle of wine.",
|
||||
:action_1=> Dict{Symbol, Any}(
|
||||
:name=>"Chatbox",
|
||||
:input=>"What occasion are you buying the wine for?",
|
||||
),
|
||||
:observation_1 => ""
|
||||
)
|
||||
julia>
|
||||
```
|
||||
|
||||
# TODO
|
||||
- [] add other actions
|
||||
- [WORKING] add embedding of newstate and store in newstate[:embedding]
|
||||
- [] update docstring
|
||||
- [x] implement the function
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function mctstransition(workDict::T1, transition::Function, state::T2, thoughtDict::T2
|
||||
)::Tuple{String, Dict{Symbol, <:Any}} where {T1<:AbstractDict, T2<:AbstractDict}
|
||||
error("--> mctstransition")
|
||||
# actionname = thoughtDict[:action][:name]
|
||||
# actioninput = thoughtDict[:action][:input]
|
||||
function makeNewState(currentstate::T1, thoughtDict::T4, response::T2, select::Union{T3, Nothing},
|
||||
reward::T3, isterminal::Bool
|
||||
)::Tuple{String, Dict{Symbol, <:Any}} where {T1<:AbstractDict, T2<:AbstractString, T3<:Number, T4<:AbstractDict}
|
||||
|
||||
# # map action and input() to llm function
|
||||
# response, select, reward, isterminal =
|
||||
# if actionname == "chatbox"
|
||||
# # deepcopy(state[:virtualCustomerChatHistory]) because I want to keep it clean
|
||||
# # so that other simulation start from this same node is not contaminated with actioninput
|
||||
# virtualWineUserChatbox(workDict, actioninput, deepcopy(state[:virtualCustomerChatHistory])) # virtual customer
|
||||
# elseif actionname == "winestock"
|
||||
# winestock(a, actioninput)
|
||||
# elseif actionname == "recommendbox"
|
||||
# virtualWineUserRecommendbox(workDict, actioninput)
|
||||
# else
|
||||
# error("undefined LLM function. Requesting $actionname")
|
||||
# end
|
||||
currentstate_latestThoughtKey, currentstate_latestThoughtIndice =
|
||||
GeneralUtils.findHighestIndexKey(currentstate[:thoughtHistory], "thought")
|
||||
currentstate_nextIndice = currentstate_latestThoughtKey == :NA ? 1 : currentstate_latestThoughtIndice + 1
|
||||
currentstate_latestThoughtKey = Symbol("thought_$currentstate_nextIndice")
|
||||
latestActionKey = Symbol("action_$currentstate_nextIndice")
|
||||
|
||||
# newNodeKey, newstate = makeNewState(state, thoughtDict, response, select, reward, isterminal)
|
||||
# if actionname == "chatbox"
|
||||
# push!(newstate[:virtualCustomerChatHistory], Dict(:name=>"assistant", :text=> actioninput) )
|
||||
# push!(newstate[:virtualCustomerChatHistory], Dict(:name=>"user", :text=> response))
|
||||
# end
|
||||
_, thoughtDict_latestThoughtIndice =
|
||||
GeneralUtils.findHighestIndexKey(thoughtDict, "thought")
|
||||
|
||||
thoughtDict_latestThoughtKey, thoughtDict_latestActionKey =
|
||||
if thoughtDict_latestThoughtIndice == -1
|
||||
(:thought, :action)
|
||||
else
|
||||
(
|
||||
Symbol("thought_$thoughtDict_latestThoughtIndice"),
|
||||
Symbol("action_$thoughtDict_latestThoughtIndice"),
|
||||
)
|
||||
end
|
||||
|
||||
# add Thought, action, observation to thoughtHistory
|
||||
newstate = deepcopy(currentstate)
|
||||
newstate[:thoughtHistory][currentstate_latestThoughtKey] =
|
||||
thoughtDict[thoughtDict_latestThoughtKey]
|
||||
newstate[:thoughtHistory][latestActionKey] = thoughtDict[thoughtDict_latestActionKey]
|
||||
newObservationKey = Symbol("observation_$(currentstate_nextIndice)")
|
||||
newstate[:thoughtHistory][newObservationKey] = response
|
||||
newstate[:reward] = reward
|
||||
newstate[:select] = select
|
||||
newstate[:isterminal] = isterminal
|
||||
|
||||
newNodeKey = GeneralUtils.uuid4snakecase()
|
||||
|
||||
return (newNodeKey, newstate)
|
||||
end
|
||||
@@ -460,13 +426,6 @@ end
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user