This commit is contained in:
narawat lamaiin
2024-05-04 21:17:02 +07:00
parent dea3f0260e
commit 8907156522
3 changed files with 28 additions and 24 deletions

View File

@@ -186,7 +186,7 @@ function decisionMaker(a::T1, state::T2)::Dict{Symbol, Any} where {T1<:agent, T2
_response = GeneralUtils.sendReceiveMqttMsg(outgoingMsg) _response = GeneralUtils.sendReceiveMqttMsg(outgoingMsg)
_thoughtJsonStr = _response[:response][:text] _thoughtJsonStr = _response[:response][:text]
thoughtJsonStr = jsoncorrection(a, _thoughtJsonStr, "") thoughtJsonStr = jsoncorrection(a, _thoughtJsonStr, responseformat)
thoughtDict = copy(JSON3.read(thoughtJsonStr)) thoughtDict = copy(JSON3.read(thoughtJsonStr))
pprint(thoughtDict) pprint(thoughtDict)
return thoughtDict return thoughtDict
@@ -324,11 +324,14 @@ function reflector()
end end
""" """ Determine whether the state is a terminal state
# Arguments # Arguments
- `state::T`
a game state
# Return # Return
- `(isterminal, reward)::Tuple{Bool, Number}`
# Example # Example
```jldoctest ```jldoctest
@@ -336,13 +339,19 @@ julia>
``` ```
# TODO # TODO
- [] update docstring - [x] update docstring
- [] implement the function - [TESTING] implement the function
# Signature # Signature
""" """
function isterminal() function isterminal(state::T)::Tuple{Bool, Number} where {T<:AbstractDict}
latestObservationKey, _ = GeneralUtils.findHighestIndexKey(state[:thoughtHistory], "Observation")
latestObservation = state[:thoughtHistory][latestObservationKey]
# terminal condition is when the user select wine by putting <<winename>> in latest observation
if occursin("<<", latestObservation) && occursin(">>", latestObservation)
return true, 1
end
end end
@@ -423,7 +432,6 @@ function conversation(a::T, userinput::Dict) where {T<:agent}
# deepcopy the info to prevent modifying the info unintentionally during MCTS planning # deepcopy the info to prevent modifying the info unintentionally during MCTS planning
:customerinfo=> deepcopy(a.keywordinfo[:customerinfo]), :customerinfo=> deepcopy(a.keywordinfo[:customerinfo]),
:storeinfo=> deepcopy(a.keywordinfo[:storeinfo]), :storeinfo=> deepcopy(a.keywordinfo[:storeinfo]),
:thoughtHistory=> OrderedDict{Symbol, Any}( # contain question, thought_1, action_1, observation_1, thought_2, ... :thoughtHistory=> OrderedDict{Symbol, Any}( # contain question, thought_1, action_1, observation_1, thought_2, ...
:Question=> userinput[:text], :Question=> userinput[:text],
) )

View File

@@ -399,7 +399,7 @@ function jsoncorrection(a::T1, input::T2,
correctjson = incorrectjson correctjson = incorrectjson
break break
catch catch
println("Attempting correct JSON string. $attempting") println("Attempting correct JSON string. $attemptround")
_prompt = _prompt =
""" """
Your goal is to correct a given incorrect JSON string. Your goal is to correct a given incorrect JSON string.

View File

@@ -53,6 +53,7 @@ struct MCTSNode{T<:AbstractDict}
visits::Integer visits::Integer
progressValue::Number progressValue::Number
reward::Number reward::Number
isterminal::Bool
parent::Union{MCTSNode, Nothing} parent::Union{MCTSNode, Nothing}
children::Dict{String, MCTSNode} children::Dict{String, MCTSNode}
end end
@@ -126,19 +127,15 @@ function expand(a::T1, node::MCTSNode, decisionMaker::Function,
# sampling action from decisionMaker # sampling action from decisionMaker
for sample in 1:n for sample in 1:n
thoughtDict = decisionMaker(a, node.state) thoughtDict = decisionMaker(a, node.state)
@show node.state
@show thoughtDict newNodeKey, newstate, isterminal, reward = MCTStransition(a, node.state, thoughtDict)
newNodeKey, newstate = MCTStransition(a, node.state, thoughtDict) #[] Implement your transition function
# add progressValueEstimator # add progressValueEstimator
progressRationale, progressValue = progressValueEstimator(a, newstate) progressRationale, progressValue = progressValueEstimator(a, newstate)
#[WORKING] check for terminal state
if newNodeKey keys(node.children) if newNodeKey keys(node.children)
node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, progressValue, 0, node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, progressValue,
node, Dict{String, MCTSNode}()) reward, isterminal, node, Dict{String, MCTSNode}())
end end
end end
end end
@@ -242,15 +239,12 @@ julia> thoughtDict = Dict(
- [] update docstring - [] update docstring
- [PENDING] add other actions - [PENDING] add other actions
- [] add embedding of newstate and store in newstate[:embedding] - [] add embedding of newstate and store in newstate[:embedding]
- [x] check for terminal state and assign reward
# Signature # Signature
""" """
function MCTStransition(a::T1, state::T2, function MCTStransition(a::T1, state::T2,
thoughtDict::T3)::Tuple{String, Dict{Symbol, Any}} where {T1<:agent, T2<:AbstractDict, T3<:AbstractDict} thoughtDict::T3)::Tuple{String, Dict{Symbol, Any}} where {T1<:agent, T2<:AbstractDict, T3<:AbstractDict}
println("")
# latestThoughtKey, _ = GeneralUtils.findHighestIndexKey(thoughtDict, "Thought")
# latestActionKey, latestActionIndice = GeneralUtils.findHighestIndexKey(thoughtDict, "Action")
# _action = thoughtDict[:Action]
actionname = thoughtDict[:Action][:name] actionname = thoughtDict[:Action][:name]
actioninput = thoughtDict[:Action][:input] actioninput = thoughtDict[:Action][:input]
@@ -266,8 +260,9 @@ function MCTStransition(a::T1, state::T2,
end end
_, latestThoughtIndice = GeneralUtils.findHighestIndexKey(state[:thoughtHistory], "Thought") latestThoughtKey, latestThoughtIndice = GeneralUtils.findHighestIndexKey(state[:thoughtHistory],
nextIndice = latestThoughtIndice === nothing ? 1 : latestThoughtIndice + 1 "Thought")
nextIndice = latestThoughtKey == :NA ? 1 : latestThoughtIndice + 1
latestThoughtKey = Symbol("Thought_$nextIndice") latestThoughtKey = Symbol("Thought_$nextIndice")
latestActionKey = Symbol("Action_$nextIndice") latestActionKey = Symbol("Action_$nextIndice")
@@ -279,8 +274,9 @@ function MCTStransition(a::T1, state::T2,
newstate[:thoughtHistory][latestObservationKey] = response newstate[:thoughtHistory][latestObservationKey] = response
newNodeKey = GeneralUtils.uuid4snakecase() newNodeKey = GeneralUtils.uuid4snakecase()
isterminalstate, reward = isterminal(newstate)
return newNodeKey, newstate return newNodeKey, newstate, isterminalstate, reward
end end
@@ -328,7 +324,7 @@ julia>
# TODO # TODO
- [] update docstring - [] update docstring
- [WORKING] implement the function - [x] implement the function
# Signature # Signature
""" """
@@ -397,7 +393,7 @@ function runMCTS(
maxIterations::Integer, maxIterations::Integer,
w::Float64) where {T1<:agent} w::Float64) where {T1<:agent}
root = MCTSNode("root", initialState, 0, 0, 0, nothing, Dict{String, MCTSNode}()) root = MCTSNode("root", initialState, 0, 0, 0, false, nothing, Dict{String, MCTSNode}())
for _ in 1:maxIterations for _ in 1:maxIterations
node = root node = root