Compare commits
7 Commits
13d0c64183
...
d92333cab4
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d92333cab4 | ||
|
|
093290a33b | ||
| c777800948 | |||
| ceced04171 | |||
| ee5f8a8a52 | |||
| 693cbfd82d | |||
| 842626ae35 |
@@ -1,7 +1,7 @@
|
|||||||
name = "LLMMCTS"
|
name = "LLMMCTS"
|
||||||
uuid = "d76c5a4d-449e-4835-8cc4-dd86ec44f241"
|
uuid = "d76c5a4d-449e-4835-8cc4-dd86ec44f241"
|
||||||
authors = ["narawat lamaiin <narawat@outlook.com>"]
|
authors = ["narawat lamaiin <narawat@outlook.com>"]
|
||||||
version = "0.1.3"
|
version = "0.1.4"
|
||||||
|
|
||||||
[deps]
|
[deps]
|
||||||
GeneralUtils = "c6c72f09-b708-4ac8-ac7c-2084d70108fe"
|
GeneralUtils = "c6c72f09-b708-4ac8-ac7c-2084d70108fe"
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
module LLMMCTS
|
module LLMMCTS
|
||||||
|
|
||||||
# export agent
|
export MCTSNode
|
||||||
|
|
||||||
|
|
||||||
""" Order by dependencies of each file. The 1st included file must not depend on any other
|
""" Order by dependencies of each file. The 1st included file must not depend on any other
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ using ..type, ..mcts, ..util
|
|||||||
# ---------------------------------------------- 100 --------------------------------------------- #
|
# ---------------------------------------------- 100 --------------------------------------------- #
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
""" Search the best action to take for a given state and task
|
""" Search the best action to take for a given state and task
|
||||||
|
|
||||||
# Arguments
|
# Arguments
|
||||||
@@ -34,7 +33,7 @@ using ..type, ..mcts, ..util
|
|||||||
a known state. 1.0 balance between exploration and exploitation like 50%-50%. 2.0 makes MCTS
|
a known state. 1.0 balance between exploration and exploitation like 50%-50%. 2.0 makes MCTS
|
||||||
aggressively explore new state (default: 1.0)
|
aggressively explore new state (default: 1.0)
|
||||||
- `earlystop::Union{Function,Nothing}`
|
- `earlystop::Union{Function,Nothing}`
|
||||||
optional function to check early stopping condition (default: nothing)
|
optional function to check early stopping condition if it is satisfied, MCTS will break iterations (default: nothing)
|
||||||
- `saveSimulatedNode::Bool`
|
- `saveSimulatedNode::Bool`
|
||||||
whether to save nodes created during simulation phase (default: false)
|
whether to save nodes created during simulation phase (default: false)
|
||||||
- `multithread::Bool`
|
- `multithread::Bool`
|
||||||
@@ -63,12 +62,16 @@ function runMCTS(
|
|||||||
explorationweight::Number=1.0,
|
explorationweight::Number=1.0,
|
||||||
earlystop::Union{Function,Nothing}=nothing,
|
earlystop::Union{Function,Nothing}=nothing,
|
||||||
saveSimulatedNode::Bool=false,
|
saveSimulatedNode::Bool=false,
|
||||||
multithread=false
|
multithread=false,
|
||||||
)::NamedTuple{(:root, :bestNextState, :bestFinalState),Tuple{MCTSNode,T,T}} where {T<:Any}
|
)::NamedTuple{(:root, :bestNextState, :bestTerminalState, :highValueStateList),
|
||||||
|
Tuple{MCTSNode,T,T,Vector{Dict{Symbol,Any}}}} where {T<:Any}
|
||||||
|
|
||||||
root = MCTSNode("root", initialstate, 0, 0, 0, 0, false, nothing, Dict{String,MCTSNode}(),
|
root = MCTSNode("root", initialstate, 0, 0, 0, 0, false, nothing, Dict{String,MCTSNode}(),
|
||||||
Dict{Symbol,Any}())
|
Dict{Symbol,Any}())
|
||||||
|
|
||||||
|
# storage for holding all high reward terminal nodes
|
||||||
|
highValueState = Channel{Any}(100)
|
||||||
|
|
||||||
for nth in 1:maxiterations
|
for nth in 1:maxiterations
|
||||||
node = root
|
node = root
|
||||||
node.visits += 1
|
node.visits += 1
|
||||||
@@ -78,6 +81,10 @@ function runMCTS(
|
|||||||
end
|
end
|
||||||
|
|
||||||
if node.isterminal
|
if node.isterminal
|
||||||
|
if node.state[:reward] >= 8
|
||||||
|
put!(highrewardNode, deepcopy(node.state))
|
||||||
|
end
|
||||||
|
|
||||||
# MCTS arrive at the leaf node that is also a terminal state,
|
# MCTS arrive at the leaf node that is also a terminal state,
|
||||||
# do nothing then go directly to backpropagation. It means the end of this iteration
|
# do nothing then go directly to backpropagation. It means the end of this iteration
|
||||||
backpropagate(node, node.reward)
|
backpropagate(node, node.reward)
|
||||||
@@ -91,15 +98,18 @@ function runMCTS(
|
|||||||
maxSimulationDepth=maxSimulationDepth,
|
maxSimulationDepth=maxSimulationDepth,
|
||||||
horizontalSampleSimulationPhase=horizontalSampleSimulationPhase,
|
horizontalSampleSimulationPhase=horizontalSampleSimulationPhase,
|
||||||
saveSimulatedNode=saveSimulatedNode,
|
saveSimulatedNode=saveSimulatedNode,
|
||||||
multithread=multithread)
|
multithread=multithread,
|
||||||
|
highValueState=highValueState,
|
||||||
|
)
|
||||||
end
|
end
|
||||||
else
|
else
|
||||||
for (leafNodeKey, leafNode) in node.children
|
for (leafNodeKey, leafNode) in node.children
|
||||||
simulateThenBackpropagate(leafNode, transition, transitionargs;
|
simulateThenBackpropagate(leafNode, transition, transitionargs;
|
||||||
maxSimulationDepth=maxSimulationDepth,
|
maxSimulationDepth=maxSimulationDepth,
|
||||||
horizontalSampleSimulationPhase=horizontalSampleSimulationPhase,
|
horizontalSampleSimulationPhase=horizontalSampleSimulationPhase,
|
||||||
saveSimulatedNode=saveSimulatedNode,
|
saveSimulatedNode=saveSimulatedNode,
|
||||||
multithread=multithread)
|
multithread=multithread,
|
||||||
|
highValueState=highValueState)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
@@ -110,11 +120,24 @@ function runMCTS(
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
# select the best next state and the best final state
|
# select the best next state and the best terminal state along the best trajectory
|
||||||
bestNextState = selectBestNextNode(root)
|
bestNextState = selectBestNextNode(root)
|
||||||
besttrajectory = selectBestTrajectoryNode(root)
|
bestTerminalState = selectBestTrajectoryNode(root)
|
||||||
|
|
||||||
return (root=root, bestNextState=bestNextState.state, bestFinalState=besttrajectory.state)
|
# take all high value state from highValueState channel and put it in a list
|
||||||
|
highValueStateList = Vector{Dict{Symbol, Any}}()
|
||||||
|
while !isempty(highValueState)
|
||||||
|
push!(highValueStateList, take!(highValueState))
|
||||||
|
end
|
||||||
|
|
||||||
|
result = (
|
||||||
|
root=root,
|
||||||
|
bestNextState=bestNextState.state,
|
||||||
|
bestTerminalState=bestTerminalState.state,
|
||||||
|
highValueStateList=highValueStateList
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
end
|
end
|
||||||
|
|
||||||
""" Search the best action to take for a given state and task
|
""" Search the best action to take for a given state and task
|
||||||
@@ -143,11 +166,21 @@ end
|
|||||||
function simulateThenBackpropagate(node::MCTSNode, transition::Function, transitionargs::NamedTuple;
|
function simulateThenBackpropagate(node::MCTSNode, transition::Function, transitionargs::NamedTuple;
|
||||||
maxSimulationDepth::Integer=3, horizontalSampleSimulationPhase::Integer=3,
|
maxSimulationDepth::Integer=3, horizontalSampleSimulationPhase::Integer=3,
|
||||||
saveSimulatedNode::Bool=false,
|
saveSimulatedNode::Bool=false,
|
||||||
multithread=false)
|
multithread=false,
|
||||||
simTrajectoryReward, terminalstate = simulate(node, transition, transitionargs;
|
highValueState=Union{Nothing,Any}=nothing)
|
||||||
maxSimulationDepth=maxSimulationDepth,
|
simTrajectoryReward, terminalstate =
|
||||||
horizontalSample=horizontalSampleSimulationPhase,
|
simulate(node, transition, transitionargs;
|
||||||
multithread=multithread)
|
maxSimulationDepth=maxSimulationDepth,
|
||||||
|
horizontalSample=horizontalSampleSimulationPhase,
|
||||||
|
multithread=multithread)
|
||||||
|
# if a node has state value >= 8, store it in highValueState
|
||||||
|
if highValueState !== nothing &&
|
||||||
|
terminalstate !== nothing &&
|
||||||
|
terminalstate[:reward] >= 8
|
||||||
|
|
||||||
|
put!(highValueState, deepcopy(terminalstate))
|
||||||
|
end
|
||||||
|
|
||||||
backpropagate(node, simTrajectoryReward)
|
backpropagate(node, simTrajectoryReward)
|
||||||
|
|
||||||
# check if the user wants to keep the simulated node
|
# check if the user wants to keep the simulated node
|
||||||
|
|||||||
143
src/mcts.jl
143
src/mcts.jl
@@ -1,7 +1,7 @@
|
|||||||
module mcts
|
module mcts
|
||||||
|
|
||||||
export selectBestNextNode, selectBestTrajectoryNode, backpropagate, isleaf, isroot, selectChildNode,
|
export selectBestNextNode, selectBestTrajectoryNode, backpropagate, isleaf, isroot, selectChildNode,
|
||||||
expand, simulate, makeNewState
|
expand, simulate
|
||||||
using Base.Threads
|
using Base.Threads
|
||||||
using GeneralUtils
|
using GeneralUtils
|
||||||
|
|
||||||
@@ -280,7 +280,7 @@ end
|
|||||||
"""
|
"""
|
||||||
function simulate(node::MCTSNode, transition::Function, transitionargs::NamedTuple;
|
function simulate(node::MCTSNode, transition::Function, transitionargs::NamedTuple;
|
||||||
maxSimulationDepth::Integer=3, horizontalSample::Integer=3, multithread=false
|
maxSimulationDepth::Integer=3, horizontalSample::Integer=3, multithread=false
|
||||||
)::NamedTuple{(:simTrajectoryReward, :terminalstate), Tuple{<:Number, Union{Dict{Symbol, Any}, Nothing}}}
|
)::NamedTuple{(:simTrajectoryReward, :terminalstate), Tuple{<:Number, Union{Dict{Symbol, Any}, Nothing}}}
|
||||||
|
|
||||||
simTrajectoryReward = 0.0
|
simTrajectoryReward = 0.0
|
||||||
terminalstate = nothing
|
terminalstate = nothing
|
||||||
@@ -298,87 +298,88 @@ function simulate(node::MCTSNode, transition::Function, transitionargs::NamedTup
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
return (simTrajectoryReward=simTrajectoryReward, terminalstate=terminalstate)
|
return (simTrajectoryReward=simTrajectoryReward,
|
||||||
|
terminalstate=terminalstate)
|
||||||
end
|
end
|
||||||
|
|
||||||
""" Make new state
|
# """ Make new state
|
||||||
|
|
||||||
# Arguments
|
# # Arguments
|
||||||
- `currentstate::T1`
|
# - `currentstate::T1`
|
||||||
Current state dictionary containing thought history and metadata
|
# Current state dictionary containing thought history and metadata
|
||||||
- `thoughtDict::T4`
|
# - `thoughtDict::T4`
|
||||||
Dictionary containing new thought and action
|
# Dictionary containing new thought and action
|
||||||
- `response::T2`
|
# - `response::T2`
|
||||||
Response string from the environment
|
# Response string from the environment
|
||||||
- `select::Union{T3, Nothing}`
|
# - `select::Union{T3, Nothing}`
|
||||||
Selection value or nothing
|
# Selection value or nothing
|
||||||
- `reward::T3`
|
# - `reward::T3`
|
||||||
Reward value for this state
|
# Reward value for this state
|
||||||
- `isterminal::Bool`
|
# - `isterminal::Bool`
|
||||||
Whether this state is terminal
|
# Whether this state is terminal
|
||||||
|
|
||||||
# Return
|
# # Return
|
||||||
- `Tuple{String, Dict{Symbol, <:Any}}`
|
# - `Tuple{String, Dict{Symbol, <:Any}}`
|
||||||
A tuple containing:
|
# A tuple containing:
|
||||||
- A unique node key string
|
# - A unique node key string
|
||||||
- A new state dictionary with updated thought history and metadata
|
# - A new state dictionary with updated thought history and metadata
|
||||||
|
|
||||||
# Example
|
# # Example
|
||||||
```jldoctest
|
# ```jldoctest
|
||||||
julia>
|
# julia>
|
||||||
```
|
# ```
|
||||||
|
|
||||||
# Signature
|
# # Signature
|
||||||
"""
|
# """
|
||||||
function makeNewState(currentstate::T1, thoughtDict::T4, response::T2, select::Union{T3, Nothing},
|
# function makeNewState(currentstate::T1, thoughtDict::T4, response::T2, select::Union{T3, Nothing},
|
||||||
reward::T3, isterminal::Bool
|
# reward::T3, isterminal::Bool
|
||||||
)::Tuple{String, Dict{Symbol, <:Any}} where {T1<:AbstractDict, T2<:AbstractString, T3<:Number, T4<:AbstractDict}
|
# )::Tuple{String, Dict{Symbol, <:Any}} where {T1<:AbstractDict, T2<:AbstractString, T3<:Number, T4<:AbstractDict}
|
||||||
|
|
||||||
# Find the latest thought key and index from current state's thought history
|
# # Find the latest thought key and index from current state's thought history
|
||||||
currentstate_latestThoughtKey, currentstate_latestThoughtIndice =
|
# currentstate_latestThoughtKey, currentstate_latestThoughtIndice =
|
||||||
GeneralUtils.findHighestIndexKey(currentstate[:thoughtHistory], "thought")
|
# GeneralUtils.findHighestIndexKey(currentstate[:thoughtHistory], "thought")
|
||||||
# Calculate next index for new thought/action
|
# # Calculate next index for new thought/action
|
||||||
currentstate_nextIndice =
|
# currentstate_nextIndice =
|
||||||
currentstate_latestThoughtKey == :NA ? 1 : currentstate_latestThoughtIndice + 1
|
# currentstate_latestThoughtKey == :NA ? 1 : currentstate_latestThoughtIndice + 1
|
||||||
# Create new keys for thought and action based on next index
|
# # Create new keys for thought and action based on next index
|
||||||
currentstate_latestThoughtKey = Symbol("thought_$currentstate_nextIndice")
|
# currentstate_latestThoughtKey = Symbol("thought_$currentstate_nextIndice")
|
||||||
latestActionKey = Symbol("action_$currentstate_nextIndice")
|
# latestActionKey = Symbol("action_$currentstate_nextIndice")
|
||||||
|
|
||||||
# Find the latest thought index from input thought dictionary
|
# # Find the latest thought index from input thought dictionary
|
||||||
_, thoughtDict_latestThoughtIndice =
|
# _, thoughtDict_latestThoughtIndice =
|
||||||
GeneralUtils.findHighestIndexKey(thoughtDict, "thought")
|
# GeneralUtils.findHighestIndexKey(thoughtDict, "thought")
|
||||||
|
|
||||||
# Determine thought and action keys from thought dictionary
|
# # Determine thought and action keys from thought dictionary
|
||||||
thoughtDict_latestThoughtKey, thoughtDict_latestActionKey =
|
# thoughtDict_latestThoughtKey, thoughtDict_latestActionKey =
|
||||||
if thoughtDict_latestThoughtIndice == -1
|
# if thoughtDict_latestThoughtIndice == -1
|
||||||
(:thought, :action)
|
# (:thought, :action)
|
||||||
else
|
# else
|
||||||
(
|
# (
|
||||||
Symbol("thought_$thoughtDict_latestThoughtIndice"),
|
# Symbol("thought_$thoughtDict_latestThoughtIndice"),
|
||||||
Symbol("action_$thoughtDict_latestThoughtIndice"),
|
# Symbol("action_$thoughtDict_latestThoughtIndice"),
|
||||||
)
|
# )
|
||||||
end
|
# end
|
||||||
|
|
||||||
# Create new state by deep copying current state
|
# # Create new state by deep copying current state
|
||||||
newstate = deepcopy(currentstate)
|
# newstate = deepcopy(currentstate)
|
||||||
# Update thought history with new thought
|
# # Update thought history with new thought
|
||||||
newstate[:thoughtHistory][currentstate_latestThoughtKey] =
|
# newstate[:thoughtHistory][currentstate_latestThoughtKey] =
|
||||||
thoughtDict[thoughtDict_latestThoughtKey]
|
# thoughtDict[thoughtDict_latestThoughtKey]
|
||||||
# Update thought history with new action
|
# # Update thought history with new action
|
||||||
newstate[:thoughtHistory][latestActionKey] = thoughtDict[thoughtDict_latestActionKey]
|
# newstate[:thoughtHistory][latestActionKey] = thoughtDict[thoughtDict_latestActionKey]
|
||||||
# Create and add new observation to thought history
|
# # Create and add new observation to thought history
|
||||||
newObservationKey = Symbol("observation_$(currentstate_nextIndice)")
|
# newObservationKey = Symbol("observation_$(currentstate_nextIndice)")
|
||||||
newstate[:thoughtHistory][newObservationKey] = response
|
# newstate[:thoughtHistory][newObservationKey] = response
|
||||||
# Update state metadata
|
# # Update state metadata
|
||||||
newstate[:reward] = reward
|
# newstate[:reward] = reward
|
||||||
newstate[:select] = select
|
# newstate[:select] = select
|
||||||
newstate[:isterminal] = isterminal
|
# newstate[:isterminal] = isterminal
|
||||||
|
|
||||||
# Generate unique ID for new node
|
# # Generate unique ID for new node
|
||||||
newNodeKey = GeneralUtils.uuid4snakecase()
|
# newNodeKey = GeneralUtils.uuid4snakecase()
|
||||||
|
|
||||||
return (newNodeKey, newstate)
|
# return (newNodeKey, newstate)
|
||||||
end
|
# end
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
41
test/Manifest.toml
Normal file
41
test/Manifest.toml
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
# This file is machine-generated - editing it directly is not advised
|
||||||
|
|
||||||
|
julia_version = "1.11.4"
|
||||||
|
manifest_format = "2.0"
|
||||||
|
project_hash = "71d91126b5a1fb1020e1098d9d492de2a4438fd2"
|
||||||
|
|
||||||
|
[[deps.Base64]]
|
||||||
|
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
|
||||||
|
version = "1.11.0"
|
||||||
|
|
||||||
|
[[deps.InteractiveUtils]]
|
||||||
|
deps = ["Markdown"]
|
||||||
|
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
|
||||||
|
version = "1.11.0"
|
||||||
|
|
||||||
|
[[deps.Logging]]
|
||||||
|
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
|
||||||
|
version = "1.11.0"
|
||||||
|
|
||||||
|
[[deps.Markdown]]
|
||||||
|
deps = ["Base64"]
|
||||||
|
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
|
||||||
|
version = "1.11.0"
|
||||||
|
|
||||||
|
[[deps.Random]]
|
||||||
|
deps = ["SHA"]
|
||||||
|
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
|
||||||
|
version = "1.11.0"
|
||||||
|
|
||||||
|
[[deps.SHA]]
|
||||||
|
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
|
||||||
|
version = "0.7.0"
|
||||||
|
|
||||||
|
[[deps.Serialization]]
|
||||||
|
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
|
||||||
|
version = "1.11.0"
|
||||||
|
|
||||||
|
[[deps.Test]]
|
||||||
|
deps = ["InteractiveUtils", "Logging", "Random", "Serialization"]
|
||||||
|
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
|
||||||
|
version = "1.11.0"
|
||||||
2
test/Project.toml
Normal file
2
test/Project.toml
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
[deps]
|
||||||
|
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
|
||||||
Reference in New Issue
Block a user