module mcts export selectBestNextNode, selectBestTrajectoryNode, backpropagate, isleaf, isroot, selectChildNode, expand, simulate, makeNewState using Base.Threads using GeneralUtils using ..type # ---------------------------------------------- 100 --------------------------------------------- # """ # Arguments - `node::MCTSNode` node of a search tree # Return - `childNode::MCTSNode` the highest value child node # Signature """ function selectBestNextNode(node::MCTSNode)::MCTSNode highestProgressValue = -1 nodekey = nothing # if all childnode has statevalue == 0, use progressvalue + reward to select the best node stateValueSum = sum([v.statevalue for (k, v) in node.children]) if stateValueSum != 0 for (k, childnode) in node.children potential = childnode.statevalue / childnode.visits if potential > highestProgressValue highestProgressValue = potential nodekey = childnode.nodekey end end else for (k, childnode) in node.children potential = childnode.progressvalue + childnode.reward if potential > highestProgressValue highestProgressValue = potential nodekey = childnode.nodekey end end end return node.children[nodekey] end """ # Arguments - `node::MCTSNode` node of a search tree # Return - `childNode::MCTSNode` the highest value child node # Signature """ function selectBestTrajectoryNode(node::MCTSNode)::MCTSNode while !isleaf(node) node = selectBestNextNode(node) end return node end """ Backpropagate reward along the simulation chain # Arguments - `node::MCTSNode` leaf node of a search tree - `simTrajectoryReward::T` total reward from trajectory simulation - `discountRewardCoeff::AbstractFloat` A discount reward coefficient to reduce future reward. The futher in the future the lower reward it is now. # Return - `None` # Signature """ function backpropagate(node::MCTSNode, simTrajectoryReward::T; discountRewardCoeff::AbstractFloat=0.9) where {T<:Number} while !isroot(node) # Update the statistics of the current node based on the result of the playout node.visits += 1 node.statevalue += ((node.statevalue * (node.visits-1)) + simTrajectoryReward) / node.visits simTrajectoryReward *= discountRewardCoeff # discount because future reward is uncertain node = node.parent end end """ Determine whether a node is a leaf node of a search tree. # Arguments - `node::MCTSNode` a search tree node # Return - `result::Bool` true if it is a leaf node, false otherwise. # Example ```jldoctest julia> using Revise julia> using YiemAgent, DataStructures julia> initialState = Dict{Symbol, Any}( :customerinfo=> Dict{Symbol, Any}(), :storeinfo=> Dict{Symbol, Any}(), :thoughtHistory=> OrderedDict{Symbol, Any}( :question=> "How are you?", ) ) julia> statetype = typeof(initialState) julia> root = YiemAgent.MCTSNode(initialState, 0, 0.0, Dict{statetype, YiemAgent.MCTSNode}()) julia> YiemAgent.isleaf(root) true ``` # TODO [] update docs # Signature """ isleaf(node::MCTSNode)::Bool = isempty(node.children) """ Determine wheter a given node is a root node # Arguments - `node::MCTSNode` node of a search tree # Return - `isrootnode::Bool` true if the given node is root node, false otherwise # Signature """ isroot(node::MCTSNode)::Bool = node.nodekey == "root" ? true : false """ Select child node based on the highest statevalue # Arguments - `node::MCTSNode` node of a search tree # Return - `childNode::MCTSNode` the highest value child node # Signature """ function selectChildNode(node::MCTSNode)::MCTSNode highestProgressValue = -1 nodekey = nothing # loop thought node children dictionary to find the highest progress value for (k, childNode) in node.children potential = childNode.progressvalue + childNode.reward if potential > highestProgressValue highestProgressValue = potential nodekey = childNode.nodekey end end return node.children[nodekey] end """ Expand selected node. # Arguments - `node::MCTSNode` MCTS node - `transition::Function` A function that handles state transition. - `transitionargs::NamedTuple` Arguments for transition() - `totalsample::Integer` Total number to sample from the current node (i.e. expand new node horizontally) # Return - None # Signature """ # function expand(node::MCTSNode, transition::Function, transitionargs::NamedTuple; # totalsample::Integer=3) # # not use Any[] because I want to preserve result order # results = Vector{Any}(undef, totalsample) # @sync for i in 1:totalsample # @spawn begin # result = transition(deepcopy(node.state), deepcopy(transitionargs)) # results[i] = result # end # end # for result in results # newNodeKey::AbstractString = result[:newNodeKey] # newstate::AbstractDict = result[:newstate] # progressvalue::Integer = result[:progressvalue] # """ # [] newNodeKey ∉ keys(node.children). # New state may have semantic vector close enought to # one of existing child state. Which can be assume that they are the same state # semantically-wise i.e. De javu. This could be used to recall lessons for this # similar situation to improve decisionMaker and evaluator. # """ # if newNodeKey ∉ keys(node.children) # node.children[newNodeKey] = # MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward], # newstate[:isterminal], node, Dict{String, MCTSNode}()) # end # end # end function expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple; horizontalSample::Integer=3) nthSample = 0 listOfNewNodeId = [] while true nthSample += 1 if nthSample <= horizontalSample result = transition(node.state, transitionargs) newNodeKey::AbstractString = result[:newNodeKey] newstate::AbstractDict = result[:newstate] progressvalue::Integer = result[:progressvalue] """ [] newNodeKey ∉ keys(node.children). New state may have semantic vector close enought to one of existing child state. Which can be assume that they are the same state semantically-wise i.e. De javu. This could be used to recall lessons for this similar situation to improve decisionMaker and evaluator. """ if newNodeKey ∉ keys(node.children) push!(listOfNewNodeId, newNodeKey) newNode = MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward], newstate[:isterminal], node, Dict{String, MCTSNode}(), Dict{Symbol, Any}()) node.children[newNodeKey] = newNode end else return listOfNewNodeId end end end """ Simulate interactions between agent and environment # Arguments - `node::MCTSNode` node that will be a simulation starting point. - `transition::Function` A user function that handles how state transition. - `transitionargs::NamedTuple` Arguments for everything the user will use within transition(). - `maxdepth::Integer` maximum depth level MCTS goes vertically. - horizontalSample::Integer Total number to sample from the current node (i.e. expand new node horizontally) # Return - `::NamedTuple{(:simTrajectoryReward, :terminalstate), Tuple{Number, Union{Dict{Symbol, Any}, Nothing}}}` # Signature """ function simulate(outputchannel::Channel, node::MCTSNode, transition::Function, transitionargs::NamedTuple; maxdepth::Integer=3, horizontalSample::Integer=3 )::NamedTuple{(:simTrajectoryReward, :terminalstate), Tuple{Number, Union{Dict{Symbol, Any}, Nothing}}} simTrajectoryReward = 0.0 terminalstate = nothing # listOfSimulatedNodeId = [] for depth in 1:maxdepth simTrajectoryReward += node.reward if node.isterminal terminalstate = node.state break else _ = expand(node, transition, transitionargs; horizontalSample=horizontalSample) node = selectChildNode(node) end end put!(outputchannel, (simTrajectoryReward=simTrajectoryReward, terminalstate=terminalstate)) # return (simTrajectoryReward=simTrajectoryReward, terminalstate=terminalstate) end """ # Arguments # Return # Example ```jldoctest julia> ``` # TODO - [] update docstring - [x] implement the function # Signature """ function makeNewState(currentstate::T1, thoughtDict::T4, response::T2, select::Union{T3, Nothing}, reward::T3, isterminal::Bool )::Tuple{String, Dict{Symbol, <:Any}} where {T1<:AbstractDict, T2<:AbstractString, T3<:Number, T4<:AbstractDict} currentstate_latestThoughtKey, currentstate_latestThoughtIndice = GeneralUtils.findHighestIndexKey(currentstate[:thoughtHistory], "thought") currentstate_nextIndice = currentstate_latestThoughtKey == :NA ? 1 : currentstate_latestThoughtIndice + 1 currentstate_latestThoughtKey = Symbol("thought_$currentstate_nextIndice") latestActionKey = Symbol("action_$currentstate_nextIndice") _, thoughtDict_latestThoughtIndice = GeneralUtils.findHighestIndexKey(thoughtDict, "thought") thoughtDict_latestThoughtKey, thoughtDict_latestActionKey = if thoughtDict_latestThoughtIndice == -1 (:thought, :action) else ( Symbol("thought_$thoughtDict_latestThoughtIndice"), Symbol("action_$thoughtDict_latestThoughtIndice"), ) end # add Thought, action, observation to thoughtHistory newstate = deepcopy(currentstate) newstate[:thoughtHistory][currentstate_latestThoughtKey] = thoughtDict[thoughtDict_latestThoughtKey] newstate[:thoughtHistory][latestActionKey] = thoughtDict[thoughtDict_latestActionKey] newObservationKey = Symbol("observation_$(currentstate_nextIndice)") newstate[:thoughtHistory][newObservationKey] = response newstate[:reward] = reward newstate[:select] = select newstate[:isterminal] = isterminal newNodeKey = GeneralUtils.uuid4snakecase() return (newNodeKey, newstate) end end # module mcts