From ee5f8a8a52a3080ab4443e8f63d46b2702363567 Mon Sep 17 00:00:00 2001 From: tonaerospace Date: Tue, 18 Mar 2025 21:23:09 +0700 Subject: [PATCH] update --- src/interface.jl | 38 +++++++------ src/mcts.jl | 138 +++++++++++++++++++++++------------------------ 2 files changed, 92 insertions(+), 84 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 8251df4..db8970a 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -9,7 +9,6 @@ using ..type, ..mcts, ..util # ---------------------------------------------- 100 --------------------------------------------- # - """ Search the best action to take for a given state and task # Arguments @@ -63,14 +62,15 @@ function runMCTS( explorationweight::Number=1.0, earlystop::Union{Function,Nothing}=nothing, saveSimulatedNode::Bool=false, - multithread=false - )::NamedTuple{(:root, :bestNextState, :bestFinalState),Tuple{MCTSNode,T,T}} where {T<:Any} + multithread=false, + )::NamedTuple{(:root, :bestNextState, :bestTerminalState, :highValueStateList), + Tuple{MCTSNode,T,T,Vector{Any}}} where {T<:Any} root = MCTSNode("root", initialstate, 0, 0, 0, 0, false, nothing, Dict{String,MCTSNode}(), Dict{Symbol,Any}()) - # [WORKING] storage for holding all high reward terminal nodes - highStateValueNode = Channel{Any}(100) + # storage for holding all high reward terminal nodes + highValueState = Channel{Any}(100) for nth in 1:maxiterations node = root @@ -99,7 +99,7 @@ function runMCTS( horizontalSampleSimulationPhase=horizontalSampleSimulationPhase, saveSimulatedNode=saveSimulatedNode, multithread=multithread, - highStateValueNode=highStateValueNode, + highValueState=highValueState, ) end else @@ -109,7 +109,7 @@ function runMCTS( horizontalSampleSimulationPhase=horizontalSampleSimulationPhase, saveSimulatedNode=saveSimulatedNode, multithread=multithread, - highStateValueNode=highStateValueNode) + highValueState=highValueState) end end end @@ -120,15 +120,23 @@ function runMCTS( end end - # select the best next state and the best final state + # select the best next state and the best terminal state along the best trajectory bestNextState = selectBestNextNode(root) - besttrajectory = selectBestTrajectoryNode(root) + bestTerminalState = selectBestTrajectoryNode(root) - #[WORKING] compare all high value answer then select the best one + # take all high value state from highValueState channel and put it in a list + highValueStateList = Any[] + while !isempty(highValueState) + push!(highValueStateList, take!(highValueState)) + end - return (root=root, bestNextState=bestNextState.state, bestFinalState=besttrajectory.state) + return (root=root, + bestNextState=bestNextState.state, + bestTerminalState=bestTerminalState.state, + highValueStateList=highValueStateList) end + """ Search the best action to take for a given state and task # Arguments @@ -156,18 +164,18 @@ function simulateThenBackpropagate(node::MCTSNode, transition::Function, transit maxSimulationDepth::Integer=3, horizontalSampleSimulationPhase::Integer=3, saveSimulatedNode::Bool=false, multithread=false, - highStateValueNode=Union{Nothing,Any}=nothing) + highValueState=Union{Nothing,Any}=nothing) simTrajectoryReward, terminalstate = simulate(node, transition, transitionargs; maxSimulationDepth=maxSimulationDepth, horizontalSample=horizontalSampleSimulationPhase, multithread=multithread) - # if a node has state value >= 8, store it in highStateValueNode - if highStateValueNode !== nothing && + # if a node has state value >= 8, store it in highValueState + if highValueState !== nothing && terminalstate !== nothing && terminalstate[:reward] >= 8 - put!(highStateValueNode, deepcopy(terminalstate)) + put!(highValueState, deepcopy(terminalstate)) end backpropagate(node, simTrajectoryReward) diff --git a/src/mcts.jl b/src/mcts.jl index 59b50e5..ce7171f 100644 --- a/src/mcts.jl +++ b/src/mcts.jl @@ -1,7 +1,7 @@ module mcts export selectBestNextNode, selectBestTrajectoryNode, backpropagate, isleaf, isroot, selectChildNode, - expand, simulate, makeNewState + expand, simulate using Base.Threads using GeneralUtils @@ -302,84 +302,84 @@ function simulate(node::MCTSNode, transition::Function, transitionargs::NamedTup terminalstate=terminalstate) end -""" Make new state +# """ Make new state -# Arguments - - `currentstate::T1` - Current state dictionary containing thought history and metadata - - `thoughtDict::T4` - Dictionary containing new thought and action - - `response::T2` - Response string from the environment - - `select::Union{T3, Nothing}` - Selection value or nothing - - `reward::T3` - Reward value for this state - - `isterminal::Bool` - Whether this state is terminal +# # Arguments +# - `currentstate::T1` +# Current state dictionary containing thought history and metadata +# - `thoughtDict::T4` +# Dictionary containing new thought and action +# - `response::T2` +# Response string from the environment +# - `select::Union{T3, Nothing}` +# Selection value or nothing +# - `reward::T3` +# Reward value for this state +# - `isterminal::Bool` +# Whether this state is terminal -# Return - - `Tuple{String, Dict{Symbol, <:Any}}` - A tuple containing: - - A unique node key string - - A new state dictionary with updated thought history and metadata +# # Return +# - `Tuple{String, Dict{Symbol, <:Any}}` +# A tuple containing: +# - A unique node key string +# - A new state dictionary with updated thought history and metadata -# Example -```jldoctest -julia> -``` +# # Example +# ```jldoctest +# julia> +# ``` -# Signature -""" -function makeNewState(currentstate::T1, thoughtDict::T4, response::T2, select::Union{T3, Nothing}, - reward::T3, isterminal::Bool - )::Tuple{String, Dict{Symbol, <:Any}} where {T1<:AbstractDict, T2<:AbstractString, T3<:Number, T4<:AbstractDict} +# # Signature +# """ +# function makeNewState(currentstate::T1, thoughtDict::T4, response::T2, select::Union{T3, Nothing}, +# reward::T3, isterminal::Bool +# )::Tuple{String, Dict{Symbol, <:Any}} where {T1<:AbstractDict, T2<:AbstractString, T3<:Number, T4<:AbstractDict} - # Find the latest thought key and index from current state's thought history - currentstate_latestThoughtKey, currentstate_latestThoughtIndice = - GeneralUtils.findHighestIndexKey(currentstate[:thoughtHistory], "thought") - # Calculate next index for new thought/action - currentstate_nextIndice = - currentstate_latestThoughtKey == :NA ? 1 : currentstate_latestThoughtIndice + 1 - # Create new keys for thought and action based on next index - currentstate_latestThoughtKey = Symbol("thought_$currentstate_nextIndice") - latestActionKey = Symbol("action_$currentstate_nextIndice") +# # Find the latest thought key and index from current state's thought history +# currentstate_latestThoughtKey, currentstate_latestThoughtIndice = +# GeneralUtils.findHighestIndexKey(currentstate[:thoughtHistory], "thought") +# # Calculate next index for new thought/action +# currentstate_nextIndice = +# currentstate_latestThoughtKey == :NA ? 1 : currentstate_latestThoughtIndice + 1 +# # Create new keys for thought and action based on next index +# currentstate_latestThoughtKey = Symbol("thought_$currentstate_nextIndice") +# latestActionKey = Symbol("action_$currentstate_nextIndice") - # Find the latest thought index from input thought dictionary - _, thoughtDict_latestThoughtIndice = - GeneralUtils.findHighestIndexKey(thoughtDict, "thought") +# # Find the latest thought index from input thought dictionary +# _, thoughtDict_latestThoughtIndice = +# GeneralUtils.findHighestIndexKey(thoughtDict, "thought") - # Determine thought and action keys from thought dictionary - thoughtDict_latestThoughtKey, thoughtDict_latestActionKey = - if thoughtDict_latestThoughtIndice == -1 - (:thought, :action) - else - ( - Symbol("thought_$thoughtDict_latestThoughtIndice"), - Symbol("action_$thoughtDict_latestThoughtIndice"), - ) - end +# # Determine thought and action keys from thought dictionary +# thoughtDict_latestThoughtKey, thoughtDict_latestActionKey = +# if thoughtDict_latestThoughtIndice == -1 +# (:thought, :action) +# else +# ( +# Symbol("thought_$thoughtDict_latestThoughtIndice"), +# Symbol("action_$thoughtDict_latestThoughtIndice"), +# ) +# end - # Create new state by deep copying current state - newstate = deepcopy(currentstate) - # Update thought history with new thought - newstate[:thoughtHistory][currentstate_latestThoughtKey] = - thoughtDict[thoughtDict_latestThoughtKey] - # Update thought history with new action - newstate[:thoughtHistory][latestActionKey] = thoughtDict[thoughtDict_latestActionKey] - # Create and add new observation to thought history - newObservationKey = Symbol("observation_$(currentstate_nextIndice)") - newstate[:thoughtHistory][newObservationKey] = response - # Update state metadata - newstate[:reward] = reward - newstate[:select] = select - newstate[:isterminal] = isterminal +# # Create new state by deep copying current state +# newstate = deepcopy(currentstate) +# # Update thought history with new thought +# newstate[:thoughtHistory][currentstate_latestThoughtKey] = +# thoughtDict[thoughtDict_latestThoughtKey] +# # Update thought history with new action +# newstate[:thoughtHistory][latestActionKey] = thoughtDict[thoughtDict_latestActionKey] +# # Create and add new observation to thought history +# newObservationKey = Symbol("observation_$(currentstate_nextIndice)") +# newstate[:thoughtHistory][newObservationKey] = response +# # Update state metadata +# newstate[:reward] = reward +# newstate[:select] = select +# newstate[:isterminal] = isterminal - # Generate unique ID for new node - newNodeKey = GeneralUtils.uuid4snakecase() +# # Generate unique ID for new node +# newNodeKey = GeneralUtils.uuid4snakecase() - return (newNodeKey, newstate) -end +# return (newNodeKey, newstate) +# end