update
This commit is contained in:
28
src/mcts.jl
28
src/mcts.jl
@@ -7,7 +7,7 @@ module mcts
|
||||
|
||||
export MCTSNode, runMCTS, isleaf
|
||||
|
||||
using Dates, UUIDs, DataStructures, JSON3, Random
|
||||
using Dates, UUIDs, DataStructures, JSON3, Random, PrettyPrinting
|
||||
using GeneralUtils
|
||||
using ..type, ..llmfunction
|
||||
|
||||
@@ -211,13 +211,11 @@ end
|
||||
current game state
|
||||
- `thoughtDict::T3`
|
||||
contain Thought, Action, Observation
|
||||
- `isterminal::Function`
|
||||
a function to determine terminal state
|
||||
|
||||
# Return
|
||||
- (newNodeKey, )
|
||||
- `newNodeKey::String`
|
||||
key for newstate
|
||||
- `newstate::Dict{Symbol, Any}`
|
||||
next game state
|
||||
- `(newNodeKey, newstate, isterminalstate, reward)::Tuple{String, Dict{Symbol, <:Any}, Bool, <:Number}`
|
||||
|
||||
# Example
|
||||
```jldoctest
|
||||
@@ -238,16 +236,14 @@ julia> thoughtDict = Dict(
|
||||
```
|
||||
|
||||
# TODO
|
||||
- [WORKING] update docstring
|
||||
- [PENDING] add other actions
|
||||
- [] add embedding of newstate and store in newstate[:embedding]
|
||||
- [x] check for terminal state and assign reward
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function MCTStransition(a::T1, state::T2, thoughtDict::T3, isterminal::Function
|
||||
)::Tuple{String, Dict{Symbol, <:Any}, Bool, <:Number} where {T1<:agent, T2<:AbstractDict, T3<:AbstractDict}
|
||||
|
||||
pprint(thoughtDict)
|
||||
actionname = thoughtDict[:Action][:name]
|
||||
actioninput = thoughtDict[:Action][:input]
|
||||
|
||||
@@ -279,7 +275,7 @@ function MCTStransition(a::T1, state::T2, thoughtDict::T3, isterminal::Function
|
||||
newNodeKey = GeneralUtils.uuid4snakecase()
|
||||
isterminalstate, reward = isterminal(newstate)
|
||||
|
||||
return newNodeKey, newstate, isterminalstate, reward
|
||||
return (newNodeKey, newstate, isterminalstate, reward)
|
||||
end
|
||||
|
||||
|
||||
@@ -317,21 +313,21 @@ isleaf(node::MCTSNode)::Bool = isempty(node.children)
|
||||
""" Select child node based on the highest progressValue
|
||||
|
||||
# Arguments
|
||||
|
||||
- `node::MCTSNode`
|
||||
node of a search tree
|
||||
|
||||
# Return
|
||||
- `childNode::MCTSNode`
|
||||
the highest value child node
|
||||
|
||||
# Example
|
||||
```jldoctest
|
||||
julia>
|
||||
```
|
||||
|
||||
# TODO
|
||||
- [WORKING] update docstring
|
||||
- [x] implement the function
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function selectChildNode(node::MCTSNode)
|
||||
function selectChildNode(node::MCTSNode)::MCTSNode
|
||||
highestProgressValue = 0
|
||||
nodekey = nothing
|
||||
|
||||
|
||||
Reference in New Issue
Block a user