From 2eff443f70e28968c00fce585fe9fc7831c69b3e Mon Sep 17 00:00:00 2001 From: tonaerospace Date: Fri, 14 Mar 2025 21:57:59 +0700 Subject: [PATCH] update --- src/mcts.jl | 55 ++++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 40 insertions(+), 15 deletions(-) diff --git a/src/mcts.jl b/src/mcts.jl index 0df0c1b..81839ab 100644 --- a/src/mcts.jl +++ b/src/mcts.jl @@ -211,7 +211,6 @@ function expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple; end end - function _expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple) result = transition(node.state, transitionargs) newNodeKey::AbstractString = result[:newNodeKey] @@ -244,17 +243,24 @@ end Arguments for everything the user will use within transition(). - `maxSimulationDepth::Integer` maximum depth level MCTS goes vertically during simulation. - - horizontalSample::Integer + - `horizontalSample::Integer` Total number to sample from the current node (i.e. expand new node horizontally) + +# Keyword Arguments + - `multithread::Bool` + Whether to run expansion in parallel using multiple threads. Defaults to false. # Return - - `::NamedTuple{(:simTrajectoryReward, :terminalstate), Tuple{Number, Union{Dict{Symbol, Any}, Nothing}}}` + - `simTrajectoryReward::Number` + Cumulative reward collected along the simulation trajectory + - `terminalstate::Union{Dict{Symbol, Any}, Nothing}` + Final state if terminal state reached, nothing otherwise # Signature """ function simulate(node::MCTSNode, transition::Function, transitionargs::NamedTuple; - maxSimulationDepth::Integer=3, horizontalSample::Integer=3, multithread=false) -# )::NamedTuple{(:simTrajectoryReward, :terminalstate), Tuple{Number, Union{Dict{Symbol, Any}, Nothing}}} + maxSimulationDepth::Integer=3, horizontalSample::Integer=3, multithread=false +)::NamedTuple{(:simTrajectoryReward, :terminalstate), Tuple{<:Number, Union{Dict{Symbol, Any}, Nothing}}} simTrajectoryReward = 0.0 terminalstate = nothing @@ -275,38 +281,54 @@ function simulate(node::MCTSNode, transition::Function, transitionargs::NamedTup return (simTrajectoryReward=simTrajectoryReward, terminalstate=terminalstate) end - -""" +""" Make new state # Arguments - + - `currentstate::T1` + Current state dictionary containing thought history and metadata + - `thoughtDict::T4` + Dictionary containing new thought and action + - `response::T2` + Response string from the environment + - `select::Union{T3, Nothing}` + Selection value or nothing + - `reward::T3` + Reward value for this state + - `isterminal::Bool` + Whether this state is terminal + # Return + - `Tuple{String, Dict{Symbol, <:Any}}` + A tuple containing: + - A unique node key string + - A new state dictionary with updated thought history and metadata # Example ```jldoctest julia> ``` -# TODO - - [] update docstring - - [x] implement the function - # Signature """ function makeNewState(currentstate::T1, thoughtDict::T4, response::T2, select::Union{T3, Nothing}, reward::T3, isterminal::Bool )::Tuple{String, Dict{Symbol, <:Any}} where {T1<:AbstractDict, T2<:AbstractString, T3<:Number, T4<:AbstractDict} + # Find the latest thought key and index from current state's thought history currentstate_latestThoughtKey, currentstate_latestThoughtIndice = GeneralUtils.findHighestIndexKey(currentstate[:thoughtHistory], "thought") + # Calculate next index for new thought/action currentstate_nextIndice = currentstate_latestThoughtKey == :NA ? 1 : currentstate_latestThoughtIndice + 1 + # Create new keys for thought and action based on next index currentstate_latestThoughtKey = Symbol("thought_$currentstate_nextIndice") latestActionKey = Symbol("action_$currentstate_nextIndice") + # Find the latest thought index from input thought dictionary _, thoughtDict_latestThoughtIndice = GeneralUtils.findHighestIndexKey(thoughtDict, "thought") + # Determine thought and action keys from thought dictionary thoughtDict_latestThoughtKey, thoughtDict_latestActionKey = if thoughtDict_latestThoughtIndice == -1 (:thought, :action) @@ -317,17 +339,22 @@ function makeNewState(currentstate::T1, thoughtDict::T4, response::T2, select::U ) end - # add Thought, action, observation to thoughtHistory + # Create new state by deep copying current state newstate = deepcopy(currentstate) + # Update thought history with new thought newstate[:thoughtHistory][currentstate_latestThoughtKey] = thoughtDict[thoughtDict_latestThoughtKey] + # Update thought history with new action newstate[:thoughtHistory][latestActionKey] = thoughtDict[thoughtDict_latestActionKey] + # Create and add new observation to thought history newObservationKey = Symbol("observation_$(currentstate_nextIndice)") newstate[:thoughtHistory][newObservationKey] = response + # Update state metadata newstate[:reward] = reward newstate[:select] = select newstate[:isterminal] = isterminal + # Generate unique ID for new node newNodeKey = GeneralUtils.uuid4snakecase() return (newNodeKey, newstate) @@ -394,8 +421,6 @@ end - -