diff --git a/src/interface.jl b/src/interface.jl index 18a355c..c6053e3 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -109,7 +109,10 @@ function conversation(a::T, userinput::Dict) where {T<:agent} else #[WORKING] new thinking initialState = Dict( - :info=> Dict{Symbol, Any}(), # keyword info + + # deepcopy the info to prevent modifying the info unintentionally during MCTS planning + :info=> deepcopy(a.keywordinfo), + :thoughtHistory=> Dict{Symbol, Any}( # contain question, thought_1, action_1, observation_1, thought_2, ... :question=> userinput[:text], ) @@ -117,6 +120,7 @@ function conversation(a::T, userinput::Dict) where {T<:agent} bestplan = runMCTS(initialState, decisionMaker, stateValueEstimator, reflector, 3, 10, 1000, 1.0) error("---> bestplan") + # actor loop(bestplan) end diff --git a/src/mcts.jl b/src/mcts.jl index 7891a5b..9fee8b2 100644 --- a/src/mcts.jl +++ b/src/mcts.jl @@ -133,7 +133,8 @@ function expand(node::MCTSNode, state::T, decisionMaker::Function, stateValueEst # sampling action from decisionMaker for sample in 1:n - result = decisionMaker(state) + thought = decisionMaker(state) + error("--> expand") newState = transition(node.state, action) #[] Implement your transition function if newState ∉ keys(node.children) node.children[newState] = MCTSNode(newState, 0, 0.0, Dict{T, MCTSNode}()) @@ -266,6 +267,78 @@ isLeaf(node::MCTSNode)::Bool = isempty(node.children) """ Think and choose action + Arguments\n + ----- + state::T + a game state + + Return\n + ----- + thought::Dict + + Example\n + ----- + ```jldoctest + julia> + ``` + + TODO\n + ----- + [] update docstring + [WORKING] implement the function + [] implement RAG to pull similar experience + [] use iterative prompting to ensure JSON format + + Signature\n + ----- +""" +function decisionMaker(state::T) where {T<:AbstractDict} + customerinfo = + """ + I will give you the following information about customer: + $(JSON3.write(state[:customerinfo])) + """ + + storeinfo = + """ + I will give you the following information about your store: + $(JSON3.write(state[:storeinfo])) + """ + + prompt = + """ + You are a helpful sommelier working for a wine store. + You helps users by searching wine that match the user preferences from your inventory. + + $customerinfo + + You must follow the + """ + + + + result = iterativeprompting(prompt, syntaxcheck("json")) + + + + + + + + + + + + + + + + + error("--> decisionMaker") +end + +""" + Arguments\n ----- @@ -281,16 +354,44 @@ isLeaf(node::MCTSNode)::Bool = isempty(node.children) TODO\n ----- [] update docstring - [] implement the function - [WORKING] implement RAG to pull similar experience + [WORKING] implement the function Signature\n ----- """ -function decisionMaker(state::T) where {T<:AbstractDict} +function iterativeprompting(prommpt::String, verification::Function) end +""" + + Arguments\n + ----- + + Return\n + ----- + + Example\n + ----- + ```jldoctest + julia> + ``` + + TODO\n + ----- + [] update docstring + [WORKING] implement the function + + Signature\n + ----- +""" +function syntaxcheck_json(jsonstring::String)::NamedTuple + success, result, errormsg, st = GeneralUtils.showstracktrace(JSON3.read, jsonstring) + + + return +end + """ Arguments\n @@ -387,7 +488,7 @@ end TODO\n ----- [] update docstring - [WORKING] implement the function + [] implement the function Signature\n ----- @@ -440,7 +541,7 @@ function runMCTS(initialState, decisionMaker::Function, stateValueEstimator::Fun maxIterations::Integer, w::Float64) statetype = typeof(initialState) root = MCTSNode(initialState, 0, 0.0, Dict{statetype, MCTSNode}()) - error("---> runMCTS") + for _ in 1:maxIterations node = root while !isLeaf(node) @@ -448,7 +549,7 @@ function runMCTS(initialState, decisionMaker::Function, stateValueEstimator::Fun end expand(node, node.state, decisionMaker, stateValueEstimator, n=n) - + error("---> runMCTS") leaf_node = node.children[node.state] # mark leaf node reward = simulate(leaf_node.state, maxDepth) backpropagate(leaf_node, reward) diff --git a/src/type.jl b/src/type.jl index 3beb072..c987267 100644 --- a/src/type.jl +++ b/src/type.jl @@ -98,8 +98,8 @@ abstract type agent end maxHistoryMsg::Integer # 31th and earlier messages will get summarized keywordinfo::Dict{Symbol, Any} = Dict{Symbol, Any}( - :userinfo => Dict{Symbol, Any}(), - :retailerinfo => Dict{Symbol, Any}(), + :customerinfo => Dict{Symbol, Any}(), + :storeinfo => Dict{Symbol, Any}(), ) mctstree::Dict{Symbol, Any} = Dict{Symbol, Any}()