This commit is contained in:
2025-03-16 22:11:38 +07:00
parent 842626ae35
commit 693cbfd82d
3 changed files with 36 additions and 13 deletions

View File

@@ -1,6 +1,6 @@
module LLMMCTS module LLMMCTS
# export agent export MCTSNode
""" Order by dependencies of each file. The 1st included file must not depend on any other """ Order by dependencies of each file. The 1st included file must not depend on any other

View File

@@ -69,6 +69,9 @@ function runMCTS(
root = MCTSNode("root", initialstate, 0, 0, 0, 0, false, nothing, Dict{String,MCTSNode}(), root = MCTSNode("root", initialstate, 0, 0, 0, 0, false, nothing, Dict{String,MCTSNode}(),
Dict{Symbol,Any}()) Dict{Symbol,Any}())
# [WORKING] storage for holding all high reward terminal nodes
highStateValueNode = Channel{Any}(100)
for nth in 1:maxiterations for nth in 1:maxiterations
node = root node = root
node.visits += 1 node.visits += 1
@@ -78,6 +81,10 @@ function runMCTS(
end end
if node.isterminal if node.isterminal
if node.state[:reward] >= 8
put!(highrewardNode, deepcopy(node.state))
end
# MCTS arrive at the leaf node that is also a terminal state, # MCTS arrive at the leaf node that is also a terminal state,
# do nothing then go directly to backpropagation. It means the end of this iteration # do nothing then go directly to backpropagation. It means the end of this iteration
backpropagate(node, node.reward) backpropagate(node, node.reward)
@@ -91,7 +98,9 @@ function runMCTS(
maxSimulationDepth=maxSimulationDepth, maxSimulationDepth=maxSimulationDepth,
horizontalSampleSimulationPhase=horizontalSampleSimulationPhase, horizontalSampleSimulationPhase=horizontalSampleSimulationPhase,
saveSimulatedNode=saveSimulatedNode, saveSimulatedNode=saveSimulatedNode,
multithread=multithread) multithread=multithread,
highStateValueNode=highStateValueNode,
)
end end
else else
for (leafNodeKey, leafNode) in node.children for (leafNodeKey, leafNode) in node.children
@@ -99,7 +108,8 @@ function runMCTS(
maxSimulationDepth=maxSimulationDepth, maxSimulationDepth=maxSimulationDepth,
horizontalSampleSimulationPhase=horizontalSampleSimulationPhase, horizontalSampleSimulationPhase=horizontalSampleSimulationPhase,
saveSimulatedNode=saveSimulatedNode, saveSimulatedNode=saveSimulatedNode,
multithread=multithread) multithread=multithread,
highStateValueNode=highStateValueNode)
end end
end end
end end
@@ -114,6 +124,8 @@ function runMCTS(
bestNextState = selectBestNextNode(root) bestNextState = selectBestNextNode(root)
besttrajectory = selectBestTrajectoryNode(root) besttrajectory = selectBestTrajectoryNode(root)
#[WORKING] compare all high value answer then select the best one
return (root=root, bestNextState=bestNextState.state, bestFinalState=besttrajectory.state) return (root=root, bestNextState=bestNextState.state, bestFinalState=besttrajectory.state)
end end
@@ -143,11 +155,21 @@ end
function simulateThenBackpropagate(node::MCTSNode, transition::Function, transitionargs::NamedTuple; function simulateThenBackpropagate(node::MCTSNode, transition::Function, transitionargs::NamedTuple;
maxSimulationDepth::Integer=3, horizontalSampleSimulationPhase::Integer=3, maxSimulationDepth::Integer=3, horizontalSampleSimulationPhase::Integer=3,
saveSimulatedNode::Bool=false, saveSimulatedNode::Bool=false,
multithread=false) multithread=false,
simTrajectoryReward, terminalstate = simulate(node, transition, transitionargs; highStateValueNode=Union{Nothing,Any}=nothing)
simTrajectoryReward, terminalstate =
simulate(node, transition, transitionargs;
maxSimulationDepth=maxSimulationDepth, maxSimulationDepth=maxSimulationDepth,
horizontalSample=horizontalSampleSimulationPhase, horizontalSample=horizontalSampleSimulationPhase,
multithread=multithread) multithread=multithread)
# if a node has state value >= 8, store it in highStateValueNode
if highStateValueNode !== nothing &&
terminalstate !== nothing &&
terminalstate[:reward] >= 8
put!(highStateValueNode, deepcopy(terminalstate))
end
backpropagate(node, simTrajectoryReward) backpropagate(node, simTrajectoryReward)
# check if the user wants to keep the simulated node # check if the user wants to keep the simulated node

View File

@@ -280,7 +280,7 @@ end
""" """
function simulate(node::MCTSNode, transition::Function, transitionargs::NamedTuple; function simulate(node::MCTSNode, transition::Function, transitionargs::NamedTuple;
maxSimulationDepth::Integer=3, horizontalSample::Integer=3, multithread=false maxSimulationDepth::Integer=3, horizontalSample::Integer=3, multithread=false
)::NamedTuple{(:simTrajectoryReward, :terminalstate), Tuple{<:Number, Union{Dict{Symbol, Any}, Nothing}}} )::NamedTuple{(:simTrajectoryReward, :terminalstate), Tuple{<:Number, Union{Dict{Symbol, Any}, Nothing}}}
simTrajectoryReward = 0.0 simTrajectoryReward = 0.0
terminalstate = nothing terminalstate = nothing
@@ -298,7 +298,8 @@ function simulate(node::MCTSNode, transition::Function, transitionargs::NamedTup
end end
end end
return (simTrajectoryReward=simTrajectoryReward, terminalstate=terminalstate) return (simTrajectoryReward=simTrajectoryReward,
terminalstate=terminalstate)
end end
""" Make new state """ Make new state