From 8907156522d8ed30fd66ef06926f0b311b1afbfe Mon Sep 17 00:00:00 2001 From: narawat lamaiin Date: Sat, 4 May 2024 21:17:02 +0700 Subject: [PATCH] update --- src/interface.jl | 20 ++++++++++++++------ src/llmfunction.jl | 2 +- src/mcts.jl | 30 +++++++++++++----------------- 3 files changed, 28 insertions(+), 24 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 76b9e02..89dadac 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -186,7 +186,7 @@ function decisionMaker(a::T1, state::T2)::Dict{Symbol, Any} where {T1<:agent, T2 _response = GeneralUtils.sendReceiveMqttMsg(outgoingMsg) _thoughtJsonStr = _response[:response][:text] - thoughtJsonStr = jsoncorrection(a, _thoughtJsonStr, "") + thoughtJsonStr = jsoncorrection(a, _thoughtJsonStr, responseformat) thoughtDict = copy(JSON3.read(thoughtJsonStr)) pprint(thoughtDict) return thoughtDict @@ -324,11 +324,14 @@ function reflector() end -""" +""" Determine whether the state is a terminal state # Arguments + - `state::T` + a game state # Return + - `(isterminal, reward)::Tuple{Bool, Number}` # Example ```jldoctest @@ -336,13 +339,19 @@ julia> ``` # TODO - - [] update docstring - - [] implement the function + - [x] update docstring + - [TESTING] implement the function # Signature """ -function isterminal() +function isterminal(state::T)::Tuple{Bool, Number} where {T<:AbstractDict} + latestObservationKey, _ = GeneralUtils.findHighestIndexKey(state[:thoughtHistory], "Observation") + latestObservation = state[:thoughtHistory][latestObservationKey] + # terminal condition is when the user select wine by putting <> in latest observation + if occursin("<<", latestObservation) && occursin(">>", latestObservation) + return true, 1 + end end @@ -423,7 +432,6 @@ function conversation(a::T, userinput::Dict) where {T<:agent} # deepcopy the info to prevent modifying the info unintentionally during MCTS planning :customerinfo=> deepcopy(a.keywordinfo[:customerinfo]), :storeinfo=> deepcopy(a.keywordinfo[:storeinfo]), - :thoughtHistory=> OrderedDict{Symbol, Any}( # contain question, thought_1, action_1, observation_1, thought_2, ... :Question=> userinput[:text], ) diff --git a/src/llmfunction.jl b/src/llmfunction.jl index 2a54bc0..4227073 100644 --- a/src/llmfunction.jl +++ b/src/llmfunction.jl @@ -399,7 +399,7 @@ function jsoncorrection(a::T1, input::T2, correctjson = incorrectjson break catch - println("Attempting correct JSON string. $attempting") + println("Attempting correct JSON string. $attemptround") _prompt = """ Your goal is to correct a given incorrect JSON string. diff --git a/src/mcts.jl b/src/mcts.jl index 41dc874..d840e33 100644 --- a/src/mcts.jl +++ b/src/mcts.jl @@ -53,6 +53,7 @@ struct MCTSNode{T<:AbstractDict} visits::Integer progressValue::Number reward::Number + isterminal::Bool parent::Union{MCTSNode, Nothing} children::Dict{String, MCTSNode} end @@ -126,19 +127,15 @@ function expand(a::T1, node::MCTSNode, decisionMaker::Function, # sampling action from decisionMaker for sample in 1:n thoughtDict = decisionMaker(a, node.state) - @show node.state - @show thoughtDict - newNodeKey, newstate = MCTStransition(a, node.state, thoughtDict) #[] Implement your transition function + + newNodeKey, newstate, isterminal, reward = MCTStransition(a, node.state, thoughtDict) # add progressValueEstimator progressRationale, progressValue = progressValueEstimator(a, newstate) - #[WORKING] check for terminal state - - if newNodeKey ∉ keys(node.children) - node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, progressValue, 0, - node, Dict{String, MCTSNode}()) + node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, progressValue, + reward, isterminal, node, Dict{String, MCTSNode}()) end end end @@ -242,15 +239,12 @@ julia> thoughtDict = Dict( - [] update docstring - [PENDING] add other actions - [] add embedding of newstate and store in newstate[:embedding] + - [x] check for terminal state and assign reward # Signature """ function MCTStransition(a::T1, state::T2, thoughtDict::T3)::Tuple{String, Dict{Symbol, Any}} where {T1<:agent, T2<:AbstractDict, T3<:AbstractDict} - println("") - # latestThoughtKey, _ = GeneralUtils.findHighestIndexKey(thoughtDict, "Thought") - # latestActionKey, latestActionIndice = GeneralUtils.findHighestIndexKey(thoughtDict, "Action") - # _action = thoughtDict[:Action] actionname = thoughtDict[:Action][:name] actioninput = thoughtDict[:Action][:input] @@ -266,8 +260,9 @@ function MCTStransition(a::T1, state::T2, end - _, latestThoughtIndice = GeneralUtils.findHighestIndexKey(state[:thoughtHistory], "Thought") - nextIndice = latestThoughtIndice === nothing ? 1 : latestThoughtIndice + 1 + latestThoughtKey, latestThoughtIndice = GeneralUtils.findHighestIndexKey(state[:thoughtHistory], + "Thought") + nextIndice = latestThoughtKey == :NA ? 1 : latestThoughtIndice + 1 latestThoughtKey = Symbol("Thought_$nextIndice") latestActionKey = Symbol("Action_$nextIndice") @@ -279,8 +274,9 @@ function MCTStransition(a::T1, state::T2, newstate[:thoughtHistory][latestObservationKey] = response newNodeKey = GeneralUtils.uuid4snakecase() + isterminalstate, reward = isterminal(newstate) - return newNodeKey, newstate + return newNodeKey, newstate, isterminalstate, reward end @@ -328,7 +324,7 @@ julia> # TODO - [] update docstring - - [WORKING] implement the function + - [x] implement the function # Signature """ @@ -397,7 +393,7 @@ function runMCTS( maxIterations::Integer, w::Float64) where {T1<:agent} - root = MCTSNode("root", initialState, 0, 0, 0, nothing, Dict{String, MCTSNode}()) + root = MCTSNode("root", initialState, 0, 0, 0, false, nothing, Dict{String, MCTSNode}()) for _ in 1:maxIterations node = root