This commit is contained in:
narawat lamaiin
2024-05-07 13:25:18 +07:00
parent 43e7ba3991
commit b93264ae58
4 changed files with 110 additions and 141 deletions

View File

@@ -333,7 +333,6 @@ end
# ``` # ```
# # TODO # # TODO
# [PENDING] add Reflect()
# # Signature # # Signature
# """ # """
@@ -407,8 +406,9 @@ julia> response = ChatAgent.conversation(newAgent, "Hi! how are you?")
``` ```
# TODO # TODO
- [] update docstring - [] update docstring
- [PENDING] MCTS() for planning - [WORKING] MCTS() for planning
- [] add recap to initialState for earlier completed question
# Signature # Signature
""" """
@@ -441,11 +441,12 @@ function conversation(a::T, userinput::Dict) where {T<:agent}
:reward=> 0, :reward=> 0,
:isterminal=> false, :isterminal=> false,
:thoughtHistory=> OrderedDict{Symbol, Any}( # contain question, thought_1, action_1, observation_1, thought_2, ... :thoughtHistory=> OrderedDict{Symbol, Any}( # contain question, thought_1, action_1, observation_1, thought_2, ...
# :recap=>,
:question=> userinput[:text], :question=> userinput[:text],
) )
) )
bestplan = runMCTS(a, initialState, decisionMaker, progressValueEstimator, reflector, bestplan = runMCTS(a, initialState, decisionMaker, progressValueEstimator, reflector,
isterminal, 2, 3, 3, 1.0) 2, 3, 4, 1.0)
error("---> bestplan") error("---> bestplan")
# actor loop(bestplan) # actor loop(bestplan)

View File

@@ -69,7 +69,8 @@ julia>
# Signature # Signature
""" """
function virtualWineCustomerReccommendbox(a::T1, input::T2)::String where {T1<:agent, T2<:AbstractString} function virtualWineCustomerReccommendbox(a::T1, input
)::Union{Tuple{String, Number, Number, Bool}, Tuple{String, Nothing, Number, Bool}} where {T1<:agent}
input = "I reccomment Zeno crown vista" input = "I reccomment Zeno crown vista"
@@ -92,7 +93,7 @@ function virtualWineCustomerReccommendbox(a::T1, input::T2)::String where {T1<:a
mqttBroker= a.config[:mqttServerInfo][:broker], mqttBroker= a.config[:mqttServerInfo][:broker],
mqttBrokerPort= a.config[:mqttServerInfo][:port], mqttBrokerPort= a.config[:mqttServerInfo][:port],
msgId = "dummyid" #CHANGE remove after testing finished msgId = "dummyid" #CHANGE remove after testing finished
) )
outgoingMsg = Dict( outgoingMsg = Dict(
:msgMeta=> msgMeta, :msgMeta=> msgMeta,
@@ -102,9 +103,9 @@ function virtualWineCustomerReccommendbox(a::T1, input::T2)::String where {T1<:a
) )
@show outgoingMsg @show outgoingMsg
result = GeneralUtils.sendReceiveMqttMsg(outgoingMsg; timeout=120) result = GeneralUtils.sendReceiveMqttMsg(outgoingMsg; timeout=120)
response = result[:response][:text] response = result[:response]
return response return (response[:text], response[:select], response[:reward], response[:isterminal])
end end
@@ -131,7 +132,8 @@ julia>
# Signature # Signature
""" """
function virtualWineCustomerChatbox(a::T1, input::T2)::String where {T1<:agent, T2<:AbstractString} function virtualWineCustomerChatbox(a::T1, input::T2
)::Union{Tuple{String, Number, Number, Bool}, Tuple{String, Nothing, Number, Bool}} where {T1<:agent, T2<:AbstractString}
# put in model format # put in model format
virtualWineCustomer = a.config[:externalservice][:virtualWineCustomer_1] virtualWineCustomer = a.config[:externalservice][:virtualWineCustomer_1]
@@ -192,7 +194,9 @@ julia> result = winestock(agent, input)
# Signature # Signature
""" """
function winestock(a::T1, input::T2) where {T1<:agent, T2<:AbstractString} function winestock(a::T1, input::T2
)::Union{Tuple{String, Number, Number, Bool}, Tuple{String, Nothing, Number, Bool}} where {T1<:agent, T2<:AbstractString}
winesStr = winesStr =
""" """
1: El Enemigo Cabernet Franc 2019 1: El Enemigo Cabernet Franc 2019
@@ -205,7 +209,7 @@ function winestock(a::T1, input::T2) where {T1<:agent, T2<:AbstractString}
$winesStr $winesStr
} }
""" """
return result return result, nothing, 0, false
end end
@@ -218,17 +222,14 @@ end
text to be send to virtual wine customer text to be send to virtual wine customer
# Return # Return
- `response::String` - `correctjson::String`
response of virtual wine customer corrected json string
# Example # Example
```jldoctest ```jldoctest
julia> julia>
``` ```
# TODO
- [] update docstring
- [x] implement the function
# Signature # Signature
""" """
function jsoncorrection(a::T1, input::T2, function jsoncorrection(a::T1, input::T2,
@@ -306,77 +307,6 @@ function jsoncorrection(a::T1, input::T2,
return correctjson return correctjson
end end
# function jsoncorrection(a::T1, input::T2,
# correctJsonExample::T3) where {T1<:agent, T2<:AbstractString, T3<:AbstractString}
# attemptround = 0
# incorrectjson = deepcopy(input)
# correctjson = nothing
# while true
# attemptround += 1
# if attemptround <= 5
# try
# JSON3.read(incorrectjson)
# correctjson = incorrectjson
# break
# catch
# @warn "Attempting correct JSON string. $attemptround"
# incorrectjson = deepcopy(input)
# _prompt =
# """
# Your goal is to correct a given incorrect JSON format while retaining original content.
# $correctJsonExample
# Incorrect JSON:
# $incorrectjson
# Corrention:
# """
# # apply LLM specific instruct format
# externalService = a.config[:externalservice][:text2textinstruct]
# llminfo = externalService[:llminfo]
# prompt =
# if llminfo[:name] == "llama3instruct"
# formatLLMtext_llama3instruct("system", _prompt)
# else
# error("llm model name is not defied yet $(@__LINE__)")
# end
# # send formatted input to user using GeneralUtils.sendReceiveMqttMsg
# msgMeta = GeneralUtils.generate_msgMeta(
# externalService[:mqtttopic],
# senderName= "jsoncorrection",
# senderId= a.id,
# receiverName= "text2textinstruct",
# mqttBroker= a.config[:mqttServerInfo][:broker],
# mqttBrokerPort= a.config[:mqttServerInfo][:port],
# )
# outgoingMsg = Dict(
# :msgMeta=> msgMeta,
# :payload=> Dict(
# :text=> prompt,
# :kwargs=> Dict(
# :max_tokens=> 512,
# :stop=> ["<|eot_id|>"],
# )
# )
# )
# result = GeneralUtils.sendReceiveMqttMsg(outgoingMsg)
# incorrectjson = result[:response][:text]
# end
# else
# error("Can't fix JSON string")
# break
# end
# end
# return correctjson
# end

View File

@@ -64,22 +64,21 @@ end
# Arguments # Arguments
- `node::MCTSNode` - `node::MCTSNode`
mcts node mcts node
- `w::Float64` - `w::T`
exploration weight exploration weight. Value is usually between 1 to 2.
Value 1.0 makes MCTS balance between exploration and exploitation like 50%-50%.
Value 2.0 makes MCTS aggressively search the tree.
# Return # Return
- `selectedNode::MCTSNode`
# Example # Example
```jldoctest ```jldoctest
julia> julia>
``` ```
# TODO
[] update docstring
[x] check childNode.total_reward w/ LATS paper. Which value total_reward representing
# Signature # Signature
""" """
function UCTselect(node::MCTSNode, w::Float64) function UCTselect(node::MCTSNode, w::T)::MCTSNode where {T<:AbstractFloat}
max_uct = -Inf max_uct = -Inf
selectedNode = nothing selectedNode = nothing
@@ -130,7 +129,7 @@ julia>
# Signature # Signature
""" """
function expand(a::T1, node::MCTSNode, decisionMaker::Function, function expand(a::T1, node::MCTSNode, decisionMaker::Function,
progressValueEstimator::Function, isterminal::Function; n::Integer=3) where {T1<:agent} progressValueEstimator::Function; n::Integer=3) where {T1<:agent}
nthSample = 0 nthSample = 0
while nthSample < n while nthSample < n
@@ -138,7 +137,7 @@ function expand(a::T1, node::MCTSNode, decisionMaker::Function,
thoughtDict = decisionMaker(a, node.state) thoughtDict = decisionMaker(a, node.state)
newNodeKey, newstate, reward, isterminalstate = newNodeKey, newstate, reward, isterminalstate =
MCTStransition(a, node.state, thoughtDict, isterminal) MCTStransition(a, node.state, thoughtDict)
# add progressValueEstimator # add progressValueEstimator
stateevaluation, statevalue = progressValueEstimator(a, newstate) stateevaluation, statevalue = progressValueEstimator(a, newstate)
@@ -148,69 +147,78 @@ function expand(a::T1, node::MCTSNode, decisionMaker::Function,
reward, isterminalstate, node, Dict{String, MCTSNode}()) reward, isterminalstate, node, Dict{String, MCTSNode}())
end end
nthSample += 1 nthSample += 1
catch catch e
# skip this child node if error occurs io = IOBuffer()
println("retry node expand") showerror(io, e)
errorMsg = String(take!(io))
st = sprint((io, v) -> show(io, "text/plain", v), stacktrace(catch_backtrace()))
println("")
@warn "Error occurred: $errorMsg\n$st"
println("")
end end
end end
end end
"""
""" Simulate interactions between agent and environment
# Arguments # Arguments
- `a::T`
one of YiemAgent's agent
- `node::MCTSNode` - `node::MCTSNode`
node that will be a simulation starting point. node that will be a simulation starting point.
- `decisionMaker::Function`
function that receive state return Thought and Action
# Return # Return
- `simTrajectoryReward::Number`
# Example # Example
```jldoctest ```jldoctest
julia> julia>
``` ```
# TODO
- [] update docstring
- [x] implement the function
- [] check for the terminal state (node.reward != 0), break if it is terminal state
# Signature # Signature
""" """
function simulate(a, node::MCTSNode, decisionMaker::Function, progressValueEstimator::Function, function simulate(a::T, node::MCTSNode, decisionMaker::Function, progressValueEstimator::Function,
isterminal::Function, maxDepth::Int; n=3)::Number maxDepth::Int; n=3)::Number where {T<:agent}
simTrajectoryReward = 0.0 simTrajectoryReward = 0.0
for depth in 1:maxDepth for depth in 1:maxDepth
simTrajectoryReward += node.reward simTrajectoryReward += node.reward
if node.isterminalrd if node.isterminal
break break
else else
expand(a, node, decisionMaker, progressValueEstimator, isterminal, n=n) expand(a, node, decisionMaker, progressValueEstimator, n=n)
node = selectChildNode(node) node = selectChildNode(node)
end end
end end
#BUG new expanded state has reward but it is not included because it is over maxdept by 1 state
return simTrajectoryReward return simTrajectoryReward
end end
""" """ Backpropagate reward along the simulation chain
# Arguments # Arguments
- `node::MCTSNode`
node of a search tree
- `simTrajectoryReward::T`
total reward from all node in simulation trajectory
# Return # Return
- `No return`
# Example # Example
```jldoctest ```jldoctest
julia> julia>
``` ```
# TODO
- [] update docstring
- [WORKING] implement the function
# Signature # Signature
""" """
function backpropagate(node, simTrajectoryReward; discountRewardCoeff=0.9) function backpropagate(node::MCTSNode, simTrajectoryReward::T;
discountRewardCoeff::AbstractFloat=0.9) where {T<:Number}
while !isroot(node) while !isroot(node)
# Update the statistics of the current node based on the result of the playout # Update the statistics of the current node based on the result of the playout
node.visits += 1 node.visits += 1
@@ -260,8 +268,8 @@ julia> thoughtDict = Dict(
# Signature # Signature
""" """
function MCTStransition(a::T1, state::T2, thoughtDict::T3, isterminal::Function function MCTStransition(a::T1, state::T2, thoughtDict::T3
)::Tuple{String, Dict{Symbol, <:Any}, Bool, <:Number} where {T1<:agent, T2<:AbstractDict, T3<:AbstractDict} )::Tuple{String, Dict{Symbol, <:Any}, <:Number, Bool} where {T1<:agent, T2<:AbstractDict, T3<:AbstractDict}
actionname = thoughtDict[:action][:name] actionname = thoughtDict[:action][:name]
actioninput = thoughtDict[:action][:input] actioninput = thoughtDict[:action][:input]
@@ -383,10 +391,6 @@ end
julia> julia>
``` ```
# TODO
[] update docs
[TESTING] implement the function
# Signature # Signature
""" """
isroot(node::MCTSNode)::Bool = node.nodekey == "root" ? true : false isroot(node::MCTSNode)::Bool = node.nodekey == "root" ? true : false
@@ -437,7 +441,6 @@ function runMCTS(
decisionMaker::Function, decisionMaker::Function,
progressValueEstimator::Function, progressValueEstimator::Function,
reflector::Function, reflector::Function,
isterminal::Function,
n::Integer, n::Integer,
maxDepth::Integer, maxDepth::Integer,
maxIterations::Integer, maxIterations::Integer,
@@ -455,10 +458,10 @@ function runMCTS(
# do nothing then go directly to backpropagation # do nothing then go directly to backpropagation
backpropagate(leafNode, node.reward) backpropagate(leafNode, node.reward)
else else
expand(a, node, decisionMaker, progressValueEstimator, isterminal, n=n) expand(a, node, decisionMaker, progressValueEstimator, n=n)
leafNode = UCTselect(node, w) leafNode = UCTselect(node, w)
simTrajectoryReward = simulate(a, leafNode, decisionMaker, progressValueEstimator, simTrajectoryReward = simulate(a, leafNode, decisionMaker, progressValueEstimator,
isterminal, maxDepth, n=n) maxDepth, n=n)
backpropagate(leafNode, simTrajectoryReward) backpropagate(leafNode, simTrajectoryReward)
end end
end end

View File

@@ -42,6 +42,9 @@ outgoingMsg = Dict(
:msgMeta=> msgMeta, :msgMeta=> msgMeta,
:payload=> Dict( :payload=> Dict(
:text=> "It will be Thai dishes.", :text=> "It will be Thai dishes.",
:select=> nothing,
:reward=> 0,
:isterminal=> false,
) )
) )
result = GeneralUtils.sendMqttMsg(outgoingMsg) result = GeneralUtils.sendMqttMsg(outgoingMsg)
@@ -52,7 +55,10 @@ result = GeneralUtils.sendMqttMsg(outgoingMsg)
outgoingMsg = Dict( outgoingMsg = Dict(
:msgMeta=> msgMeta, :msgMeta=> msgMeta,
:payload=> Dict( :payload=> Dict(
:text=> "50 bucks.", :text=> "I would spend up to 50 bucks.",
:select=> nothing,
:reward=> 0,
:isterminal=> false,
) )
) )
result = GeneralUtils.sendMqttMsg(outgoingMsg) result = GeneralUtils.sendMqttMsg(outgoingMsg)
@@ -64,6 +70,9 @@ outgoingMsg = Dict(
:msgMeta=> msgMeta, :msgMeta=> msgMeta,
:payload=> Dict( :payload=> Dict(
:text=> "I like full-bodied Red wine with low tannin.", :text=> "I like full-bodied Red wine with low tannin.",
:select=> nothing,
:reward=> 0,
:isterminal=> false,
) )
) )
result = GeneralUtils.sendMqttMsg(outgoingMsg) result = GeneralUtils.sendMqttMsg(outgoingMsg)
@@ -74,28 +83,22 @@ outgoingMsg = Dict(
:msgMeta=> msgMeta, :msgMeta=> msgMeta,
:payload=> Dict( :payload=> Dict(
:text=> "What do you have?", :text=> "What do you have?",
:select=> nothing,
:reward=> 0,
:isterminal=> false,
) )
) )
result = GeneralUtils.sendMqttMsg(outgoingMsg) result = GeneralUtils.sendMqttMsg(outgoingMsg)
outgoingMsg = Dict(
:msgMeta=> msgMeta,
:payload=> Dict(
:text=> "OK, I'll take it.",
)
)
result = GeneralUtils.sendMqttMsg(outgoingMsg)
outgoingMsg = Dict( outgoingMsg = Dict(
:msgMeta=> msgMeta, :msgMeta=> msgMeta,
:payload=> Dict( :payload=> Dict(
:text=> "Dry please.", :text=> "Dry please.",
:select=> nothing,
:reward=> 0,
:isterminal=> false,
) )
) )
result = GeneralUtils.sendMqttMsg(outgoingMsg) result = GeneralUtils.sendMqttMsg(outgoingMsg)
@@ -107,6 +110,9 @@ outgoingMsg = Dict(
:msgMeta=> msgMeta, :msgMeta=> msgMeta,
:payload=> Dict( :payload=> Dict(
:text=> "You did not gave me any choice.", :text=> "You did not gave me any choice.",
:select=> nothing,
:reward=> -1,
:isterminal=> false,
) )
) )
result = GeneralUtils.sendMqttMsg(outgoingMsg) result = GeneralUtils.sendMqttMsg(outgoingMsg)
@@ -117,7 +123,10 @@ result = GeneralUtils.sendMqttMsg(outgoingMsg)
outgoingMsg = Dict( outgoingMsg = Dict(
:msgMeta=> msgMeta, :msgMeta=> msgMeta,
:payload=> Dict( :payload=> Dict(
:text=> "Yes.", :text=> "Are there any other options?",
:select=> nothing,
:reward=> 0,
:isterminal=> false,
) )
) )
result = GeneralUtils.sendMqttMsg(outgoingMsg) result = GeneralUtils.sendMqttMsg(outgoingMsg)
@@ -125,5 +134,31 @@ result = GeneralUtils.sendMqttMsg(outgoingMsg)
outgoingMsg = Dict(
:msgMeta=> msgMeta,
:payload=> Dict(
:text=> "Yep.",
:select=> nothing,
:reward=> 0,
:isterminal=> false,
)
)
result = GeneralUtils.sendMqttMsg(outgoingMsg)
outgoingMsg = Dict(
:msgMeta=> msgMeta,
:payload=> Dict(
:text=> "OK, I'll take it.",
:select=> 1,
:reward=> 1,
:isterminal=> true,
)
)
result = GeneralUtils.sendMqttMsg(outgoingMsg)