This commit is contained in:
2024-05-01 13:15:48 +07:00
parent 513f159be1
commit fba99ab695
3 changed files with 54 additions and 34 deletions

View File

@@ -110,33 +110,26 @@ julia>
# Signature
"""
function expand(a::T1, node::MCTSNode, state::T2, decisionMaker::Function, stateValueEstimator::Function;
n::Integer=3) where {T1<:agent, T2<:AbstractDict}
function expand(a::T1, node::MCTSNode, state::T2, decisionMaker::Function,
stateValueEstimator::Function; n::Integer=3) where {T1<:agent, T2<:AbstractDict}
# sampling action from decisionMaker
for sample in 1:n
thoughtJstr = decisionMaker(a, state)
thoughtDict = copy(JSON3.read(thoughtJstr))
""" Example of thoughtDict
Dict{Symbol, Any} with 3 entries:
:Thought_1 => "The customer wants to buy a bottle of wine. This is a good start!"
:Action_1 => Dict{Symbol, Any}(
:action=>"Chatbox",
:input=>"What occasion are you buying the wine for?"
)
:Observation_1 => ""
"""
thoughtDict = decisionMaker(a, state)
@show state
@show thoughtDict
newstate = MCTStransition(a, node.state, thoughtDict) #[] Implement your transition function
newStatekey, newstate = MCTStransition(a, node.state, thoughtDict) #[] Implement your transition function
if newstate keys(node.children)# BUG should be "key of the newstate" here not newstate itself
if newStatekey keys(node.children)# BUG should be "key of the newstate" here not newstate itself
statetype = typeof(state)
# BUG should be node.children[key of newstate] here not newstate. may be a uuid
node.children[newstate] = MCTSNode(newstate, 0, 0.0, Dict{statetype, MCTSNode}())
node.children[newStatekey] = MCTSNode(newstate, 0, 0.0, Dict{statetype, MCTSNode}())
end
# add stateValueEstimator
end
end
@@ -235,6 +228,7 @@ julia> thoughtDict = Dict(
# TODO
- [] update docstring
- [PENDING] add other actions
- [] add embedding newstate then store in newstate[:embedding]
# Signature
"""
@@ -265,7 +259,9 @@ function MCTStransition(a::T1, state::T2,
latestObservationKey = Symbol("Observation_$(latestActionIndice)")
newstate[:thoughtHistory][latestObservationKey] = response
return newstate
newStatekey = Symbol(GeneralUtils.uuid4snakecase())
return newStatekey, newstate
end
@@ -402,4 +398,24 @@ end
end
end # module mcts