This commit is contained in:
2025-03-18 21:23:09 +07:00
parent 693cbfd82d
commit ee5f8a8a52
2 changed files with 92 additions and 84 deletions

View File

@@ -9,7 +9,6 @@ using ..type, ..mcts, ..util
# ---------------------------------------------- 100 --------------------------------------------- #
""" Search the best action to take for a given state and task
# Arguments
@@ -63,14 +62,15 @@ function runMCTS(
explorationweight::Number=1.0,
earlystop::Union{Function,Nothing}=nothing,
saveSimulatedNode::Bool=false,
multithread=false
)::NamedTuple{(:root, :bestNextState, :bestFinalState),Tuple{MCTSNode,T,T}} where {T<:Any}
multithread=false,
)::NamedTuple{(:root, :bestNextState, :bestTerminalState, :highValueStateList),
Tuple{MCTSNode,T,T,Vector{Any}}} where {T<:Any}
root = MCTSNode("root", initialstate, 0, 0, 0, 0, false, nothing, Dict{String,MCTSNode}(),
Dict{Symbol,Any}())
# [WORKING] storage for holding all high reward terminal nodes
highStateValueNode = Channel{Any}(100)
# storage for holding all high reward terminal nodes
highValueState = Channel{Any}(100)
for nth in 1:maxiterations
node = root
@@ -99,7 +99,7 @@ function runMCTS(
horizontalSampleSimulationPhase=horizontalSampleSimulationPhase,
saveSimulatedNode=saveSimulatedNode,
multithread=multithread,
highStateValueNode=highStateValueNode,
highValueState=highValueState,
)
end
else
@@ -109,7 +109,7 @@ function runMCTS(
horizontalSampleSimulationPhase=horizontalSampleSimulationPhase,
saveSimulatedNode=saveSimulatedNode,
multithread=multithread,
highStateValueNode=highStateValueNode)
highValueState=highValueState)
end
end
end
@@ -120,15 +120,23 @@ function runMCTS(
end
end
# select the best next state and the best final state
# select the best next state and the best terminal state along the best trajectory
bestNextState = selectBestNextNode(root)
besttrajectory = selectBestTrajectoryNode(root)
bestTerminalState = selectBestTrajectoryNode(root)
#[WORKING] compare all high value answer then select the best one
# take all high value state from highValueState channel and put it in a list
highValueStateList = Any[]
while !isempty(highValueState)
push!(highValueStateList, take!(highValueState))
end
return (root=root, bestNextState=bestNextState.state, bestFinalState=besttrajectory.state)
return (root=root,
bestNextState=bestNextState.state,
bestTerminalState=bestTerminalState.state,
highValueStateList=highValueStateList)
end
""" Search the best action to take for a given state and task
# Arguments
@@ -156,18 +164,18 @@ function simulateThenBackpropagate(node::MCTSNode, transition::Function, transit
maxSimulationDepth::Integer=3, horizontalSampleSimulationPhase::Integer=3,
saveSimulatedNode::Bool=false,
multithread=false,
highStateValueNode=Union{Nothing,Any}=nothing)
highValueState=Union{Nothing,Any}=nothing)
simTrajectoryReward, terminalstate =
simulate(node, transition, transitionargs;
maxSimulationDepth=maxSimulationDepth,
horizontalSample=horizontalSampleSimulationPhase,
multithread=multithread)
# if a node has state value >= 8, store it in highStateValueNode
if highStateValueNode !== nothing &&
# if a node has state value >= 8, store it in highValueState
if highValueState !== nothing &&
terminalstate !== nothing &&
terminalstate[:reward] >= 8
put!(highStateValueNode, deepcopy(terminalstate))
put!(highValueState, deepcopy(terminalstate))
end
backpropagate(node, simTrajectoryReward)

View File

@@ -1,7 +1,7 @@
module mcts
export selectBestNextNode, selectBestTrajectoryNode, backpropagate, isleaf, isroot, selectChildNode,
expand, simulate, makeNewState
expand, simulate
using Base.Threads
using GeneralUtils
@@ -302,84 +302,84 @@ function simulate(node::MCTSNode, transition::Function, transitionargs::NamedTup
terminalstate=terminalstate)
end
""" Make new state
# """ Make new state
# Arguments
- `currentstate::T1`
Current state dictionary containing thought history and metadata
- `thoughtDict::T4`
Dictionary containing new thought and action
- `response::T2`
Response string from the environment
- `select::Union{T3, Nothing}`
Selection value or nothing
- `reward::T3`
Reward value for this state
- `isterminal::Bool`
Whether this state is terminal
# # Arguments
# - `currentstate::T1`
# Current state dictionary containing thought history and metadata
# - `thoughtDict::T4`
# Dictionary containing new thought and action
# - `response::T2`
# Response string from the environment
# - `select::Union{T3, Nothing}`
# Selection value or nothing
# - `reward::T3`
# Reward value for this state
# - `isterminal::Bool`
# Whether this state is terminal
# Return
- `Tuple{String, Dict{Symbol, <:Any}}`
A tuple containing:
- A unique node key string
- A new state dictionary with updated thought history and metadata
# # Return
# - `Tuple{String, Dict{Symbol, <:Any}}`
# A tuple containing:
# - A unique node key string
# - A new state dictionary with updated thought history and metadata
# Example
```jldoctest
julia>
```
# # Example
# ```jldoctest
# julia>
# ```
# Signature
"""
function makeNewState(currentstate::T1, thoughtDict::T4, response::T2, select::Union{T3, Nothing},
reward::T3, isterminal::Bool
)::Tuple{String, Dict{Symbol, <:Any}} where {T1<:AbstractDict, T2<:AbstractString, T3<:Number, T4<:AbstractDict}
# # Signature
# """
# function makeNewState(currentstate::T1, thoughtDict::T4, response::T2, select::Union{T3, Nothing},
# reward::T3, isterminal::Bool
# )::Tuple{String, Dict{Symbol, <:Any}} where {T1<:AbstractDict, T2<:AbstractString, T3<:Number, T4<:AbstractDict}
# Find the latest thought key and index from current state's thought history
currentstate_latestThoughtKey, currentstate_latestThoughtIndice =
GeneralUtils.findHighestIndexKey(currentstate[:thoughtHistory], "thought")
# Calculate next index for new thought/action
currentstate_nextIndice =
currentstate_latestThoughtKey == :NA ? 1 : currentstate_latestThoughtIndice + 1
# Create new keys for thought and action based on next index
currentstate_latestThoughtKey = Symbol("thought_$currentstate_nextIndice")
latestActionKey = Symbol("action_$currentstate_nextIndice")
# # Find the latest thought key and index from current state's thought history
# currentstate_latestThoughtKey, currentstate_latestThoughtIndice =
# GeneralUtils.findHighestIndexKey(currentstate[:thoughtHistory], "thought")
# # Calculate next index for new thought/action
# currentstate_nextIndice =
# currentstate_latestThoughtKey == :NA ? 1 : currentstate_latestThoughtIndice + 1
# # Create new keys for thought and action based on next index
# currentstate_latestThoughtKey = Symbol("thought_$currentstate_nextIndice")
# latestActionKey = Symbol("action_$currentstate_nextIndice")
# Find the latest thought index from input thought dictionary
_, thoughtDict_latestThoughtIndice =
GeneralUtils.findHighestIndexKey(thoughtDict, "thought")
# # Find the latest thought index from input thought dictionary
# _, thoughtDict_latestThoughtIndice =
# GeneralUtils.findHighestIndexKey(thoughtDict, "thought")
# Determine thought and action keys from thought dictionary
thoughtDict_latestThoughtKey, thoughtDict_latestActionKey =
if thoughtDict_latestThoughtIndice == -1
(:thought, :action)
else
(
Symbol("thought_$thoughtDict_latestThoughtIndice"),
Symbol("action_$thoughtDict_latestThoughtIndice"),
)
end
# # Determine thought and action keys from thought dictionary
# thoughtDict_latestThoughtKey, thoughtDict_latestActionKey =
# if thoughtDict_latestThoughtIndice == -1
# (:thought, :action)
# else
# (
# Symbol("thought_$thoughtDict_latestThoughtIndice"),
# Symbol("action_$thoughtDict_latestThoughtIndice"),
# )
# end
# Create new state by deep copying current state
newstate = deepcopy(currentstate)
# Update thought history with new thought
newstate[:thoughtHistory][currentstate_latestThoughtKey] =
thoughtDict[thoughtDict_latestThoughtKey]
# Update thought history with new action
newstate[:thoughtHistory][latestActionKey] = thoughtDict[thoughtDict_latestActionKey]
# Create and add new observation to thought history
newObservationKey = Symbol("observation_$(currentstate_nextIndice)")
newstate[:thoughtHistory][newObservationKey] = response
# Update state metadata
newstate[:reward] = reward
newstate[:select] = select
newstate[:isterminal] = isterminal
# # Create new state by deep copying current state
# newstate = deepcopy(currentstate)
# # Update thought history with new thought
# newstate[:thoughtHistory][currentstate_latestThoughtKey] =
# thoughtDict[thoughtDict_latestThoughtKey]
# # Update thought history with new action
# newstate[:thoughtHistory][latestActionKey] = thoughtDict[thoughtDict_latestActionKey]
# # Create and add new observation to thought history
# newObservationKey = Symbol("observation_$(currentstate_nextIndice)")
# newstate[:thoughtHistory][newObservationKey] = response
# # Update state metadata
# newstate[:reward] = reward
# newstate[:select] = select
# newstate[:isterminal] = isterminal
# Generate unique ID for new node
newNodeKey = GeneralUtils.uuid4snakecase()
# # Generate unique ID for new node
# newNodeKey = GeneralUtils.uuid4snakecase()
return (newNodeKey, newstate)
end
# return (newNodeKey, newstate)
# end