using for loop to install conda package

This commit is contained in:
ton
2023-04-16 11:03:27 +07:00
parent 49da9f29c1
commit 0c2b34d6f8
12168 changed files with 2656238 additions and 1 deletions

View File

@@ -0,0 +1,6 @@
from networkx.utils.misc import *
from networkx.utils.decorators import *
from networkx.utils.random_sequence import *
from networkx.utils.union_find import *
from networkx.utils.rcm import *
from networkx.utils.heaps import *

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,340 @@
"""
Min-heaps.
"""
from heapq import heappop, heappush
from itertools import count
import networkx as nx
__all__ = ["MinHeap", "PairingHeap", "BinaryHeap"]
class MinHeap:
"""Base class for min-heaps.
A MinHeap stores a collection of key-value pairs ordered by their values.
It supports querying the minimum pair, inserting a new pair, decreasing the
value in an existing pair and deleting the minimum pair.
"""
class _Item:
"""Used by subclassess to represent a key-value pair."""
__slots__ = ("key", "value")
def __init__(self, key, value):
self.key = key
self.value = value
def __repr__(self):
return repr((self.key, self.value))
def __init__(self):
"""Initialize a new min-heap."""
self._dict = {}
def min(self):
"""Query the minimum key-value pair.
Returns
-------
key, value : tuple
The key-value pair with the minimum value in the heap.
Raises
------
NetworkXError
If the heap is empty.
"""
raise NotImplementedError
def pop(self):
"""Delete the minimum pair in the heap.
Returns
-------
key, value : tuple
The key-value pair with the minimum value in the heap.
Raises
------
NetworkXError
If the heap is empty.
"""
raise NotImplementedError
def get(self, key, default=None):
"""Returns the value associated with a key.
Parameters
----------
key : hashable object
The key to be looked up.
default : object
Default value to return if the key is not present in the heap.
Default value: None.
Returns
-------
value : object.
The value associated with the key.
"""
raise NotImplementedError
def insert(self, key, value, allow_increase=False):
"""Insert a new key-value pair or modify the value in an existing
pair.
Parameters
----------
key : hashable object
The key.
value : object comparable with existing values.
The value.
allow_increase : bool
Whether the value is allowed to increase. If False, attempts to
increase an existing value have no effect. Default value: False.
Returns
-------
decreased : bool
True if a pair is inserted or the existing value is decreased.
"""
raise NotImplementedError
def __nonzero__(self):
"""Returns whether the heap if empty."""
return bool(self._dict)
def __bool__(self):
"""Returns whether the heap if empty."""
return bool(self._dict)
def __len__(self):
"""Returns the number of key-value pairs in the heap."""
return len(self._dict)
def __contains__(self, key):
"""Returns whether a key exists in the heap.
Parameters
----------
key : any hashable object.
The key to be looked up.
"""
return key in self._dict
class PairingHeap(MinHeap):
"""A pairing heap."""
class _Node(MinHeap._Item):
"""A node in a pairing heap.
A tree in a pairing heap is stored using the left-child, right-sibling
representation.
"""
__slots__ = ("left", "next", "prev", "parent")
def __init__(self, key, value):
super().__init__(key, value)
# The leftmost child.
self.left = None
# The next sibling.
self.next = None
# The previous sibling.
self.prev = None
# The parent.
self.parent = None
def __init__(self):
"""Initialize a pairing heap."""
super().__init__()
self._root = None
def min(self):
if self._root is None:
raise nx.NetworkXError("heap is empty.")
return (self._root.key, self._root.value)
def pop(self):
if self._root is None:
raise nx.NetworkXError("heap is empty.")
min_node = self._root
self._root = self._merge_children(self._root)
del self._dict[min_node.key]
return (min_node.key, min_node.value)
def get(self, key, default=None):
node = self._dict.get(key)
return node.value if node is not None else default
def insert(self, key, value, allow_increase=False):
node = self._dict.get(key)
root = self._root
if node is not None:
if value < node.value:
node.value = value
if node is not root and value < node.parent.value:
self._cut(node)
self._root = self._link(root, node)
return True
elif allow_increase and value > node.value:
node.value = value
child = self._merge_children(node)
# Nonstandard step: Link the merged subtree with the root. See
# below for the standard step.
if child is not None:
self._root = self._link(self._root, child)
# Standard step: Perform a decrease followed by a pop as if the
# value were the smallest in the heap. Then insert the new
# value into the heap.
# if node is not root:
# self._cut(node)
# if child is not None:
# root = self._link(root, child)
# self._root = self._link(root, node)
# else:
# self._root = (self._link(node, child)
# if child is not None else node)
return False
else:
# Insert a new key.
node = self._Node(key, value)
self._dict[key] = node
self._root = self._link(root, node) if root is not None else node
return True
def _link(self, root, other):
"""Link two nodes, making the one with the smaller value the parent of
the other.
"""
if other.value < root.value:
root, other = other, root
next = root.left
other.next = next
if next is not None:
next.prev = other
other.prev = None
root.left = other
other.parent = root
return root
def _merge_children(self, root):
"""Merge the subtrees of the root using the standard two-pass method.
The resulting subtree is detached from the root.
"""
node = root.left
root.left = None
if node is not None:
link = self._link
# Pass 1: Merge pairs of consecutive subtrees from left to right.
# At the end of the pass, only the prev pointers of the resulting
# subtrees have meaningful values. The other pointers will be fixed
# in pass 2.
prev = None
while True:
next = node.next
if next is None:
node.prev = prev
break
next_next = next.next
node = link(node, next)
node.prev = prev
prev = node
if next_next is None:
break
node = next_next
# Pass 2: Successively merge the subtrees produced by pass 1 from
# right to left with the rightmost one.
prev = node.prev
while prev is not None:
prev_prev = prev.prev
node = link(prev, node)
prev = prev_prev
# Now node can become the new root. Its has no parent nor siblings.
node.prev = None
node.next = None
node.parent = None
return node
def _cut(self, node):
"""Cut a node from its parent."""
prev = node.prev
next = node.next
if prev is not None:
prev.next = next
else:
node.parent.left = next
node.prev = None
if next is not None:
next.prev = prev
node.next = None
node.parent = None
class BinaryHeap(MinHeap):
"""A binary heap."""
def __init__(self):
"""Initialize a binary heap."""
super().__init__()
self._heap = []
self._count = count()
def min(self):
dict = self._dict
if not dict:
raise nx.NetworkXError("heap is empty")
heap = self._heap
pop = heappop
# Repeatedly remove stale key-value pairs until a up-to-date one is
# met.
while True:
value, _, key = heap[0]
if key in dict and value == dict[key]:
break
pop(heap)
return (key, value)
def pop(self):
dict = self._dict
if not dict:
raise nx.NetworkXError("heap is empty")
heap = self._heap
pop = heappop
# Repeatedly remove stale key-value pairs until a up-to-date one is
# met.
while True:
value, _, key = heap[0]
pop(heap)
if key in dict and value == dict[key]:
break
del dict[key]
return (key, value)
def get(self, key, default=None):
return self._dict.get(key, default)
def insert(self, key, value, allow_increase=False):
dict = self._dict
if key in dict:
old_value = dict[key]
if value < old_value or (allow_increase and value > old_value):
# Since there is no way to efficiently obtain the location of a
# key-value pair in the heap, insert a new pair even if ones
# with the same key may already be present. Deem the old ones
# as stale and skip them when the minimum pair is queried.
dict[key] = value
heappush(self._heap, (value, next(self._count), key))
return value < old_value
return False
else:
dict[key] = value
heappush(self._heap, (value, next(self._count), key))
return True

View File

@@ -0,0 +1,298 @@
"""Priority queue class with updatable priorities.
"""
import heapq
__all__ = ["MappedQueue"]
class _HeapElement:
"""This proxy class separates the heap element from its priority.
The idea is that using a 2-tuple (priority, element) works
for sorting, but not for dict lookup because priorities are
often floating point values so round-off can mess up equality.
So, we need inequalities to look at the priority (for sorting)
and equality (and hash) to look at the element to enable
updates to the priority.
Unfortunately, this class can be tricky to work with if you forget that
`__lt__` compares the priority while `__eq__` compares the element.
In `greedy_modularity_communities()` the following code is
used to check that two _HeapElements differ in either element or priority:
if d_oldmax != row_max or d_oldmax.priority != row_max.priority:
If the priorities are the same, this implementation uses the element
as a tiebreaker. This provides compatibility with older systems that
use tuples to combine priority and elements.
"""
__slots__ = ["priority", "element", "_hash"]
def __init__(self, priority, element):
self.priority = priority
self.element = element
self._hash = hash(element)
def __lt__(self, other):
try:
other_priority = other.priority
except AttributeError:
return self.priority < other
# assume comparing to another _HeapElement
if self.priority == other_priority:
try:
return self.element < other.element
except TypeError as err:
raise TypeError(
"Consider using a tuple, with a priority value that can be compared."
)
return self.priority < other_priority
def __gt__(self, other):
try:
other_priority = other.priority
except AttributeError:
return self.priority > other
# assume comparing to another _HeapElement
if self.priority == other_priority:
try:
return self.element > other.element
except TypeError as err:
raise TypeError(
"Consider using a tuple, with a priority value that can be compared."
)
return self.priority > other_priority
def __eq__(self, other):
try:
return self.element == other.element
except AttributeError:
return self.element == other
def __hash__(self):
return self._hash
def __getitem__(self, indx):
return self.priority if indx == 0 else self.element[indx - 1]
def __iter__(self):
yield self.priority
try:
yield from self.element
except TypeError:
yield self.element
def __repr__(self):
return f"_HeapElement({self.priority}, {self.element})"
class MappedQueue:
"""The MappedQueue class implements a min-heap with removal and update-priority.
The min heap uses heapq as well as custom written _siftup and _siftdown
methods to allow the heap positions to be tracked by an additional dict
keyed by element to position. The smallest element can be popped in O(1) time,
new elements can be pushed in O(log n) time, and any element can be removed
or updated in O(log n) time. The queue cannot contain duplicate elements
and an attempt to push an element already in the queue will have no effect.
MappedQueue complements the heapq package from the python standard
library. While MappedQueue is designed for maximum compatibility with
heapq, it adds element removal, lookup, and priority update.
Parameters
----------
data : dict or iterable
Examples
--------
A `MappedQueue` can be created empty, or optionally, given a dictionary
of initial elements and priorities. The methods `push`, `pop`,
`remove`, and `update` operate on the queue.
>>> colors_nm = {'red':665, 'blue': 470, 'green': 550}
>>> q = MappedQueue(colors_nm)
>>> q.remove('red')
>>> q.update('green', 'violet', 400)
>>> q.push('indigo', 425)
True
>>> [q.pop().element for i in range(len(q.heap))]
['violet', 'indigo', 'blue']
A `MappedQueue` can also be initialized with a list or other iterable. The priority is assumed
to be the sort order of the items in the list.
>>> q = MappedQueue([916, 50, 4609, 493, 237])
>>> q.remove(493)
>>> q.update(237, 1117)
>>> [q.pop() for i in range(len(q.heap))]
[50, 916, 1117, 4609]
An exception is raised if the elements are not comparable.
>>> q = MappedQueue([100, 'a'])
Traceback (most recent call last):
...
TypeError: '<' not supported between instances of 'int' and 'str'
To avoid the exception, use a dictionary to assign priorities to the elements.
>>> q = MappedQueue({100: 0, 'a': 1 })
References
----------
.. [1] Cormen, T. H., Leiserson, C. E., Rivest, R. L., & Stein, C. (2001).
Introduction to algorithms second edition.
.. [2] Knuth, D. E. (1997). The art of computer programming (Vol. 3).
Pearson Education.
"""
def __init__(self, data=None):
"""Priority queue class with updatable priorities."""
if data is None:
self.heap = []
elif isinstance(data, dict):
self.heap = [_HeapElement(v, k) for k, v in data.items()]
else:
self.heap = list(data)
self.position = {}
self._heapify()
def _heapify(self):
"""Restore heap invariant and recalculate map."""
heapq.heapify(self.heap)
self.position = {elt: pos for pos, elt in enumerate(self.heap)}
if len(self.heap) != len(self.position):
raise AssertionError("Heap contains duplicate elements")
def __len__(self):
return len(self.heap)
def push(self, elt, priority=None):
"""Add an element to the queue."""
if priority is not None:
elt = _HeapElement(priority, elt)
# If element is already in queue, do nothing
if elt in self.position:
return False
# Add element to heap and dict
pos = len(self.heap)
self.heap.append(elt)
self.position[elt] = pos
# Restore invariant by sifting down
self._siftdown(0, pos)
return True
def pop(self):
"""Remove and return the smallest element in the queue."""
# Remove smallest element
elt = self.heap[0]
del self.position[elt]
# If elt is last item, remove and return
if len(self.heap) == 1:
self.heap.pop()
return elt
# Replace root with last element
last = self.heap.pop()
self.heap[0] = last
self.position[last] = 0
# Restore invariant by sifting up
self._siftup(0)
# Return smallest element
return elt
def update(self, elt, new, priority=None):
"""Replace an element in the queue with a new one."""
if priority is not None:
new = _HeapElement(priority, new)
# Replace
pos = self.position[elt]
self.heap[pos] = new
del self.position[elt]
self.position[new] = pos
# Restore invariant by sifting up
self._siftup(pos)
def remove(self, elt):
"""Remove an element from the queue."""
# Find and remove element
try:
pos = self.position[elt]
del self.position[elt]
except KeyError:
# Not in queue
raise
# If elt is last item, remove and return
if pos == len(self.heap) - 1:
self.heap.pop()
return
# Replace elt with last element
last = self.heap.pop()
self.heap[pos] = last
self.position[last] = pos
# Restore invariant by sifting up
self._siftup(pos)
def _siftup(self, pos):
"""Move smaller child up until hitting a leaf.
Built to mimic code for heapq._siftup
only updating position dict too.
"""
heap, position = self.heap, self.position
end_pos = len(heap)
startpos = pos
newitem = heap[pos]
# Shift up the smaller child until hitting a leaf
child_pos = (pos << 1) + 1 # start with leftmost child position
while child_pos < end_pos:
# Set child_pos to index of smaller child.
child = heap[child_pos]
right_pos = child_pos + 1
if right_pos < end_pos:
right = heap[right_pos]
if not child < right:
child = right
child_pos = right_pos
# Move the smaller child up.
heap[pos] = child
position[child] = pos
pos = child_pos
child_pos = (pos << 1) + 1
# pos is a leaf position. Put newitem there, and bubble it up
# to its final resting place (by sifting its parents down).
while pos > 0:
parent_pos = (pos - 1) >> 1
parent = heap[parent_pos]
if not newitem < parent:
break
heap[pos] = parent
position[parent] = pos
pos = parent_pos
heap[pos] = newitem
position[newitem] = pos
def _siftdown(self, start_pos, pos):
"""Restore invariant. keep swapping with parent until smaller.
Built to mimic code for heapq._siftdown
only updating position dict too.
"""
heap, position = self.heap, self.position
newitem = heap[pos]
# Follow the path to the root, moving parents down until finding a place
# newitem fits.
while pos > start_pos:
parent_pos = (pos - 1) >> 1
parent = heap[parent_pos]
if not newitem < parent:
break
heap[pos] = parent
position[parent] = pos
pos = parent_pos
heap[pos] = newitem
position[newitem] = pos

View File

@@ -0,0 +1,491 @@
"""
Miscellaneous Helpers for NetworkX.
These are not imported into the base networkx namespace but
can be accessed, for example, as
>>> import networkx
>>> networkx.utils.make_list_of_ints({1, 2, 3})
[1, 2, 3]
>>> networkx.utils.arbitrary_element({5, 1, 7}) # doctest: +SKIP
1
"""
import sys
import uuid
import warnings
from collections import defaultdict, deque
from collections.abc import Iterable, Iterator, Sized
from itertools import chain, tee
import networkx as nx
__all__ = [
"flatten",
"make_list_of_ints",
"dict_to_numpy_array",
"arbitrary_element",
"pairwise",
"groups",
"create_random_state",
"create_py_random_state",
"PythonRandomInterface",
"nodes_equal",
"edges_equal",
"graphs_equal",
]
# some cookbook stuff
# used in deciding whether something is a bunch of nodes, edges, etc.
# see G.add_nodes and others in Graph Class in networkx/base.py
def flatten(obj, result=None):
"""Return flattened version of (possibly nested) iterable object."""
if not isinstance(obj, (Iterable, Sized)) or isinstance(obj, str):
return obj
if result is None:
result = []
for item in obj:
if not isinstance(item, (Iterable, Sized)) or isinstance(item, str):
result.append(item)
else:
flatten(item, result)
return tuple(result)
def make_list_of_ints(sequence):
"""Return list of ints from sequence of integral numbers.
All elements of the sequence must satisfy int(element) == element
or a ValueError is raised. Sequence is iterated through once.
If sequence is a list, the non-int values are replaced with ints.
So, no new list is created
"""
if not isinstance(sequence, list):
result = []
for i in sequence:
errmsg = f"sequence is not all integers: {i}"
try:
ii = int(i)
except ValueError:
raise nx.NetworkXError(errmsg) from None
if ii != i:
raise nx.NetworkXError(errmsg)
result.append(ii)
return result
# original sequence is a list... in-place conversion to ints
for indx, i in enumerate(sequence):
errmsg = f"sequence is not all integers: {i}"
if isinstance(i, int):
continue
try:
ii = int(i)
except ValueError:
raise nx.NetworkXError(errmsg) from None
if ii != i:
raise nx.NetworkXError(errmsg)
sequence[indx] = ii
return sequence
def dict_to_numpy_array(d, mapping=None):
"""Convert a dictionary of dictionaries to a numpy array
with optional mapping."""
try:
return _dict_to_numpy_array2(d, mapping)
except (AttributeError, TypeError):
# AttributeError is when no mapping was provided and v.keys() fails.
# TypeError is when a mapping was provided and d[k1][k2] fails.
return _dict_to_numpy_array1(d, mapping)
def _dict_to_numpy_array2(d, mapping=None):
"""Convert a dictionary of dictionaries to a 2d numpy array
with optional mapping.
"""
import numpy as np
if mapping is None:
s = set(d.keys())
for k, v in d.items():
s.update(v.keys())
mapping = dict(zip(s, range(len(s))))
n = len(mapping)
a = np.zeros((n, n))
for k1, i in mapping.items():
for k2, j in mapping.items():
try:
a[i, j] = d[k1][k2]
except KeyError:
pass
return a
def _dict_to_numpy_array1(d, mapping=None):
"""Convert a dictionary of numbers to a 1d numpy array with optional mapping."""
import numpy as np
if mapping is None:
s = set(d.keys())
mapping = dict(zip(s, range(len(s))))
n = len(mapping)
a = np.zeros(n)
for k1, i in mapping.items():
i = mapping[k1]
a[i] = d[k1]
return a
def arbitrary_element(iterable):
"""Returns an arbitrary element of `iterable` without removing it.
This is most useful for "peeking" at an arbitrary element of a set,
but can be used for any list, dictionary, etc., as well.
Parameters
----------
iterable : `abc.collections.Iterable` instance
Any object that implements ``__iter__``, e.g. set, dict, list, tuple,
etc.
Returns
-------
The object that results from ``next(iter(iterable))``
Raises
------
ValueError
If `iterable` is an iterator (because the current implementation of
this function would consume an element from the iterator).
Examples
--------
Arbitrary elements from common Iterable objects:
>>> nx.utils.arbitrary_element([1, 2, 3]) # list
1
>>> nx.utils.arbitrary_element((1, 2, 3)) # tuple
1
>>> nx.utils.arbitrary_element({1, 2, 3}) # set
1
>>> d = {k: v for k, v in zip([1, 2, 3], [3, 2, 1])}
>>> nx.utils.arbitrary_element(d) # dict_keys
1
>>> nx.utils.arbitrary_element(d.values()) # dict values
3
`str` is also an Iterable:
>>> nx.utils.arbitrary_element("hello")
'h'
:exc:`ValueError` is raised if `iterable` is an iterator:
>>> iterator = iter([1, 2, 3]) # Iterator, *not* Iterable
>>> nx.utils.arbitrary_element(iterator)
Traceback (most recent call last):
...
ValueError: cannot return an arbitrary item from an iterator
Notes
-----
This function does not return a *random* element. If `iterable` is
ordered, sequential calls will return the same value::
>>> l = [1, 2, 3]
>>> nx.utils.arbitrary_element(l)
1
>>> nx.utils.arbitrary_element(l)
1
"""
if isinstance(iterable, Iterator):
raise ValueError("cannot return an arbitrary item from an iterator")
# Another possible implementation is ``for x in iterable: return x``.
return next(iter(iterable))
# Recipe from the itertools documentation.
def pairwise(iterable, cyclic=False):
"s -> (s0, s1), (s1, s2), (s2, s3), ..."
a, b = tee(iterable)
first = next(b, None)
if cyclic is True:
return zip(a, chain(b, (first,)))
return zip(a, b)
def groups(many_to_one):
"""Converts a many-to-one mapping into a one-to-many mapping.
`many_to_one` must be a dictionary whose keys and values are all
:term:`hashable`.
The return value is a dictionary mapping values from `many_to_one`
to sets of keys from `many_to_one` that have that value.
Examples
--------
>>> from networkx.utils import groups
>>> many_to_one = {"a": 1, "b": 1, "c": 2, "d": 3, "e": 3}
>>> groups(many_to_one) # doctest: +SKIP
{1: {'a', 'b'}, 2: {'c'}, 3: {'e', 'd'}}
"""
one_to_many = defaultdict(set)
for v, k in many_to_one.items():
one_to_many[k].add(v)
return dict(one_to_many)
def create_random_state(random_state=None):
"""Returns a numpy.random.RandomState or numpy.random.Generator instance
depending on input.
Parameters
----------
random_state : int or NumPy RandomState or Generator instance, optional (default=None)
If int, return a numpy.random.RandomState instance set with seed=int.
if `numpy.random.RandomState` instance, return it.
if `numpy.random.Generator` instance, return it.
if None or numpy.random, return the global random number generator used
by numpy.random.
"""
import numpy as np
if random_state is None or random_state is np.random:
return np.random.mtrand._rand
if isinstance(random_state, np.random.RandomState):
return random_state
if isinstance(random_state, int):
return np.random.RandomState(random_state)
if isinstance(random_state, np.random.Generator):
return random_state
msg = (
f"{random_state} cannot be used to create a numpy.random.RandomState or\n"
"numpy.random.Generator instance"
)
raise ValueError(msg)
class PythonRandomInterface:
def __init__(self, rng=None):
try:
import numpy as np
except ImportError:
msg = "numpy not found, only random.random available."
warnings.warn(msg, ImportWarning)
if rng is None:
self._rng = np.random.mtrand._rand
else:
self._rng = rng
def random(self):
return self._rng.random()
def uniform(self, a, b):
return a + (b - a) * self._rng.random()
def randrange(self, a, b=None):
import numpy as np
if isinstance(self._rng, np.random.Generator):
return self._rng.integers(a, b)
return self._rng.randint(a, b)
# NOTE: the numpy implementations of `choice` don't support strings, so
# this cannot be replaced with self._rng.choice
def choice(self, seq):
import numpy as np
if isinstance(self._rng, np.random.Generator):
idx = self._rng.integers(0, len(seq))
else:
idx = self._rng.randint(0, len(seq))
return seq[idx]
def gauss(self, mu, sigma):
return self._rng.normal(mu, sigma)
def shuffle(self, seq):
return self._rng.shuffle(seq)
# Some methods don't match API for numpy RandomState.
# Commented out versions are not used by NetworkX
def sample(self, seq, k):
return self._rng.choice(list(seq), size=(k,), replace=False)
def randint(self, a, b):
import numpy as np
if isinstance(self._rng, np.random.Generator):
return self._rng.integers(a, b + 1)
return self._rng.randint(a, b + 1)
# exponential as expovariate with 1/argument,
def expovariate(self, scale):
return self._rng.exponential(1 / scale)
# pareto as paretovariate with 1/argument,
def paretovariate(self, shape):
return self._rng.pareto(shape)
# weibull as weibullvariate multiplied by beta,
# def weibullvariate(self, alpha, beta):
# return self._rng.weibull(alpha) * beta
#
# def triangular(self, low, high, mode):
# return self._rng.triangular(low, mode, high)
#
# def choices(self, seq, weights=None, cum_weights=None, k=1):
# return self._rng.choice(seq
def create_py_random_state(random_state=None):
"""Returns a random.Random instance depending on input.
Parameters
----------
random_state : int or random number generator or None (default=None)
If int, return a random.Random instance set with seed=int.
if random.Random instance, return it.
if None or the `random` package, return the global random number
generator used by `random`.
if np.random package, return the global numpy random number
generator wrapped in a PythonRandomInterface class.
if np.random.RandomState or np.random.Generator instance, return it
wrapped in PythonRandomInterface
if a PythonRandomInterface instance, return it
"""
import random
try:
import numpy as np
if random_state is np.random:
return PythonRandomInterface(np.random.mtrand._rand)
if isinstance(random_state, (np.random.RandomState, np.random.Generator)):
return PythonRandomInterface(random_state)
if isinstance(random_state, PythonRandomInterface):
return random_state
except ImportError:
pass
if random_state is None or random_state is random:
return random._inst
if isinstance(random_state, random.Random):
return random_state
if isinstance(random_state, int):
return random.Random(random_state)
msg = f"{random_state} cannot be used to generate a random.Random instance"
raise ValueError(msg)
def nodes_equal(nodes1, nodes2):
"""Check if nodes are equal.
Equality here means equal as Python objects.
Node data must match if included.
The order of nodes is not relevant.
Parameters
----------
nodes1, nodes2 : iterables of nodes, or (node, datadict) tuples
Returns
-------
bool
True if nodes are equal, False otherwise.
"""
nlist1 = list(nodes1)
nlist2 = list(nodes2)
try:
d1 = dict(nlist1)
d2 = dict(nlist2)
except (ValueError, TypeError):
d1 = dict.fromkeys(nlist1)
d2 = dict.fromkeys(nlist2)
return d1 == d2
def edges_equal(edges1, edges2):
"""Check if edges are equal.
Equality here means equal as Python objects.
Edge data must match if included.
The order of the edges is not relevant.
Parameters
----------
edges1, edges2 : iterables of with u, v nodes as
edge tuples (u, v), or
edge tuples with data dicts (u, v, d), or
edge tuples with keys and data dicts (u, v, k, d)
Returns
-------
bool
True if edges are equal, False otherwise.
"""
from collections import defaultdict
d1 = defaultdict(dict)
d2 = defaultdict(dict)
c1 = 0
for c1, e in enumerate(edges1):
u, v = e[0], e[1]
data = [e[2:]]
if v in d1[u]:
data = d1[u][v] + data
d1[u][v] = data
d1[v][u] = data
c2 = 0
for c2, e in enumerate(edges2):
u, v = e[0], e[1]
data = [e[2:]]
if v in d2[u]:
data = d2[u][v] + data
d2[u][v] = data
d2[v][u] = data
if c1 != c2:
return False
# can check one direction because lengths are the same.
for n, nbrdict in d1.items():
for nbr, datalist in nbrdict.items():
if n not in d2:
return False
if nbr not in d2[n]:
return False
d2datalist = d2[n][nbr]
for data in datalist:
if datalist.count(data) != d2datalist.count(data):
return False
return True
def graphs_equal(graph1, graph2):
"""Check if graphs are equal.
Equality here means equal as Python objects (not isomorphism).
Node, edge and graph data must match.
Parameters
----------
graph1, graph2 : graph
Returns
-------
bool
True if graphs are equal, False otherwise.
"""
return (
graph1.adj == graph2.adj
and graph1.nodes == graph2.nodes
and graph1.graph == graph2.graph
)

View File

@@ -0,0 +1,164 @@
"""
Utilities for generating random numbers, random sequences, and
random selections.
"""
import networkx as nx
from networkx.utils import py_random_state
__all__ = [
"powerlaw_sequence",
"zipf_rv",
"cumulative_distribution",
"discrete_sequence",
"random_weighted_sample",
"weighted_choice",
]
# The same helpers for choosing random sequences from distributions
# uses Python's random module
# https://docs.python.org/3/library/random.html
@py_random_state(2)
def powerlaw_sequence(n, exponent=2.0, seed=None):
"""
Return sample sequence of length n from a power law distribution.
"""
return [seed.paretovariate(exponent - 1) for i in range(n)]
@py_random_state(2)
def zipf_rv(alpha, xmin=1, seed=None):
r"""Returns a random value chosen from the Zipf distribution.
The return value is an integer drawn from the probability distribution
.. math::
p(x)=\frac{x^{-\alpha}}{\zeta(\alpha, x_{\min})},
where $\zeta(\alpha, x_{\min})$ is the Hurwitz zeta function.
Parameters
----------
alpha : float
Exponent value of the distribution
xmin : int
Minimum value
seed : integer, random_state, or None (default)
Indicator of random number generation state.
See :ref:`Randomness<randomness>`.
Returns
-------
x : int
Random value from Zipf distribution
Raises
------
ValueError:
If xmin < 1 or
If alpha <= 1
Notes
-----
The rejection algorithm generates random values for a the power-law
distribution in uniformly bounded expected time dependent on
parameters. See [1]_ for details on its operation.
Examples
--------
>>> nx.utils.zipf_rv(alpha=2, xmin=3, seed=42)
8
References
----------
.. [1] Luc Devroye, Non-Uniform Random Variate Generation,
Springer-Verlag, New York, 1986.
"""
if xmin < 1:
raise ValueError("xmin < 1")
if alpha <= 1:
raise ValueError("a <= 1.0")
a1 = alpha - 1.0
b = 2**a1
while True:
u = 1.0 - seed.random() # u in (0,1]
v = seed.random() # v in [0,1)
x = int(xmin * u ** -(1.0 / a1))
t = (1.0 + (1.0 / x)) ** a1
if v * x * (t - 1.0) / (b - 1.0) <= t / b:
break
return x
def cumulative_distribution(distribution):
"""Returns normalized cumulative distribution from discrete distribution."""
cdf = [0.0]
psum = sum(distribution)
for i in range(0, len(distribution)):
cdf.append(cdf[i] + distribution[i] / psum)
return cdf
@py_random_state(3)
def discrete_sequence(n, distribution=None, cdistribution=None, seed=None):
"""
Return sample sequence of length n from a given discrete distribution
or discrete cumulative distribution.
One of the following must be specified.
distribution = histogram of values, will be normalized
cdistribution = normalized discrete cumulative distribution
"""
import bisect
if cdistribution is not None:
cdf = cdistribution
elif distribution is not None:
cdf = cumulative_distribution(distribution)
else:
raise nx.NetworkXError(
"discrete_sequence: distribution or cdistribution missing"
)
# get a uniform random number
inputseq = [seed.random() for i in range(n)]
# choose from CDF
seq = [bisect.bisect_left(cdf, s) - 1 for s in inputseq]
return seq
@py_random_state(2)
def random_weighted_sample(mapping, k, seed=None):
"""Returns k items without replacement from a weighted sample.
The input is a dictionary of items with weights as values.
"""
if k > len(mapping):
raise ValueError("sample larger than population")
sample = set()
while len(sample) < k:
sample.add(weighted_choice(mapping, seed))
return list(sample)
@py_random_state(1)
def weighted_choice(mapping, seed=None):
"""Returns a single element from a weighted sample.
The input is a dictionary of items with weights as values.
"""
# use roulette method
rnd = seed.random() * sum(mapping.values())
for k, w in mapping.items():
rnd -= w
if rnd < 0:
return k

View File

@@ -0,0 +1,158 @@
"""
Cuthill-McKee ordering of graph nodes to produce sparse matrices
"""
from collections import deque
from operator import itemgetter
import networkx as nx
from ..utils import arbitrary_element
__all__ = ["cuthill_mckee_ordering", "reverse_cuthill_mckee_ordering"]
def cuthill_mckee_ordering(G, heuristic=None):
"""Generate an ordering (permutation) of the graph nodes to make
a sparse matrix.
Uses the Cuthill-McKee heuristic (based on breadth-first search) [1]_.
Parameters
----------
G : graph
A NetworkX graph
heuristic : function, optional
Function to choose starting node for RCM algorithm. If None
a node from a pseudo-peripheral pair is used. A user-defined function
can be supplied that takes a graph object and returns a single node.
Returns
-------
nodes : generator
Generator of nodes in Cuthill-McKee ordering.
Examples
--------
>>> from networkx.utils import cuthill_mckee_ordering
>>> G = nx.path_graph(4)
>>> rcm = list(cuthill_mckee_ordering(G))
>>> A = nx.adjacency_matrix(G, nodelist=rcm)
Smallest degree node as heuristic function:
>>> def smallest_degree(G):
... return min(G, key=G.degree)
>>> rcm = list(cuthill_mckee_ordering(G, heuristic=smallest_degree))
See Also
--------
reverse_cuthill_mckee_ordering
Notes
-----
The optimal solution the bandwidth reduction is NP-complete [2]_.
References
----------
.. [1] E. Cuthill and J. McKee.
Reducing the bandwidth of sparse symmetric matrices,
In Proc. 24th Nat. Conf. ACM, pages 157-172, 1969.
http://doi.acm.org/10.1145/800195.805928
.. [2] Steven S. Skiena. 1997. The Algorithm Design Manual.
Springer-Verlag New York, Inc., New York, NY, USA.
"""
for c in nx.connected_components(G):
yield from connected_cuthill_mckee_ordering(G.subgraph(c), heuristic)
def reverse_cuthill_mckee_ordering(G, heuristic=None):
"""Generate an ordering (permutation) of the graph nodes to make
a sparse matrix.
Uses the reverse Cuthill-McKee heuristic (based on breadth-first search)
[1]_.
Parameters
----------
G : graph
A NetworkX graph
heuristic : function, optional
Function to choose starting node for RCM algorithm. If None
a node from a pseudo-peripheral pair is used. A user-defined function
can be supplied that takes a graph object and returns a single node.
Returns
-------
nodes : generator
Generator of nodes in reverse Cuthill-McKee ordering.
Examples
--------
>>> from networkx.utils import reverse_cuthill_mckee_ordering
>>> G = nx.path_graph(4)
>>> rcm = list(reverse_cuthill_mckee_ordering(G))
>>> A = nx.adjacency_matrix(G, nodelist=rcm)
Smallest degree node as heuristic function:
>>> def smallest_degree(G):
... return min(G, key=G.degree)
>>> rcm = list(reverse_cuthill_mckee_ordering(G, heuristic=smallest_degree))
See Also
--------
cuthill_mckee_ordering
Notes
-----
The optimal solution the bandwidth reduction is NP-complete [2]_.
References
----------
.. [1] E. Cuthill and J. McKee.
Reducing the bandwidth of sparse symmetric matrices,
In Proc. 24th Nat. Conf. ACM, pages 157-72, 1969.
http://doi.acm.org/10.1145/800195.805928
.. [2] Steven S. Skiena. 1997. The Algorithm Design Manual.
Springer-Verlag New York, Inc., New York, NY, USA.
"""
return reversed(list(cuthill_mckee_ordering(G, heuristic=heuristic)))
def connected_cuthill_mckee_ordering(G, heuristic=None):
# the cuthill mckee algorithm for connected graphs
if heuristic is None:
start = pseudo_peripheral_node(G)
else:
start = heuristic(G)
visited = {start}
queue = deque([start])
while queue:
parent = queue.popleft()
yield parent
nd = sorted(G.degree(set(G[parent]) - visited), key=itemgetter(1))
children = [n for n, d in nd]
visited.update(children)
queue.extend(children)
def pseudo_peripheral_node(G):
# helper for cuthill-mckee to find a node in a "pseudo peripheral pair"
# to use as good starting node
u = arbitrary_element(G)
lp = 0
v = u
while True:
spl = dict(nx.shortest_path_length(G, v))
l = max(spl.values())
if l <= lp:
break
lp = l
farthest = (n for n, dist in spl.items() if dist == l)
v, deg = min(G.degree(farthest), key=itemgetter(1))
return v

View File

@@ -0,0 +1,11 @@
import pytest
def test_utils_namespace():
"""Ensure objects are not unintentionally exposed in utils namespace."""
with pytest.raises(ImportError):
from networkx.utils import nx
with pytest.raises(ImportError):
from networkx.utils import sys
with pytest.raises(ImportError):
from networkx.utils import defaultdict, deque

View File

@@ -0,0 +1,491 @@
import os
import pathlib
import random
import tempfile
import pytest
import networkx as nx
from networkx.utils.decorators import (
argmap,
not_implemented_for,
np_random_state,
open_file,
py_random_state,
)
from networkx.utils.misc import PythonRandomInterface
def test_not_implemented_decorator():
@not_implemented_for("directed")
def test_d(G):
pass
test_d(nx.Graph())
with pytest.raises(nx.NetworkXNotImplemented):
test_d(nx.DiGraph())
@not_implemented_for("undirected")
def test_u(G):
pass
test_u(nx.DiGraph())
with pytest.raises(nx.NetworkXNotImplemented):
test_u(nx.Graph())
@not_implemented_for("multigraph")
def test_m(G):
pass
test_m(nx.Graph())
with pytest.raises(nx.NetworkXNotImplemented):
test_m(nx.MultiGraph())
@not_implemented_for("graph")
def test_g(G):
pass
test_g(nx.MultiGraph())
with pytest.raises(nx.NetworkXNotImplemented):
test_g(nx.Graph())
# not MultiDiGraph (multiple arguments => AND)
@not_implemented_for("directed", "multigraph")
def test_not_md(G):
pass
test_not_md(nx.Graph())
test_not_md(nx.DiGraph())
test_not_md(nx.MultiGraph())
with pytest.raises(nx.NetworkXNotImplemented):
test_not_md(nx.MultiDiGraph())
# Graph only (multiple decorators => OR)
@not_implemented_for("directed")
@not_implemented_for("multigraph")
def test_graph_only(G):
pass
test_graph_only(nx.Graph())
with pytest.raises(nx.NetworkXNotImplemented):
test_graph_only(nx.DiGraph())
with pytest.raises(nx.NetworkXNotImplemented):
test_graph_only(nx.MultiGraph())
with pytest.raises(nx.NetworkXNotImplemented):
test_graph_only(nx.MultiDiGraph())
with pytest.raises(ValueError):
not_implemented_for("directed", "undirected")
with pytest.raises(ValueError):
not_implemented_for("multigraph", "graph")
def test_not_implemented_decorator_key():
with pytest.raises(KeyError):
@not_implemented_for("foo")
def test1(G):
pass
test1(nx.Graph())
def test_not_implemented_decorator_raise():
with pytest.raises(nx.NetworkXNotImplemented):
@not_implemented_for("graph")
def test1(G):
pass
test1(nx.Graph())
class TestOpenFileDecorator:
def setup_method(self):
self.text = ["Blah... ", "BLAH ", "BLAH!!!!"]
self.fobj = tempfile.NamedTemporaryFile("wb+", delete=False)
self.name = self.fobj.name
def teardown_method(self):
self.fobj.close()
os.unlink(self.name)
def write(self, path):
for text in self.text:
path.write(text.encode("ascii"))
@open_file(1, "r")
def read(self, path):
return path.readlines()[0]
@staticmethod
@open_file(0, "wb")
def writer_arg0(path):
path.write(b"demo")
@open_file(1, "wb+")
def writer_arg1(self, path):
self.write(path)
@open_file(2, "wb")
def writer_arg2default(self, x, path=None):
if path is None:
with tempfile.NamedTemporaryFile("wb+") as fh:
self.write(fh)
else:
self.write(path)
@open_file(4, "wb")
def writer_arg4default(self, x, y, other="hello", path=None, **kwargs):
if path is None:
with tempfile.NamedTemporaryFile("wb+") as fh:
self.write(fh)
else:
self.write(path)
@open_file("path", "wb")
def writer_kwarg(self, **kwargs):
path = kwargs.get("path", None)
if path is None:
with tempfile.NamedTemporaryFile("wb+") as fh:
self.write(fh)
else:
self.write(path)
def test_writer_arg0_str(self):
self.writer_arg0(self.name)
def test_writer_arg0_fobj(self):
self.writer_arg0(self.fobj)
def test_writer_arg0_pathlib(self):
self.writer_arg0(pathlib.Path(self.name))
def test_writer_arg1_str(self):
self.writer_arg1(self.name)
assert self.read(self.name) == "".join(self.text)
def test_writer_arg1_fobj(self):
self.writer_arg1(self.fobj)
assert not self.fobj.closed
self.fobj.close()
assert self.read(self.name) == "".join(self.text)
def test_writer_arg2default_str(self):
self.writer_arg2default(0, path=None)
self.writer_arg2default(0, path=self.name)
assert self.read(self.name) == "".join(self.text)
def test_writer_arg2default_fobj(self):
self.writer_arg2default(0, path=self.fobj)
assert not self.fobj.closed
self.fobj.close()
assert self.read(self.name) == "".join(self.text)
def test_writer_arg2default_fobj_path_none(self):
self.writer_arg2default(0, path=None)
def test_writer_arg4default_fobj(self):
self.writer_arg4default(0, 1, dog="dog", other="other")
self.writer_arg4default(0, 1, dog="dog", other="other", path=self.name)
assert self.read(self.name) == "".join(self.text)
def test_writer_kwarg_str(self):
self.writer_kwarg(path=self.name)
assert self.read(self.name) == "".join(self.text)
def test_writer_kwarg_fobj(self):
self.writer_kwarg(path=self.fobj)
self.fobj.close()
assert self.read(self.name) == "".join(self.text)
def test_writer_kwarg_path_none(self):
self.writer_kwarg(path=None)
class TestRandomState:
@classmethod
def setup_class(cls):
global np
np = pytest.importorskip("numpy")
@np_random_state(1)
def instantiate_np_random_state(self, random_state):
assert isinstance(random_state, np.random.RandomState)
return random_state.random_sample()
@py_random_state(1)
def instantiate_py_random_state(self, random_state):
assert isinstance(random_state, (random.Random, PythonRandomInterface))
return random_state.random()
def test_random_state_None(self):
np.random.seed(42)
rv = np.random.random_sample()
np.random.seed(42)
assert rv == self.instantiate_np_random_state(None)
random.seed(42)
rv = random.random()
random.seed(42)
assert rv == self.instantiate_py_random_state(None)
def test_random_state_np_random(self):
np.random.seed(42)
rv = np.random.random_sample()
np.random.seed(42)
assert rv == self.instantiate_np_random_state(np.random)
np.random.seed(42)
assert rv == self.instantiate_py_random_state(np.random)
def test_random_state_int(self):
np.random.seed(42)
np_rv = np.random.random_sample()
random.seed(42)
py_rv = random.random()
np.random.seed(42)
seed = 1
rval = self.instantiate_np_random_state(seed)
rval_expected = np.random.RandomState(seed).rand()
assert rval, rval_expected
# test that global seed wasn't changed in function
assert np_rv == np.random.random_sample()
random.seed(42)
rval = self.instantiate_py_random_state(seed)
rval_expected = random.Random(seed).random()
assert rval, rval_expected
# test that global seed wasn't changed in function
assert py_rv == random.random()
def test_random_state_np_random_RandomState(self):
np.random.seed(42)
np_rv = np.random.random_sample()
np.random.seed(42)
seed = 1
rng = np.random.RandomState(seed)
rval = self.instantiate_np_random_state(seed)
rval_expected = np.random.RandomState(seed).rand()
assert rval, rval_expected
rval = self.instantiate_py_random_state(seed)
rval_expected = np.random.RandomState(seed).rand()
assert rval, rval_expected
# test that global seed wasn't changed in function
assert np_rv == np.random.random_sample()
def test_random_state_py_random(self):
seed = 1
rng = random.Random(seed)
rv = self.instantiate_py_random_state(rng)
assert rv, random.Random(seed).random()
pytest.raises(ValueError, self.instantiate_np_random_state, rng)
def test_random_state_string_arg_index():
with pytest.raises(nx.NetworkXError):
@np_random_state("a")
def make_random_state(rs):
pass
rstate = make_random_state(1)
def test_py_random_state_string_arg_index():
with pytest.raises(nx.NetworkXError):
@py_random_state("a")
def make_random_state(rs):
pass
rstate = make_random_state(1)
def test_random_state_invalid_arg_index():
with pytest.raises(nx.NetworkXError):
@np_random_state(2)
def make_random_state(rs):
pass
rstate = make_random_state(1)
def test_py_random_state_invalid_arg_index():
with pytest.raises(nx.NetworkXError):
@py_random_state(2)
def make_random_state(rs):
pass
rstate = make_random_state(1)
class TestArgmap:
class ArgmapError(RuntimeError):
pass
def test_trivial_function(self):
def do_not_call(x):
raise ArgmapError("do not call this function")
@argmap(do_not_call)
def trivial_argmap():
return 1
assert trivial_argmap() == 1
def test_trivial_iterator(self):
def do_not_call(x):
raise ArgmapError("do not call this function")
@argmap(do_not_call)
def trivial_argmap():
yield from (1, 2, 3)
assert tuple(trivial_argmap()) == (1, 2, 3)
def test_contextmanager(self):
container = []
def contextmanager(x):
nonlocal container
return x, lambda: container.append(x)
@argmap(contextmanager, 0, 1, 2, try_finally=True)
def foo(x, y, z):
return x, y, z
x, y, z = foo("a", "b", "c")
# context exits are called in reverse
assert container == ["c", "b", "a"]
def test_tryfinally_generator(self):
container = []
def singleton(x):
return (x,)
with pytest.raises(nx.NetworkXError):
@argmap(singleton, 0, 1, 2, try_finally=True)
def foo(x, y, z):
yield from (x, y, z)
@argmap(singleton, 0, 1, 2)
def foo(x, y, z):
return x + y + z
q = foo("a", "b", "c")
assert q == ("a", "b", "c")
def test_actual_vararg(self):
@argmap(lambda x: -x, 4)
def foo(x, y, *args):
return (x, y) + tuple(args)
assert foo(1, 2, 3, 4, 5, 6) == (1, 2, 3, 4, -5, 6)
def test_signature_destroying_intermediate_decorator(self):
def add_one_to_first_bad_decorator(f):
"""Bad because it doesn't wrap the f signature (clobbers it)"""
def decorated(a, *args, **kwargs):
return f(a + 1, *args, **kwargs)
return decorated
add_two_to_second = argmap(lambda b: b + 2, 1)
@add_two_to_second
@add_one_to_first_bad_decorator
def add_one_and_two(a, b):
return a, b
assert add_one_and_two(5, 5) == (6, 7)
def test_actual_kwarg(self):
@argmap(lambda x: -x, "arg")
def foo(*, arg):
return arg
assert foo(arg=3) == -3
def test_nested_tuple(self):
def xform(x, y):
u, v = y
return x + u + v, (x + u, x + v)
# we're testing args and kwargs here, too
@argmap(xform, (0, ("t", 2)))
def foo(a, *args, **kwargs):
return a, args, kwargs
a, args, kwargs = foo(1, 2, 3, t=4)
assert a == 1 + 4 + 3
assert args == (2, 1 + 3)
assert kwargs == {"t": 1 + 4}
def test_flatten(self):
assert tuple(argmap._flatten([[[[[], []], [], []], [], [], []]], set())) == ()
rlist = ["a", ["b", "c"], [["d"], "e"], "f"]
assert "".join(argmap._flatten(rlist, set())) == "abcdef"
def test_indent(self):
code = "\n".join(
argmap._indent(
*[
"try:",
"try:",
"pass#",
"finally:",
"pass#",
"#",
"finally:",
"pass#",
]
)
)
assert (
code
== """try:
try:
pass#
finally:
pass#
#
finally:
pass#"""
)
def test_immediate_raise(self):
@not_implemented_for("directed")
def yield_nodes(G):
yield from G
G = nx.Graph([(1, 2)])
D = nx.DiGraph()
# test first call (argmap is compiled and executed)
with pytest.raises(nx.NetworkXNotImplemented):
node_iter = yield_nodes(D)
# test second call (argmap is only executed)
with pytest.raises(nx.NetworkXNotImplemented):
node_iter = yield_nodes(D)
# ensure that generators still make generators
node_iter = yield_nodes(G)
next(node_iter)
next(node_iter)
with pytest.raises(StopIteration):
next(node_iter)

View File

@@ -0,0 +1,131 @@
import pytest
import networkx as nx
from networkx.utils import BinaryHeap, PairingHeap
class X:
def __eq__(self, other):
raise self is other
def __ne__(self, other):
raise self is not other
def __lt__(self, other):
raise TypeError("cannot compare")
def __le__(self, other):
raise TypeError("cannot compare")
def __ge__(self, other):
raise TypeError("cannot compare")
def __gt__(self, other):
raise TypeError("cannot compare")
def __hash__(self):
return hash(id(self))
x = X()
data = [ # min should not invent an element.
("min", nx.NetworkXError),
# Popping an empty heap should fail.
("pop", nx.NetworkXError),
# Getting nonexisting elements should return None.
("get", 0, None),
("get", x, None),
("get", None, None),
# Inserting a new key should succeed.
("insert", x, 1, True),
("get", x, 1),
("min", (x, 1)),
# min should not pop the top element.
("min", (x, 1)),
# Inserting a new key of different type should succeed.
("insert", 1, -2.0, True),
# int and float values should interop.
("min", (1, -2.0)),
# pop removes minimum-valued element.
("insert", 3, -(10**100), True),
("insert", 4, 5, True),
("pop", (3, -(10**100))),
("pop", (1, -2.0)),
# Decrease-insert should succeed.
("insert", 4, -50, True),
("insert", 4, -60, False, True),
# Decrease-insert should not create duplicate keys.
("pop", (4, -60)),
("pop", (x, 1)),
# Popping all elements should empty the heap.
("min", nx.NetworkXError),
("pop", nx.NetworkXError),
# Non-value-changing insert should fail.
("insert", x, 0, True),
("insert", x, 0, False, False),
("min", (x, 0)),
("insert", x, 0, True, False),
("min", (x, 0)),
# Failed insert should not create duplicate keys.
("pop", (x, 0)),
("pop", nx.NetworkXError),
# Increase-insert should succeed when allowed.
("insert", None, 0, True),
("insert", 2, -1, True),
("min", (2, -1)),
("insert", 2, 1, True, False),
("min", (None, 0)),
# Increase-insert should fail when disallowed.
("insert", None, 2, False, False),
("min", (None, 0)),
# Failed increase-insert should not create duplicate keys.
("pop", (None, 0)),
("pop", (2, 1)),
("min", nx.NetworkXError),
("pop", nx.NetworkXError),
]
def _test_heap_class(cls, *args, **kwargs):
heap = cls(*args, **kwargs)
# Basic behavioral test
for op in data:
if op[-1] is not nx.NetworkXError:
assert op[-1] == getattr(heap, op[0])(*op[1:-1])
else:
pytest.raises(op[-1], getattr(heap, op[0]), *op[1:-1])
# Coverage test.
for i in range(99, -1, -1):
assert heap.insert(i, i)
for i in range(50):
assert heap.pop() == (i, i)
for i in range(100):
assert heap.insert(i, i) == (i < 50)
for i in range(100):
assert not heap.insert(i, i + 1)
for i in range(50):
assert heap.pop() == (i, i)
for i in range(100):
assert heap.insert(i, i + 1) == (i < 50)
for i in range(49):
assert heap.pop() == (i, i + 1)
assert sorted([heap.pop(), heap.pop()]) == [(49, 50), (50, 50)]
for i in range(51, 100):
assert not heap.insert(i, i + 1, True)
for i in range(51, 70):
assert heap.pop() == (i, i + 1)
for i in range(100):
assert heap.insert(i, i)
for i in range(100):
assert heap.pop() == (i, i)
pytest.raises(nx.NetworkXError, heap.pop)
def test_PairingHeap():
_test_heap_class(PairingHeap)
def test_BinaryHeap():
_test_heap_class(BinaryHeap)

View File

@@ -0,0 +1,268 @@
import pytest
from networkx.utils.mapped_queue import MappedQueue, _HeapElement
def test_HeapElement_gtlt():
bar = _HeapElement(1.1, "a")
foo = _HeapElement(1, "b")
assert foo < bar
assert bar > foo
assert foo < 1.1
assert 1 < bar
def test_HeapElement_gtlt_tied_priority():
bar = _HeapElement(1, "a")
foo = _HeapElement(1, "b")
assert foo > bar
assert bar < foo
def test_HeapElement_eq():
bar = _HeapElement(1.1, "a")
foo = _HeapElement(1, "a")
assert foo == bar
assert bar == foo
assert foo == "a"
def test_HeapElement_iter():
foo = _HeapElement(1, "a")
bar = _HeapElement(1.1, (3, 2, 1))
assert list(foo) == [1, "a"]
assert list(bar) == [1.1, 3, 2, 1]
def test_HeapElement_getitem():
foo = _HeapElement(1, "a")
bar = _HeapElement(1.1, (3, 2, 1))
assert foo[1] == "a"
assert foo[0] == 1
assert bar[0] == 1.1
assert bar[2] == 2
assert bar[3] == 1
pytest.raises(IndexError, bar.__getitem__, 4)
pytest.raises(IndexError, foo.__getitem__, 2)
class TestMappedQueue:
def setup_method(self):
pass
def _check_map(self, q):
assert q.position == {elt: pos for pos, elt in enumerate(q.heap)}
def _make_mapped_queue(self, h):
q = MappedQueue()
q.heap = h
q.position = {elt: pos for pos, elt in enumerate(h)}
return q
def test_heapify(self):
h = [5, 4, 3, 2, 1, 0]
q = self._make_mapped_queue(h)
q._heapify()
self._check_map(q)
def test_init(self):
h = [5, 4, 3, 2, 1, 0]
q = MappedQueue(h)
self._check_map(q)
def test_incomparable(self):
h = [5, 4, "a", 2, 1, 0]
pytest.raises(TypeError, MappedQueue, h)
def test_len(self):
h = [5, 4, 3, 2, 1, 0]
q = MappedQueue(h)
self._check_map(q)
assert len(q) == 6
def test_siftup_leaf(self):
h = [2]
h_sifted = [2]
q = self._make_mapped_queue(h)
q._siftup(0)
assert q.heap == h_sifted
self._check_map(q)
def test_siftup_one_child(self):
h = [2, 0]
h_sifted = [0, 2]
q = self._make_mapped_queue(h)
q._siftup(0)
assert q.heap == h_sifted
self._check_map(q)
def test_siftup_left_child(self):
h = [2, 0, 1]
h_sifted = [0, 2, 1]
q = self._make_mapped_queue(h)
q._siftup(0)
assert q.heap == h_sifted
self._check_map(q)
def test_siftup_right_child(self):
h = [2, 1, 0]
h_sifted = [0, 1, 2]
q = self._make_mapped_queue(h)
q._siftup(0)
assert q.heap == h_sifted
self._check_map(q)
def test_siftup_multiple(self):
h = [0, 1, 2, 4, 3, 5, 6]
h_sifted = [0, 1, 2, 4, 3, 5, 6]
q = self._make_mapped_queue(h)
q._siftup(0)
assert q.heap == h_sifted
self._check_map(q)
def test_siftdown_leaf(self):
h = [2]
h_sifted = [2]
q = self._make_mapped_queue(h)
q._siftdown(0, 0)
assert q.heap == h_sifted
self._check_map(q)
def test_siftdown_single(self):
h = [1, 0]
h_sifted = [0, 1]
q = self._make_mapped_queue(h)
q._siftdown(0, len(h) - 1)
assert q.heap == h_sifted
self._check_map(q)
def test_siftdown_multiple(self):
h = [1, 2, 3, 4, 5, 6, 7, 0]
h_sifted = [0, 1, 3, 2, 5, 6, 7, 4]
q = self._make_mapped_queue(h)
q._siftdown(0, len(h) - 1)
assert q.heap == h_sifted
self._check_map(q)
def test_push(self):
to_push = [6, 1, 4, 3, 2, 5, 0]
h_sifted = [0, 2, 1, 6, 3, 5, 4]
q = MappedQueue()
for elt in to_push:
q.push(elt)
assert q.heap == h_sifted
self._check_map(q)
def test_push_duplicate(self):
to_push = [2, 1, 0]
h_sifted = [0, 2, 1]
q = MappedQueue()
for elt in to_push:
inserted = q.push(elt)
assert inserted
assert q.heap == h_sifted
self._check_map(q)
inserted = q.push(1)
assert not inserted
def test_pop(self):
h = [3, 4, 6, 0, 1, 2, 5]
h_sorted = sorted(h)
q = self._make_mapped_queue(h)
q._heapify()
popped = [q.pop() for _ in range(len(h))]
assert popped == h_sorted
self._check_map(q)
def test_remove_leaf(self):
h = [0, 2, 1, 6, 3, 5, 4]
h_removed = [0, 2, 1, 6, 4, 5]
q = self._make_mapped_queue(h)
removed = q.remove(3)
assert q.heap == h_removed
def test_remove_root(self):
h = [0, 2, 1, 6, 3, 5, 4]
h_removed = [1, 2, 4, 6, 3, 5]
q = self._make_mapped_queue(h)
removed = q.remove(0)
assert q.heap == h_removed
def test_update_leaf(self):
h = [0, 20, 10, 60, 30, 50, 40]
h_updated = [0, 15, 10, 60, 20, 50, 40]
q = self._make_mapped_queue(h)
removed = q.update(30, 15)
assert q.heap == h_updated
def test_update_root(self):
h = [0, 20, 10, 60, 30, 50, 40]
h_updated = [10, 20, 35, 60, 30, 50, 40]
q = self._make_mapped_queue(h)
removed = q.update(0, 35)
assert q.heap == h_updated
class TestMappedDict(TestMappedQueue):
def _make_mapped_queue(self, h):
priority_dict = {elt: elt for elt in h}
return MappedQueue(priority_dict)
def test_init(self):
d = {5: 0, 4: 1, "a": 2, 2: 3, 1: 4}
q = MappedQueue(d)
assert q.position == d
def test_ties(self):
d = {5: 0, 4: 1, 3: 2, 2: 3, 1: 4}
q = MappedQueue(d)
assert q.position == {elt: pos for pos, elt in enumerate(q.heap)}
def test_pop(self):
d = {5: 0, 4: 1, 3: 2, 2: 3, 1: 4}
q = MappedQueue(d)
assert q.pop() == _HeapElement(0, 5)
assert q.position == {elt: pos for pos, elt in enumerate(q.heap)}
def test_empty_pop(self):
q = MappedQueue()
pytest.raises(IndexError, q.pop)
def test_incomparable_ties(self):
d = {5: 0, 4: 0, "a": 0, 2: 0, 1: 0}
pytest.raises(TypeError, MappedQueue, d)
def test_push(self):
to_push = [6, 1, 4, 3, 2, 5, 0]
h_sifted = [0, 2, 1, 6, 3, 5, 4]
q = MappedQueue()
for elt in to_push:
q.push(elt, priority=elt)
assert q.heap == h_sifted
self._check_map(q)
def test_push_duplicate(self):
to_push = [2, 1, 0]
h_sifted = [0, 2, 1]
q = MappedQueue()
for elt in to_push:
inserted = q.push(elt, priority=elt)
assert inserted
assert q.heap == h_sifted
self._check_map(q)
inserted = q.push(1, priority=1)
assert not inserted
def test_update_leaf(self):
h = [0, 20, 10, 60, 30, 50, 40]
h_updated = [0, 15, 10, 60, 20, 50, 40]
q = self._make_mapped_queue(h)
removed = q.update(30, 15, priority=15)
assert q.heap == h_updated
def test_update_root(self):
h = [0, 20, 10, 60, 30, 50, 40]
h_updated = [10, 20, 35, 60, 30, 50, 40]
q = self._make_mapped_queue(h)
removed = q.update(0, 35, priority=35)
assert q.heap == h_updated

View File

@@ -0,0 +1,255 @@
import random
from copy import copy
import pytest
import networkx as nx
from networkx.utils import (
PythonRandomInterface,
arbitrary_element,
create_py_random_state,
create_random_state,
dict_to_numpy_array,
discrete_sequence,
flatten,
groups,
make_list_of_ints,
pairwise,
powerlaw_sequence,
)
from networkx.utils.misc import _dict_to_numpy_array1, _dict_to_numpy_array2
nested_depth = (
1,
2,
(3, 4, ((5, 6, (7,), (8, (9, 10), 11), (12, 13, (14, 15)), 16), 17), 18, 19),
20,
)
nested_set = {
(1, 2, 3, 4),
(5, 6, 7, 8, 9),
(10, 11, (12, 13, 14), (15, 16, 17, 18)),
19,
20,
}
nested_mixed = [
1,
(2, 3, {4, (5, 6), 7}, [8, 9]),
{10: "foo", 11: "bar", (12, 13): "baz"},
{(14, 15): "qwe", 16: "asd"},
(17, (18, "19"), 20),
]
@pytest.mark.parametrize("result", [None, [], ["existing"], ["existing1", "existing2"]])
@pytest.mark.parametrize("nested", [nested_depth, nested_mixed, nested_set])
def test_flatten(nested, result):
if result is None:
val = flatten(nested, result)
assert len(val) == 20
else:
_result = copy(result) # because pytest passes parameters as is
nexisting = len(_result)
val = flatten(nested, _result)
assert len(val) == len(_result) == 20 + nexisting
assert issubclass(type(val), tuple)
def test_make_list_of_ints():
mylist = [1, 2, 3.0, 42, -2]
assert make_list_of_ints(mylist) is mylist
assert make_list_of_ints(mylist) == mylist
assert type(make_list_of_ints(mylist)[2]) is int
pytest.raises(nx.NetworkXError, make_list_of_ints, [1, 2, 3, "kermit"])
pytest.raises(nx.NetworkXError, make_list_of_ints, [1, 2, 3.1])
def test_random_number_distribution():
# smoke test only
z = powerlaw_sequence(20, exponent=2.5)
z = discrete_sequence(20, distribution=[0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 3])
class TestNumpyArray:
@classmethod
def setup_class(cls):
global np
np = pytest.importorskip("numpy")
def test_numpy_to_list_of_ints(self):
a = np.array([1, 2, 3], dtype=np.int64)
b = np.array([1.0, 2, 3])
c = np.array([1.1, 2, 3])
assert type(make_list_of_ints(a)) == list
assert make_list_of_ints(b) == list(b)
B = make_list_of_ints(b)
assert type(B[0]) == int
pytest.raises(nx.NetworkXError, make_list_of_ints, c)
def test__dict_to_numpy_array1(self):
d = {"a": 1, "b": 2}
a = _dict_to_numpy_array1(d, mapping={"a": 0, "b": 1})
np.testing.assert_allclose(a, np.array([1, 2]))
a = _dict_to_numpy_array1(d, mapping={"b": 0, "a": 1})
np.testing.assert_allclose(a, np.array([2, 1]))
a = _dict_to_numpy_array1(d)
np.testing.assert_allclose(a.sum(), 3)
def test__dict_to_numpy_array2(self):
d = {"a": {"a": 1, "b": 2}, "b": {"a": 10, "b": 20}}
mapping = {"a": 1, "b": 0}
a = _dict_to_numpy_array2(d, mapping=mapping)
np.testing.assert_allclose(a, np.array([[20, 10], [2, 1]]))
a = _dict_to_numpy_array2(d)
np.testing.assert_allclose(a.sum(), 33)
def test_dict_to_numpy_array_a(self):
d = {"a": {"a": 1, "b": 2}, "b": {"a": 10, "b": 20}}
mapping = {"a": 0, "b": 1}
a = dict_to_numpy_array(d, mapping=mapping)
np.testing.assert_allclose(a, np.array([[1, 2], [10, 20]]))
mapping = {"a": 1, "b": 0}
a = dict_to_numpy_array(d, mapping=mapping)
np.testing.assert_allclose(a, np.array([[20, 10], [2, 1]]))
a = _dict_to_numpy_array2(d)
np.testing.assert_allclose(a.sum(), 33)
def test_dict_to_numpy_array_b(self):
d = {"a": 1, "b": 2}
mapping = {"a": 0, "b": 1}
a = dict_to_numpy_array(d, mapping=mapping)
np.testing.assert_allclose(a, np.array([1, 2]))
a = _dict_to_numpy_array1(d)
np.testing.assert_allclose(a.sum(), 3)
def test_pairwise():
nodes = range(4)
node_pairs = [(0, 1), (1, 2), (2, 3)]
node_pairs_cycle = node_pairs + [(3, 0)]
assert list(pairwise(nodes)) == node_pairs
assert list(pairwise(iter(nodes))) == node_pairs
assert list(pairwise(nodes, cyclic=True)) == node_pairs_cycle
empty_iter = iter(())
assert list(pairwise(empty_iter)) == []
empty_iter = iter(())
assert list(pairwise(empty_iter, cyclic=True)) == []
def test_groups():
many_to_one = dict(zip("abcde", [0, 0, 1, 1, 2]))
actual = groups(many_to_one)
expected = {0: {"a", "b"}, 1: {"c", "d"}, 2: {"e"}}
assert actual == expected
assert {} == groups({})
def test_create_random_state():
np = pytest.importorskip("numpy")
rs = np.random.RandomState
assert isinstance(create_random_state(1), rs)
assert isinstance(create_random_state(None), rs)
assert isinstance(create_random_state(np.random), rs)
assert isinstance(create_random_state(rs(1)), rs)
# Support for numpy.random.Generator
rng = np.random.default_rng()
assert isinstance(create_random_state(rng), np.random.Generator)
pytest.raises(ValueError, create_random_state, "a")
assert np.all(rs(1).rand(10) == create_random_state(1).rand(10))
def test_create_py_random_state():
pyrs = random.Random
assert isinstance(create_py_random_state(1), pyrs)
assert isinstance(create_py_random_state(None), pyrs)
assert isinstance(create_py_random_state(pyrs(1)), pyrs)
pytest.raises(ValueError, create_py_random_state, "a")
np = pytest.importorskip("numpy")
rs = np.random.RandomState
rng = np.random.default_rng(1000)
rng_explicit = np.random.Generator(np.random.SFC64())
nprs = PythonRandomInterface
assert isinstance(create_py_random_state(np.random), nprs)
assert isinstance(create_py_random_state(rs(1)), nprs)
assert isinstance(create_py_random_state(rng), nprs)
assert isinstance(create_py_random_state(rng_explicit), nprs)
# test default rng input
assert isinstance(PythonRandomInterface(), nprs)
def test_PythonRandomInterface_RandomState():
np = pytest.importorskip("numpy")
rs = np.random.RandomState
rng = PythonRandomInterface(rs(42))
rs42 = rs(42)
# make sure these functions are same as expected outcome
assert rng.randrange(3, 5) == rs42.randint(3, 5)
assert rng.choice([1, 2, 3]) == rs42.choice([1, 2, 3])
assert rng.gauss(0, 1) == rs42.normal(0, 1)
assert rng.expovariate(1.5) == rs42.exponential(1 / 1.5)
assert np.all(rng.shuffle([1, 2, 3]) == rs42.shuffle([1, 2, 3]))
assert np.all(
rng.sample([1, 2, 3], 2) == rs42.choice([1, 2, 3], (2,), replace=False)
)
assert np.all(
[rng.randint(3, 5) for _ in range(100)]
== [rs42.randint(3, 6) for _ in range(100)]
)
assert rng.random() == rs42.random_sample()
def test_PythonRandomInterface_Generator():
np = pytest.importorskip("numpy")
rng = np.random.default_rng(42)
pri = PythonRandomInterface(np.random.default_rng(42))
# make sure these functions are same as expected outcome
assert pri.randrange(3, 5) == rng.integers(3, 5)
assert pri.choice([1, 2, 3]) == rng.choice([1, 2, 3])
assert pri.gauss(0, 1) == rng.normal(0, 1)
assert pri.expovariate(1.5) == rng.exponential(1 / 1.5)
assert np.all(pri.shuffle([1, 2, 3]) == rng.shuffle([1, 2, 3]))
assert np.all(
pri.sample([1, 2, 3], 2) == rng.choice([1, 2, 3], (2,), replace=False)
)
assert np.all(
[pri.randint(3, 5) for _ in range(100)]
== [rng.integers(3, 6) for _ in range(100)]
)
assert pri.random() == rng.random()
@pytest.mark.parametrize(
("iterable_type", "expected"), ((list, 1), (tuple, 1), (str, "["), (set, 1))
)
def test_arbitrary_element(iterable_type, expected):
iterable = iterable_type([1, 2, 3])
assert arbitrary_element(iterable) == expected
@pytest.mark.parametrize(
"iterator", ((i for i in range(3)), iter([1, 2, 3])) # generator
)
def test_arbitrary_element_raises(iterator):
"""Value error is raised when input is an iterator."""
with pytest.raises(ValueError, match="from an iterator"):
arbitrary_element(iterator)

View File

@@ -0,0 +1,38 @@
import pytest
from networkx.utils import (
powerlaw_sequence,
random_weighted_sample,
weighted_choice,
zipf_rv,
)
def test_degree_sequences():
seq = powerlaw_sequence(10, seed=1)
seq = powerlaw_sequence(10)
assert len(seq) == 10
def test_zipf_rv():
r = zipf_rv(2.3, xmin=2, seed=1)
r = zipf_rv(2.3, 2, 1)
r = zipf_rv(2.3)
assert type(r), int
pytest.raises(ValueError, zipf_rv, 0.5)
pytest.raises(ValueError, zipf_rv, 2, xmin=0)
def test_random_weighted_sample():
mapping = {"a": 10, "b": 20}
s = random_weighted_sample(mapping, 2, seed=1)
s = random_weighted_sample(mapping, 2)
assert sorted(s) == sorted(mapping.keys())
pytest.raises(ValueError, random_weighted_sample, mapping, 3)
def test_random_weighted_choice():
mapping = {"a": 10, "b": 0}
c = weighted_choice(mapping, seed=1)
c = weighted_choice(mapping)
assert c == "a"

View File

@@ -0,0 +1,63 @@
import networkx as nx
from networkx.utils import reverse_cuthill_mckee_ordering
def test_reverse_cuthill_mckee():
# example graph from
# http://www.boost.org/doc/libs/1_37_0/libs/graph/example/cuthill_mckee_ordering.cpp
G = nx.Graph(
[
(0, 3),
(0, 5),
(1, 2),
(1, 4),
(1, 6),
(1, 9),
(2, 3),
(2, 4),
(3, 5),
(3, 8),
(4, 6),
(5, 6),
(5, 7),
(6, 7),
]
)
rcm = list(reverse_cuthill_mckee_ordering(G))
assert rcm in [[0, 8, 5, 7, 3, 6, 2, 4, 1, 9], [0, 8, 5, 7, 3, 6, 4, 2, 1, 9]]
def test_rcm_alternate_heuristic():
# example from
G = nx.Graph(
[
(0, 0),
(0, 4),
(1, 1),
(1, 2),
(1, 5),
(1, 7),
(2, 2),
(2, 4),
(3, 3),
(3, 6),
(4, 4),
(5, 5),
(5, 7),
(6, 6),
(7, 7),
]
)
answers = [
[6, 3, 5, 7, 1, 2, 4, 0],
[6, 3, 7, 5, 1, 2, 4, 0],
[7, 5, 1, 2, 4, 0, 6, 3],
]
def smallest_degree(G):
deg, node = min((d, n) for n, d in G.degree())
return node
rcm = list(reverse_cuthill_mckee_ordering(G, heuristic=smallest_degree))
assert rcm in answers

View File

@@ -0,0 +1,55 @@
import networkx as nx
def test_unionfind():
# Fixed by: 2cddd5958689bdecdcd89b91ac9aaf6ce0e4f6b8
# Previously (in 2.x), the UnionFind class could handle mixed types.
# But in Python 3.x, this causes a TypeError such as:
# TypeError: unorderable types: str() > int()
#
# Now we just make sure that no exception is raised.
x = nx.utils.UnionFind()
x.union(0, "a")
def test_subtree_union():
# See https://github.com/networkx/networkx/pull/3224
# (35db1b551ee65780794a357794f521d8768d5049).
# Test if subtree unions hare handled correctly by to_sets().
uf = nx.utils.UnionFind()
uf.union(1, 2)
uf.union(3, 4)
uf.union(4, 5)
uf.union(1, 5)
assert list(uf.to_sets()) == [{1, 2, 3, 4, 5}]
def test_unionfind_weights():
# Tests if weights are computed correctly with unions of many elements
uf = nx.utils.UnionFind()
uf.union(1, 4, 7)
uf.union(2, 5, 8)
uf.union(3, 6, 9)
uf.union(1, 2, 3, 4, 5, 6, 7, 8, 9)
assert uf.weights[uf[1]] == 9
def test_unbalanced_merge_weights():
# Tests if the largest set's root is used as the new root when merging
uf = nx.utils.UnionFind()
uf.union(1, 2, 3)
uf.union(4, 5, 6, 7, 8, 9)
assert uf.weights[uf[1]] == 3
assert uf.weights[uf[4]] == 6
largest_root = uf[4]
uf.union(1, 4)
assert uf[1] == largest_root
assert uf.weights[largest_root] == 9
def test_empty_union():
# Tests if a null-union does nothing.
uf = nx.utils.UnionFind((0, 1))
uf.union()
assert uf[0] == 0
assert uf[1] == 1

View File

@@ -0,0 +1,106 @@
"""
Union-find data structure.
"""
from networkx.utils import groups
class UnionFind:
"""Union-find data structure.
Each unionFind instance X maintains a family of disjoint sets of
hashable objects, supporting the following two methods:
- X[item] returns a name for the set containing the given item.
Each set is named by an arbitrarily-chosen one of its members; as
long as the set remains unchanged it will keep the same name. If
the item is not yet part of a set in X, a new singleton set is
created for it.
- X.union(item1, item2, ...) merges the sets containing each item
into a single larger set. If any item is not yet part of a set
in X, it is added to X as one of the members of the merged set.
Union-find data structure. Based on Josiah Carlson's code,
https://code.activestate.com/recipes/215912/
with significant additional changes by D. Eppstein.
http://www.ics.uci.edu/~eppstein/PADS/UnionFind.py
"""
def __init__(self, elements=None):
"""Create a new empty union-find structure.
If *elements* is an iterable, this structure will be initialized
with the discrete partition on the given set of elements.
"""
if elements is None:
elements = ()
self.parents = {}
self.weights = {}
for x in elements:
self.weights[x] = 1
self.parents[x] = x
def __getitem__(self, object):
"""Find and return the name of the set containing the object."""
# check for previously unknown object
if object not in self.parents:
self.parents[object] = object
self.weights[object] = 1
return object
# find path of objects leading to the root
path = []
root = self.parents[object]
while root != object:
path.append(object)
object = root
root = self.parents[object]
# compress the path and return
for ancestor in path:
self.parents[ancestor] = root
return root
def __iter__(self):
"""Iterate through all items ever found or unioned by this structure."""
return iter(self.parents)
def to_sets(self):
"""Iterates over the sets stored in this structure.
For example::
>>> partition = UnionFind("xyz")
>>> sorted(map(sorted, partition.to_sets()))
[['x'], ['y'], ['z']]
>>> partition.union("x", "y")
>>> sorted(map(sorted, partition.to_sets()))
[['x', 'y'], ['z']]
"""
# Ensure fully pruned paths
for x in self.parents:
_ = self[x] # Evaluated for side-effect only
yield from groups(self.parents).values()
def union(self, *objects):
"""Find the sets containing the objects and merge them all."""
# Find the heaviest root according to its weight.
roots = iter(
sorted(
{self[x] for x in objects}, key=lambda r: self.weights[r], reverse=True
)
)
try:
root = next(roots)
except StopIteration:
return
for r in roots:
self.weights[root] += self.weights[r]
self.parents[r] = root