This commit is contained in:
narawat lamaiin
2024-08-18 19:15:55 +07:00
parent 32dec4b47c
commit 997cc904dd

View File

@@ -2,7 +2,7 @@ module mcts
export selectBestNextNode, selectBestTrajectoryNode, backpropagate, isleaf, isroot, selectChildNode, export selectBestNextNode, selectBestTrajectoryNode, backpropagate, isleaf, isroot, selectChildNode,
expand, simulate, makeNewState expand, simulate, makeNewState
using Base.Threads
using GeneralUtils using GeneralUtils
using ..type using ..type
@@ -196,16 +196,22 @@ end
# Return # Return
- None - None
# TODO
- [WORKING] implement multithreads
# Signature # Signature
""" """
function expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple; function expand(node::MCTSNode, transition::Function, transitionargs::NamedTuple;
totalsample::Integer=3) totalsample::Integer=3)
results = Any[]
@sync for i in 1:totalsample
@spawn begin
results[i] = transition(deepcopy(node.state), deepcopy(transitionargs))
end
println("--> sampling $i")
end
nthSample = 0 for result in results
while true
nthSample += 1
if nthSample <= totalsample
result = transition(node.state, transitionargs)
newNodeKey::AbstractString = result[:newNodeKey] newNodeKey::AbstractString = result[:newNodeKey]
newstate::AbstractDict = result[:newstate] newstate::AbstractDict = result[:newstate]
progressvalue::Integer = result[:progressvalue] progressvalue::Integer = result[:progressvalue]
@@ -222,11 +228,37 @@ function expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple;
MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward], MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward],
newstate[:isterminal], node, Dict{String, MCTSNode}()) newstate[:isterminal], node, Dict{String, MCTSNode}())
end end
else
break
end
end end
end end
# function expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple;
# totalsample::Integer=3)
# nthSample = 0
# while true
# nthSample += 1
# if nthSample <= totalsample
# result = transition(node.state, transitionargs)
# newNodeKey::AbstractString = result[:newNodeKey]
# newstate::AbstractDict = result[:newstate]
# progressvalue::Integer = result[:progressvalue]
# """
# [] newNodeKey ∉ keys(node.children).
# New state may have semantic vector close enought to
# one of existing child state. Which can be assume that they are the same state
# semantically-wise i.e. De javu. This could be used to recall lessons for this
# similar situation to improve decisionMaker and evaluator.
# """
# if newNodeKey ∉ keys(node.children)
# node.children[newNodeKey] =
# MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward],
# newstate[:isterminal], node, Dict{String, MCTSNode}())
# end
# else
# break
# end
# end
# end
""" Simulate interactions between agent and environment """ Simulate interactions between agent and environment