update
This commit is contained in:
55
src/mcts.jl
55
src/mcts.jl
@@ -211,7 +211,6 @@ function expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple;
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
function _expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple)
|
||||
result = transition(node.state, transitionargs)
|
||||
newNodeKey::AbstractString = result[:newNodeKey]
|
||||
@@ -244,17 +243,24 @@ end
|
||||
Arguments for everything the user will use within transition().
|
||||
- `maxSimulationDepth::Integer`
|
||||
maximum depth level MCTS goes vertically during simulation.
|
||||
- horizontalSample::Integer
|
||||
- `horizontalSample::Integer`
|
||||
Total number to sample from the current node (i.e. expand new node horizontally)
|
||||
|
||||
# Keyword Arguments
|
||||
- `multithread::Bool`
|
||||
Whether to run expansion in parallel using multiple threads. Defaults to false.
|
||||
|
||||
# Return
|
||||
- `::NamedTuple{(:simTrajectoryReward, :terminalstate), Tuple{Number, Union{Dict{Symbol, Any}, Nothing}}}`
|
||||
- `simTrajectoryReward::Number`
|
||||
Cumulative reward collected along the simulation trajectory
|
||||
- `terminalstate::Union{Dict{Symbol, Any}, Nothing}`
|
||||
Final state if terminal state reached, nothing otherwise
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function simulate(node::MCTSNode, transition::Function, transitionargs::NamedTuple;
|
||||
maxSimulationDepth::Integer=3, horizontalSample::Integer=3, multithread=false)
|
||||
# )::NamedTuple{(:simTrajectoryReward, :terminalstate), Tuple{Number, Union{Dict{Symbol, Any}, Nothing}}}
|
||||
maxSimulationDepth::Integer=3, horizontalSample::Integer=3, multithread=false
|
||||
)::NamedTuple{(:simTrajectoryReward, :terminalstate), Tuple{<:Number, Union{Dict{Symbol, Any}, Nothing}}}
|
||||
|
||||
simTrajectoryReward = 0.0
|
||||
terminalstate = nothing
|
||||
@@ -275,38 +281,54 @@ function simulate(node::MCTSNode, transition::Function, transitionargs::NamedTup
|
||||
return (simTrajectoryReward=simTrajectoryReward, terminalstate=terminalstate)
|
||||
end
|
||||
|
||||
|
||||
"""
|
||||
""" Make new state
|
||||
|
||||
# Arguments
|
||||
|
||||
- `currentstate::T1`
|
||||
Current state dictionary containing thought history and metadata
|
||||
- `thoughtDict::T4`
|
||||
Dictionary containing new thought and action
|
||||
- `response::T2`
|
||||
Response string from the environment
|
||||
- `select::Union{T3, Nothing}`
|
||||
Selection value or nothing
|
||||
- `reward::T3`
|
||||
Reward value for this state
|
||||
- `isterminal::Bool`
|
||||
Whether this state is terminal
|
||||
|
||||
# Return
|
||||
- `Tuple{String, Dict{Symbol, <:Any}}`
|
||||
A tuple containing:
|
||||
- A unique node key string
|
||||
- A new state dictionary with updated thought history and metadata
|
||||
|
||||
# Example
|
||||
```jldoctest
|
||||
julia>
|
||||
```
|
||||
|
||||
# TODO
|
||||
- [] update docstring
|
||||
- [x] implement the function
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function makeNewState(currentstate::T1, thoughtDict::T4, response::T2, select::Union{T3, Nothing},
|
||||
reward::T3, isterminal::Bool
|
||||
)::Tuple{String, Dict{Symbol, <:Any}} where {T1<:AbstractDict, T2<:AbstractString, T3<:Number, T4<:AbstractDict}
|
||||
|
||||
# Find the latest thought key and index from current state's thought history
|
||||
currentstate_latestThoughtKey, currentstate_latestThoughtIndice =
|
||||
GeneralUtils.findHighestIndexKey(currentstate[:thoughtHistory], "thought")
|
||||
# Calculate next index for new thought/action
|
||||
currentstate_nextIndice =
|
||||
currentstate_latestThoughtKey == :NA ? 1 : currentstate_latestThoughtIndice + 1
|
||||
# Create new keys for thought and action based on next index
|
||||
currentstate_latestThoughtKey = Symbol("thought_$currentstate_nextIndice")
|
||||
latestActionKey = Symbol("action_$currentstate_nextIndice")
|
||||
|
||||
# Find the latest thought index from input thought dictionary
|
||||
_, thoughtDict_latestThoughtIndice =
|
||||
GeneralUtils.findHighestIndexKey(thoughtDict, "thought")
|
||||
|
||||
# Determine thought and action keys from thought dictionary
|
||||
thoughtDict_latestThoughtKey, thoughtDict_latestActionKey =
|
||||
if thoughtDict_latestThoughtIndice == -1
|
||||
(:thought, :action)
|
||||
@@ -317,17 +339,22 @@ function makeNewState(currentstate::T1, thoughtDict::T4, response::T2, select::U
|
||||
)
|
||||
end
|
||||
|
||||
# add Thought, action, observation to thoughtHistory
|
||||
# Create new state by deep copying current state
|
||||
newstate = deepcopy(currentstate)
|
||||
# Update thought history with new thought
|
||||
newstate[:thoughtHistory][currentstate_latestThoughtKey] =
|
||||
thoughtDict[thoughtDict_latestThoughtKey]
|
||||
# Update thought history with new action
|
||||
newstate[:thoughtHistory][latestActionKey] = thoughtDict[thoughtDict_latestActionKey]
|
||||
# Create and add new observation to thought history
|
||||
newObservationKey = Symbol("observation_$(currentstate_nextIndice)")
|
||||
newstate[:thoughtHistory][newObservationKey] = response
|
||||
# Update state metadata
|
||||
newstate[:reward] = reward
|
||||
newstate[:select] = select
|
||||
newstate[:isterminal] = isterminal
|
||||
|
||||
# Generate unique ID for new node
|
||||
newNodeKey = GeneralUtils.uuid4snakecase()
|
||||
|
||||
return (newNodeKey, newstate)
|
||||
@@ -394,8 +421,6 @@ end
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user