diff --git a/src/interface.jl b/src/interface.jl index ed7579d..82a1eae 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -333,7 +333,6 @@ end # ``` # # TODO -# [PENDING] add Reflect() # # Signature # """ @@ -407,8 +406,9 @@ julia> response = ChatAgent.conversation(newAgent, "Hi! how are you?") ``` # TODO -- [] update docstring -- [PENDING] MCTS() for planning + - [] update docstring + - [WORKING] MCTS() for planning + - [] add recap to initialState for earlier completed question # Signature """ @@ -441,11 +441,12 @@ function conversation(a::T, userinput::Dict) where {T<:agent} :reward=> 0, :isterminal=> false, :thoughtHistory=> OrderedDict{Symbol, Any}( # contain question, thought_1, action_1, observation_1, thought_2, ... + # :recap=>, :question=> userinput[:text], ) ) bestplan = runMCTS(a, initialState, decisionMaker, progressValueEstimator, reflector, - isterminal, 2, 3, 3, 1.0) + 2, 3, 4, 1.0) error("---> bestplan") # actor loop(bestplan) diff --git a/src/llmfunction.jl b/src/llmfunction.jl index 349e08a..68dd120 100644 --- a/src/llmfunction.jl +++ b/src/llmfunction.jl @@ -69,7 +69,8 @@ julia> # Signature """ -function virtualWineCustomerReccommendbox(a::T1, input::T2)::String where {T1<:agent, T2<:AbstractString} +function virtualWineCustomerReccommendbox(a::T1, input + )::Union{Tuple{String, Number, Number, Bool}, Tuple{String, Nothing, Number, Bool}} where {T1<:agent} input = "I reccomment Zeno crown vista" @@ -92,7 +93,7 @@ function virtualWineCustomerReccommendbox(a::T1, input::T2)::String where {T1<:a mqttBroker= a.config[:mqttServerInfo][:broker], mqttBrokerPort= a.config[:mqttServerInfo][:port], msgId = "dummyid" #CHANGE remove after testing finished - ) + ) outgoingMsg = Dict( :msgMeta=> msgMeta, @@ -102,9 +103,9 @@ function virtualWineCustomerReccommendbox(a::T1, input::T2)::String where {T1<:a ) @show outgoingMsg result = GeneralUtils.sendReceiveMqttMsg(outgoingMsg; timeout=120) - response = result[:response][:text] + response = result[:response] - return response + return (response[:text], response[:select], response[:reward], response[:isterminal]) end @@ -131,7 +132,8 @@ julia> # Signature """ -function virtualWineCustomerChatbox(a::T1, input::T2)::String where {T1<:agent, T2<:AbstractString} +function virtualWineCustomerChatbox(a::T1, input::T2 + )::Union{Tuple{String, Number, Number, Bool}, Tuple{String, Nothing, Number, Bool}} where {T1<:agent, T2<:AbstractString} # put in model format virtualWineCustomer = a.config[:externalservice][:virtualWineCustomer_1] @@ -192,7 +194,9 @@ julia> result = winestock(agent, input) # Signature """ -function winestock(a::T1, input::T2) where {T1<:agent, T2<:AbstractString} +function winestock(a::T1, input::T2 + )::Union{Tuple{String, Number, Number, Bool}, Tuple{String, Nothing, Number, Bool}} where {T1<:agent, T2<:AbstractString} + winesStr = """ 1: El Enemigo Cabernet Franc 2019 @@ -205,7 +209,7 @@ function winestock(a::T1, input::T2) where {T1<:agent, T2<:AbstractString} $winesStr } """ - return result + return result, nothing, 0, false end @@ -218,17 +222,14 @@ end text to be send to virtual wine customer # Return - - `response::String` - response of virtual wine customer + - `correctjson::String` + corrected json string + # Example ```jldoctest julia> ``` -# TODO - - [] update docstring - - [x] implement the function - # Signature """ function jsoncorrection(a::T1, input::T2, @@ -306,77 +307,6 @@ function jsoncorrection(a::T1, input::T2, return correctjson end -# function jsoncorrection(a::T1, input::T2, -# correctJsonExample::T3) where {T1<:agent, T2<:AbstractString, T3<:AbstractString} - -# attemptround = 0 -# incorrectjson = deepcopy(input) -# correctjson = nothing -# while true -# attemptround += 1 -# if attemptround <= 5 -# try -# JSON3.read(incorrectjson) -# correctjson = incorrectjson -# break -# catch -# @warn "Attempting correct JSON string. $attemptround" -# incorrectjson = deepcopy(input) -# _prompt = -# """ -# Your goal is to correct a given incorrect JSON format while retaining original content. - -# $correctJsonExample - -# Incorrect JSON: -# $incorrectjson -# Corrention: -# """ - -# # apply LLM specific instruct format -# externalService = a.config[:externalservice][:text2textinstruct] -# llminfo = externalService[:llminfo] -# prompt = -# if llminfo[:name] == "llama3instruct" -# formatLLMtext_llama3instruct("system", _prompt) -# else -# error("llm model name is not defied yet $(@__LINE__)") -# end - -# # send formatted input to user using GeneralUtils.sendReceiveMqttMsg -# msgMeta = GeneralUtils.generate_msgMeta( -# externalService[:mqtttopic], -# senderName= "jsoncorrection", -# senderId= a.id, -# receiverName= "text2textinstruct", -# mqttBroker= a.config[:mqttServerInfo][:broker], -# mqttBrokerPort= a.config[:mqttServerInfo][:port], -# ) - -# outgoingMsg = Dict( -# :msgMeta=> msgMeta, -# :payload=> Dict( -# :text=> prompt, -# :kwargs=> Dict( -# :max_tokens=> 512, -# :stop=> ["<|eot_id|>"], -# ) -# ) -# ) -# result = GeneralUtils.sendReceiveMqttMsg(outgoingMsg) -# incorrectjson = result[:response][:text] -# end -# else -# error("Can't fix JSON string") -# break -# end -# end - -# return correctjson -# end - - - diff --git a/src/mcts.jl b/src/mcts.jl index 1d3d618..a0b707c 100644 --- a/src/mcts.jl +++ b/src/mcts.jl @@ -64,22 +64,21 @@ end # Arguments - `node::MCTSNode` mcts node - - `w::Float64` - exploration weight + - `w::T` + exploration weight. Value is usually between 1 to 2. + Value 1.0 makes MCTS balance between exploration and exploitation like 50%-50%. + Value 2.0 makes MCTS aggressively search the tree. # Return + - `selectedNode::MCTSNode` # Example ```jldoctest julia> ``` -# TODO - [] update docstring - [x] check childNode.total_reward w/ LATS paper. Which value total_reward representing - # Signature """ -function UCTselect(node::MCTSNode, w::Float64) +function UCTselect(node::MCTSNode, w::T)::MCTSNode where {T<:AbstractFloat} max_uct = -Inf selectedNode = nothing @@ -130,7 +129,7 @@ julia> # Signature """ function expand(a::T1, node::MCTSNode, decisionMaker::Function, - progressValueEstimator::Function, isterminal::Function; n::Integer=3) where {T1<:agent} + progressValueEstimator::Function; n::Integer=3) where {T1<:agent} nthSample = 0 while nthSample < n @@ -138,7 +137,7 @@ function expand(a::T1, node::MCTSNode, decisionMaker::Function, thoughtDict = decisionMaker(a, node.state) newNodeKey, newstate, reward, isterminalstate = - MCTStransition(a, node.state, thoughtDict, isterminal) + MCTStransition(a, node.state, thoughtDict) # add progressValueEstimator stateevaluation, statevalue = progressValueEstimator(a, newstate) @@ -148,69 +147,78 @@ function expand(a::T1, node::MCTSNode, decisionMaker::Function, reward, isterminalstate, node, Dict{String, MCTSNode}()) end nthSample += 1 - catch - # skip this child node if error occurs - println("retry node expand") + catch e + io = IOBuffer() + showerror(io, e) + errorMsg = String(take!(io)) + st = sprint((io, v) -> show(io, "text/plain", v), stacktrace(catch_backtrace())) + println("") + @warn "Error occurred: $errorMsg\n$st" + println("") end end end -""" + + +""" Simulate interactions between agent and environment # Arguments + - `a::T` + one of YiemAgent's agent - `node::MCTSNode` node that will be a simulation starting point. + - `decisionMaker::Function` + function that receive state return Thought and Action # Return + - `simTrajectoryReward::Number` # Example ```jldoctest julia> ``` -# TODO - - [] update docstring - - [x] implement the function - - [] check for the terminal state (node.reward != 0), break if it is terminal state - # Signature """ -function simulate(a, node::MCTSNode, decisionMaker::Function, progressValueEstimator::Function, - isterminal::Function, maxDepth::Int; n=3)::Number +function simulate(a::T, node::MCTSNode, decisionMaker::Function, progressValueEstimator::Function, + maxDepth::Int; n=3)::Number where {T<:agent} simTrajectoryReward = 0.0 for depth in 1:maxDepth simTrajectoryReward += node.reward - if node.isterminalrd + if node.isterminal break else - expand(a, node, decisionMaker, progressValueEstimator, isterminal, n=n) + expand(a, node, decisionMaker, progressValueEstimator, n=n) node = selectChildNode(node) end end - #BUG new expanded state has reward but it is not included because it is over maxdept by 1 state + return simTrajectoryReward end -""" +""" Backpropagate reward along the simulation chain # Arguments + - `node::MCTSNode` + node of a search tree + - `simTrajectoryReward::T` + total reward from all node in simulation trajectory # Return + - `No return` # Example ```jldoctest julia> ``` -# TODO - - [] update docstring - - [WORKING] implement the function - # Signature """ -function backpropagate(node, simTrajectoryReward; discountRewardCoeff=0.9) +function backpropagate(node::MCTSNode, simTrajectoryReward::T; + discountRewardCoeff::AbstractFloat=0.9) where {T<:Number} while !isroot(node) # Update the statistics of the current node based on the result of the playout node.visits += 1 @@ -260,8 +268,8 @@ julia> thoughtDict = Dict( # Signature """ -function MCTStransition(a::T1, state::T2, thoughtDict::T3, isterminal::Function - )::Tuple{String, Dict{Symbol, <:Any}, Bool, <:Number} where {T1<:agent, T2<:AbstractDict, T3<:AbstractDict} +function MCTStransition(a::T1, state::T2, thoughtDict::T3 + )::Tuple{String, Dict{Symbol, <:Any}, <:Number, Bool} where {T1<:agent, T2<:AbstractDict, T3<:AbstractDict} actionname = thoughtDict[:action][:name] actioninput = thoughtDict[:action][:input] @@ -383,10 +391,6 @@ end julia> ``` -# TODO - [] update docs - [TESTING] implement the function - # Signature """ isroot(node::MCTSNode)::Bool = node.nodekey == "root" ? true : false @@ -437,7 +441,6 @@ function runMCTS( decisionMaker::Function, progressValueEstimator::Function, reflector::Function, - isterminal::Function, n::Integer, maxDepth::Integer, maxIterations::Integer, @@ -455,10 +458,10 @@ function runMCTS( # do nothing then go directly to backpropagation backpropagate(leafNode, node.reward) else - expand(a, node, decisionMaker, progressValueEstimator, isterminal, n=n) + expand(a, node, decisionMaker, progressValueEstimator, n=n) leafNode = UCTselect(node, w) - simTrajectoryReward = simulate(a, leafNode, decisionMaker, progressValueEstimator, - isterminal, maxDepth, n=n) + simTrajectoryReward = simulate(a, leafNode, decisionMaker, progressValueEstimator, + maxDepth, n=n) backpropagate(leafNode, simTrajectoryReward) end end diff --git a/test/test_1.jl b/test/test_1.jl index c4da5ec..20bb6e5 100644 --- a/test/test_1.jl +++ b/test/test_1.jl @@ -42,6 +42,9 @@ outgoingMsg = Dict( :msgMeta=> msgMeta, :payload=> Dict( :text=> "It will be Thai dishes.", + :select=> nothing, + :reward=> 0, + :isterminal=> false, ) ) result = GeneralUtils.sendMqttMsg(outgoingMsg) @@ -52,7 +55,10 @@ result = GeneralUtils.sendMqttMsg(outgoingMsg) outgoingMsg = Dict( :msgMeta=> msgMeta, :payload=> Dict( - :text=> "50 bucks.", + :text=> "I would spend up to 50 bucks.", + :select=> nothing, + :reward=> 0, + :isterminal=> false, ) ) result = GeneralUtils.sendMqttMsg(outgoingMsg) @@ -64,6 +70,9 @@ outgoingMsg = Dict( :msgMeta=> msgMeta, :payload=> Dict( :text=> "I like full-bodied Red wine with low tannin.", + :select=> nothing, + :reward=> 0, + :isterminal=> false, ) ) result = GeneralUtils.sendMqttMsg(outgoingMsg) @@ -74,28 +83,22 @@ outgoingMsg = Dict( :msgMeta=> msgMeta, :payload=> Dict( :text=> "What do you have?", + :select=> nothing, + :reward=> 0, + :isterminal=> false, ) ) result = GeneralUtils.sendMqttMsg(outgoingMsg) - -outgoingMsg = Dict( - :msgMeta=> msgMeta, - :payload=> Dict( - :text=> "OK, I'll take it.", - ) -) -result = GeneralUtils.sendMqttMsg(outgoingMsg) - - - - outgoingMsg = Dict( :msgMeta=> msgMeta, :payload=> Dict( :text=> "Dry please.", + :select=> nothing, + :reward=> 0, + :isterminal=> false, ) ) result = GeneralUtils.sendMqttMsg(outgoingMsg) @@ -107,6 +110,9 @@ outgoingMsg = Dict( :msgMeta=> msgMeta, :payload=> Dict( :text=> "You did not gave me any choice.", + :select=> nothing, + :reward=> -1, + :isterminal=> false, ) ) result = GeneralUtils.sendMqttMsg(outgoingMsg) @@ -117,7 +123,10 @@ result = GeneralUtils.sendMqttMsg(outgoingMsg) outgoingMsg = Dict( :msgMeta=> msgMeta, :payload=> Dict( - :text=> "Yes.", + :text=> "Are there any other options?", + :select=> nothing, + :reward=> 0, + :isterminal=> false, ) ) result = GeneralUtils.sendMqttMsg(outgoingMsg) @@ -125,5 +134,31 @@ result = GeneralUtils.sendMqttMsg(outgoingMsg) +outgoingMsg = Dict( + :msgMeta=> msgMeta, + :payload=> Dict( + :text=> "Yep.", + :select=> nothing, + :reward=> 0, + :isterminal=> false, + ) +) +result = GeneralUtils.sendMqttMsg(outgoingMsg) + + + + +outgoingMsg = Dict( + :msgMeta=> msgMeta, + :payload=> Dict( + :text=> "OK, I'll take it.", + :select=> 1, + :reward=> 1, + :isterminal=> true, + ) +) +result = GeneralUtils.sendMqttMsg(outgoingMsg) + +