diff --git a/src/LLMMCTS.jl b/src/LLMMCTS.jl index c96bbd4..c3b98de 100644 --- a/src/LLMMCTS.jl +++ b/src/LLMMCTS.jl @@ -1,6 +1,6 @@ module LLMMCTS - # export agent + export MCTSNode """ Order by dependencies of each file. The 1st included file must not depend on any other diff --git a/src/interface.jl b/src/interface.jl index 32a8b3f..8251df4 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -69,6 +69,9 @@ function runMCTS( root = MCTSNode("root", initialstate, 0, 0, 0, 0, false, nothing, Dict{String,MCTSNode}(), Dict{Symbol,Any}()) + # [WORKING] storage for holding all high reward terminal nodes + highStateValueNode = Channel{Any}(100) + for nth in 1:maxiterations node = root node.visits += 1 @@ -78,6 +81,10 @@ function runMCTS( end if node.isterminal + if node.state[:reward] >= 8 + put!(highrewardNode, deepcopy(node.state)) + end + # MCTS arrive at the leaf node that is also a terminal state, # do nothing then go directly to backpropagation. It means the end of this iteration backpropagate(node, node.reward) @@ -91,15 +98,18 @@ function runMCTS( maxSimulationDepth=maxSimulationDepth, horizontalSampleSimulationPhase=horizontalSampleSimulationPhase, saveSimulatedNode=saveSimulatedNode, - multithread=multithread) + multithread=multithread, + highStateValueNode=highStateValueNode, + ) end else for (leafNodeKey, leafNode) in node.children simulateThenBackpropagate(leafNode, transition, transitionargs; - maxSimulationDepth=maxSimulationDepth, - horizontalSampleSimulationPhase=horizontalSampleSimulationPhase, - saveSimulatedNode=saveSimulatedNode, - multithread=multithread) + maxSimulationDepth=maxSimulationDepth, + horizontalSampleSimulationPhase=horizontalSampleSimulationPhase, + saveSimulatedNode=saveSimulatedNode, + multithread=multithread, + highStateValueNode=highStateValueNode) end end end @@ -114,6 +124,8 @@ function runMCTS( bestNextState = selectBestNextNode(root) besttrajectory = selectBestTrajectoryNode(root) + #[WORKING] compare all high value answer then select the best one + return (root=root, bestNextState=bestNextState.state, bestFinalState=besttrajectory.state) end @@ -143,11 +155,21 @@ end function simulateThenBackpropagate(node::MCTSNode, transition::Function, transitionargs::NamedTuple; maxSimulationDepth::Integer=3, horizontalSampleSimulationPhase::Integer=3, saveSimulatedNode::Bool=false, - multithread=false) - simTrajectoryReward, terminalstate = simulate(node, transition, transitionargs; - maxSimulationDepth=maxSimulationDepth, - horizontalSample=horizontalSampleSimulationPhase, - multithread=multithread) + multithread=false, + highStateValueNode=Union{Nothing,Any}=nothing) + simTrajectoryReward, terminalstate = + simulate(node, transition, transitionargs; + maxSimulationDepth=maxSimulationDepth, + horizontalSample=horizontalSampleSimulationPhase, + multithread=multithread) + # if a node has state value >= 8, store it in highStateValueNode + if highStateValueNode !== nothing && + terminalstate !== nothing && + terminalstate[:reward] >= 8 + + put!(highStateValueNode, deepcopy(terminalstate)) + end + backpropagate(node, simTrajectoryReward) # check if the user wants to keep the simulated node diff --git a/src/mcts.jl b/src/mcts.jl index 2715baa..59b50e5 100644 --- a/src/mcts.jl +++ b/src/mcts.jl @@ -280,7 +280,7 @@ end """ function simulate(node::MCTSNode, transition::Function, transitionargs::NamedTuple; maxSimulationDepth::Integer=3, horizontalSample::Integer=3, multithread=false -)::NamedTuple{(:simTrajectoryReward, :terminalstate), Tuple{<:Number, Union{Dict{Symbol, Any}, Nothing}}} + )::NamedTuple{(:simTrajectoryReward, :terminalstate), Tuple{<:Number, Union{Dict{Symbol, Any}, Nothing}}} simTrajectoryReward = 0.0 terminalstate = nothing @@ -298,7 +298,8 @@ function simulate(node::MCTSNode, transition::Function, transitionargs::NamedTup end end - return (simTrajectoryReward=simTrajectoryReward, terminalstate=terminalstate) + return (simTrajectoryReward=simTrajectoryReward, + terminalstate=terminalstate) end """ Make new state