update
This commit is contained in:
@@ -109,7 +109,10 @@ function conversation(a::T, userinput::Dict) where {T<:agent}
|
||||
|
||||
else #[WORKING] new thinking
|
||||
initialState = Dict(
|
||||
:info=> Dict{Symbol, Any}(), # keyword info
|
||||
|
||||
# deepcopy the info to prevent modifying the info unintentionally during MCTS planning
|
||||
:info=> deepcopy(a.keywordinfo),
|
||||
|
||||
:thoughtHistory=> Dict{Symbol, Any}( # contain question, thought_1, action_1, observation_1, thought_2, ...
|
||||
:question=> userinput[:text],
|
||||
)
|
||||
@@ -117,6 +120,7 @@ function conversation(a::T, userinput::Dict) where {T<:agent}
|
||||
bestplan = runMCTS(initialState, decisionMaker, stateValueEstimator, reflector,
|
||||
3, 10, 1000, 1.0)
|
||||
error("---> bestplan")
|
||||
|
||||
# actor loop(bestplan)
|
||||
|
||||
end
|
||||
|
||||
115
src/mcts.jl
115
src/mcts.jl
@@ -133,7 +133,8 @@ function expand(node::MCTSNode, state::T, decisionMaker::Function, stateValueEst
|
||||
|
||||
# sampling action from decisionMaker
|
||||
for sample in 1:n
|
||||
result = decisionMaker(state)
|
||||
thought = decisionMaker(state)
|
||||
error("--> expand")
|
||||
newState = transition(node.state, action) #[] Implement your transition function
|
||||
if newState ∉ keys(node.children)
|
||||
node.children[newState] = MCTSNode(newState, 0, 0.0, Dict{T, MCTSNode}())
|
||||
@@ -266,6 +267,78 @@ isLeaf(node::MCTSNode)::Bool = isempty(node.children)
|
||||
|
||||
""" Think and choose action
|
||||
|
||||
Arguments\n
|
||||
-----
|
||||
state::T
|
||||
a game state
|
||||
|
||||
Return\n
|
||||
-----
|
||||
thought::Dict
|
||||
|
||||
Example\n
|
||||
-----
|
||||
```jldoctest
|
||||
julia>
|
||||
```
|
||||
|
||||
TODO\n
|
||||
-----
|
||||
[] update docstring
|
||||
[WORKING] implement the function
|
||||
[] implement RAG to pull similar experience
|
||||
[] use iterative prompting to ensure JSON format
|
||||
|
||||
Signature\n
|
||||
-----
|
||||
"""
|
||||
function decisionMaker(state::T) where {T<:AbstractDict}
|
||||
customerinfo =
|
||||
"""
|
||||
I will give you the following information about customer:
|
||||
$(JSON3.write(state[:customerinfo]))
|
||||
"""
|
||||
|
||||
storeinfo =
|
||||
"""
|
||||
I will give you the following information about your store:
|
||||
$(JSON3.write(state[:storeinfo]))
|
||||
"""
|
||||
|
||||
prompt =
|
||||
"""
|
||||
You are a helpful sommelier working for a wine store.
|
||||
You helps users by searching wine that match the user preferences from your inventory.
|
||||
|
||||
$customerinfo
|
||||
|
||||
You must follow the
|
||||
"""
|
||||
|
||||
|
||||
|
||||
result = iterativeprompting(prompt, syntaxcheck("json"))
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
error("--> decisionMaker")
|
||||
end
|
||||
|
||||
"""
|
||||
|
||||
Arguments\n
|
||||
-----
|
||||
|
||||
@@ -281,16 +354,44 @@ isLeaf(node::MCTSNode)::Bool = isempty(node.children)
|
||||
TODO\n
|
||||
-----
|
||||
[] update docstring
|
||||
[] implement the function
|
||||
[WORKING] implement RAG to pull similar experience
|
||||
[WORKING] implement the function
|
||||
|
||||
Signature\n
|
||||
-----
|
||||
"""
|
||||
function decisionMaker(state::T) where {T<:AbstractDict}
|
||||
function iterativeprompting(prommpt::String, verification::Function)
|
||||
|
||||
end
|
||||
|
||||
"""
|
||||
|
||||
Arguments\n
|
||||
-----
|
||||
|
||||
Return\n
|
||||
-----
|
||||
|
||||
Example\n
|
||||
-----
|
||||
```jldoctest
|
||||
julia>
|
||||
```
|
||||
|
||||
TODO\n
|
||||
-----
|
||||
[] update docstring
|
||||
[WORKING] implement the function
|
||||
|
||||
Signature\n
|
||||
-----
|
||||
"""
|
||||
function syntaxcheck_json(jsonstring::String)::NamedTuple
|
||||
success, result, errormsg, st = GeneralUtils.showstracktrace(JSON3.read, jsonstring)
|
||||
|
||||
|
||||
return
|
||||
end
|
||||
|
||||
"""
|
||||
|
||||
Arguments\n
|
||||
@@ -387,7 +488,7 @@ end
|
||||
TODO\n
|
||||
-----
|
||||
[] update docstring
|
||||
[WORKING] implement the function
|
||||
[] implement the function
|
||||
|
||||
Signature\n
|
||||
-----
|
||||
@@ -440,7 +541,7 @@ function runMCTS(initialState, decisionMaker::Function, stateValueEstimator::Fun
|
||||
maxIterations::Integer, w::Float64)
|
||||
statetype = typeof(initialState)
|
||||
root = MCTSNode(initialState, 0, 0.0, Dict{statetype, MCTSNode}())
|
||||
error("---> runMCTS")
|
||||
|
||||
for _ in 1:maxIterations
|
||||
node = root
|
||||
while !isLeaf(node)
|
||||
@@ -448,7 +549,7 @@ function runMCTS(initialState, decisionMaker::Function, stateValueEstimator::Fun
|
||||
end
|
||||
|
||||
expand(node, node.state, decisionMaker, stateValueEstimator, n=n)
|
||||
|
||||
error("---> runMCTS")
|
||||
leaf_node = node.children[node.state] # mark leaf node
|
||||
reward = simulate(leaf_node.state, maxDepth)
|
||||
backpropagate(leaf_node, reward)
|
||||
|
||||
@@ -98,8 +98,8 @@ abstract type agent end
|
||||
|
||||
maxHistoryMsg::Integer # 31th and earlier messages will get summarized
|
||||
keywordinfo::Dict{Symbol, Any} = Dict{Symbol, Any}(
|
||||
:userinfo => Dict{Symbol, Any}(),
|
||||
:retailerinfo => Dict{Symbol, Any}(),
|
||||
:customerinfo => Dict{Symbol, Any}(),
|
||||
:storeinfo => Dict{Symbol, Any}(),
|
||||
)
|
||||
mctstree::Dict{Symbol, Any} = Dict{Symbol, Any}()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user