This commit is contained in:
narawat lamaiin
2024-06-01 08:17:48 +07:00
parent 74a4190440
commit 2378ddfa70
4 changed files with 458 additions and 94 deletions

View File

@@ -1,7 +1,9 @@
module mcts
export selectBestNextState, selectBestTrajectory, backpropagate, isleaf, isroot, selectChildNode,
expand, simulate, mctstransition
expand, simulate, makeNewState
using GeneralUtils
using ..type
@@ -242,7 +244,7 @@ julia>
# Signature
"""
function expand(workDict::T1, node::MCTSNode, decisionMaker::Function, evaluator::Function,
function expand(config::T1, node::MCTSNode, decisionMaker::Function, evaluator::Function,
reflector::Function, transition::Function; totalsample::Integer=3
) where {T1<:AbstractDict}
@@ -250,32 +252,8 @@ function expand(workDict::T1, node::MCTSNode, decisionMaker::Function, evaluator
while true
nthSample += 1
if nthSample <= totalsample
thoughtDict = decisionMaker(workDict, node.state)
println("---> expand() sample $nthSample")
pprintln(node.state[:thoughtHistory])
pprintln(thoughtDict)
newNodeKey, newstate = mctstransition(workDict, transition, node.state, thoughtDict)
stateevaluation, progressvalue = evaluator(workDict, newstate)
if newstate[:reward] < 0
pprint(newstate[:thoughtHistory])
newstate[:evaluation] = stateevaluation
newstate[:lesson] = reflector(workDict, newstate)
# store new lesson for later use
lessonDict = copy(JSON3.read("lesson.json"))
latestLessonKey, latestLessonIndice =
GeneralUtils.findHighestIndexKey(lessonDict, "lesson")
nextIndice = latestLessonKey == :NA ? 1 : latestLessonIndice + 1
newLessonKey = Symbol("lesson_$(nextIndice)")
lessonDict[newLessonKey] = newstate
open("lesson.json", "w") do io
JSON3.pretty(io, lessonDict)
end
print("---> reflector()")
end
newNodeKey, newstate, progressvalue = transition(config, node.state, decisionMaker,
evaluator, reflector)
if newNodeKey keys(node.children)
node.children[newNodeKey] =
MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward],
@@ -311,8 +289,8 @@ julia>
# Signature
"""
function simulate(workDict::T, node::MCTSNode, decisionMaker::Function, evaluator::Function,
reflector::Function; maxDepth::Integer=3, totalsample::Integer=3
function simulate(config::T, node::MCTSNode, decisionMaker::Function, evaluator::Function,
reflector::Function, transition::Function; maxDepth::Integer=3, totalsample::Integer=3
)::Union{Tuple{Number, Dict{Symbol, <:Any}}, Tuple{Number, Nothing}} where {T<:AbstractDict}
simTrajectoryReward = 0.0
@@ -324,7 +302,8 @@ function simulate(workDict::T, node::MCTSNode, decisionMaker::Function, evaluato
terminalstate = node.state
break
else
expand(workDict, node, decisionMaker, evaluator, reflector; totalsample=totalsample)
expand(config, node, decisionMaker, evaluator, reflector, transition;
totalsample=totalsample)
node = selectChildNode(node)
end
end
@@ -333,71 +312,58 @@ function simulate(workDict::T, node::MCTSNode, decisionMaker::Function, evaluato
end
""" Get a new state
"""
# Arguments
- `a::T1`
one of YiemAgent's agent
- `state::T2`
current game state
- `thoughtDict::T3`
contain Thought, Action, Observation
- `isterminal::Function`
a function to determine terminal state
# Return
- `(newNodeKey, newstate, isterminalstate, reward)::Tuple{String, Dict{Symbol, <:Any}, Bool, <:Number}`
# Example
```jldoctest
julia> state = Dict{Symbol, Dict{Symbol, Any}}(
:thoughtHistory => Dict(:question => "Hello, I want to buy a bottle of wine."),
:storeinfo => Dict(),
:customerinfo => Dict()
)
julia> thoughtDict = Dict(
:question=> "I want to buy a bottle of wine.",
:thought_1=> "The customer wants to buy a bottle of wine.",
:action_1=> Dict{Symbol, Any}(
:name=>"Chatbox",
:input=>"What occasion are you buying the wine for?",
),
:observation_1 => ""
)
julia>
```
# TODO
- [] add other actions
- [WORKING] add embedding of newstate and store in newstate[:embedding]
- [] update docstring
- [x] implement the function
# Signature
"""
function mctstransition(workDict::T1, transition::Function, state::T2, thoughtDict::T2
)::Tuple{String, Dict{Symbol, <:Any}} where {T1<:AbstractDict, T2<:AbstractDict}
error("--> mctstransition")
# actionname = thoughtDict[:action][:name]
# actioninput = thoughtDict[:action][:input]
function makeNewState(currentstate::T1, thoughtDict::T4, response::T2, select::Union{T3, Nothing},
reward::T3, isterminal::Bool
)::Tuple{String, Dict{Symbol, <:Any}} where {T1<:AbstractDict, T2<:AbstractString, T3<:Number, T4<:AbstractDict}
# # map action and input() to llm function
# response, select, reward, isterminal =
# if actionname == "chatbox"
# # deepcopy(state[:virtualCustomerChatHistory]) because I want to keep it clean
# # so that other simulation start from this same node is not contaminated with actioninput
# virtualWineUserChatbox(workDict, actioninput, deepcopy(state[:virtualCustomerChatHistory])) # virtual customer
# elseif actionname == "winestock"
# winestock(a, actioninput)
# elseif actionname == "recommendbox"
# virtualWineUserRecommendbox(workDict, actioninput)
# else
# error("undefined LLM function. Requesting $actionname")
# end
currentstate_latestThoughtKey, currentstate_latestThoughtIndice =
GeneralUtils.findHighestIndexKey(currentstate[:thoughtHistory], "thought")
currentstate_nextIndice = currentstate_latestThoughtKey == :NA ? 1 : currentstate_latestThoughtIndice + 1
currentstate_latestThoughtKey = Symbol("thought_$currentstate_nextIndice")
latestActionKey = Symbol("action_$currentstate_nextIndice")
# newNodeKey, newstate = makeNewState(state, thoughtDict, response, select, reward, isterminal)
# if actionname == "chatbox"
# push!(newstate[:virtualCustomerChatHistory], Dict(:name=>"assistant", :text=> actioninput) )
# push!(newstate[:virtualCustomerChatHistory], Dict(:name=>"user", :text=> response))
# end
_, thoughtDict_latestThoughtIndice =
GeneralUtils.findHighestIndexKey(thoughtDict, "thought")
thoughtDict_latestThoughtKey, thoughtDict_latestActionKey =
if thoughtDict_latestThoughtIndice == -1
(:thought, :action)
else
(
Symbol("thought_$thoughtDict_latestThoughtIndice"),
Symbol("action_$thoughtDict_latestThoughtIndice"),
)
end
# add Thought, action, observation to thoughtHistory
newstate = deepcopy(currentstate)
newstate[:thoughtHistory][currentstate_latestThoughtKey] =
thoughtDict[thoughtDict_latestThoughtKey]
newstate[:thoughtHistory][latestActionKey] = thoughtDict[thoughtDict_latestActionKey]
newObservationKey = Symbol("observation_$(currentstate_nextIndice)")
newstate[:thoughtHistory][newObservationKey] = response
newstate[:reward] = reward
newstate[:select] = select
newstate[:isterminal] = isterminal
newNodeKey = GeneralUtils.uuid4snakecase()
return (newNodeKey, newstate)
end
@@ -460,13 +426,6 @@ end