update
This commit is contained in:
@@ -331,27 +331,35 @@ end
|
|||||||
a game state
|
a game state
|
||||||
|
|
||||||
# Return
|
# Return
|
||||||
- `(isterminal, reward)::Tuple{Bool, Number}`
|
- `(isterminalstate, reward)::Tuple{Bool, <:Number}`
|
||||||
|
|
||||||
# Example
|
# Example
|
||||||
```jldoctest
|
```jldoctest
|
||||||
julia>
|
julia>
|
||||||
```
|
```
|
||||||
|
|
||||||
# TODO
|
|
||||||
- [x] update docstring
|
|
||||||
- [TESTING] implement the function
|
|
||||||
|
|
||||||
# Signature
|
# Signature
|
||||||
"""
|
"""
|
||||||
function isterminal(state::T)::Tuple{Bool, Number} where {T<:AbstractDict}
|
function isterminal(state::T)::Tuple{Bool, <:Number} where {T<:AbstractDict}
|
||||||
latestObservationKey, _ = GeneralUtils.findHighestIndexKey(state[:thoughtHistory], "Observation")
|
latestObservationKey, _ = GeneralUtils.findHighestIndexKey(state[:thoughtHistory], "Observation")
|
||||||
latestObservation = state[:thoughtHistory][latestObservationKey]
|
latestObservation = state[:thoughtHistory][latestObservationKey]
|
||||||
|
|
||||||
# terminal condition is when the user select wine by putting <<winename>> in latest observation
|
if latestObservation !== nothing
|
||||||
if occursin("<<", latestObservation) && occursin(">>", latestObservation)
|
|
||||||
return true, 1
|
# terminal condition is when the user select wine by putting <<winename>> in latest observation
|
||||||
|
if occursin("<<", latestObservation) && occursin(">>", latestObservation)
|
||||||
|
isterminalstate = true
|
||||||
|
reward = 1
|
||||||
|
else
|
||||||
|
isterminalstate = false
|
||||||
|
reward = 0
|
||||||
|
end
|
||||||
|
else
|
||||||
|
isterminalstate = false
|
||||||
|
reward = 0
|
||||||
end
|
end
|
||||||
|
|
||||||
|
return (isterminalstate, reward)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
22
src/mcts.jl
22
src/mcts.jl
@@ -122,20 +122,21 @@ julia>
|
|||||||
# Signature
|
# Signature
|
||||||
"""
|
"""
|
||||||
function expand(a::T1, node::MCTSNode, decisionMaker::Function,
|
function expand(a::T1, node::MCTSNode, decisionMaker::Function,
|
||||||
progressValueEstimator::Function; n::Integer=3) where {T1<:agent}
|
progressValueEstimator::Function, isterminal::Function; n::Integer=3) where {T1<:agent}
|
||||||
|
|
||||||
# sampling action from decisionMaker
|
# sampling action from decisionMaker
|
||||||
for sample in 1:n
|
for sample in 1:n
|
||||||
thoughtDict = decisionMaker(a, node.state)
|
thoughtDict = decisionMaker(a, node.state)
|
||||||
|
|
||||||
newNodeKey, newstate, isterminal, reward = MCTStransition(a, node.state, thoughtDict)
|
newNodeKey, newstate, isterminalstate, reward = MCTStransition(a, node.state, thoughtDict,
|
||||||
|
isterminal)
|
||||||
|
|
||||||
# add progressValueEstimator
|
# add progressValueEstimator
|
||||||
progressRationale, progressValue = progressValueEstimator(a, newstate)
|
progressRationale, progressValue = progressValueEstimator(a, newstate)
|
||||||
|
|
||||||
if newNodeKey ∉ keys(node.children)
|
if newNodeKey ∉ keys(node.children)
|
||||||
node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, progressValue,
|
node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, progressValue,
|
||||||
reward, isterminal, node, Dict{String, MCTSNode}())
|
reward, isterminalstate, node, Dict{String, MCTSNode}())
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
@@ -158,11 +159,12 @@ julia>
|
|||||||
|
|
||||||
# Signature
|
# Signature
|
||||||
"""
|
"""
|
||||||
function simulate(a, node::MCTSNode, decisionMaker, progressValueEstimator, max_depth::Int; n=3)
|
function simulate(a, node::MCTSNode, decisionMaker::Function, progressValueEstimator::Function,
|
||||||
|
isterminal::Function, max_depth::Int; n=3)
|
||||||
|
|
||||||
for _ in 1:max_depth
|
for _ in 1:max_depth
|
||||||
node = selectChildNode(node)
|
node = selectChildNode(node)
|
||||||
expand(a, node, decisionMaker, progressValueEstimator, n=n)
|
expand(a, node, decisionMaker, progressValueEstimator, isterminal, n=n)
|
||||||
|
|
||||||
# if isterminal (use for loop over node to look for childNode.reward != 0)
|
# if isterminal (use for loop over node to look for childNode.reward != 0)
|
||||||
|
|
||||||
@@ -243,8 +245,9 @@ julia> thoughtDict = Dict(
|
|||||||
|
|
||||||
# Signature
|
# Signature
|
||||||
"""
|
"""
|
||||||
function MCTStransition(a::T1, state::T2,
|
function MCTStransition(a::T1, state::T2, thoughtDict::T3, isterminal::Function
|
||||||
thoughtDict::T3)::Tuple{String, Dict{Symbol, Any}} where {T1<:agent, T2<:AbstractDict, T3<:AbstractDict}
|
)::Tuple{String, Dict{Symbol, <:Any}, Bool, <:Number} where {T1<:agent, T2<:AbstractDict, T3<:AbstractDict}
|
||||||
|
|
||||||
actionname = thoughtDict[:Action][:name]
|
actionname = thoughtDict[:Action][:name]
|
||||||
actioninput = thoughtDict[:Action][:input]
|
actioninput = thoughtDict[:Action][:input]
|
||||||
|
|
||||||
@@ -401,11 +404,12 @@ function runMCTS(
|
|||||||
node = UCTselect(node, w)
|
node = UCTselect(node, w)
|
||||||
end
|
end
|
||||||
|
|
||||||
expand(a, node, decisionMaker, progressValueEstimator, n=n)
|
expand(a, node, decisionMaker, progressValueEstimator, isterminal, n=n)
|
||||||
|
|
||||||
# from paper, just start simulation at this node. Not the node that newly expanded
|
# from paper, just start simulation at this node. Not the node that newly expanded
|
||||||
startsim_node = node
|
startsim_node = node
|
||||||
reward = simulate(a, startsim_node, decisionMaker, progressValueEstimator, maxDepth, n=n)
|
reward = simulate(a, startsim_node, decisionMaker, progressValueEstimator,
|
||||||
|
isterminal, maxDepth, n=n)
|
||||||
backpropagate(leaf_node, reward)
|
backpropagate(leaf_node, reward)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user