This commit is contained in:
narawat lamaiin
2024-04-26 08:22:11 +07:00
parent b2c24e97ae
commit 54fc9c67bc
3 changed files with 115 additions and 10 deletions

View File

@@ -109,7 +109,10 @@ function conversation(a::T, userinput::Dict) where {T<:agent}
else #[WORKING] new thinking else #[WORKING] new thinking
initialState = Dict( initialState = Dict(
:info=> Dict{Symbol, Any}(), # keyword info
# deepcopy the info to prevent modifying the info unintentionally during MCTS planning
:info=> deepcopy(a.keywordinfo),
:thoughtHistory=> Dict{Symbol, Any}( # contain question, thought_1, action_1, observation_1, thought_2, ... :thoughtHistory=> Dict{Symbol, Any}( # contain question, thought_1, action_1, observation_1, thought_2, ...
:question=> userinput[:text], :question=> userinput[:text],
) )
@@ -117,6 +120,7 @@ function conversation(a::T, userinput::Dict) where {T<:agent}
bestplan = runMCTS(initialState, decisionMaker, stateValueEstimator, reflector, bestplan = runMCTS(initialState, decisionMaker, stateValueEstimator, reflector,
3, 10, 1000, 1.0) 3, 10, 1000, 1.0)
error("---> bestplan") error("---> bestplan")
# actor loop(bestplan) # actor loop(bestplan)
end end

View File

@@ -133,7 +133,8 @@ function expand(node::MCTSNode, state::T, decisionMaker::Function, stateValueEst
# sampling action from decisionMaker # sampling action from decisionMaker
for sample in 1:n for sample in 1:n
result = decisionMaker(state) thought = decisionMaker(state)
error("--> expand")
newState = transition(node.state, action) #[] Implement your transition function newState = transition(node.state, action) #[] Implement your transition function
if newState keys(node.children) if newState keys(node.children)
node.children[newState] = MCTSNode(newState, 0, 0.0, Dict{T, MCTSNode}()) node.children[newState] = MCTSNode(newState, 0, 0.0, Dict{T, MCTSNode}())
@@ -266,6 +267,78 @@ isLeaf(node::MCTSNode)::Bool = isempty(node.children)
""" Think and choose action """ Think and choose action
Arguments\n
-----
state::T
a game state
Return\n
-----
thought::Dict
Example\n
-----
```jldoctest
julia>
```
TODO\n
-----
[] update docstring
[WORKING] implement the function
[] implement RAG to pull similar experience
[] use iterative prompting to ensure JSON format
Signature\n
-----
"""
function decisionMaker(state::T) where {T<:AbstractDict}
customerinfo =
"""
I will give you the following information about customer:
$(JSON3.write(state[:customerinfo]))
"""
storeinfo =
"""
I will give you the following information about your store:
$(JSON3.write(state[:storeinfo]))
"""
prompt =
"""
You are a helpful sommelier working for a wine store.
You helps users by searching wine that match the user preferences from your inventory.
$customerinfo
You must follow the
"""
result = iterativeprompting(prompt, syntaxcheck("json"))
error("--> decisionMaker")
end
"""
Arguments\n Arguments\n
----- -----
@@ -281,16 +354,44 @@ isLeaf(node::MCTSNode)::Bool = isempty(node.children)
TODO\n TODO\n
----- -----
[] update docstring [] update docstring
[] implement the function [WORKING] implement the function
[WORKING] implement RAG to pull similar experience
Signature\n Signature\n
----- -----
""" """
function decisionMaker(state::T) where {T<:AbstractDict} function iterativeprompting(prommpt::String, verification::Function)
end end
"""
Arguments\n
-----
Return\n
-----
Example\n
-----
```jldoctest
julia>
```
TODO\n
-----
[] update docstring
[WORKING] implement the function
Signature\n
-----
"""
function syntaxcheck_json(jsonstring::String)::NamedTuple
success, result, errormsg, st = GeneralUtils.showstracktrace(JSON3.read, jsonstring)
return
end
""" """
Arguments\n Arguments\n
@@ -387,7 +488,7 @@ end
TODO\n TODO\n
----- -----
[] update docstring [] update docstring
[WORKING] implement the function [] implement the function
Signature\n Signature\n
----- -----
@@ -440,7 +541,7 @@ function runMCTS(initialState, decisionMaker::Function, stateValueEstimator::Fun
maxIterations::Integer, w::Float64) maxIterations::Integer, w::Float64)
statetype = typeof(initialState) statetype = typeof(initialState)
root = MCTSNode(initialState, 0, 0.0, Dict{statetype, MCTSNode}()) root = MCTSNode(initialState, 0, 0.0, Dict{statetype, MCTSNode}())
error("---> runMCTS")
for _ in 1:maxIterations for _ in 1:maxIterations
node = root node = root
while !isLeaf(node) while !isLeaf(node)
@@ -448,7 +549,7 @@ function runMCTS(initialState, decisionMaker::Function, stateValueEstimator::Fun
end end
expand(node, node.state, decisionMaker, stateValueEstimator, n=n) expand(node, node.state, decisionMaker, stateValueEstimator, n=n)
error("---> runMCTS")
leaf_node = node.children[node.state] # mark leaf node leaf_node = node.children[node.state] # mark leaf node
reward = simulate(leaf_node.state, maxDepth) reward = simulate(leaf_node.state, maxDepth)
backpropagate(leaf_node, reward) backpropagate(leaf_node, reward)

View File

@@ -98,8 +98,8 @@ abstract type agent end
maxHistoryMsg::Integer # 31th and earlier messages will get summarized maxHistoryMsg::Integer # 31th and earlier messages will get summarized
keywordinfo::Dict{Symbol, Any} = Dict{Symbol, Any}( keywordinfo::Dict{Symbol, Any} = Dict{Symbol, Any}(
:userinfo => Dict{Symbol, Any}(), :customerinfo => Dict{Symbol, Any}(),
:retailerinfo => Dict{Symbol, Any}(), :storeinfo => Dict{Symbol, Any}(),
) )
mctstree::Dict{Symbol, Any} = Dict{Symbol, Any}() mctstree::Dict{Symbol, Any} = Dict{Symbol, Any}()