This commit is contained in:
narawat lamaiin
2024-04-26 21:25:46 +07:00
parent 54fc9c67bc
commit 8f68d177e7
3 changed files with 173 additions and 140 deletions

View File

@@ -33,6 +33,122 @@ using ..type, ..util, ..llmfunction, ..mcts
# ---------------------------------------------- 100 --------------------------------------------- #
""" Think and choose action
Arguments\n
-----
state::T
a game state
Return\n
-----
thought::Dict
Example\n
-----
```jldoctest
julia>
```
TODO\n
-----
[] update docstring
[x] implement the function
[] implement RAG to pull similar experience
[] use iterative prompting to ensure JSON format
Signature\n
-----
"""
function decisionMaker(state::T) where {T<:AbstractDict}
customerinfo =
"""
I will give you the following information about customer:
$(JSON3.write(state[:customerinfo]))
"""
storeinfo =
"""
I will give you the following information about your store:
$(JSON3.write(state[:storeinfo]))
"""
#[x]
prompt =
"""
You are a helpful sommelier working for a wine store.
You helps users by searching wine that match the user preferences from your inventory.
$customerinfo
You must follow the
"""
thought = iterativeprompting(prompt, syntaxcheck_json)
error("--> decisionMaker")
return thought
end
"""
Arguments\n
-----
Return\n
-----
Example\n
-----
```jldoctest
julia>
```
TODO\n
-----
[] update docstring
[] implement the function
Signature\n
-----
"""
function stateValueEstimator()
end
"""
Arguments\n
-----
Return\n
-----
Example\n
-----
```jldoctest
julia>
```
TODO\n
-----
[] update docstring
[] implement the function
Signature\n
-----
"""
function reflector()
end
""" Chat with llm.
Arguments\n
@@ -117,7 +233,7 @@ function conversation(a::T, userinput::Dict) where {T<:agent}
:question=> userinput[:text],
)
)
bestplan = runMCTS(initialState, decisionMaker, stateValueEstimator, reflector,
bestplan = runMCTS(a, initialState, decisionMaker, stateValueEstimator, reflector,
3, 10, 1000, 1.0)
error("---> bestplan")
@@ -178,6 +294,7 @@ end
end # module interface

View File

@@ -5,10 +5,11 @@
module mcts
export MCTSNode, runMCTS, decisionMaker, stateValueEstimator, reflector
export MCTSNode, runMCTS
using Dates, UUIDs, DataStructures, JSON3, Random
using GeneralUtils
using ..type
# ---------------------------------------------- 100 --------------------------------------------- #
@@ -123,17 +124,17 @@ end
TODO\n
-----
[] update docstring
[WORKING] implement the function
[x] implement the function
Signature\n
-----
"""
function expand(node::MCTSNode, state::T, decisionMaker::Function, stateValueEstimator::Function;
n::Integer=3) where {T<:AbstractDict}
function expand(a::T1, node::MCTSNode, state::T2, decisionMaker::Function, stateValueEstimator::Function;
n::Integer=3) where {T1<:agent, T2<:AbstractDict}
# sampling action from decisionMaker
for sample in 1:n
thought = decisionMaker(state)
thought = decisionMaker(a, state)
error("--> expand")
newState = transition(node.state, action) #[] Implement your transition function
if newState keys(node.children)
@@ -265,77 +266,7 @@ end
"""
isLeaf(node::MCTSNode)::Bool = isempty(node.children)
""" Think and choose action
Arguments\n
-----
state::T
a game state
Return\n
-----
thought::Dict
Example\n
-----
```jldoctest
julia>
```
TODO\n
-----
[] update docstring
[WORKING] implement the function
[] implement RAG to pull similar experience
[] use iterative prompting to ensure JSON format
Signature\n
-----
"""
function decisionMaker(state::T) where {T<:AbstractDict}
customerinfo =
"""
I will give you the following information about customer:
$(JSON3.write(state[:customerinfo]))
"""
storeinfo =
"""
I will give you the following information about your store:
$(JSON3.write(state[:storeinfo]))
"""
prompt =
"""
You are a helpful sommelier working for a wine store.
You helps users by searching wine that match the user preferences from your inventory.
$customerinfo
You must follow the
"""
result = iterativeprompting(prompt, syntaxcheck("json"))
error("--> decisionMaker")
end
"""
@@ -359,11 +290,37 @@ end
Signature\n
-----
"""
function iterativeprompting(prommpt::String, verification::Function)
function iterativeprompting(a::T, prompt::String, verification::Function) where {T<:agent}
msgMeta = GeneralUtils.generate_msgMeta(
a.config[:thirdPartyService][:text2textinstruct],
)
outgoing =
# iteration loop
while true
# send prompt to LLM
response = GeneralUtils.sendReceiveMqttMsg()
# check for correctness and get feedback
if correct
break
else # add feedback to prompt
end
end
return
end
"""
""" Check JSON format correctness, provide
Arguments\n
-----
@@ -387,62 +344,19 @@ end
"""
function syntaxcheck_json(jsonstring::String)::NamedTuple
success, result, errormsg, st = GeneralUtils.showstracktrace(JSON3.read, jsonstring)
if !success # gives feedback
return
else
end
return (success, result, critique)
end
"""
Arguments\n
-----
Return\n
-----
Example\n
-----
```jldoctest
julia>
```
TODO\n
-----
[] update docstring
[] implement the function
Signature\n
-----
"""
function stateValueEstimator()
end
"""
Arguments\n
-----
Return\n
-----
Example\n
-----
```jldoctest
julia>
```
TODO\n
-----
[] update docstring
[] implement the function
Signature\n
-----
"""
function reflector()
end
"""
@@ -536,9 +450,9 @@ end
Signature\n
-----
"""
function runMCTS(initialState, decisionMaker::Function, stateValueEstimator::Function,
function runMCTS(a::T, initialState, decisionMaker::Function, stateValueEstimator::Function,
reflector::Function, n::Integer, maxDepth::Integer,
maxIterations::Integer, w::Float64)
maxIterations::Integer, w::Float64) where {T<:agent}
statetype = typeof(initialState)
root = MCTSNode(initialState, 0, 0.0, Dict{statetype, MCTSNode}())
@@ -548,7 +462,7 @@ function runMCTS(initialState, decisionMaker::Function, stateValueEstimator::Fun
node = select(node, w)
end
expand(node, node.state, decisionMaker, stateValueEstimator, n=n)
expand(a, node, node.state, decisionMaker, stateValueEstimator, n=n)
error("---> runMCTS")
leaf_node = node.children[node.state] # mark leaf node
reward = simulate(leaf_node.state, maxDepth)

View File

@@ -126,16 +126,18 @@ end
function sommelier(
receiveUserMsgChannel::Channel,
receiveInternalMsgChannel::Channel,
msgMeta::Dict= GeneralUtils.generate_msgMeta("N/A"),
config::Dict = Dict(
:frontend=> Dict(
:mqtttopic=> nothing
:mqttServerInfo=> Dict(
:broker=> nothing,
:port=> nothing,
),
:internal=> Dict(
:mqtttopic=> nothing
:receivemsg=> Dict(
:prompt=> nothing, # topic to receive prompt i.e. frontend send msg to this topic
:internal=> nothing,
),
:text2text=> Dict(
:mqtttopic=> "txt2text/api/v1/prompt/gpu",
:thirdPartyService=> Dict(
:text2textinstruct=> nothing,
:text2textchat=> nothing,
),
)
;