update
This commit is contained in:
50
src/mcts.jl
50
src/mcts.jl
@@ -2,7 +2,7 @@ module mcts
|
|||||||
|
|
||||||
export selectBestNextNode, selectBestTrajectoryNode, backpropagate, isleaf, isroot, selectChildNode,
|
export selectBestNextNode, selectBestTrajectoryNode, backpropagate, isleaf, isroot, selectChildNode,
|
||||||
expand, simulate, makeNewState
|
expand, simulate, makeNewState
|
||||||
|
using Base.Threads
|
||||||
using GeneralUtils
|
using GeneralUtils
|
||||||
|
|
||||||
using ..type
|
using ..type
|
||||||
@@ -196,16 +196,22 @@ end
|
|||||||
# Return
|
# Return
|
||||||
- None
|
- None
|
||||||
|
|
||||||
|
# TODO
|
||||||
|
- [WORKING] implement multithreads
|
||||||
|
|
||||||
# Signature
|
# Signature
|
||||||
"""
|
"""
|
||||||
function expand(node::MCTSNode, transition::Function, transitionargs::NamedTuple;
|
function expand(node::MCTSNode, transition::Function, transitionargs::NamedTuple;
|
||||||
totalsample::Integer=3)
|
totalsample::Integer=3)
|
||||||
|
results = Any[]
|
||||||
|
@sync for i in 1:totalsample
|
||||||
|
@spawn begin
|
||||||
|
results[i] = transition(deepcopy(node.state), deepcopy(transitionargs))
|
||||||
|
end
|
||||||
|
println("--> sampling $i")
|
||||||
|
end
|
||||||
|
|
||||||
nthSample = 0
|
for result in results
|
||||||
while true
|
|
||||||
nthSample += 1
|
|
||||||
if nthSample <= totalsample
|
|
||||||
result = transition(node.state, transitionargs)
|
|
||||||
newNodeKey::AbstractString = result[:newNodeKey]
|
newNodeKey::AbstractString = result[:newNodeKey]
|
||||||
newstate::AbstractDict = result[:newstate]
|
newstate::AbstractDict = result[:newstate]
|
||||||
progressvalue::Integer = result[:progressvalue]
|
progressvalue::Integer = result[:progressvalue]
|
||||||
@@ -222,11 +228,37 @@ function expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple;
|
|||||||
MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward],
|
MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward],
|
||||||
newstate[:isterminal], node, Dict{String, MCTSNode}())
|
newstate[:isterminal], node, Dict{String, MCTSNode}())
|
||||||
end
|
end
|
||||||
else
|
|
||||||
break
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
# function expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple;
|
||||||
|
# totalsample::Integer=3)
|
||||||
|
|
||||||
|
# nthSample = 0
|
||||||
|
# while true
|
||||||
|
# nthSample += 1
|
||||||
|
# if nthSample <= totalsample
|
||||||
|
# result = transition(node.state, transitionargs)
|
||||||
|
# newNodeKey::AbstractString = result[:newNodeKey]
|
||||||
|
# newstate::AbstractDict = result[:newstate]
|
||||||
|
# progressvalue::Integer = result[:progressvalue]
|
||||||
|
|
||||||
|
# """
|
||||||
|
# [] newNodeKey ∉ keys(node.children).
|
||||||
|
# New state may have semantic vector close enought to
|
||||||
|
# one of existing child state. Which can be assume that they are the same state
|
||||||
|
# semantically-wise i.e. De javu. This could be used to recall lessons for this
|
||||||
|
# similar situation to improve decisionMaker and evaluator.
|
||||||
|
# """
|
||||||
|
# if newNodeKey ∉ keys(node.children)
|
||||||
|
# node.children[newNodeKey] =
|
||||||
|
# MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward],
|
||||||
|
# newstate[:isterminal], node, Dict{String, MCTSNode}())
|
||||||
|
# end
|
||||||
|
# else
|
||||||
|
# break
|
||||||
|
# end
|
||||||
|
# end
|
||||||
|
# end
|
||||||
|
|
||||||
|
|
||||||
""" Simulate interactions between agent and environment
|
""" Simulate interactions between agent and environment
|
||||||
|
|||||||
Reference in New Issue
Block a user