update
This commit is contained in:
55
src/mcts.jl
55
src/mcts.jl
@@ -211,7 +211,6 @@ function expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple;
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
function _expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple)
|
function _expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple)
|
||||||
result = transition(node.state, transitionargs)
|
result = transition(node.state, transitionargs)
|
||||||
newNodeKey::AbstractString = result[:newNodeKey]
|
newNodeKey::AbstractString = result[:newNodeKey]
|
||||||
@@ -244,17 +243,24 @@ end
|
|||||||
Arguments for everything the user will use within transition().
|
Arguments for everything the user will use within transition().
|
||||||
- `maxSimulationDepth::Integer`
|
- `maxSimulationDepth::Integer`
|
||||||
maximum depth level MCTS goes vertically during simulation.
|
maximum depth level MCTS goes vertically during simulation.
|
||||||
- horizontalSample::Integer
|
- `horizontalSample::Integer`
|
||||||
Total number to sample from the current node (i.e. expand new node horizontally)
|
Total number to sample from the current node (i.e. expand new node horizontally)
|
||||||
|
|
||||||
|
# Keyword Arguments
|
||||||
|
- `multithread::Bool`
|
||||||
|
Whether to run expansion in parallel using multiple threads. Defaults to false.
|
||||||
|
|
||||||
# Return
|
# Return
|
||||||
- `::NamedTuple{(:simTrajectoryReward, :terminalstate), Tuple{Number, Union{Dict{Symbol, Any}, Nothing}}}`
|
- `simTrajectoryReward::Number`
|
||||||
|
Cumulative reward collected along the simulation trajectory
|
||||||
|
- `terminalstate::Union{Dict{Symbol, Any}, Nothing}`
|
||||||
|
Final state if terminal state reached, nothing otherwise
|
||||||
|
|
||||||
# Signature
|
# Signature
|
||||||
"""
|
"""
|
||||||
function simulate(node::MCTSNode, transition::Function, transitionargs::NamedTuple;
|
function simulate(node::MCTSNode, transition::Function, transitionargs::NamedTuple;
|
||||||
maxSimulationDepth::Integer=3, horizontalSample::Integer=3, multithread=false)
|
maxSimulationDepth::Integer=3, horizontalSample::Integer=3, multithread=false
|
||||||
# )::NamedTuple{(:simTrajectoryReward, :terminalstate), Tuple{Number, Union{Dict{Symbol, Any}, Nothing}}}
|
)::NamedTuple{(:simTrajectoryReward, :terminalstate), Tuple{<:Number, Union{Dict{Symbol, Any}, Nothing}}}
|
||||||
|
|
||||||
simTrajectoryReward = 0.0
|
simTrajectoryReward = 0.0
|
||||||
terminalstate = nothing
|
terminalstate = nothing
|
||||||
@@ -275,38 +281,54 @@ function simulate(node::MCTSNode, transition::Function, transitionargs::NamedTup
|
|||||||
return (simTrajectoryReward=simTrajectoryReward, terminalstate=terminalstate)
|
return (simTrajectoryReward=simTrajectoryReward, terminalstate=terminalstate)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
""" Make new state
|
||||||
"""
|
|
||||||
|
|
||||||
# Arguments
|
# Arguments
|
||||||
|
- `currentstate::T1`
|
||||||
|
Current state dictionary containing thought history and metadata
|
||||||
|
- `thoughtDict::T4`
|
||||||
|
Dictionary containing new thought and action
|
||||||
|
- `response::T2`
|
||||||
|
Response string from the environment
|
||||||
|
- `select::Union{T3, Nothing}`
|
||||||
|
Selection value or nothing
|
||||||
|
- `reward::T3`
|
||||||
|
Reward value for this state
|
||||||
|
- `isterminal::Bool`
|
||||||
|
Whether this state is terminal
|
||||||
|
|
||||||
# Return
|
# Return
|
||||||
|
- `Tuple{String, Dict{Symbol, <:Any}}`
|
||||||
|
A tuple containing:
|
||||||
|
- A unique node key string
|
||||||
|
- A new state dictionary with updated thought history and metadata
|
||||||
|
|
||||||
# Example
|
# Example
|
||||||
```jldoctest
|
```jldoctest
|
||||||
julia>
|
julia>
|
||||||
```
|
```
|
||||||
|
|
||||||
# TODO
|
|
||||||
- [] update docstring
|
|
||||||
- [x] implement the function
|
|
||||||
|
|
||||||
# Signature
|
# Signature
|
||||||
"""
|
"""
|
||||||
function makeNewState(currentstate::T1, thoughtDict::T4, response::T2, select::Union{T3, Nothing},
|
function makeNewState(currentstate::T1, thoughtDict::T4, response::T2, select::Union{T3, Nothing},
|
||||||
reward::T3, isterminal::Bool
|
reward::T3, isterminal::Bool
|
||||||
)::Tuple{String, Dict{Symbol, <:Any}} where {T1<:AbstractDict, T2<:AbstractString, T3<:Number, T4<:AbstractDict}
|
)::Tuple{String, Dict{Symbol, <:Any}} where {T1<:AbstractDict, T2<:AbstractString, T3<:Number, T4<:AbstractDict}
|
||||||
|
|
||||||
|
# Find the latest thought key and index from current state's thought history
|
||||||
currentstate_latestThoughtKey, currentstate_latestThoughtIndice =
|
currentstate_latestThoughtKey, currentstate_latestThoughtIndice =
|
||||||
GeneralUtils.findHighestIndexKey(currentstate[:thoughtHistory], "thought")
|
GeneralUtils.findHighestIndexKey(currentstate[:thoughtHistory], "thought")
|
||||||
|
# Calculate next index for new thought/action
|
||||||
currentstate_nextIndice =
|
currentstate_nextIndice =
|
||||||
currentstate_latestThoughtKey == :NA ? 1 : currentstate_latestThoughtIndice + 1
|
currentstate_latestThoughtKey == :NA ? 1 : currentstate_latestThoughtIndice + 1
|
||||||
|
# Create new keys for thought and action based on next index
|
||||||
currentstate_latestThoughtKey = Symbol("thought_$currentstate_nextIndice")
|
currentstate_latestThoughtKey = Symbol("thought_$currentstate_nextIndice")
|
||||||
latestActionKey = Symbol("action_$currentstate_nextIndice")
|
latestActionKey = Symbol("action_$currentstate_nextIndice")
|
||||||
|
|
||||||
|
# Find the latest thought index from input thought dictionary
|
||||||
_, thoughtDict_latestThoughtIndice =
|
_, thoughtDict_latestThoughtIndice =
|
||||||
GeneralUtils.findHighestIndexKey(thoughtDict, "thought")
|
GeneralUtils.findHighestIndexKey(thoughtDict, "thought")
|
||||||
|
|
||||||
|
# Determine thought and action keys from thought dictionary
|
||||||
thoughtDict_latestThoughtKey, thoughtDict_latestActionKey =
|
thoughtDict_latestThoughtKey, thoughtDict_latestActionKey =
|
||||||
if thoughtDict_latestThoughtIndice == -1
|
if thoughtDict_latestThoughtIndice == -1
|
||||||
(:thought, :action)
|
(:thought, :action)
|
||||||
@@ -317,17 +339,22 @@ function makeNewState(currentstate::T1, thoughtDict::T4, response::T2, select::U
|
|||||||
)
|
)
|
||||||
end
|
end
|
||||||
|
|
||||||
# add Thought, action, observation to thoughtHistory
|
# Create new state by deep copying current state
|
||||||
newstate = deepcopy(currentstate)
|
newstate = deepcopy(currentstate)
|
||||||
|
# Update thought history with new thought
|
||||||
newstate[:thoughtHistory][currentstate_latestThoughtKey] =
|
newstate[:thoughtHistory][currentstate_latestThoughtKey] =
|
||||||
thoughtDict[thoughtDict_latestThoughtKey]
|
thoughtDict[thoughtDict_latestThoughtKey]
|
||||||
|
# Update thought history with new action
|
||||||
newstate[:thoughtHistory][latestActionKey] = thoughtDict[thoughtDict_latestActionKey]
|
newstate[:thoughtHistory][latestActionKey] = thoughtDict[thoughtDict_latestActionKey]
|
||||||
|
# Create and add new observation to thought history
|
||||||
newObservationKey = Symbol("observation_$(currentstate_nextIndice)")
|
newObservationKey = Symbol("observation_$(currentstate_nextIndice)")
|
||||||
newstate[:thoughtHistory][newObservationKey] = response
|
newstate[:thoughtHistory][newObservationKey] = response
|
||||||
|
# Update state metadata
|
||||||
newstate[:reward] = reward
|
newstate[:reward] = reward
|
||||||
newstate[:select] = select
|
newstate[:select] = select
|
||||||
newstate[:isterminal] = isterminal
|
newstate[:isterminal] = isterminal
|
||||||
|
|
||||||
|
# Generate unique ID for new node
|
||||||
newNodeKey = GeneralUtils.uuid4snakecase()
|
newNodeKey = GeneralUtils.uuid4snakecase()
|
||||||
|
|
||||||
return (newNodeKey, newstate)
|
return (newNodeKey, newstate)
|
||||||
@@ -394,8 +421,6 @@ end
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user