Compare commits

...

7 Commits

Author SHA1 Message Date
narawat lamaiin
d92333cab4 update 2025-05-18 17:22:21 +07:00
narawat lamaiin
093290a33b update 2025-03-22 21:33:20 +07:00
c777800948 update 2025-03-20 16:15:08 +07:00
ceced04171 update 2025-03-20 05:45:58 +07:00
ee5f8a8a52 update 2025-03-18 21:23:09 +07:00
693cbfd82d update 2025-03-16 22:11:38 +07:00
842626ae35 mark new version 2025-03-16 18:17:23 +07:00
7 changed files with 167 additions and 90 deletions

View File

@@ -1,7 +1,7 @@
name = "LLMMCTS" name = "LLMMCTS"
uuid = "d76c5a4d-449e-4835-8cc4-dd86ec44f241" uuid = "d76c5a4d-449e-4835-8cc4-dd86ec44f241"
authors = ["narawat lamaiin <narawat@outlook.com>"] authors = ["narawat lamaiin <narawat@outlook.com>"]
version = "0.1.3" version = "0.1.4"
[deps] [deps]
GeneralUtils = "c6c72f09-b708-4ac8-ac7c-2084d70108fe" GeneralUtils = "c6c72f09-b708-4ac8-ac7c-2084d70108fe"

View File

@@ -1,6 +1,6 @@
module LLMMCTS module LLMMCTS
# export agent export MCTSNode
""" Order by dependencies of each file. The 1st included file must not depend on any other """ Order by dependencies of each file. The 1st included file must not depend on any other

View File

@@ -9,7 +9,6 @@ using ..type, ..mcts, ..util
# ---------------------------------------------- 100 --------------------------------------------- # # ---------------------------------------------- 100 --------------------------------------------- #
""" Search the best action to take for a given state and task """ Search the best action to take for a given state and task
# Arguments # Arguments
@@ -34,7 +33,7 @@ using ..type, ..mcts, ..util
a known state. 1.0 balance between exploration and exploitation like 50%-50%. 2.0 makes MCTS a known state. 1.0 balance between exploration and exploitation like 50%-50%. 2.0 makes MCTS
aggressively explore new state (default: 1.0) aggressively explore new state (default: 1.0)
- `earlystop::Union{Function,Nothing}` - `earlystop::Union{Function,Nothing}`
optional function to check early stopping condition (default: nothing) optional function to check early stopping condition if it is satisfied, MCTS will break iterations (default: nothing)
- `saveSimulatedNode::Bool` - `saveSimulatedNode::Bool`
whether to save nodes created during simulation phase (default: false) whether to save nodes created during simulation phase (default: false)
- `multithread::Bool` - `multithread::Bool`
@@ -63,12 +62,16 @@ function runMCTS(
explorationweight::Number=1.0, explorationweight::Number=1.0,
earlystop::Union{Function,Nothing}=nothing, earlystop::Union{Function,Nothing}=nothing,
saveSimulatedNode::Bool=false, saveSimulatedNode::Bool=false,
multithread=false multithread=false,
)::NamedTuple{(:root, :bestNextState, :bestFinalState),Tuple{MCTSNode,T,T}} where {T<:Any} )::NamedTuple{(:root, :bestNextState, :bestTerminalState, :highValueStateList),
Tuple{MCTSNode,T,T,Vector{Dict{Symbol,Any}}}} where {T<:Any}
root = MCTSNode("root", initialstate, 0, 0, 0, 0, false, nothing, Dict{String,MCTSNode}(), root = MCTSNode("root", initialstate, 0, 0, 0, 0, false, nothing, Dict{String,MCTSNode}(),
Dict{Symbol,Any}()) Dict{Symbol,Any}())
# storage for holding all high reward terminal nodes
highValueState = Channel{Any}(100)
for nth in 1:maxiterations for nth in 1:maxiterations
node = root node = root
node.visits += 1 node.visits += 1
@@ -78,6 +81,10 @@ function runMCTS(
end end
if node.isterminal if node.isterminal
if node.state[:reward] >= 8
put!(highrewardNode, deepcopy(node.state))
end
# MCTS arrive at the leaf node that is also a terminal state, # MCTS arrive at the leaf node that is also a terminal state,
# do nothing then go directly to backpropagation. It means the end of this iteration # do nothing then go directly to backpropagation. It means the end of this iteration
backpropagate(node, node.reward) backpropagate(node, node.reward)
@@ -91,15 +98,18 @@ function runMCTS(
maxSimulationDepth=maxSimulationDepth, maxSimulationDepth=maxSimulationDepth,
horizontalSampleSimulationPhase=horizontalSampleSimulationPhase, horizontalSampleSimulationPhase=horizontalSampleSimulationPhase,
saveSimulatedNode=saveSimulatedNode, saveSimulatedNode=saveSimulatedNode,
multithread=multithread) multithread=multithread,
highValueState=highValueState,
)
end end
else else
for (leafNodeKey, leafNode) in node.children for (leafNodeKey, leafNode) in node.children
simulateThenBackpropagate(leafNode, transition, transitionargs; simulateThenBackpropagate(leafNode, transition, transitionargs;
maxSimulationDepth=maxSimulationDepth, maxSimulationDepth=maxSimulationDepth,
horizontalSampleSimulationPhase=horizontalSampleSimulationPhase, horizontalSampleSimulationPhase=horizontalSampleSimulationPhase,
saveSimulatedNode=saveSimulatedNode, saveSimulatedNode=saveSimulatedNode,
multithread=multithread) multithread=multithread,
highValueState=highValueState)
end end
end end
end end
@@ -110,11 +120,24 @@ function runMCTS(
end end
end end
# select the best next state and the best final state # select the best next state and the best terminal state along the best trajectory
bestNextState = selectBestNextNode(root) bestNextState = selectBestNextNode(root)
besttrajectory = selectBestTrajectoryNode(root) bestTerminalState = selectBestTrajectoryNode(root)
return (root=root, bestNextState=bestNextState.state, bestFinalState=besttrajectory.state) # take all high value state from highValueState channel and put it in a list
highValueStateList = Vector{Dict{Symbol, Any}}()
while !isempty(highValueState)
push!(highValueStateList, take!(highValueState))
end
result = (
root=root,
bestNextState=bestNextState.state,
bestTerminalState=bestTerminalState.state,
highValueStateList=highValueStateList
)
return result
end end
""" Search the best action to take for a given state and task """ Search the best action to take for a given state and task
@@ -143,11 +166,21 @@ end
function simulateThenBackpropagate(node::MCTSNode, transition::Function, transitionargs::NamedTuple; function simulateThenBackpropagate(node::MCTSNode, transition::Function, transitionargs::NamedTuple;
maxSimulationDepth::Integer=3, horizontalSampleSimulationPhase::Integer=3, maxSimulationDepth::Integer=3, horizontalSampleSimulationPhase::Integer=3,
saveSimulatedNode::Bool=false, saveSimulatedNode::Bool=false,
multithread=false) multithread=false,
simTrajectoryReward, terminalstate = simulate(node, transition, transitionargs; highValueState=Union{Nothing,Any}=nothing)
maxSimulationDepth=maxSimulationDepth, simTrajectoryReward, terminalstate =
horizontalSample=horizontalSampleSimulationPhase, simulate(node, transition, transitionargs;
multithread=multithread) maxSimulationDepth=maxSimulationDepth,
horizontalSample=horizontalSampleSimulationPhase,
multithread=multithread)
# if a node has state value >= 8, store it in highValueState
if highValueState !== nothing &&
terminalstate !== nothing &&
terminalstate[:reward] >= 8
put!(highValueState, deepcopy(terminalstate))
end
backpropagate(node, simTrajectoryReward) backpropagate(node, simTrajectoryReward)
# check if the user wants to keep the simulated node # check if the user wants to keep the simulated node

View File

@@ -1,7 +1,7 @@
module mcts module mcts
export selectBestNextNode, selectBestTrajectoryNode, backpropagate, isleaf, isroot, selectChildNode, export selectBestNextNode, selectBestTrajectoryNode, backpropagate, isleaf, isroot, selectChildNode,
expand, simulate, makeNewState expand, simulate
using Base.Threads using Base.Threads
using GeneralUtils using GeneralUtils
@@ -280,7 +280,7 @@ end
""" """
function simulate(node::MCTSNode, transition::Function, transitionargs::NamedTuple; function simulate(node::MCTSNode, transition::Function, transitionargs::NamedTuple;
maxSimulationDepth::Integer=3, horizontalSample::Integer=3, multithread=false maxSimulationDepth::Integer=3, horizontalSample::Integer=3, multithread=false
)::NamedTuple{(:simTrajectoryReward, :terminalstate), Tuple{<:Number, Union{Dict{Symbol, Any}, Nothing}}} )::NamedTuple{(:simTrajectoryReward, :terminalstate), Tuple{<:Number, Union{Dict{Symbol, Any}, Nothing}}}
simTrajectoryReward = 0.0 simTrajectoryReward = 0.0
terminalstate = nothing terminalstate = nothing
@@ -298,87 +298,88 @@ function simulate(node::MCTSNode, transition::Function, transitionargs::NamedTup
end end
end end
return (simTrajectoryReward=simTrajectoryReward, terminalstate=terminalstate) return (simTrajectoryReward=simTrajectoryReward,
terminalstate=terminalstate)
end end
""" Make new state # """ Make new state
# Arguments # # Arguments
- `currentstate::T1` # - `currentstate::T1`
Current state dictionary containing thought history and metadata # Current state dictionary containing thought history and metadata
- `thoughtDict::T4` # - `thoughtDict::T4`
Dictionary containing new thought and action # Dictionary containing new thought and action
- `response::T2` # - `response::T2`
Response string from the environment # Response string from the environment
- `select::Union{T3, Nothing}` # - `select::Union{T3, Nothing}`
Selection value or nothing # Selection value or nothing
- `reward::T3` # - `reward::T3`
Reward value for this state # Reward value for this state
- `isterminal::Bool` # - `isterminal::Bool`
Whether this state is terminal # Whether this state is terminal
# Return # # Return
- `Tuple{String, Dict{Symbol, <:Any}}` # - `Tuple{String, Dict{Symbol, <:Any}}`
A tuple containing: # A tuple containing:
- A unique node key string # - A unique node key string
- A new state dictionary with updated thought history and metadata # - A new state dictionary with updated thought history and metadata
# Example # # Example
```jldoctest # ```jldoctest
julia> # julia>
``` # ```
# Signature # # Signature
""" # """
function makeNewState(currentstate::T1, thoughtDict::T4, response::T2, select::Union{T3, Nothing}, # function makeNewState(currentstate::T1, thoughtDict::T4, response::T2, select::Union{T3, Nothing},
reward::T3, isterminal::Bool # reward::T3, isterminal::Bool
)::Tuple{String, Dict{Symbol, <:Any}} where {T1<:AbstractDict, T2<:AbstractString, T3<:Number, T4<:AbstractDict} # )::Tuple{String, Dict{Symbol, <:Any}} where {T1<:AbstractDict, T2<:AbstractString, T3<:Number, T4<:AbstractDict}
# Find the latest thought key and index from current state's thought history # # Find the latest thought key and index from current state's thought history
currentstate_latestThoughtKey, currentstate_latestThoughtIndice = # currentstate_latestThoughtKey, currentstate_latestThoughtIndice =
GeneralUtils.findHighestIndexKey(currentstate[:thoughtHistory], "thought") # GeneralUtils.findHighestIndexKey(currentstate[:thoughtHistory], "thought")
# Calculate next index for new thought/action # # Calculate next index for new thought/action
currentstate_nextIndice = # currentstate_nextIndice =
currentstate_latestThoughtKey == :NA ? 1 : currentstate_latestThoughtIndice + 1 # currentstate_latestThoughtKey == :NA ? 1 : currentstate_latestThoughtIndice + 1
# Create new keys for thought and action based on next index # # Create new keys for thought and action based on next index
currentstate_latestThoughtKey = Symbol("thought_$currentstate_nextIndice") # currentstate_latestThoughtKey = Symbol("thought_$currentstate_nextIndice")
latestActionKey = Symbol("action_$currentstate_nextIndice") # latestActionKey = Symbol("action_$currentstate_nextIndice")
# Find the latest thought index from input thought dictionary # # Find the latest thought index from input thought dictionary
_, thoughtDict_latestThoughtIndice = # _, thoughtDict_latestThoughtIndice =
GeneralUtils.findHighestIndexKey(thoughtDict, "thought") # GeneralUtils.findHighestIndexKey(thoughtDict, "thought")
# Determine thought and action keys from thought dictionary # # Determine thought and action keys from thought dictionary
thoughtDict_latestThoughtKey, thoughtDict_latestActionKey = # thoughtDict_latestThoughtKey, thoughtDict_latestActionKey =
if thoughtDict_latestThoughtIndice == -1 # if thoughtDict_latestThoughtIndice == -1
(:thought, :action) # (:thought, :action)
else # else
( # (
Symbol("thought_$thoughtDict_latestThoughtIndice"), # Symbol("thought_$thoughtDict_latestThoughtIndice"),
Symbol("action_$thoughtDict_latestThoughtIndice"), # Symbol("action_$thoughtDict_latestThoughtIndice"),
) # )
end # end
# Create new state by deep copying current state # # Create new state by deep copying current state
newstate = deepcopy(currentstate) # newstate = deepcopy(currentstate)
# Update thought history with new thought # # Update thought history with new thought
newstate[:thoughtHistory][currentstate_latestThoughtKey] = # newstate[:thoughtHistory][currentstate_latestThoughtKey] =
thoughtDict[thoughtDict_latestThoughtKey] # thoughtDict[thoughtDict_latestThoughtKey]
# Update thought history with new action # # Update thought history with new action
newstate[:thoughtHistory][latestActionKey] = thoughtDict[thoughtDict_latestActionKey] # newstate[:thoughtHistory][latestActionKey] = thoughtDict[thoughtDict_latestActionKey]
# Create and add new observation to thought history # # Create and add new observation to thought history
newObservationKey = Symbol("observation_$(currentstate_nextIndice)") # newObservationKey = Symbol("observation_$(currentstate_nextIndice)")
newstate[:thoughtHistory][newObservationKey] = response # newstate[:thoughtHistory][newObservationKey] = response
# Update state metadata # # Update state metadata
newstate[:reward] = reward # newstate[:reward] = reward
newstate[:select] = select # newstate[:select] = select
newstate[:isterminal] = isterminal # newstate[:isterminal] = isterminal
# Generate unique ID for new node # # Generate unique ID for new node
newNodeKey = GeneralUtils.uuid4snakecase() # newNodeKey = GeneralUtils.uuid4snakecase()
return (newNodeKey, newstate) # return (newNodeKey, newstate)
end # end

41
test/Manifest.toml Normal file
View File

@@ -0,0 +1,41 @@
# This file is machine-generated - editing it directly is not advised
julia_version = "1.11.4"
manifest_format = "2.0"
project_hash = "71d91126b5a1fb1020e1098d9d492de2a4438fd2"
[[deps.Base64]]
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
version = "1.11.0"
[[deps.InteractiveUtils]]
deps = ["Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
version = "1.11.0"
[[deps.Logging]]
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
version = "1.11.0"
[[deps.Markdown]]
deps = ["Base64"]
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
version = "1.11.0"
[[deps.Random]]
deps = ["SHA"]
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
version = "1.11.0"
[[deps.SHA]]
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
version = "0.7.0"
[[deps.Serialization]]
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
version = "1.11.0"
[[deps.Test]]
deps = ["InteractiveUtils", "Logging", "Random", "Serialization"]
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
version = "1.11.0"

2
test/Project.toml Normal file
View File

@@ -0,0 +1,2 @@
[deps]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"