This commit is contained in:
narawat lamaiin
2024-04-26 21:25:46 +07:00
parent 54fc9c67bc
commit 8f68d177e7
3 changed files with 173 additions and 140 deletions

View File

@@ -33,6 +33,122 @@ using ..type, ..util, ..llmfunction, ..mcts
# ---------------------------------------------- 100 --------------------------------------------- # # ---------------------------------------------- 100 --------------------------------------------- #
""" Think and choose action
Arguments\n
-----
state::T
a game state
Return\n
-----
thought::Dict
Example\n
-----
```jldoctest
julia>
```
TODO\n
-----
[] update docstring
[x] implement the function
[] implement RAG to pull similar experience
[] use iterative prompting to ensure JSON format
Signature\n
-----
"""
function decisionMaker(state::T) where {T<:AbstractDict}
customerinfo =
"""
I will give you the following information about customer:
$(JSON3.write(state[:customerinfo]))
"""
storeinfo =
"""
I will give you the following information about your store:
$(JSON3.write(state[:storeinfo]))
"""
#[x]
prompt =
"""
You are a helpful sommelier working for a wine store.
You helps users by searching wine that match the user preferences from your inventory.
$customerinfo
You must follow the
"""
thought = iterativeprompting(prompt, syntaxcheck_json)
error("--> decisionMaker")
return thought
end
"""
Arguments\n
-----
Return\n
-----
Example\n
-----
```jldoctest
julia>
```
TODO\n
-----
[] update docstring
[] implement the function
Signature\n
-----
"""
function stateValueEstimator()
end
"""
Arguments\n
-----
Return\n
-----
Example\n
-----
```jldoctest
julia>
```
TODO\n
-----
[] update docstring
[] implement the function
Signature\n
-----
"""
function reflector()
end
""" Chat with llm. """ Chat with llm.
Arguments\n Arguments\n
@@ -117,7 +233,7 @@ function conversation(a::T, userinput::Dict) where {T<:agent}
:question=> userinput[:text], :question=> userinput[:text],
) )
) )
bestplan = runMCTS(initialState, decisionMaker, stateValueEstimator, reflector, bestplan = runMCTS(a, initialState, decisionMaker, stateValueEstimator, reflector,
3, 10, 1000, 1.0) 3, 10, 1000, 1.0)
error("---> bestplan") error("---> bestplan")
@@ -178,6 +294,7 @@ end
end # module interface end # module interface

View File

@@ -5,10 +5,11 @@
module mcts module mcts
export MCTSNode, runMCTS, decisionMaker, stateValueEstimator, reflector export MCTSNode, runMCTS
using Dates, UUIDs, DataStructures, JSON3, Random using Dates, UUIDs, DataStructures, JSON3, Random
using GeneralUtils using GeneralUtils
using ..type
# ---------------------------------------------- 100 --------------------------------------------- # # ---------------------------------------------- 100 --------------------------------------------- #
@@ -123,17 +124,17 @@ end
TODO\n TODO\n
----- -----
[] update docstring [] update docstring
[WORKING] implement the function [x] implement the function
Signature\n Signature\n
----- -----
""" """
function expand(node::MCTSNode, state::T, decisionMaker::Function, stateValueEstimator::Function; function expand(a::T1, node::MCTSNode, state::T2, decisionMaker::Function, stateValueEstimator::Function;
n::Integer=3) where {T<:AbstractDict} n::Integer=3) where {T1<:agent, T2<:AbstractDict}
# sampling action from decisionMaker # sampling action from decisionMaker
for sample in 1:n for sample in 1:n
thought = decisionMaker(state) thought = decisionMaker(a, state)
error("--> expand") error("--> expand")
newState = transition(node.state, action) #[] Implement your transition function newState = transition(node.state, action) #[] Implement your transition function
if newState keys(node.children) if newState keys(node.children)
@@ -265,77 +266,7 @@ end
""" """
isLeaf(node::MCTSNode)::Bool = isempty(node.children) isLeaf(node::MCTSNode)::Bool = isempty(node.children)
""" Think and choose action
Arguments\n
-----
state::T
a game state
Return\n
-----
thought::Dict
Example\n
-----
```jldoctest
julia>
```
TODO\n
-----
[] update docstring
[WORKING] implement the function
[] implement RAG to pull similar experience
[] use iterative prompting to ensure JSON format
Signature\n
-----
"""
function decisionMaker(state::T) where {T<:AbstractDict}
customerinfo =
"""
I will give you the following information about customer:
$(JSON3.write(state[:customerinfo]))
"""
storeinfo =
"""
I will give you the following information about your store:
$(JSON3.write(state[:storeinfo]))
"""
prompt =
"""
You are a helpful sommelier working for a wine store.
You helps users by searching wine that match the user preferences from your inventory.
$customerinfo
You must follow the
"""
result = iterativeprompting(prompt, syntaxcheck("json"))
error("--> decisionMaker")
end
""" """
@@ -359,11 +290,37 @@ end
Signature\n Signature\n
----- -----
""" """
function iterativeprompting(prommpt::String, verification::Function) function iterativeprompting(a::T, prompt::String, verification::Function) where {T<:agent}
msgMeta = GeneralUtils.generate_msgMeta(
a.config[:thirdPartyService][:text2textinstruct],
)
outgoing =
# iteration loop
while true
# send prompt to LLM
response = GeneralUtils.sendReceiveMqttMsg()
# check for correctness and get feedback
if correct
break
else # add feedback to prompt
end
end
return
end end
"""
""" Check JSON format correctness, provide
Arguments\n Arguments\n
----- -----
@@ -387,62 +344,19 @@ end
""" """
function syntaxcheck_json(jsonstring::String)::NamedTuple function syntaxcheck_json(jsonstring::String)::NamedTuple
success, result, errormsg, st = GeneralUtils.showstracktrace(JSON3.read, jsonstring) success, result, errormsg, st = GeneralUtils.showstracktrace(JSON3.read, jsonstring)
if !success # gives feedback
return
else
end
return (success, result, critique)
end end
"""
Arguments\n
-----
Return\n
-----
Example\n
-----
```jldoctest
julia>
```
TODO\n
-----
[] update docstring
[] implement the function
Signature\n
-----
"""
function stateValueEstimator()
end
"""
Arguments\n
-----
Return\n
-----
Example\n
-----
```jldoctest
julia>
```
TODO\n
-----
[] update docstring
[] implement the function
Signature\n
-----
"""
function reflector()
end
""" """
@@ -536,9 +450,9 @@ end
Signature\n Signature\n
----- -----
""" """
function runMCTS(initialState, decisionMaker::Function, stateValueEstimator::Function, function runMCTS(a::T, initialState, decisionMaker::Function, stateValueEstimator::Function,
reflector::Function, n::Integer, maxDepth::Integer, reflector::Function, n::Integer, maxDepth::Integer,
maxIterations::Integer, w::Float64) maxIterations::Integer, w::Float64) where {T<:agent}
statetype = typeof(initialState) statetype = typeof(initialState)
root = MCTSNode(initialState, 0, 0.0, Dict{statetype, MCTSNode}()) root = MCTSNode(initialState, 0, 0.0, Dict{statetype, MCTSNode}())
@@ -548,7 +462,7 @@ function runMCTS(initialState, decisionMaker::Function, stateValueEstimator::Fun
node = select(node, w) node = select(node, w)
end end
expand(node, node.state, decisionMaker, stateValueEstimator, n=n) expand(a, node, node.state, decisionMaker, stateValueEstimator, n=n)
error("---> runMCTS") error("---> runMCTS")
leaf_node = node.children[node.state] # mark leaf node leaf_node = node.children[node.state] # mark leaf node
reward = simulate(leaf_node.state, maxDepth) reward = simulate(leaf_node.state, maxDepth)

View File

@@ -126,16 +126,18 @@ end
function sommelier( function sommelier(
receiveUserMsgChannel::Channel, receiveUserMsgChannel::Channel,
receiveInternalMsgChannel::Channel, receiveInternalMsgChannel::Channel,
msgMeta::Dict= GeneralUtils.generate_msgMeta("N/A"),
config::Dict = Dict( config::Dict = Dict(
:frontend=> Dict( :mqttServerInfo=> Dict(
:mqtttopic=> nothing :broker=> nothing,
:port=> nothing,
), ),
:internal=> Dict( :receivemsg=> Dict(
:mqtttopic=> nothing :prompt=> nothing, # topic to receive prompt i.e. frontend send msg to this topic
:internal=> nothing,
), ),
:text2text=> Dict( :thirdPartyService=> Dict(
:mqtttopic=> "txt2text/api/v1/prompt/gpu", :text2textinstruct=> nothing,
:text2textchat=> nothing,
), ),
) )
; ;