This commit is contained in:
narawat lamaiin
2024-05-05 12:06:32 +07:00
parent 1ac2ad1801
commit 77b590c6ad
4 changed files with 78 additions and 128 deletions

View File

@@ -7,7 +7,7 @@ module mcts
export MCTSNode, runMCTS, isleaf
using Dates, UUIDs, DataStructures, JSON3, Random
using Dates, UUIDs, DataStructures, JSON3, Random, PrettyPrinting
using GeneralUtils
using ..type, ..llmfunction
@@ -211,13 +211,11 @@ end
current game state
- `thoughtDict::T3`
contain Thought, Action, Observation
- `isterminal::Function`
a function to determine terminal state
# Return
- (newNodeKey, )
- `newNodeKey::String`
key for newstate
- `newstate::Dict{Symbol, Any}`
next game state
- `(newNodeKey, newstate, isterminalstate, reward)::Tuple{String, Dict{Symbol, <:Any}, Bool, <:Number}`
# Example
```jldoctest
@@ -238,16 +236,14 @@ julia> thoughtDict = Dict(
```
# TODO
- [WORKING] update docstring
- [PENDING] add other actions
- [] add embedding of newstate and store in newstate[:embedding]
- [x] check for terminal state and assign reward
# Signature
"""
function MCTStransition(a::T1, state::T2, thoughtDict::T3, isterminal::Function
)::Tuple{String, Dict{Symbol, <:Any}, Bool, <:Number} where {T1<:agent, T2<:AbstractDict, T3<:AbstractDict}
pprint(thoughtDict)
actionname = thoughtDict[:Action][:name]
actioninput = thoughtDict[:Action][:input]
@@ -279,7 +275,7 @@ function MCTStransition(a::T1, state::T2, thoughtDict::T3, isterminal::Function
newNodeKey = GeneralUtils.uuid4snakecase()
isterminalstate, reward = isterminal(newstate)
return newNodeKey, newstate, isterminalstate, reward
return (newNodeKey, newstate, isterminalstate, reward)
end
@@ -317,21 +313,21 @@ isleaf(node::MCTSNode)::Bool = isempty(node.children)
""" Select child node based on the highest progressValue
# Arguments
- `node::MCTSNode`
node of a search tree
# Return
- `childNode::MCTSNode`
the highest value child node
# Example
```jldoctest
julia>
```
# TODO
- [WORKING] update docstring
- [x] implement the function
# Signature
"""
function selectChildNode(node::MCTSNode)
function selectChildNode(node::MCTSNode)::MCTSNode
highestProgressValue = 0
nodekey = nothing