This commit is contained in:
narawat lamaiin
2024-05-04 21:17:02 +07:00
parent dea3f0260e
commit 8907156522
3 changed files with 28 additions and 24 deletions

View File

@@ -186,7 +186,7 @@ function decisionMaker(a::T1, state::T2)::Dict{Symbol, Any} where {T1<:agent, T2
_response = GeneralUtils.sendReceiveMqttMsg(outgoingMsg)
_thoughtJsonStr = _response[:response][:text]
thoughtJsonStr = jsoncorrection(a, _thoughtJsonStr, "")
thoughtJsonStr = jsoncorrection(a, _thoughtJsonStr, responseformat)
thoughtDict = copy(JSON3.read(thoughtJsonStr))
pprint(thoughtDict)
return thoughtDict
@@ -324,11 +324,14 @@ function reflector()
end
"""
""" Determine whether the state is a terminal state
# Arguments
- `state::T`
a game state
# Return
- `(isterminal, reward)::Tuple{Bool, Number}`
# Example
```jldoctest
@@ -336,13 +339,19 @@ julia>
```
# TODO
- [] update docstring
- [] implement the function
- [x] update docstring
- [TESTING] implement the function
# Signature
"""
function isterminal()
function isterminal(state::T)::Tuple{Bool, Number} where {T<:AbstractDict}
latestObservationKey, _ = GeneralUtils.findHighestIndexKey(state[:thoughtHistory], "Observation")
latestObservation = state[:thoughtHistory][latestObservationKey]
# terminal condition is when the user select wine by putting <<winename>> in latest observation
if occursin("<<", latestObservation) && occursin(">>", latestObservation)
return true, 1
end
end
@@ -423,7 +432,6 @@ function conversation(a::T, userinput::Dict) where {T<:agent}
# deepcopy the info to prevent modifying the info unintentionally during MCTS planning
:customerinfo=> deepcopy(a.keywordinfo[:customerinfo]),
:storeinfo=> deepcopy(a.keywordinfo[:storeinfo]),
:thoughtHistory=> OrderedDict{Symbol, Any}( # contain question, thought_1, action_1, observation_1, thought_2, ...
:Question=> userinput[:text],
)

View File

@@ -399,7 +399,7 @@ function jsoncorrection(a::T1, input::T2,
correctjson = incorrectjson
break
catch
println("Attempting correct JSON string. $attempting")
println("Attempting correct JSON string. $attemptround")
_prompt =
"""
Your goal is to correct a given incorrect JSON string.

View File

@@ -53,6 +53,7 @@ struct MCTSNode{T<:AbstractDict}
visits::Integer
progressValue::Number
reward::Number
isterminal::Bool
parent::Union{MCTSNode, Nothing}
children::Dict{String, MCTSNode}
end
@@ -126,19 +127,15 @@ function expand(a::T1, node::MCTSNode, decisionMaker::Function,
# sampling action from decisionMaker
for sample in 1:n
thoughtDict = decisionMaker(a, node.state)
@show node.state
@show thoughtDict
newNodeKey, newstate = MCTStransition(a, node.state, thoughtDict) #[] Implement your transition function
newNodeKey, newstate, isterminal, reward = MCTStransition(a, node.state, thoughtDict)
# add progressValueEstimator
progressRationale, progressValue = progressValueEstimator(a, newstate)
#[WORKING] check for terminal state
if newNodeKey keys(node.children)
node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, progressValue, 0,
node, Dict{String, MCTSNode}())
node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, progressValue,
reward, isterminal, node, Dict{String, MCTSNode}())
end
end
end
@@ -242,15 +239,12 @@ julia> thoughtDict = Dict(
- [] update docstring
- [PENDING] add other actions
- [] add embedding of newstate and store in newstate[:embedding]
- [x] check for terminal state and assign reward
# Signature
"""
function MCTStransition(a::T1, state::T2,
thoughtDict::T3)::Tuple{String, Dict{Symbol, Any}} where {T1<:agent, T2<:AbstractDict, T3<:AbstractDict}
println("")
# latestThoughtKey, _ = GeneralUtils.findHighestIndexKey(thoughtDict, "Thought")
# latestActionKey, latestActionIndice = GeneralUtils.findHighestIndexKey(thoughtDict, "Action")
# _action = thoughtDict[:Action]
actionname = thoughtDict[:Action][:name]
actioninput = thoughtDict[:Action][:input]
@@ -266,8 +260,9 @@ function MCTStransition(a::T1, state::T2,
end
_, latestThoughtIndice = GeneralUtils.findHighestIndexKey(state[:thoughtHistory], "Thought")
nextIndice = latestThoughtIndice === nothing ? 1 : latestThoughtIndice + 1
latestThoughtKey, latestThoughtIndice = GeneralUtils.findHighestIndexKey(state[:thoughtHistory],
"Thought")
nextIndice = latestThoughtKey == :NA ? 1 : latestThoughtIndice + 1
latestThoughtKey = Symbol("Thought_$nextIndice")
latestActionKey = Symbol("Action_$nextIndice")
@@ -279,8 +274,9 @@ function MCTStransition(a::T1, state::T2,
newstate[:thoughtHistory][latestObservationKey] = response
newNodeKey = GeneralUtils.uuid4snakecase()
isterminalstate, reward = isterminal(newstate)
return newNodeKey, newstate
return newNodeKey, newstate, isterminalstate, reward
end
@@ -328,7 +324,7 @@ julia>
# TODO
- [] update docstring
- [WORKING] implement the function
- [x] implement the function
# Signature
"""
@@ -397,7 +393,7 @@ function runMCTS(
maxIterations::Integer,
w::Float64) where {T1<:agent}
root = MCTSNode("root", initialState, 0, 0, 0, nothing, Dict{String, MCTSNode}())
root = MCTSNode("root", initialState, 0, 0, 0, false, nothing, Dict{String, MCTSNode}())
for _ in 1:maxIterations
node = root