diff --git a/src/interface.jl b/src/interface.jl index 89dadac..e47db12 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -331,27 +331,35 @@ end a game state # Return - - `(isterminal, reward)::Tuple{Bool, Number}` + - `(isterminalstate, reward)::Tuple{Bool, <:Number}` # Example ```jldoctest julia> ``` -# TODO - - [x] update docstring - - [TESTING] implement the function - # Signature """ -function isterminal(state::T)::Tuple{Bool, Number} where {T<:AbstractDict} +function isterminal(state::T)::Tuple{Bool, <:Number} where {T<:AbstractDict} latestObservationKey, _ = GeneralUtils.findHighestIndexKey(state[:thoughtHistory], "Observation") latestObservation = state[:thoughtHistory][latestObservationKey] - # terminal condition is when the user select wine by putting <> in latest observation - if occursin("<<", latestObservation) && occursin(">>", latestObservation) - return true, 1 + if latestObservation !== nothing + + # terminal condition is when the user select wine by putting <> in latest observation + if occursin("<<", latestObservation) && occursin(">>", latestObservation) + isterminalstate = true + reward = 1 + else + isterminalstate = false + reward = 0 + end + else + isterminalstate = false + reward = 0 end + + return (isterminalstate, reward) end diff --git a/src/mcts.jl b/src/mcts.jl index d840e33..683ef00 100644 --- a/src/mcts.jl +++ b/src/mcts.jl @@ -122,20 +122,21 @@ julia> # Signature """ function expand(a::T1, node::MCTSNode, decisionMaker::Function, - progressValueEstimator::Function; n::Integer=3) where {T1<:agent} + progressValueEstimator::Function, isterminal::Function; n::Integer=3) where {T1<:agent} # sampling action from decisionMaker for sample in 1:n thoughtDict = decisionMaker(a, node.state) - newNodeKey, newstate, isterminal, reward = MCTStransition(a, node.state, thoughtDict) + newNodeKey, newstate, isterminalstate, reward = MCTStransition(a, node.state, thoughtDict, + isterminal) # add progressValueEstimator progressRationale, progressValue = progressValueEstimator(a, newstate) if newNodeKey ∉ keys(node.children) node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, progressValue, - reward, isterminal, node, Dict{String, MCTSNode}()) + reward, isterminalstate, node, Dict{String, MCTSNode}()) end end end @@ -158,11 +159,12 @@ julia> # Signature """ -function simulate(a, node::MCTSNode, decisionMaker, progressValueEstimator, max_depth::Int; n=3) +function simulate(a, node::MCTSNode, decisionMaker::Function, progressValueEstimator::Function, + isterminal::Function, max_depth::Int; n=3) for _ in 1:max_depth node = selectChildNode(node) - expand(a, node, decisionMaker, progressValueEstimator, n=n) + expand(a, node, decisionMaker, progressValueEstimator, isterminal, n=n) # if isterminal (use for loop over node to look for childNode.reward != 0) @@ -243,8 +245,9 @@ julia> thoughtDict = Dict( # Signature """ -function MCTStransition(a::T1, state::T2, - thoughtDict::T3)::Tuple{String, Dict{Symbol, Any}} where {T1<:agent, T2<:AbstractDict, T3<:AbstractDict} +function MCTStransition(a::T1, state::T2, thoughtDict::T3, isterminal::Function + )::Tuple{String, Dict{Symbol, <:Any}, Bool, <:Number} where {T1<:agent, T2<:AbstractDict, T3<:AbstractDict} + actionname = thoughtDict[:Action][:name] actioninput = thoughtDict[:Action][:input] @@ -401,11 +404,12 @@ function runMCTS( node = UCTselect(node, w) end - expand(a, node, decisionMaker, progressValueEstimator, n=n) + expand(a, node, decisionMaker, progressValueEstimator, isterminal, n=n) # from paper, just start simulation at this node. Not the node that newly expanded startsim_node = node - reward = simulate(a, startsim_node, decisionMaker, progressValueEstimator, maxDepth, n=n) + reward = simulate(a, startsim_node, decisionMaker, progressValueEstimator, + isterminal, maxDepth, n=n) backpropagate(leaf_node, reward) end