This commit is contained in:
2025-03-18 21:23:09 +07:00
parent 693cbfd82d
commit ee5f8a8a52
2 changed files with 92 additions and 84 deletions

View File

@@ -9,7 +9,6 @@ using ..type, ..mcts, ..util
# ---------------------------------------------- 100 --------------------------------------------- # # ---------------------------------------------- 100 --------------------------------------------- #
""" Search the best action to take for a given state and task """ Search the best action to take for a given state and task
# Arguments # Arguments
@@ -63,14 +62,15 @@ function runMCTS(
explorationweight::Number=1.0, explorationweight::Number=1.0,
earlystop::Union{Function,Nothing}=nothing, earlystop::Union{Function,Nothing}=nothing,
saveSimulatedNode::Bool=false, saveSimulatedNode::Bool=false,
multithread=false multithread=false,
)::NamedTuple{(:root, :bestNextState, :bestFinalState),Tuple{MCTSNode,T,T}} where {T<:Any} )::NamedTuple{(:root, :bestNextState, :bestTerminalState, :highValueStateList),
Tuple{MCTSNode,T,T,Vector{Any}}} where {T<:Any}
root = MCTSNode("root", initialstate, 0, 0, 0, 0, false, nothing, Dict{String,MCTSNode}(), root = MCTSNode("root", initialstate, 0, 0, 0, 0, false, nothing, Dict{String,MCTSNode}(),
Dict{Symbol,Any}()) Dict{Symbol,Any}())
# [WORKING] storage for holding all high reward terminal nodes # storage for holding all high reward terminal nodes
highStateValueNode = Channel{Any}(100) highValueState = Channel{Any}(100)
for nth in 1:maxiterations for nth in 1:maxiterations
node = root node = root
@@ -99,7 +99,7 @@ function runMCTS(
horizontalSampleSimulationPhase=horizontalSampleSimulationPhase, horizontalSampleSimulationPhase=horizontalSampleSimulationPhase,
saveSimulatedNode=saveSimulatedNode, saveSimulatedNode=saveSimulatedNode,
multithread=multithread, multithread=multithread,
highStateValueNode=highStateValueNode, highValueState=highValueState,
) )
end end
else else
@@ -109,7 +109,7 @@ function runMCTS(
horizontalSampleSimulationPhase=horizontalSampleSimulationPhase, horizontalSampleSimulationPhase=horizontalSampleSimulationPhase,
saveSimulatedNode=saveSimulatedNode, saveSimulatedNode=saveSimulatedNode,
multithread=multithread, multithread=multithread,
highStateValueNode=highStateValueNode) highValueState=highValueState)
end end
end end
end end
@@ -120,15 +120,23 @@ function runMCTS(
end end
end end
# select the best next state and the best final state # select the best next state and the best terminal state along the best trajectory
bestNextState = selectBestNextNode(root) bestNextState = selectBestNextNode(root)
besttrajectory = selectBestTrajectoryNode(root) bestTerminalState = selectBestTrajectoryNode(root)
#[WORKING] compare all high value answer then select the best one # take all high value state from highValueState channel and put it in a list
highValueStateList = Any[]
return (root=root, bestNextState=bestNextState.state, bestFinalState=besttrajectory.state) while !isempty(highValueState)
push!(highValueStateList, take!(highValueState))
end end
return (root=root,
bestNextState=bestNextState.state,
bestTerminalState=bestTerminalState.state,
highValueStateList=highValueStateList)
end
""" Search the best action to take for a given state and task """ Search the best action to take for a given state and task
# Arguments # Arguments
@@ -156,18 +164,18 @@ function simulateThenBackpropagate(node::MCTSNode, transition::Function, transit
maxSimulationDepth::Integer=3, horizontalSampleSimulationPhase::Integer=3, maxSimulationDepth::Integer=3, horizontalSampleSimulationPhase::Integer=3,
saveSimulatedNode::Bool=false, saveSimulatedNode::Bool=false,
multithread=false, multithread=false,
highStateValueNode=Union{Nothing,Any}=nothing) highValueState=Union{Nothing,Any}=nothing)
simTrajectoryReward, terminalstate = simTrajectoryReward, terminalstate =
simulate(node, transition, transitionargs; simulate(node, transition, transitionargs;
maxSimulationDepth=maxSimulationDepth, maxSimulationDepth=maxSimulationDepth,
horizontalSample=horizontalSampleSimulationPhase, horizontalSample=horizontalSampleSimulationPhase,
multithread=multithread) multithread=multithread)
# if a node has state value >= 8, store it in highStateValueNode # if a node has state value >= 8, store it in highValueState
if highStateValueNode !== nothing && if highValueState !== nothing &&
terminalstate !== nothing && terminalstate !== nothing &&
terminalstate[:reward] >= 8 terminalstate[:reward] >= 8
put!(highStateValueNode, deepcopy(terminalstate)) put!(highValueState, deepcopy(terminalstate))
end end
backpropagate(node, simTrajectoryReward) backpropagate(node, simTrajectoryReward)

View File

@@ -1,7 +1,7 @@
module mcts module mcts
export selectBestNextNode, selectBestTrajectoryNode, backpropagate, isleaf, isroot, selectChildNode, export selectBestNextNode, selectBestTrajectoryNode, backpropagate, isleaf, isroot, selectChildNode,
expand, simulate, makeNewState expand, simulate
using Base.Threads using Base.Threads
using GeneralUtils using GeneralUtils
@@ -302,84 +302,84 @@ function simulate(node::MCTSNode, transition::Function, transitionargs::NamedTup
terminalstate=terminalstate) terminalstate=terminalstate)
end end
""" Make new state # """ Make new state
# Arguments # # Arguments
- `currentstate::T1` # - `currentstate::T1`
Current state dictionary containing thought history and metadata # Current state dictionary containing thought history and metadata
- `thoughtDict::T4` # - `thoughtDict::T4`
Dictionary containing new thought and action # Dictionary containing new thought and action
- `response::T2` # - `response::T2`
Response string from the environment # Response string from the environment
- `select::Union{T3, Nothing}` # - `select::Union{T3, Nothing}`
Selection value or nothing # Selection value or nothing
- `reward::T3` # - `reward::T3`
Reward value for this state # Reward value for this state
- `isterminal::Bool` # - `isterminal::Bool`
Whether this state is terminal # Whether this state is terminal
# Return # # Return
- `Tuple{String, Dict{Symbol, <:Any}}` # - `Tuple{String, Dict{Symbol, <:Any}}`
A tuple containing: # A tuple containing:
- A unique node key string # - A unique node key string
- A new state dictionary with updated thought history and metadata # - A new state dictionary with updated thought history and metadata
# Example # # Example
```jldoctest # ```jldoctest
julia> # julia>
``` # ```
# Signature # # Signature
""" # """
function makeNewState(currentstate::T1, thoughtDict::T4, response::T2, select::Union{T3, Nothing}, # function makeNewState(currentstate::T1, thoughtDict::T4, response::T2, select::Union{T3, Nothing},
reward::T3, isterminal::Bool # reward::T3, isterminal::Bool
)::Tuple{String, Dict{Symbol, <:Any}} where {T1<:AbstractDict, T2<:AbstractString, T3<:Number, T4<:AbstractDict} # )::Tuple{String, Dict{Symbol, <:Any}} where {T1<:AbstractDict, T2<:AbstractString, T3<:Number, T4<:AbstractDict}
# Find the latest thought key and index from current state's thought history # # Find the latest thought key and index from current state's thought history
currentstate_latestThoughtKey, currentstate_latestThoughtIndice = # currentstate_latestThoughtKey, currentstate_latestThoughtIndice =
GeneralUtils.findHighestIndexKey(currentstate[:thoughtHistory], "thought") # GeneralUtils.findHighestIndexKey(currentstate[:thoughtHistory], "thought")
# Calculate next index for new thought/action # # Calculate next index for new thought/action
currentstate_nextIndice = # currentstate_nextIndice =
currentstate_latestThoughtKey == :NA ? 1 : currentstate_latestThoughtIndice + 1 # currentstate_latestThoughtKey == :NA ? 1 : currentstate_latestThoughtIndice + 1
# Create new keys for thought and action based on next index # # Create new keys for thought and action based on next index
currentstate_latestThoughtKey = Symbol("thought_$currentstate_nextIndice") # currentstate_latestThoughtKey = Symbol("thought_$currentstate_nextIndice")
latestActionKey = Symbol("action_$currentstate_nextIndice") # latestActionKey = Symbol("action_$currentstate_nextIndice")
# Find the latest thought index from input thought dictionary # # Find the latest thought index from input thought dictionary
_, thoughtDict_latestThoughtIndice = # _, thoughtDict_latestThoughtIndice =
GeneralUtils.findHighestIndexKey(thoughtDict, "thought") # GeneralUtils.findHighestIndexKey(thoughtDict, "thought")
# Determine thought and action keys from thought dictionary # # Determine thought and action keys from thought dictionary
thoughtDict_latestThoughtKey, thoughtDict_latestActionKey = # thoughtDict_latestThoughtKey, thoughtDict_latestActionKey =
if thoughtDict_latestThoughtIndice == -1 # if thoughtDict_latestThoughtIndice == -1
(:thought, :action) # (:thought, :action)
else # else
( # (
Symbol("thought_$thoughtDict_latestThoughtIndice"), # Symbol("thought_$thoughtDict_latestThoughtIndice"),
Symbol("action_$thoughtDict_latestThoughtIndice"), # Symbol("action_$thoughtDict_latestThoughtIndice"),
) # )
end # end
# Create new state by deep copying current state # # Create new state by deep copying current state
newstate = deepcopy(currentstate) # newstate = deepcopy(currentstate)
# Update thought history with new thought # # Update thought history with new thought
newstate[:thoughtHistory][currentstate_latestThoughtKey] = # newstate[:thoughtHistory][currentstate_latestThoughtKey] =
thoughtDict[thoughtDict_latestThoughtKey] # thoughtDict[thoughtDict_latestThoughtKey]
# Update thought history with new action # # Update thought history with new action
newstate[:thoughtHistory][latestActionKey] = thoughtDict[thoughtDict_latestActionKey] # newstate[:thoughtHistory][latestActionKey] = thoughtDict[thoughtDict_latestActionKey]
# Create and add new observation to thought history # # Create and add new observation to thought history
newObservationKey = Symbol("observation_$(currentstate_nextIndice)") # newObservationKey = Symbol("observation_$(currentstate_nextIndice)")
newstate[:thoughtHistory][newObservationKey] = response # newstate[:thoughtHistory][newObservationKey] = response
# Update state metadata # # Update state metadata
newstate[:reward] = reward # newstate[:reward] = reward
newstate[:select] = select # newstate[:select] = select
newstate[:isterminal] = isterminal # newstate[:isterminal] = isterminal
# Generate unique ID for new node # # Generate unique ID for new node
newNodeKey = GeneralUtils.uuid4snakecase() # newNodeKey = GeneralUtils.uuid4snakecase()
return (newNodeKey, newstate) # return (newNodeKey, newstate)
end # end