update
This commit is contained in:
@@ -331,27 +331,35 @@ end
|
||||
a game state
|
||||
|
||||
# Return
|
||||
- `(isterminal, reward)::Tuple{Bool, Number}`
|
||||
- `(isterminalstate, reward)::Tuple{Bool, <:Number}`
|
||||
|
||||
# Example
|
||||
```jldoctest
|
||||
julia>
|
||||
```
|
||||
|
||||
# TODO
|
||||
- [x] update docstring
|
||||
- [TESTING] implement the function
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function isterminal(state::T)::Tuple{Bool, Number} where {T<:AbstractDict}
|
||||
function isterminal(state::T)::Tuple{Bool, <:Number} where {T<:AbstractDict}
|
||||
latestObservationKey, _ = GeneralUtils.findHighestIndexKey(state[:thoughtHistory], "Observation")
|
||||
latestObservation = state[:thoughtHistory][latestObservationKey]
|
||||
|
||||
if latestObservation !== nothing
|
||||
|
||||
# terminal condition is when the user select wine by putting <<winename>> in latest observation
|
||||
if occursin("<<", latestObservation) && occursin(">>", latestObservation)
|
||||
return true, 1
|
||||
isterminalstate = true
|
||||
reward = 1
|
||||
else
|
||||
isterminalstate = false
|
||||
reward = 0
|
||||
end
|
||||
else
|
||||
isterminalstate = false
|
||||
reward = 0
|
||||
end
|
||||
|
||||
return (isterminalstate, reward)
|
||||
end
|
||||
|
||||
|
||||
|
||||
22
src/mcts.jl
22
src/mcts.jl
@@ -122,20 +122,21 @@ julia>
|
||||
# Signature
|
||||
"""
|
||||
function expand(a::T1, node::MCTSNode, decisionMaker::Function,
|
||||
progressValueEstimator::Function; n::Integer=3) where {T1<:agent}
|
||||
progressValueEstimator::Function, isterminal::Function; n::Integer=3) where {T1<:agent}
|
||||
|
||||
# sampling action from decisionMaker
|
||||
for sample in 1:n
|
||||
thoughtDict = decisionMaker(a, node.state)
|
||||
|
||||
newNodeKey, newstate, isterminal, reward = MCTStransition(a, node.state, thoughtDict)
|
||||
newNodeKey, newstate, isterminalstate, reward = MCTStransition(a, node.state, thoughtDict,
|
||||
isterminal)
|
||||
|
||||
# add progressValueEstimator
|
||||
progressRationale, progressValue = progressValueEstimator(a, newstate)
|
||||
|
||||
if newNodeKey ∉ keys(node.children)
|
||||
node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, progressValue,
|
||||
reward, isterminal, node, Dict{String, MCTSNode}())
|
||||
reward, isterminalstate, node, Dict{String, MCTSNode}())
|
||||
end
|
||||
end
|
||||
end
|
||||
@@ -158,11 +159,12 @@ julia>
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function simulate(a, node::MCTSNode, decisionMaker, progressValueEstimator, max_depth::Int; n=3)
|
||||
function simulate(a, node::MCTSNode, decisionMaker::Function, progressValueEstimator::Function,
|
||||
isterminal::Function, max_depth::Int; n=3)
|
||||
|
||||
for _ in 1:max_depth
|
||||
node = selectChildNode(node)
|
||||
expand(a, node, decisionMaker, progressValueEstimator, n=n)
|
||||
expand(a, node, decisionMaker, progressValueEstimator, isterminal, n=n)
|
||||
|
||||
# if isterminal (use for loop over node to look for childNode.reward != 0)
|
||||
|
||||
@@ -243,8 +245,9 @@ julia> thoughtDict = Dict(
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function MCTStransition(a::T1, state::T2,
|
||||
thoughtDict::T3)::Tuple{String, Dict{Symbol, Any}} where {T1<:agent, T2<:AbstractDict, T3<:AbstractDict}
|
||||
function MCTStransition(a::T1, state::T2, thoughtDict::T3, isterminal::Function
|
||||
)::Tuple{String, Dict{Symbol, <:Any}, Bool, <:Number} where {T1<:agent, T2<:AbstractDict, T3<:AbstractDict}
|
||||
|
||||
actionname = thoughtDict[:Action][:name]
|
||||
actioninput = thoughtDict[:Action][:input]
|
||||
|
||||
@@ -401,11 +404,12 @@ function runMCTS(
|
||||
node = UCTselect(node, w)
|
||||
end
|
||||
|
||||
expand(a, node, decisionMaker, progressValueEstimator, n=n)
|
||||
expand(a, node, decisionMaker, progressValueEstimator, isterminal, n=n)
|
||||
|
||||
# from paper, just start simulation at this node. Not the node that newly expanded
|
||||
startsim_node = node
|
||||
reward = simulate(a, startsim_node, decisionMaker, progressValueEstimator, maxDepth, n=n)
|
||||
reward = simulate(a, startsim_node, decisionMaker, progressValueEstimator,
|
||||
isterminal, maxDepth, n=n)
|
||||
backpropagate(leaf_node, reward)
|
||||
end
|
||||
|
||||
|
||||
Reference in New Issue
Block a user