This commit is contained in:
2025-03-14 21:57:59 +07:00
parent 7e160f2031
commit 2eff443f70

View File

@@ -211,7 +211,6 @@ function expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple;
end end
end end
function _expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple) function _expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple)
result = transition(node.state, transitionargs) result = transition(node.state, transitionargs)
newNodeKey::AbstractString = result[:newNodeKey] newNodeKey::AbstractString = result[:newNodeKey]
@@ -244,17 +243,24 @@ end
Arguments for everything the user will use within transition(). Arguments for everything the user will use within transition().
- `maxSimulationDepth::Integer` - `maxSimulationDepth::Integer`
maximum depth level MCTS goes vertically during simulation. maximum depth level MCTS goes vertically during simulation.
- horizontalSample::Integer - `horizontalSample::Integer`
Total number to sample from the current node (i.e. expand new node horizontally) Total number to sample from the current node (i.e. expand new node horizontally)
# Keyword Arguments
- `multithread::Bool`
Whether to run expansion in parallel using multiple threads. Defaults to false.
# Return # Return
- `::NamedTuple{(:simTrajectoryReward, :terminalstate), Tuple{Number, Union{Dict{Symbol, Any}, Nothing}}}` - `simTrajectoryReward::Number`
Cumulative reward collected along the simulation trajectory
- `terminalstate::Union{Dict{Symbol, Any}, Nothing}`
Final state if terminal state reached, nothing otherwise
# Signature # Signature
""" """
function simulate(node::MCTSNode, transition::Function, transitionargs::NamedTuple; function simulate(node::MCTSNode, transition::Function, transitionargs::NamedTuple;
maxSimulationDepth::Integer=3, horizontalSample::Integer=3, multithread=false) maxSimulationDepth::Integer=3, horizontalSample::Integer=3, multithread=false
# )::NamedTuple{(:simTrajectoryReward, :terminalstate), Tuple{Number, Union{Dict{Symbol, Any}, Nothing}}} )::NamedTuple{(:simTrajectoryReward, :terminalstate), Tuple{<:Number, Union{Dict{Symbol, Any}, Nothing}}}
simTrajectoryReward = 0.0 simTrajectoryReward = 0.0
terminalstate = nothing terminalstate = nothing
@@ -275,38 +281,54 @@ function simulate(node::MCTSNode, transition::Function, transitionargs::NamedTup
return (simTrajectoryReward=simTrajectoryReward, terminalstate=terminalstate) return (simTrajectoryReward=simTrajectoryReward, terminalstate=terminalstate)
end end
""" Make new state
"""
# Arguments # Arguments
- `currentstate::T1`
Current state dictionary containing thought history and metadata
- `thoughtDict::T4`
Dictionary containing new thought and action
- `response::T2`
Response string from the environment
- `select::Union{T3, Nothing}`
Selection value or nothing
- `reward::T3`
Reward value for this state
- `isterminal::Bool`
Whether this state is terminal
# Return # Return
- `Tuple{String, Dict{Symbol, <:Any}}`
A tuple containing:
- A unique node key string
- A new state dictionary with updated thought history and metadata
# Example # Example
```jldoctest ```jldoctest
julia> julia>
``` ```
# TODO
- [] update docstring
- [x] implement the function
# Signature # Signature
""" """
function makeNewState(currentstate::T1, thoughtDict::T4, response::T2, select::Union{T3, Nothing}, function makeNewState(currentstate::T1, thoughtDict::T4, response::T2, select::Union{T3, Nothing},
reward::T3, isterminal::Bool reward::T3, isterminal::Bool
)::Tuple{String, Dict{Symbol, <:Any}} where {T1<:AbstractDict, T2<:AbstractString, T3<:Number, T4<:AbstractDict} )::Tuple{String, Dict{Symbol, <:Any}} where {T1<:AbstractDict, T2<:AbstractString, T3<:Number, T4<:AbstractDict}
# Find the latest thought key and index from current state's thought history
currentstate_latestThoughtKey, currentstate_latestThoughtIndice = currentstate_latestThoughtKey, currentstate_latestThoughtIndice =
GeneralUtils.findHighestIndexKey(currentstate[:thoughtHistory], "thought") GeneralUtils.findHighestIndexKey(currentstate[:thoughtHistory], "thought")
# Calculate next index for new thought/action
currentstate_nextIndice = currentstate_nextIndice =
currentstate_latestThoughtKey == :NA ? 1 : currentstate_latestThoughtIndice + 1 currentstate_latestThoughtKey == :NA ? 1 : currentstate_latestThoughtIndice + 1
# Create new keys for thought and action based on next index
currentstate_latestThoughtKey = Symbol("thought_$currentstate_nextIndice") currentstate_latestThoughtKey = Symbol("thought_$currentstate_nextIndice")
latestActionKey = Symbol("action_$currentstate_nextIndice") latestActionKey = Symbol("action_$currentstate_nextIndice")
# Find the latest thought index from input thought dictionary
_, thoughtDict_latestThoughtIndice = _, thoughtDict_latestThoughtIndice =
GeneralUtils.findHighestIndexKey(thoughtDict, "thought") GeneralUtils.findHighestIndexKey(thoughtDict, "thought")
# Determine thought and action keys from thought dictionary
thoughtDict_latestThoughtKey, thoughtDict_latestActionKey = thoughtDict_latestThoughtKey, thoughtDict_latestActionKey =
if thoughtDict_latestThoughtIndice == -1 if thoughtDict_latestThoughtIndice == -1
(:thought, :action) (:thought, :action)
@@ -317,17 +339,22 @@ function makeNewState(currentstate::T1, thoughtDict::T4, response::T2, select::U
) )
end end
# add Thought, action, observation to thoughtHistory # Create new state by deep copying current state
newstate = deepcopy(currentstate) newstate = deepcopy(currentstate)
# Update thought history with new thought
newstate[:thoughtHistory][currentstate_latestThoughtKey] = newstate[:thoughtHistory][currentstate_latestThoughtKey] =
thoughtDict[thoughtDict_latestThoughtKey] thoughtDict[thoughtDict_latestThoughtKey]
# Update thought history with new action
newstate[:thoughtHistory][latestActionKey] = thoughtDict[thoughtDict_latestActionKey] newstate[:thoughtHistory][latestActionKey] = thoughtDict[thoughtDict_latestActionKey]
# Create and add new observation to thought history
newObservationKey = Symbol("observation_$(currentstate_nextIndice)") newObservationKey = Symbol("observation_$(currentstate_nextIndice)")
newstate[:thoughtHistory][newObservationKey] = response newstate[:thoughtHistory][newObservationKey] = response
# Update state metadata
newstate[:reward] = reward newstate[:reward] = reward
newstate[:select] = select newstate[:select] = select
newstate[:isterminal] = isterminal newstate[:isterminal] = isterminal
# Generate unique ID for new node
newNodeKey = GeneralUtils.uuid4snakecase() newNodeKey = GeneralUtils.uuid4snakecase()
return (newNodeKey, newstate) return (newNodeKey, newstate)
@@ -394,8 +421,6 @@ end