diff --git a/src/mcts.jl b/src/mcts.jl index 1eb44f5..37efc01 100644 --- a/src/mcts.jl +++ b/src/mcts.jl @@ -407,6 +407,9 @@ function selectChildNode(node::MCTSNode)::MCTSNode # loop thought node children dictionary to find the highest progress value for (k, childNode) in node.children potential = childNode.progressvalue + childNode.reward + if childNode.reward > 0 #XXX for testing. remove when done. + println("") + end if potential > highestProgressValue highestProgressValue = potential nodekey = childNode.nodekey @@ -485,7 +488,8 @@ function runMCTS( n::Integer, maxDepth::Integer, maxIterations::Integer, - w::Float64) where {T1<:agent} + w::Float64 + ) where {T1<:agent} root = MCTSNode("root", initialState, 0, 0, 0, 0, false, nothing, Dict{String, MCTSNode}()) @@ -510,7 +514,11 @@ function runMCTS( end best_child_state = argmax([child.statevalue / child.visits for child in values(root.children)]) - error("---> runMCTS") + + + + + return best_child_state end diff --git a/test/test_1.jl b/test/test_1.jl index 887703e..62ba30e 100644 --- a/test/test_1.jl +++ b/test/test_1.jl @@ -162,10 +162,10 @@ result = GeneralUtils.sendMqttMsg(outgoingMsg) outgoingMsg = Dict( :msgMeta=> msgMeta, :payload=> Dict( - :text=> "You didn't tell me wine name.", + :text=> "Yes.", :select=> nothing, - :reward=> -1, - :isterminal=> true, + :reward=> 0, + :isterminal=> false, ) ) result = GeneralUtils.sendMqttMsg(outgoingMsg)