This commit is contained in:
2025-03-16 22:11:38 +07:00
parent 842626ae35
commit 693cbfd82d
3 changed files with 36 additions and 13 deletions

View File

@@ -1,6 +1,6 @@
module LLMMCTS
# export agent
export MCTSNode
""" Order by dependencies of each file. The 1st included file must not depend on any other

View File

@@ -69,6 +69,9 @@ function runMCTS(
root = MCTSNode("root", initialstate, 0, 0, 0, 0, false, nothing, Dict{String,MCTSNode}(),
Dict{Symbol,Any}())
# [WORKING] storage for holding all high reward terminal nodes
highStateValueNode = Channel{Any}(100)
for nth in 1:maxiterations
node = root
node.visits += 1
@@ -78,6 +81,10 @@ function runMCTS(
end
if node.isterminal
if node.state[:reward] >= 8
put!(highrewardNode, deepcopy(node.state))
end
# MCTS arrive at the leaf node that is also a terminal state,
# do nothing then go directly to backpropagation. It means the end of this iteration
backpropagate(node, node.reward)
@@ -91,15 +98,18 @@ function runMCTS(
maxSimulationDepth=maxSimulationDepth,
horizontalSampleSimulationPhase=horizontalSampleSimulationPhase,
saveSimulatedNode=saveSimulatedNode,
multithread=multithread)
multithread=multithread,
highStateValueNode=highStateValueNode,
)
end
else
for (leafNodeKey, leafNode) in node.children
simulateThenBackpropagate(leafNode, transition, transitionargs;
maxSimulationDepth=maxSimulationDepth,
horizontalSampleSimulationPhase=horizontalSampleSimulationPhase,
saveSimulatedNode=saveSimulatedNode,
multithread=multithread)
maxSimulationDepth=maxSimulationDepth,
horizontalSampleSimulationPhase=horizontalSampleSimulationPhase,
saveSimulatedNode=saveSimulatedNode,
multithread=multithread,
highStateValueNode=highStateValueNode)
end
end
end
@@ -114,6 +124,8 @@ function runMCTS(
bestNextState = selectBestNextNode(root)
besttrajectory = selectBestTrajectoryNode(root)
#[WORKING] compare all high value answer then select the best one
return (root=root, bestNextState=bestNextState.state, bestFinalState=besttrajectory.state)
end
@@ -143,11 +155,21 @@ end
function simulateThenBackpropagate(node::MCTSNode, transition::Function, transitionargs::NamedTuple;
maxSimulationDepth::Integer=3, horizontalSampleSimulationPhase::Integer=3,
saveSimulatedNode::Bool=false,
multithread=false)
simTrajectoryReward, terminalstate = simulate(node, transition, transitionargs;
maxSimulationDepth=maxSimulationDepth,
horizontalSample=horizontalSampleSimulationPhase,
multithread=multithread)
multithread=false,
highStateValueNode=Union{Nothing,Any}=nothing)
simTrajectoryReward, terminalstate =
simulate(node, transition, transitionargs;
maxSimulationDepth=maxSimulationDepth,
horizontalSample=horizontalSampleSimulationPhase,
multithread=multithread)
# if a node has state value >= 8, store it in highStateValueNode
if highStateValueNode !== nothing &&
terminalstate !== nothing &&
terminalstate[:reward] >= 8
put!(highStateValueNode, deepcopy(terminalstate))
end
backpropagate(node, simTrajectoryReward)
# check if the user wants to keep the simulated node

View File

@@ -280,7 +280,7 @@ end
"""
function simulate(node::MCTSNode, transition::Function, transitionargs::NamedTuple;
maxSimulationDepth::Integer=3, horizontalSample::Integer=3, multithread=false
)::NamedTuple{(:simTrajectoryReward, :terminalstate), Tuple{<:Number, Union{Dict{Symbol, Any}, Nothing}}}
)::NamedTuple{(:simTrajectoryReward, :terminalstate), Tuple{<:Number, Union{Dict{Symbol, Any}, Nothing}}}
simTrajectoryReward = 0.0
terminalstate = nothing
@@ -298,7 +298,8 @@ function simulate(node::MCTSNode, transition::Function, transitionargs::NamedTup
end
end
return (simTrajectoryReward=simTrajectoryReward, terminalstate=terminalstate)
return (simTrajectoryReward=simTrajectoryReward,
terminalstate=terminalstate)
end
""" Make new state