update
This commit is contained in:
@@ -9,7 +9,6 @@ using ..type, ..mcts, ..util
|
||||
# ---------------------------------------------- 100 --------------------------------------------- #
|
||||
|
||||
|
||||
|
||||
""" Search the best action to take for a given state and task
|
||||
|
||||
# Arguments
|
||||
@@ -63,14 +62,15 @@ function runMCTS(
|
||||
explorationweight::Number=1.0,
|
||||
earlystop::Union{Function,Nothing}=nothing,
|
||||
saveSimulatedNode::Bool=false,
|
||||
multithread=false
|
||||
)::NamedTuple{(:root, :bestNextState, :bestFinalState),Tuple{MCTSNode,T,T}} where {T<:Any}
|
||||
multithread=false,
|
||||
)::NamedTuple{(:root, :bestNextState, :bestTerminalState, :highValueStateList),
|
||||
Tuple{MCTSNode,T,T,Vector{Any}}} where {T<:Any}
|
||||
|
||||
root = MCTSNode("root", initialstate, 0, 0, 0, 0, false, nothing, Dict{String,MCTSNode}(),
|
||||
Dict{Symbol,Any}())
|
||||
|
||||
# [WORKING] storage for holding all high reward terminal nodes
|
||||
highStateValueNode = Channel{Any}(100)
|
||||
# storage for holding all high reward terminal nodes
|
||||
highValueState = Channel{Any}(100)
|
||||
|
||||
for nth in 1:maxiterations
|
||||
node = root
|
||||
@@ -99,7 +99,7 @@ function runMCTS(
|
||||
horizontalSampleSimulationPhase=horizontalSampleSimulationPhase,
|
||||
saveSimulatedNode=saveSimulatedNode,
|
||||
multithread=multithread,
|
||||
highStateValueNode=highStateValueNode,
|
||||
highValueState=highValueState,
|
||||
)
|
||||
end
|
||||
else
|
||||
@@ -109,7 +109,7 @@ function runMCTS(
|
||||
horizontalSampleSimulationPhase=horizontalSampleSimulationPhase,
|
||||
saveSimulatedNode=saveSimulatedNode,
|
||||
multithread=multithread,
|
||||
highStateValueNode=highStateValueNode)
|
||||
highValueState=highValueState)
|
||||
end
|
||||
end
|
||||
end
|
||||
@@ -120,15 +120,23 @@ function runMCTS(
|
||||
end
|
||||
end
|
||||
|
||||
# select the best next state and the best final state
|
||||
# select the best next state and the best terminal state along the best trajectory
|
||||
bestNextState = selectBestNextNode(root)
|
||||
besttrajectory = selectBestTrajectoryNode(root)
|
||||
bestTerminalState = selectBestTrajectoryNode(root)
|
||||
|
||||
#[WORKING] compare all high value answer then select the best one
|
||||
|
||||
return (root=root, bestNextState=bestNextState.state, bestFinalState=besttrajectory.state)
|
||||
# take all high value state from highValueState channel and put it in a list
|
||||
highValueStateList = Any[]
|
||||
while !isempty(highValueState)
|
||||
push!(highValueStateList, take!(highValueState))
|
||||
end
|
||||
|
||||
return (root=root,
|
||||
bestNextState=bestNextState.state,
|
||||
bestTerminalState=bestTerminalState.state,
|
||||
highValueStateList=highValueStateList)
|
||||
end
|
||||
|
||||
|
||||
""" Search the best action to take for a given state and task
|
||||
|
||||
# Arguments
|
||||
@@ -156,18 +164,18 @@ function simulateThenBackpropagate(node::MCTSNode, transition::Function, transit
|
||||
maxSimulationDepth::Integer=3, horizontalSampleSimulationPhase::Integer=3,
|
||||
saveSimulatedNode::Bool=false,
|
||||
multithread=false,
|
||||
highStateValueNode=Union{Nothing,Any}=nothing)
|
||||
highValueState=Union{Nothing,Any}=nothing)
|
||||
simTrajectoryReward, terminalstate =
|
||||
simulate(node, transition, transitionargs;
|
||||
maxSimulationDepth=maxSimulationDepth,
|
||||
horizontalSample=horizontalSampleSimulationPhase,
|
||||
multithread=multithread)
|
||||
# if a node has state value >= 8, store it in highStateValueNode
|
||||
if highStateValueNode !== nothing &&
|
||||
# if a node has state value >= 8, store it in highValueState
|
||||
if highValueState !== nothing &&
|
||||
terminalstate !== nothing &&
|
||||
terminalstate[:reward] >= 8
|
||||
|
||||
put!(highStateValueNode, deepcopy(terminalstate))
|
||||
put!(highValueState, deepcopy(terminalstate))
|
||||
end
|
||||
|
||||
backpropagate(node, simTrajectoryReward)
|
||||
|
||||
138
src/mcts.jl
138
src/mcts.jl
@@ -1,7 +1,7 @@
|
||||
module mcts
|
||||
|
||||
export selectBestNextNode, selectBestTrajectoryNode, backpropagate, isleaf, isroot, selectChildNode,
|
||||
expand, simulate, makeNewState
|
||||
expand, simulate
|
||||
using Base.Threads
|
||||
using GeneralUtils
|
||||
|
||||
@@ -302,84 +302,84 @@ function simulate(node::MCTSNode, transition::Function, transitionargs::NamedTup
|
||||
terminalstate=terminalstate)
|
||||
end
|
||||
|
||||
""" Make new state
|
||||
# """ Make new state
|
||||
|
||||
# Arguments
|
||||
- `currentstate::T1`
|
||||
Current state dictionary containing thought history and metadata
|
||||
- `thoughtDict::T4`
|
||||
Dictionary containing new thought and action
|
||||
- `response::T2`
|
||||
Response string from the environment
|
||||
- `select::Union{T3, Nothing}`
|
||||
Selection value or nothing
|
||||
- `reward::T3`
|
||||
Reward value for this state
|
||||
- `isterminal::Bool`
|
||||
Whether this state is terminal
|
||||
# # Arguments
|
||||
# - `currentstate::T1`
|
||||
# Current state dictionary containing thought history and metadata
|
||||
# - `thoughtDict::T4`
|
||||
# Dictionary containing new thought and action
|
||||
# - `response::T2`
|
||||
# Response string from the environment
|
||||
# - `select::Union{T3, Nothing}`
|
||||
# Selection value or nothing
|
||||
# - `reward::T3`
|
||||
# Reward value for this state
|
||||
# - `isterminal::Bool`
|
||||
# Whether this state is terminal
|
||||
|
||||
# Return
|
||||
- `Tuple{String, Dict{Symbol, <:Any}}`
|
||||
A tuple containing:
|
||||
- A unique node key string
|
||||
- A new state dictionary with updated thought history and metadata
|
||||
# # Return
|
||||
# - `Tuple{String, Dict{Symbol, <:Any}}`
|
||||
# A tuple containing:
|
||||
# - A unique node key string
|
||||
# - A new state dictionary with updated thought history and metadata
|
||||
|
||||
# Example
|
||||
```jldoctest
|
||||
julia>
|
||||
```
|
||||
# # Example
|
||||
# ```jldoctest
|
||||
# julia>
|
||||
# ```
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function makeNewState(currentstate::T1, thoughtDict::T4, response::T2, select::Union{T3, Nothing},
|
||||
reward::T3, isterminal::Bool
|
||||
)::Tuple{String, Dict{Symbol, <:Any}} where {T1<:AbstractDict, T2<:AbstractString, T3<:Number, T4<:AbstractDict}
|
||||
# # Signature
|
||||
# """
|
||||
# function makeNewState(currentstate::T1, thoughtDict::T4, response::T2, select::Union{T3, Nothing},
|
||||
# reward::T3, isterminal::Bool
|
||||
# )::Tuple{String, Dict{Symbol, <:Any}} where {T1<:AbstractDict, T2<:AbstractString, T3<:Number, T4<:AbstractDict}
|
||||
|
||||
# Find the latest thought key and index from current state's thought history
|
||||
currentstate_latestThoughtKey, currentstate_latestThoughtIndice =
|
||||
GeneralUtils.findHighestIndexKey(currentstate[:thoughtHistory], "thought")
|
||||
# Calculate next index for new thought/action
|
||||
currentstate_nextIndice =
|
||||
currentstate_latestThoughtKey == :NA ? 1 : currentstate_latestThoughtIndice + 1
|
||||
# Create new keys for thought and action based on next index
|
||||
currentstate_latestThoughtKey = Symbol("thought_$currentstate_nextIndice")
|
||||
latestActionKey = Symbol("action_$currentstate_nextIndice")
|
||||
# # Find the latest thought key and index from current state's thought history
|
||||
# currentstate_latestThoughtKey, currentstate_latestThoughtIndice =
|
||||
# GeneralUtils.findHighestIndexKey(currentstate[:thoughtHistory], "thought")
|
||||
# # Calculate next index for new thought/action
|
||||
# currentstate_nextIndice =
|
||||
# currentstate_latestThoughtKey == :NA ? 1 : currentstate_latestThoughtIndice + 1
|
||||
# # Create new keys for thought and action based on next index
|
||||
# currentstate_latestThoughtKey = Symbol("thought_$currentstate_nextIndice")
|
||||
# latestActionKey = Symbol("action_$currentstate_nextIndice")
|
||||
|
||||
# Find the latest thought index from input thought dictionary
|
||||
_, thoughtDict_latestThoughtIndice =
|
||||
GeneralUtils.findHighestIndexKey(thoughtDict, "thought")
|
||||
# # Find the latest thought index from input thought dictionary
|
||||
# _, thoughtDict_latestThoughtIndice =
|
||||
# GeneralUtils.findHighestIndexKey(thoughtDict, "thought")
|
||||
|
||||
# Determine thought and action keys from thought dictionary
|
||||
thoughtDict_latestThoughtKey, thoughtDict_latestActionKey =
|
||||
if thoughtDict_latestThoughtIndice == -1
|
||||
(:thought, :action)
|
||||
else
|
||||
(
|
||||
Symbol("thought_$thoughtDict_latestThoughtIndice"),
|
||||
Symbol("action_$thoughtDict_latestThoughtIndice"),
|
||||
)
|
||||
end
|
||||
# # Determine thought and action keys from thought dictionary
|
||||
# thoughtDict_latestThoughtKey, thoughtDict_latestActionKey =
|
||||
# if thoughtDict_latestThoughtIndice == -1
|
||||
# (:thought, :action)
|
||||
# else
|
||||
# (
|
||||
# Symbol("thought_$thoughtDict_latestThoughtIndice"),
|
||||
# Symbol("action_$thoughtDict_latestThoughtIndice"),
|
||||
# )
|
||||
# end
|
||||
|
||||
# Create new state by deep copying current state
|
||||
newstate = deepcopy(currentstate)
|
||||
# Update thought history with new thought
|
||||
newstate[:thoughtHistory][currentstate_latestThoughtKey] =
|
||||
thoughtDict[thoughtDict_latestThoughtKey]
|
||||
# Update thought history with new action
|
||||
newstate[:thoughtHistory][latestActionKey] = thoughtDict[thoughtDict_latestActionKey]
|
||||
# Create and add new observation to thought history
|
||||
newObservationKey = Symbol("observation_$(currentstate_nextIndice)")
|
||||
newstate[:thoughtHistory][newObservationKey] = response
|
||||
# Update state metadata
|
||||
newstate[:reward] = reward
|
||||
newstate[:select] = select
|
||||
newstate[:isterminal] = isterminal
|
||||
# # Create new state by deep copying current state
|
||||
# newstate = deepcopy(currentstate)
|
||||
# # Update thought history with new thought
|
||||
# newstate[:thoughtHistory][currentstate_latestThoughtKey] =
|
||||
# thoughtDict[thoughtDict_latestThoughtKey]
|
||||
# # Update thought history with new action
|
||||
# newstate[:thoughtHistory][latestActionKey] = thoughtDict[thoughtDict_latestActionKey]
|
||||
# # Create and add new observation to thought history
|
||||
# newObservationKey = Symbol("observation_$(currentstate_nextIndice)")
|
||||
# newstate[:thoughtHistory][newObservationKey] = response
|
||||
# # Update state metadata
|
||||
# newstate[:reward] = reward
|
||||
# newstate[:select] = select
|
||||
# newstate[:isterminal] = isterminal
|
||||
|
||||
# Generate unique ID for new node
|
||||
newNodeKey = GeneralUtils.uuid4snakecase()
|
||||
# # Generate unique ID for new node
|
||||
# newNodeKey = GeneralUtils.uuid4snakecase()
|
||||
|
||||
return (newNodeKey, newstate)
|
||||
end
|
||||
# return (newNodeKey, newstate)
|
||||
# end
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user