update
This commit is contained in:
58
src/mcts.jl
58
src/mcts.jl
@@ -110,33 +110,26 @@ julia>
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function expand(a::T1, node::MCTSNode, state::T2, decisionMaker::Function, stateValueEstimator::Function;
|
||||
n::Integer=3) where {T1<:agent, T2<:AbstractDict}
|
||||
function expand(a::T1, node::MCTSNode, state::T2, decisionMaker::Function,
|
||||
stateValueEstimator::Function; n::Integer=3) where {T1<:agent, T2<:AbstractDict}
|
||||
|
||||
# sampling action from decisionMaker
|
||||
for sample in 1:n
|
||||
thoughtJstr = decisionMaker(a, state)
|
||||
thoughtDict = copy(JSON3.read(thoughtJstr))
|
||||
|
||||
""" Example of thoughtDict
|
||||
Dict{Symbol, Any} with 3 entries:
|
||||
:Thought_1 => "The customer wants to buy a bottle of wine. This is a good start!"
|
||||
:Action_1 => Dict{Symbol, Any}(
|
||||
:action=>"Chatbox",
|
||||
:input=>"What occasion are you buying the wine for?"
|
||||
)
|
||||
:Observation_1 => ""
|
||||
"""
|
||||
thoughtDict = decisionMaker(a, state)
|
||||
@show state
|
||||
@show thoughtDict
|
||||
newstate = MCTStransition(a, node.state, thoughtDict) #[] Implement your transition function
|
||||
newStatekey, newstate = MCTStransition(a, node.state, thoughtDict) #[] Implement your transition function
|
||||
|
||||
if newstate ∉ keys(node.children)# BUG should be "key of the newstate" here not newstate itself
|
||||
if newStatekey ∉ keys(node.children)# BUG should be "key of the newstate" here not newstate itself
|
||||
statetype = typeof(state)
|
||||
|
||||
# BUG should be node.children[key of newstate] here not newstate. may be a uuid
|
||||
node.children[newstate] = MCTSNode(newstate, 0, 0.0, Dict{statetype, MCTSNode}())
|
||||
node.children[newStatekey] = MCTSNode(newstate, 0, 0.0, Dict{statetype, MCTSNode}())
|
||||
end
|
||||
|
||||
# add stateValueEstimator
|
||||
|
||||
|
||||
|
||||
|
||||
end
|
||||
end
|
||||
|
||||
@@ -235,6 +228,7 @@ julia> thoughtDict = Dict(
|
||||
# TODO
|
||||
- [] update docstring
|
||||
- [PENDING] add other actions
|
||||
- [] add embedding newstate then store in newstate[:embedding]
|
||||
|
||||
# Signature
|
||||
"""
|
||||
@@ -265,7 +259,9 @@ function MCTStransition(a::T1, state::T2,
|
||||
latestObservationKey = Symbol("Observation_$(latestActionIndice)")
|
||||
newstate[:thoughtHistory][latestObservationKey] = response
|
||||
|
||||
return newstate
|
||||
newStatekey = Symbol(GeneralUtils.uuid4snakecase())
|
||||
|
||||
return newStatekey, newstate
|
||||
end
|
||||
|
||||
|
||||
@@ -402,4 +398,24 @@ end
|
||||
|
||||
|
||||
|
||||
end
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
end # module mcts
|
||||
Reference in New Issue
Block a user