update
This commit is contained in:
@@ -1,6 +1,6 @@
|
|||||||
module LLMMCTS
|
module LLMMCTS
|
||||||
|
|
||||||
# export agent
|
export MCTSNode
|
||||||
|
|
||||||
|
|
||||||
""" Order by dependencies of each file. The 1st included file must not depend on any other
|
""" Order by dependencies of each file. The 1st included file must not depend on any other
|
||||||
|
|||||||
@@ -69,6 +69,9 @@ function runMCTS(
|
|||||||
root = MCTSNode("root", initialstate, 0, 0, 0, 0, false, nothing, Dict{String,MCTSNode}(),
|
root = MCTSNode("root", initialstate, 0, 0, 0, 0, false, nothing, Dict{String,MCTSNode}(),
|
||||||
Dict{Symbol,Any}())
|
Dict{Symbol,Any}())
|
||||||
|
|
||||||
|
# [WORKING] storage for holding all high reward terminal nodes
|
||||||
|
highStateValueNode = Channel{Any}(100)
|
||||||
|
|
||||||
for nth in 1:maxiterations
|
for nth in 1:maxiterations
|
||||||
node = root
|
node = root
|
||||||
node.visits += 1
|
node.visits += 1
|
||||||
@@ -78,6 +81,10 @@ function runMCTS(
|
|||||||
end
|
end
|
||||||
|
|
||||||
if node.isterminal
|
if node.isterminal
|
||||||
|
if node.state[:reward] >= 8
|
||||||
|
put!(highrewardNode, deepcopy(node.state))
|
||||||
|
end
|
||||||
|
|
||||||
# MCTS arrive at the leaf node that is also a terminal state,
|
# MCTS arrive at the leaf node that is also a terminal state,
|
||||||
# do nothing then go directly to backpropagation. It means the end of this iteration
|
# do nothing then go directly to backpropagation. It means the end of this iteration
|
||||||
backpropagate(node, node.reward)
|
backpropagate(node, node.reward)
|
||||||
@@ -91,7 +98,9 @@ function runMCTS(
|
|||||||
maxSimulationDepth=maxSimulationDepth,
|
maxSimulationDepth=maxSimulationDepth,
|
||||||
horizontalSampleSimulationPhase=horizontalSampleSimulationPhase,
|
horizontalSampleSimulationPhase=horizontalSampleSimulationPhase,
|
||||||
saveSimulatedNode=saveSimulatedNode,
|
saveSimulatedNode=saveSimulatedNode,
|
||||||
multithread=multithread)
|
multithread=multithread,
|
||||||
|
highStateValueNode=highStateValueNode,
|
||||||
|
)
|
||||||
end
|
end
|
||||||
else
|
else
|
||||||
for (leafNodeKey, leafNode) in node.children
|
for (leafNodeKey, leafNode) in node.children
|
||||||
@@ -99,7 +108,8 @@ function runMCTS(
|
|||||||
maxSimulationDepth=maxSimulationDepth,
|
maxSimulationDepth=maxSimulationDepth,
|
||||||
horizontalSampleSimulationPhase=horizontalSampleSimulationPhase,
|
horizontalSampleSimulationPhase=horizontalSampleSimulationPhase,
|
||||||
saveSimulatedNode=saveSimulatedNode,
|
saveSimulatedNode=saveSimulatedNode,
|
||||||
multithread=multithread)
|
multithread=multithread,
|
||||||
|
highStateValueNode=highStateValueNode)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
@@ -114,6 +124,8 @@ function runMCTS(
|
|||||||
bestNextState = selectBestNextNode(root)
|
bestNextState = selectBestNextNode(root)
|
||||||
besttrajectory = selectBestTrajectoryNode(root)
|
besttrajectory = selectBestTrajectoryNode(root)
|
||||||
|
|
||||||
|
#[WORKING] compare all high value answer then select the best one
|
||||||
|
|
||||||
return (root=root, bestNextState=bestNextState.state, bestFinalState=besttrajectory.state)
|
return (root=root, bestNextState=bestNextState.state, bestFinalState=besttrajectory.state)
|
||||||
end
|
end
|
||||||
|
|
||||||
@@ -143,11 +155,21 @@ end
|
|||||||
function simulateThenBackpropagate(node::MCTSNode, transition::Function, transitionargs::NamedTuple;
|
function simulateThenBackpropagate(node::MCTSNode, transition::Function, transitionargs::NamedTuple;
|
||||||
maxSimulationDepth::Integer=3, horizontalSampleSimulationPhase::Integer=3,
|
maxSimulationDepth::Integer=3, horizontalSampleSimulationPhase::Integer=3,
|
||||||
saveSimulatedNode::Bool=false,
|
saveSimulatedNode::Bool=false,
|
||||||
multithread=false)
|
multithread=false,
|
||||||
simTrajectoryReward, terminalstate = simulate(node, transition, transitionargs;
|
highStateValueNode=Union{Nothing,Any}=nothing)
|
||||||
|
simTrajectoryReward, terminalstate =
|
||||||
|
simulate(node, transition, transitionargs;
|
||||||
maxSimulationDepth=maxSimulationDepth,
|
maxSimulationDepth=maxSimulationDepth,
|
||||||
horizontalSample=horizontalSampleSimulationPhase,
|
horizontalSample=horizontalSampleSimulationPhase,
|
||||||
multithread=multithread)
|
multithread=multithread)
|
||||||
|
# if a node has state value >= 8, store it in highStateValueNode
|
||||||
|
if highStateValueNode !== nothing &&
|
||||||
|
terminalstate !== nothing &&
|
||||||
|
terminalstate[:reward] >= 8
|
||||||
|
|
||||||
|
put!(highStateValueNode, deepcopy(terminalstate))
|
||||||
|
end
|
||||||
|
|
||||||
backpropagate(node, simTrajectoryReward)
|
backpropagate(node, simTrajectoryReward)
|
||||||
|
|
||||||
# check if the user wants to keep the simulated node
|
# check if the user wants to keep the simulated node
|
||||||
|
|||||||
@@ -298,7 +298,8 @@ function simulate(node::MCTSNode, transition::Function, transitionargs::NamedTup
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
return (simTrajectoryReward=simTrajectoryReward, terminalstate=terminalstate)
|
return (simTrajectoryReward=simTrajectoryReward,
|
||||||
|
terminalstate=terminalstate)
|
||||||
end
|
end
|
||||||
|
|
||||||
""" Make new state
|
""" Make new state
|
||||||
|
|||||||
Reference in New Issue
Block a user