update
This commit is contained in:
@@ -186,7 +186,7 @@ function decisionMaker(a::T1, state::T2)::Dict{Symbol, Any} where {T1<:agent, T2
|
|||||||
|
|
||||||
_response = GeneralUtils.sendReceiveMqttMsg(outgoingMsg)
|
_response = GeneralUtils.sendReceiveMqttMsg(outgoingMsg)
|
||||||
_thoughtJsonStr = _response[:response][:text]
|
_thoughtJsonStr = _response[:response][:text]
|
||||||
thoughtJsonStr = jsoncorrection(a, _thoughtJsonStr, "")
|
thoughtJsonStr = jsoncorrection(a, _thoughtJsonStr, responseformat)
|
||||||
thoughtDict = copy(JSON3.read(thoughtJsonStr))
|
thoughtDict = copy(JSON3.read(thoughtJsonStr))
|
||||||
pprint(thoughtDict)
|
pprint(thoughtDict)
|
||||||
return thoughtDict
|
return thoughtDict
|
||||||
@@ -324,11 +324,14 @@ function reflector()
|
|||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
"""
|
""" Determine whether the state is a terminal state
|
||||||
|
|
||||||
# Arguments
|
# Arguments
|
||||||
|
- `state::T`
|
||||||
|
a game state
|
||||||
|
|
||||||
# Return
|
# Return
|
||||||
|
- `(isterminal, reward)::Tuple{Bool, Number}`
|
||||||
|
|
||||||
# Example
|
# Example
|
||||||
```jldoctest
|
```jldoctest
|
||||||
@@ -336,13 +339,19 @@ julia>
|
|||||||
```
|
```
|
||||||
|
|
||||||
# TODO
|
# TODO
|
||||||
- [] update docstring
|
- [x] update docstring
|
||||||
- [] implement the function
|
- [TESTING] implement the function
|
||||||
|
|
||||||
# Signature
|
# Signature
|
||||||
"""
|
"""
|
||||||
function isterminal()
|
function isterminal(state::T)::Tuple{Bool, Number} where {T<:AbstractDict}
|
||||||
|
latestObservationKey, _ = GeneralUtils.findHighestIndexKey(state[:thoughtHistory], "Observation")
|
||||||
|
latestObservation = state[:thoughtHistory][latestObservationKey]
|
||||||
|
|
||||||
|
# terminal condition is when the user select wine by putting <<winename>> in latest observation
|
||||||
|
if occursin("<<", latestObservation) && occursin(">>", latestObservation)
|
||||||
|
return true, 1
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
@@ -423,7 +432,6 @@ function conversation(a::T, userinput::Dict) where {T<:agent}
|
|||||||
# deepcopy the info to prevent modifying the info unintentionally during MCTS planning
|
# deepcopy the info to prevent modifying the info unintentionally during MCTS planning
|
||||||
:customerinfo=> deepcopy(a.keywordinfo[:customerinfo]),
|
:customerinfo=> deepcopy(a.keywordinfo[:customerinfo]),
|
||||||
:storeinfo=> deepcopy(a.keywordinfo[:storeinfo]),
|
:storeinfo=> deepcopy(a.keywordinfo[:storeinfo]),
|
||||||
|
|
||||||
:thoughtHistory=> OrderedDict{Symbol, Any}( # contain question, thought_1, action_1, observation_1, thought_2, ...
|
:thoughtHistory=> OrderedDict{Symbol, Any}( # contain question, thought_1, action_1, observation_1, thought_2, ...
|
||||||
:Question=> userinput[:text],
|
:Question=> userinput[:text],
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -399,7 +399,7 @@ function jsoncorrection(a::T1, input::T2,
|
|||||||
correctjson = incorrectjson
|
correctjson = incorrectjson
|
||||||
break
|
break
|
||||||
catch
|
catch
|
||||||
println("Attempting correct JSON string. $attempting")
|
println("Attempting correct JSON string. $attemptround")
|
||||||
_prompt =
|
_prompt =
|
||||||
"""
|
"""
|
||||||
Your goal is to correct a given incorrect JSON string.
|
Your goal is to correct a given incorrect JSON string.
|
||||||
|
|||||||
30
src/mcts.jl
30
src/mcts.jl
@@ -53,6 +53,7 @@ struct MCTSNode{T<:AbstractDict}
|
|||||||
visits::Integer
|
visits::Integer
|
||||||
progressValue::Number
|
progressValue::Number
|
||||||
reward::Number
|
reward::Number
|
||||||
|
isterminal::Bool
|
||||||
parent::Union{MCTSNode, Nothing}
|
parent::Union{MCTSNode, Nothing}
|
||||||
children::Dict{String, MCTSNode}
|
children::Dict{String, MCTSNode}
|
||||||
end
|
end
|
||||||
@@ -126,19 +127,15 @@ function expand(a::T1, node::MCTSNode, decisionMaker::Function,
|
|||||||
# sampling action from decisionMaker
|
# sampling action from decisionMaker
|
||||||
for sample in 1:n
|
for sample in 1:n
|
||||||
thoughtDict = decisionMaker(a, node.state)
|
thoughtDict = decisionMaker(a, node.state)
|
||||||
@show node.state
|
|
||||||
@show thoughtDict
|
newNodeKey, newstate, isterminal, reward = MCTStransition(a, node.state, thoughtDict)
|
||||||
newNodeKey, newstate = MCTStransition(a, node.state, thoughtDict) #[] Implement your transition function
|
|
||||||
|
|
||||||
# add progressValueEstimator
|
# add progressValueEstimator
|
||||||
progressRationale, progressValue = progressValueEstimator(a, newstate)
|
progressRationale, progressValue = progressValueEstimator(a, newstate)
|
||||||
|
|
||||||
#[WORKING] check for terminal state
|
|
||||||
|
|
||||||
|
|
||||||
if newNodeKey ∉ keys(node.children)
|
if newNodeKey ∉ keys(node.children)
|
||||||
node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, progressValue, 0,
|
node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, progressValue,
|
||||||
node, Dict{String, MCTSNode}())
|
reward, isterminal, node, Dict{String, MCTSNode}())
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
@@ -242,15 +239,12 @@ julia> thoughtDict = Dict(
|
|||||||
- [] update docstring
|
- [] update docstring
|
||||||
- [PENDING] add other actions
|
- [PENDING] add other actions
|
||||||
- [] add embedding of newstate and store in newstate[:embedding]
|
- [] add embedding of newstate and store in newstate[:embedding]
|
||||||
|
- [x] check for terminal state and assign reward
|
||||||
|
|
||||||
# Signature
|
# Signature
|
||||||
"""
|
"""
|
||||||
function MCTStransition(a::T1, state::T2,
|
function MCTStransition(a::T1, state::T2,
|
||||||
thoughtDict::T3)::Tuple{String, Dict{Symbol, Any}} where {T1<:agent, T2<:AbstractDict, T3<:AbstractDict}
|
thoughtDict::T3)::Tuple{String, Dict{Symbol, Any}} where {T1<:agent, T2<:AbstractDict, T3<:AbstractDict}
|
||||||
println("")
|
|
||||||
# latestThoughtKey, _ = GeneralUtils.findHighestIndexKey(thoughtDict, "Thought")
|
|
||||||
# latestActionKey, latestActionIndice = GeneralUtils.findHighestIndexKey(thoughtDict, "Action")
|
|
||||||
# _action = thoughtDict[:Action]
|
|
||||||
actionname = thoughtDict[:Action][:name]
|
actionname = thoughtDict[:Action][:name]
|
||||||
actioninput = thoughtDict[:Action][:input]
|
actioninput = thoughtDict[:Action][:input]
|
||||||
|
|
||||||
@@ -266,8 +260,9 @@ function MCTStransition(a::T1, state::T2,
|
|||||||
|
|
||||||
end
|
end
|
||||||
|
|
||||||
_, latestThoughtIndice = GeneralUtils.findHighestIndexKey(state[:thoughtHistory], "Thought")
|
latestThoughtKey, latestThoughtIndice = GeneralUtils.findHighestIndexKey(state[:thoughtHistory],
|
||||||
nextIndice = latestThoughtIndice === nothing ? 1 : latestThoughtIndice + 1
|
"Thought")
|
||||||
|
nextIndice = latestThoughtKey == :NA ? 1 : latestThoughtIndice + 1
|
||||||
latestThoughtKey = Symbol("Thought_$nextIndice")
|
latestThoughtKey = Symbol("Thought_$nextIndice")
|
||||||
latestActionKey = Symbol("Action_$nextIndice")
|
latestActionKey = Symbol("Action_$nextIndice")
|
||||||
|
|
||||||
@@ -279,8 +274,9 @@ function MCTStransition(a::T1, state::T2,
|
|||||||
newstate[:thoughtHistory][latestObservationKey] = response
|
newstate[:thoughtHistory][latestObservationKey] = response
|
||||||
|
|
||||||
newNodeKey = GeneralUtils.uuid4snakecase()
|
newNodeKey = GeneralUtils.uuid4snakecase()
|
||||||
|
isterminalstate, reward = isterminal(newstate)
|
||||||
|
|
||||||
return newNodeKey, newstate
|
return newNodeKey, newstate, isterminalstate, reward
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
@@ -328,7 +324,7 @@ julia>
|
|||||||
|
|
||||||
# TODO
|
# TODO
|
||||||
- [] update docstring
|
- [] update docstring
|
||||||
- [WORKING] implement the function
|
- [x] implement the function
|
||||||
|
|
||||||
# Signature
|
# Signature
|
||||||
"""
|
"""
|
||||||
@@ -397,7 +393,7 @@ function runMCTS(
|
|||||||
maxIterations::Integer,
|
maxIterations::Integer,
|
||||||
w::Float64) where {T1<:agent}
|
w::Float64) where {T1<:agent}
|
||||||
|
|
||||||
root = MCTSNode("root", initialState, 0, 0, 0, nothing, Dict{String, MCTSNode}())
|
root = MCTSNode("root", initialState, 0, 0, 0, false, nothing, Dict{String, MCTSNode}())
|
||||||
|
|
||||||
for _ in 1:maxIterations
|
for _ in 1:maxIterations
|
||||||
node = root
|
node = root
|
||||||
|
|||||||
Reference in New Issue
Block a user