This commit is contained in:
ton
2024-10-07 10:13:40 +07:00
parent aa1631742f
commit 3a7d696db6
9729 changed files with 1832837 additions and 161742 deletions

View File

@@ -0,0 +1,7 @@
from .version_requirements import is_installed
import sys
import platform
has_mpl = is_installed("matplotlib", ">=3.3")
is_wasm = (sys.platform == "emscripten") or (platform.machine() in ["wasm32", "wasm64"])

View File

@@ -0,0 +1,54 @@
__all__ = ['polygon_clip', 'polygon_area']
import numpy as np
from .version_requirements import require
@require("matplotlib", ">=3.3")
def polygon_clip(rp, cp, r0, c0, r1, c1):
"""Clip a polygon to the given bounding box.
Parameters
----------
rp, cp : (K,) ndarray of double
Row and column coordinates of the polygon.
(r0, c0), (r1, c1) : double
Top-left and bottom-right coordinates of the bounding box.
Returns
-------
r_clipped, c_clipped : (L,) ndarray of double
Coordinates of clipped polygon.
Notes
-----
This makes use of Sutherland-Hodgman clipping as implemented in
AGG 2.4 and exposed in Matplotlib.
"""
from matplotlib import path, transforms
poly = path.Path(np.vstack((rp, cp)).T, closed=True)
clip_rect = transforms.Bbox([[r0, c0], [r1, c1]])
poly_clipped = poly.clip_to_bbox(clip_rect).to_polygons()[0]
return poly_clipped[:, 0], poly_clipped[:, 1]
def polygon_area(pr, pc):
"""Compute the area of a polygon.
Parameters
----------
pr, pc : (K,) array of float
Polygon row and column coordinates.
Returns
-------
a : float
Area of the polygon.
"""
pr = np.asarray(pr)
pc = np.asarray(pc)
return 0.5 * np.abs(np.sum((pc[:-1] * pr[1:]) - (pc[1:] * pr[:-1])))

View File

@@ -0,0 +1,28 @@
from tempfile import NamedTemporaryFile
from contextlib import contextmanager
import os
@contextmanager
def temporary_file(suffix=''):
"""Yield a writeable temporary filename that is deleted on context exit.
Parameters
----------
suffix : string, optional
The suffix for the file.
Examples
--------
>>> import numpy as np
>>> from skimage import io
>>> with temporary_file('.tif') as tempfile:
... im = np.arange(25, dtype=np.uint8).reshape((5, 5))
... io.imsave(tempfile, im)
... assert np.all(io.imread(tempfile) == im)
"""
with NamedTemporaryFile(suffix=suffix, delete=False) as tempfile_stream:
tempfile = tempfile_stream.name
yield tempfile
os.remove(tempfile)

View File

@@ -0,0 +1,149 @@
from contextlib import contextmanager
import sys
import warnings
import re
import functools
import os
__all__ = ['all_warnings', 'expected_warnings', 'warn']
# A version of `warnings.warn` with a default stacklevel of 2.
# functool is used so as not to increase the call stack accidentally
warn = functools.partial(warnings.warn, stacklevel=2)
@contextmanager
def all_warnings():
"""
Context for use in testing to ensure that all warnings are raised.
Examples
--------
>>> import warnings
>>> def foo():
... warnings.warn(RuntimeWarning("bar"), stacklevel=2)
We raise the warning once, while the warning filter is set to "once".
Hereafter, the warning is invisible, even with custom filters:
>>> with warnings.catch_warnings():
... warnings.simplefilter('once')
... foo() # doctest: +SKIP
We can now run ``foo()`` without a warning being raised:
>>> from numpy.testing import assert_warns
>>> foo() # doctest: +SKIP
To catch the warning, we call in the help of ``all_warnings``:
>>> with all_warnings():
... assert_warns(RuntimeWarning, foo)
"""
# _warnings.py is on the critical import path.
# Since this is a testing only function, we lazy import inspect.
import inspect
# Whenever a warning is triggered, Python adds a __warningregistry__
# member to the *calling* module. The exercise here is to find
# and eradicate all those breadcrumbs that were left lying around.
#
# We proceed by first searching all parent calling frames and explicitly
# clearing their warning registries (necessary for the doctests above to
# pass). Then, we search for all submodules of skimage and clear theirs
# as well (necessary for the skimage test suite to pass).
frame = inspect.currentframe()
if frame:
for f in inspect.getouterframes(frame):
f[0].f_locals['__warningregistry__'] = {}
del frame
for mod_name, mod in list(sys.modules.items()):
try:
mod.__warningregistry__.clear()
except AttributeError:
pass
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
yield w
@contextmanager
def expected_warnings(matching):
r"""Context for use in testing to catch known warnings matching regexes
Parameters
----------
matching : None or a list of strings or compiled regexes
Regexes for the desired warning to catch
If matching is None, this behaves as a no-op.
Examples
--------
>>> import numpy as np
>>> rng = np.random.default_rng()
>>> image = rng.integers(0, 2**16, size=(100, 100), dtype=np.uint16)
>>> # rank filters are slow when bit-depth exceeds 10 bits
>>> from skimage import filters
>>> with expected_warnings(['Bad rank filter performance']):
... median_filtered = filters.rank.median(image)
Notes
-----
Uses `all_warnings` to ensure all warnings are raised.
Upon exiting, it checks the recorded warnings for the desired matching
pattern(s).
Raises a ValueError if any match was not found or an unexpected
warning was raised.
Allows for three types of behaviors: `and`, `or`, and `optional` matches.
This is done to accommodate different build environments or loop conditions
that may produce different warnings. The behaviors can be combined.
If you pass multiple patterns, you get an orderless `and`, where all of the
warnings must be raised.
If you use the `|` operator in a pattern, you can catch one of several
warnings.
Finally, you can use `|\A\Z` in a pattern to signify it as optional.
"""
if isinstance(matching, str):
raise ValueError(
'``matching`` should be a list of strings and not ' 'a string itself.'
)
# Special case for disabling the context manager
if matching is None:
yield None
return
strict_warnings = os.environ.get('SKIMAGE_TEST_STRICT_WARNINGS', '1')
if strict_warnings.lower() == 'true':
strict_warnings = True
elif strict_warnings.lower() == 'false':
strict_warnings = False
else:
strict_warnings = bool(int(strict_warnings))
with all_warnings() as w:
# enter context
yield w
# exited user context, check the recorded warnings
# Allow users to provide None
while None in matching:
matching.remove(None)
remaining = [m for m in matching if r'\A\Z' not in m.split('|')]
for warn in w:
found = False
for match in matching:
if re.search(match, str(warn.message)) is not None:
found = True
if match in remaining:
remaining.remove(match)
if strict_warnings and not found:
raise ValueError(f'Unexpected warning: {str(warn.message)}')
if strict_warnings and (len(remaining) > 0):
newline = "\n"
msg = f"No warning raised matching:{newline}{newline.join(remaining)}"
raise ValueError(msg)

View File

@@ -0,0 +1,30 @@
"""Compatibility helpers for dependencies."""
from packaging.version import parse
import numpy as np
import scipy as sp
__all__ = [
"NP_COPY_IF_NEEDED",
"SCIPY_CG_TOL_PARAM_NAME",
]
NUMPY_LT_2_0_0 = parse(np.__version__) < parse('2.0.0.dev0')
# With NumPy 2.0.0, `copy=False` now raises a ValueError if the copy cannot be
# made. The previous behavior to only copy if needed is provided with `copy=None`.
# During the transition period, use this symbol instead.
# Remove once NumPy 2.0.0 is the minimal required version.
# https://numpy.org/devdocs/release/2.0.0-notes.html#new-copy-keyword-meaning-for-array-and-asarray-constructors
# https://github.com/numpy/numpy/pull/25168
NP_COPY_IF_NEEDED = False if NUMPY_LT_2_0_0 else None
SCIPY_LT_1_12 = parse(sp.__version__) < parse('1.12')
# Starting in SciPy v1.12, 'scipy.sparse.linalg.cg' keyword argument `tol` is
# deprecated in favor of `rtol`.
SCIPY_CG_TOL_PARAM_NAME = "tol" if SCIPY_LT_1_12 else "rtol"

View File

@@ -0,0 +1,125 @@
import numpy as np
from scipy.spatial import cKDTree, distance
def _ensure_spacing(coord, spacing, p_norm, max_out):
"""Returns a subset of coord where a minimum spacing is guaranteed.
Parameters
----------
coord : ndarray
The coordinates of the considered points.
spacing : float
the maximum allowed spacing between the points.
p_norm : float
Which Minkowski p-norm to use. Should be in the range [1, inf].
A finite large p may cause a ValueError if overflow can occur.
``inf`` corresponds to the Chebyshev distance and 2 to the
Euclidean distance.
max_out: int
If not None, at most the first ``max_out`` candidates are
returned.
Returns
-------
output : ndarray
A subset of coord where a minimum spacing is guaranteed.
"""
# Use KDtree to find the peaks that are too close to each other
tree = cKDTree(coord)
indices = tree.query_ball_point(coord, r=spacing, p=p_norm)
rejected_peaks_indices = set()
naccepted = 0
for idx, candidates in enumerate(indices):
if idx not in rejected_peaks_indices:
# keep current point and the points at exactly spacing from it
candidates.remove(idx)
dist = distance.cdist(
[coord[idx]], coord[candidates], distance.minkowski, p=p_norm
).reshape(-1)
candidates = [c for c, d in zip(candidates, dist) if d < spacing]
# candidates.remove(keep)
rejected_peaks_indices.update(candidates)
naccepted += 1
if max_out is not None and naccepted >= max_out:
break
# Remove the peaks that are too close to each other
output = np.delete(coord, tuple(rejected_peaks_indices), axis=0)
if max_out is not None:
output = output[:max_out]
return output
def ensure_spacing(
coords,
spacing=1,
p_norm=np.inf,
min_split_size=50,
max_out=None,
*,
max_split_size=2000,
):
"""Returns a subset of coord where a minimum spacing is guaranteed.
Parameters
----------
coords : array_like
The coordinates of the considered points.
spacing : float
the maximum allowed spacing between the points.
p_norm : float
Which Minkowski p-norm to use. Should be in the range [1, inf].
A finite large p may cause a ValueError if overflow can occur.
``inf`` corresponds to the Chebyshev distance and 2 to the
Euclidean distance.
min_split_size : int
Minimum split size used to process ``coords`` by batch to save
memory. If None, the memory saving strategy is not applied.
max_out : int
If not None, only the first ``max_out`` candidates are returned.
max_split_size : int
Maximum split size used to process ``coords`` by batch to save
memory. This number was decided by profiling with a large number
of points. Too small a number results in too much looping in
Python instead of C, slowing down the process, while too large
a number results in large memory allocations, slowdowns, and,
potentially, in the process being killed -- see gh-6010. See
benchmark results `here
<https://github.com/scikit-image/scikit-image/pull/6035#discussion_r751518691>`_.
Returns
-------
output : array_like
A subset of coord where a minimum spacing is guaranteed.
"""
output = coords
if len(coords):
coords = np.atleast_2d(coords)
if min_split_size is None:
batch_list = [coords]
else:
coord_count = len(coords)
split_idx = [min_split_size]
split_size = min_split_size
while coord_count - split_idx[-1] > max_split_size:
split_size *= 2
split_idx.append(split_idx[-1] + min(split_size, max_split_size))
batch_list = np.array_split(coords, split_idx)
output = np.zeros((0, coords.shape[1]), dtype=coords.dtype)
for batch in batch_list:
output = _ensure_spacing(
np.vstack([output, batch]), spacing, p_norm, max_out
)
if max_out is not None and len(output) >= max_out:
break
return output

View File

@@ -0,0 +1,73 @@
import numpy as np
# Define classes of supported dtypes and Python scalar types
# Variables ending in `_dtypes` only contain numpy.dtypes of the respective
# class; variables ending in `_types` additionally include Python scalar types.
signed_integer_dtypes = {np.int8, np.int16, np.int32, np.int64}
signed_integer_types = signed_integer_dtypes | {int}
unsigned_integer_dtypes = {np.uint8, np.uint16, np.uint32, np.uint64}
integer_dtypes = signed_integer_dtypes | unsigned_integer_dtypes
integer_types = signed_integer_types | unsigned_integer_dtypes
floating_dtypes = {np.float16, np.float32, np.float64}
floating_types = floating_dtypes | {float}
complex_dtypes = {np.complex64, np.complex128}
complex_types = complex_dtypes | {complex}
inexact_dtypes = floating_dtypes | complex_dtypes
inexact_types = floating_types | complex_types
bool_types = {np.dtype(bool), bool}
numeric_dtypes = integer_dtypes | inexact_dtypes | {np.bool_}
numeric_types = integer_types | inexact_types | bool_types
def numeric_dtype_min_max(dtype):
"""Return minimum and maximum representable value for a given dtype.
A convenient wrapper around `numpy.finfo` and `numpy.iinfo` that
additionally supports numpy.bool as well.
Parameters
----------
dtype : numpy.dtype
The dtype. Tries to convert Python "types" such as int or float, to
the corresponding NumPy dtype.
Returns
-------
min, max : number
Minimum and maximum of the given `dtype`. These scalars are themselves
of the given `dtype`.
Examples
--------
>>> import numpy as np
>>> numeric_dtype_min_max(np.uint8)
(0, 255)
>>> numeric_dtype_min_max(bool)
(False, True)
>>> numeric_dtype_min_max(np.float64)
(-1.7976931348623157e+308, 1.7976931348623157e+308)
>>> numeric_dtype_min_max(int)
(-9223372036854775808, 9223372036854775807)
"""
dtype = np.dtype(dtype)
if np.issubdtype(dtype, np.integer):
info = np.iinfo(dtype)
min_ = dtype.type(info.min)
max_ = dtype.type(info.max)
elif np.issubdtype(dtype, np.inexact):
info = np.finfo(dtype)
min_ = info.min
max_ = info.max
elif np.issubdtype(dtype, np.dtype(bool)):
min_ = dtype.type(False)
max_ = dtype.type(True)
else:
raise ValueError(f"unsupported dtype {dtype!r}")
return min_, max_

View File

@@ -0,0 +1,47 @@
/* A fast approximation of the exponential function.
* Reference [1]: https://schraudolph.org/pubs/Schraudolph99.pdf
* Reference [2]: https://doi.org/10.1162/089976600300015033
* Additional improvements by Leonid Bloch. */
#include <stdint.h>
/* use just EXP_A = 1512775 for integer version, to avoid FP calculations */
#define EXP_A (1512775.3951951856938) /* 2^20/ln2 */
/* For min. RMS error */
#define EXP_BC 1072632447 /* 1023*2^20 - 60801 */
/* For min. max. relative error */
/* #define EXP_BC 1072647449 */ /* 1023*2^20 - 45799 */
/* For min. mean relative error */
/* #define EXP_BC 1072625005 */ /* 1023*2^20 - 68243 */
__inline double _fast_exp (double y)
{
union
{
double d;
struct { int32_t i, j; } n;
char t[8];
} _eco;
_eco.n.i = 1;
switch(_eco.t[0]) {
case 1:
/* Little endian */
_eco.n.j = (int32_t)(EXP_A*(y)) + EXP_BC;
_eco.n.i = 0;
break;
case 0:
/* Big endian */
_eco.n.i = (int32_t)(EXP_A*(y)) + EXP_BC;
_eco.n.j = 0;
break;
}
return _eco.d;
}
__inline float _fast_expf (float y)
{
return (float)_fast_exp((double)y);
}

View File

@@ -0,0 +1,142 @@
"""Filters used across multiple skimage submodules.
These are defined here to avoid circular imports.
The unit tests remain under skimage/filters/tests/
"""
from collections.abc import Iterable
import numpy as np
from scipy import ndimage as ndi
from .._shared.utils import (
_supported_float_type,
convert_to_float,
deprecate_parameter,
DEPRECATED,
)
@deprecate_parameter(
"output", new_name="out", start_version="0.23", stop_version="0.25"
)
def gaussian(
image,
sigma=1,
output=DEPRECATED,
mode='nearest',
cval=0,
preserve_range=False,
truncate=4.0,
*,
channel_axis=None,
out=None,
):
"""Multi-dimensional Gaussian filter.
Parameters
----------
image : ndarray
Input image (grayscale or color) to filter.
sigma : scalar or sequence of scalars, optional
Standard deviation for Gaussian kernel. The standard
deviations of the Gaussian filter are given for each axis as a
sequence, or as a single number, in which case it is equal for
all axes.
mode : {'reflect', 'constant', 'nearest', 'mirror', 'wrap'}, optional
The ``mode`` parameter determines how the array borders are
handled, where ``cval`` is the value when mode is equal to
'constant'. Default is 'nearest'.
cval : scalar, optional
Value to fill past edges of input if ``mode`` is 'constant'. Default
is 0.0
preserve_range : bool, optional
If True, keep the original range of values. Otherwise, the input
``image`` is converted according to the conventions of ``img_as_float``
(Normalized first to values [-1.0 ; 1.0] or [0 ; 1.0] depending on
dtype of input)
For more information, see:
https://scikit-image.org/docs/dev/user_guide/data_types.html
truncate : float, optional
Truncate the filter at this many standard deviations.
channel_axis : int or None, optional
If None, the image is assumed to be a grayscale (single channel) image.
Otherwise, this parameter indicates which axis of the array corresponds
to channels.
.. versionadded:: 0.19
`channel_axis` was added in 0.19.
out : ndarray, optional
If given, the filtered image will be stored in this array.
.. versionadded:: 0.23
`out` was added in 0.23.
Returns
-------
filtered_image : ndarray
the filtered array
Notes
-----
This function is a wrapper around :func:`scipy.ndimage.gaussian_filter`.
Integer arrays are converted to float.
`out` should be of floating-point data type since `gaussian` converts the
input `image` to float. If `out` is not provided, another array
will be allocated and returned as the result.
The multi-dimensional filter is implemented as a sequence of
one-dimensional convolution filters. The intermediate arrays are
stored in the same data type as the output. Therefore, for output
types with a limited precision, the results may be imprecise
because intermediate results may be stored with insufficient
precision.
Examples
--------
>>> import skimage as ski
>>> a = np.zeros((3, 3))
>>> a[1, 1] = 1
>>> a
array([[0., 0., 0.],
[0., 1., 0.],
[0., 0., 0.]])
>>> ski.filters.gaussian(a, sigma=0.4) # mild smoothing
array([[0.00163116, 0.03712502, 0.00163116],
[0.03712502, 0.84496158, 0.03712502],
[0.00163116, 0.03712502, 0.00163116]])
>>> ski.filters.gaussian(a, sigma=1) # more smoothing
array([[0.05855018, 0.09653293, 0.05855018],
[0.09653293, 0.15915589, 0.09653293],
[0.05855018, 0.09653293, 0.05855018]])
>>> # Several modes are possible for handling boundaries
>>> ski.filters.gaussian(a, sigma=1, mode='reflect')
array([[0.08767308, 0.12075024, 0.08767308],
[0.12075024, 0.16630671, 0.12075024],
[0.08767308, 0.12075024, 0.08767308]])
>>> # For RGB images, each is filtered separately
>>> image = ski.data.astronaut()
>>> filtered_img = ski.filters.gaussian(image, sigma=1, channel_axis=-1)
"""
if np.any(np.asarray(sigma) < 0.0):
raise ValueError("Sigma values less than zero are not valid")
if channel_axis is not None:
# do not filter across channels
if not isinstance(sigma, Iterable):
sigma = [sigma] * (image.ndim - 1)
if len(sigma) == image.ndim - 1:
sigma = list(sigma)
sigma.insert(channel_axis % image.ndim, 0)
image = convert_to_float(image, preserve_range)
float_dtype = _supported_float_type(image.dtype)
image = image.astype(float_dtype, copy=False)
if (out is not None) and (not np.issubdtype(out.dtype, np.floating)):
raise ValueError(f"dtype of `out` must be float; got {out.dtype!r}.")
return ndi.gaussian_filter(
image, sigma, output=out, mode=mode, cval=cval, truncate=truncate
)

View File

@@ -0,0 +1,131 @@
import os
import sys
def _show_skimage_info():
import skimage
print(f"skimage version {skimage.__version__}")
class PytestTester:
"""
Pytest test runner.
This class is made available in ``skimage._shared.testing``, and a test
function is typically added to a package's __init__.py like so::
from skimage._shared.testing import PytestTester
test = PytestTester(__name__)
del PytestTester
Calling this test function finds and runs all tests associated with the
module and all its sub-modules.
Attributes
----------
module_name : str
Full path to the package to test.
Parameters
----------
module_name : module name
The name of the module to test.
"""
def __init__(self, module_name):
self.module_name = module_name
def __call__(
self,
label='fast',
verbose=1,
extra_argv=None,
doctests=False,
coverage=False,
durations=-1,
tests=None,
):
"""
Run tests for module using pytest.
Parameters
----------
label : {'fast', 'full'}, optional
Identifies the tests to run. When set to 'fast', tests decorated
with `pytest.mark.slow` are skipped, when 'full', the slow marker
is ignored.
verbose : int, optional
Verbosity value for test outputs, in the range 1-3. Default is 1.
extra_argv : list, optional
List with any extra arguments to pass to pytests.
doctests : bool, optional
.. note:: Not supported
coverage : bool, optional
If True, report coverage of scikit-image code. Default is False.
Requires installation of (pip) pytest-cov.
durations : int, optional
If < 0, do nothing, If 0, report time of all tests, if > 0,
report the time of the slowest `timer` tests. Default is -1.
tests : test or list of tests
Tests to be executed with pytest '--pyargs'
Returns
-------
result : bool
Return True on success, false otherwise.
"""
import pytest
module = sys.modules[self.module_name]
module_path = os.path.abspath(module.__path__[0])
# setup the pytest arguments
pytest_args = ["-l"]
# offset verbosity. The "-q" cancels a "-v".
pytest_args += ["-q"]
# Filter out annoying import messages. Want these in both develop and
# release mode.
pytest_args += [
"-W ignore:Not importing directory",
"-W ignore:numpy.dtype size changed",
"-W ignore:numpy.ufunc size changed",
]
if doctests:
raise ValueError("Doctests not supported")
if extra_argv:
pytest_args += list(extra_argv)
if verbose > 1:
pytest_args += ["-" + "v" * (verbose - 1)]
if coverage:
pytest_args += ["--cov=" + module_path]
if label == "fast":
pytest_args += ["-m", "not slow"]
elif label != "full":
pytest_args += ["-m", label]
if durations >= 0:
pytest_args += [f"--durations={durations}"]
if tests is None:
tests = [self.module_name]
pytest_args += ["--pyargs"] + list(tests)
# run tests.
_show_skimage_info()
try:
code = pytest.main(pytest_args)
except SystemExit as exc:
code = exc.code
return code == 0

View File

@@ -0,0 +1,303 @@
"""
Testing utilities.
"""
import os
import platform
import re
import struct
import sys
import functools
import inspect
from tempfile import NamedTemporaryFile
import numpy as np
from numpy import testing
from numpy.testing import (
TestCase,
assert_,
assert_warns,
assert_no_warnings,
assert_equal,
assert_almost_equal,
assert_array_equal,
assert_allclose,
assert_array_almost_equal,
assert_array_almost_equal_nulp,
assert_array_less,
)
from .. import data, io
from ..data._fetchers import _fetch
from ..util import img_as_uint, img_as_float, img_as_int, img_as_ubyte
from ._warnings import expected_warnings
from ._dependency_checks import is_wasm
import pytest
skipif = pytest.mark.skipif
xfail = pytest.mark.xfail
parametrize = pytest.mark.parametrize
raises = pytest.raises
fixture = pytest.fixture
SKIP_RE = re.compile(r"(\s*>>>.*?)(\s*)#\s*skip\s+if\s+(.*)$")
# true if python is running in 32bit mode
# Calculate the size of a void * pointer in bits
# https://docs.python.org/3/library/struct.html
arch32 = struct.calcsize("P") * 8 == 32
def assert_less(a, b, msg=None):
message = f"{a!r} is not lower than {b!r}"
if msg is not None:
message += ": " + msg
assert a < b, message
def assert_greater(a, b, msg=None):
message = f"{a!r} is not greater than {b!r}"
if msg is not None:
message += ": " + msg
assert a > b, message
def doctest_skip_parser(func):
"""Decorator replaces custom skip test markup in doctests
Say a function has a docstring::
>>> something, HAVE_AMODULE, HAVE_BMODULE = 0, False, False
>>> something # skip if not HAVE_AMODULE
0
>>> something # skip if HAVE_BMODULE
0
This decorator will evaluate the expression after ``skip if``. If this
evaluates to True, then the comment is replaced by ``# doctest: +SKIP``. If
False, then the comment is just removed. The expression is evaluated in the
``globals`` scope of `func`.
For example, if the module global ``HAVE_AMODULE`` is False, and module
global ``HAVE_BMODULE`` is False, the returned function will have docstring::
>>> something # doctest: +SKIP
>>> something + else # doctest: +SKIP
>>> something # doctest: +SKIP
"""
lines = func.__doc__.split('\n')
new_lines = []
for line in lines:
match = SKIP_RE.match(line)
if match is None:
new_lines.append(line)
continue
code, space, expr = match.groups()
try:
# Works as a function decorator
if eval(expr, func.__globals__):
code = code + space + "# doctest: +SKIP"
except AttributeError:
# Works as a class decorator
if eval(expr, func.__init__.__globals__):
code = code + space + "# doctest: +SKIP"
new_lines.append(code)
func.__doc__ = "\n".join(new_lines)
return func
def roundtrip(image, plugin, suffix):
"""Save and read an image using a specified plugin"""
if '.' not in suffix:
suffix = '.' + suffix
with NamedTemporaryFile(suffix=suffix, delete=False) as temp_file:
fname = temp_file.name
io.imsave(fname, image, plugin=plugin)
new = io.imread(fname, plugin=plugin)
try:
os.remove(fname)
except Exception:
pass
return new
def color_check(plugin, fmt='png'):
"""Check roundtrip behavior for color images.
All major input types should be handled as ubytes and read
back correctly.
"""
img = img_as_ubyte(data.chelsea())
r1 = roundtrip(img, plugin, fmt)
testing.assert_allclose(img, r1)
img2 = img > 128
r2 = roundtrip(img2, plugin, fmt)
testing.assert_allclose(img2, r2.astype(bool))
img3 = img_as_float(img)
r3 = roundtrip(img3, plugin, fmt)
testing.assert_allclose(r3, img)
img4 = img_as_int(img)
if fmt.lower() in (('tif', 'tiff')):
img4 -= 100
r4 = roundtrip(img4, plugin, fmt)
testing.assert_allclose(r4, img4)
else:
r4 = roundtrip(img4, plugin, fmt)
testing.assert_allclose(r4, img_as_ubyte(img4))
img5 = img_as_uint(img)
r5 = roundtrip(img5, plugin, fmt)
testing.assert_allclose(r5, img)
def mono_check(plugin, fmt='png'):
"""Check the roundtrip behavior for images that support most types.
All major input types should be handled.
"""
img = img_as_ubyte(data.moon())
r1 = roundtrip(img, plugin, fmt)
testing.assert_allclose(img, r1)
img2 = img > 128
r2 = roundtrip(img2, plugin, fmt)
testing.assert_allclose(img2, r2.astype(bool))
img3 = img_as_float(img)
r3 = roundtrip(img3, plugin, fmt)
if r3.dtype.kind == 'f':
testing.assert_allclose(img3, r3)
else:
testing.assert_allclose(r3, img_as_uint(img))
img4 = img_as_int(img)
if fmt.lower() in (('tif', 'tiff')):
img4 -= 100
r4 = roundtrip(img4, plugin, fmt)
testing.assert_allclose(r4, img4)
else:
r4 = roundtrip(img4, plugin, fmt)
testing.assert_allclose(r4, img_as_uint(img4))
img5 = img_as_uint(img)
r5 = roundtrip(img5, plugin, fmt)
testing.assert_allclose(r5, img5)
def fetch(data_filename):
"""Attempt to fetch data, but if unavailable, skip the tests."""
try:
return _fetch(data_filename)
except (ConnectionError, ModuleNotFoundError):
pytest.skip(f'Unable to download {data_filename}', allow_module_level=True)
# Ref: about the lack of threading support in WASM, please see
# https://github.com/pyodide/pyodide/issues/237
def run_in_parallel(num_threads=2, warnings_matching=None):
"""Decorator to run the same function multiple times in parallel.
This decorator is useful to ensure that separate threads execute
concurrently and correctly while releasing the GIL.
It is currently skipped when running on WASM-based platforms, as
the threading module is not supported.
Parameters
----------
num_threads : int, optional
The number of times the function is run in parallel.
warnings_matching: list or None
This parameter is passed on to `expected_warnings` so as not to have
race conditions with the warnings filters. A single
`expected_warnings` context manager is used for all threads.
If None, then no warnings are checked.
"""
assert num_threads > 0
def wrapper(func):
if is_wasm:
# Threading isn't supported on WASM, return early
return func
import threading
@functools.wraps(func)
def inner(*args, **kwargs):
with expected_warnings(warnings_matching):
threads = []
for i in range(num_threads - 1):
thread = threading.Thread(target=func, args=args, kwargs=kwargs)
threads.append(thread)
for thread in threads:
thread.start()
func(*args, **kwargs)
for thread in threads:
thread.join()
return inner
return wrapper
def assert_stacklevel(warnings, *, offset=-1):
"""Assert correct stacklevel of captured warnings.
When scikit-image raises warnings, the stacklevel should ideally be set
so that the origin of the warnings will point to the public function
that was called by the user and not necessarily the very place where the
warnings were emitted (which may be inside of some internal function).
This utility function helps with checking that
the stacklevel was set correctly on warnings captured by `pytest.warns`.
Parameters
----------
warnings : collections.abc.Iterable[warning.WarningMessage]
Warnings that were captured by `pytest.warns`.
offset : int, optional
Offset from the line this function is called to the line were the
warning is supposed to originate from. For multiline calls, the
first line is relevant. Defaults to -1 which corresponds to the line
right above the one where this function is called.
Raises
------
AssertionError
If a warning in `warnings` does not match the expected line number or
file name.
Examples
--------
>>> def test_something():
... with pytest.warns(UserWarning, match="some message") as record:
... something_raising_a_warning()
... assert_stacklevel(record)
...
>>> def test_another_thing():
... with pytest.warns(UserWarning, match="some message") as record:
... iam_raising_many_warnings(
... "A long argument that forces the call to wrap."
... )
... assert_stacklevel(record, offset=-3)
"""
frame = inspect.stack()[1].frame # 0 is current frame, 1 is outer frame
line_number = frame.f_lineno + offset
filename = frame.f_code.co_filename
expected = f"{filename}:{line_number}"
for warning in warnings:
actual = f"{warning.filename}:{warning.lineno}"
assert actual == expected, f"{actual} != {expected}"

View File

@@ -0,0 +1,91 @@
import time
import numpy as np
import pytest
from scipy.spatial.distance import pdist, minkowski
from skimage._shared.coord import ensure_spacing
@pytest.mark.parametrize("p", [1, 2, np.inf])
@pytest.mark.parametrize("size", [30, 50, None])
def test_ensure_spacing_trivial(p, size):
# --- Empty input
assert ensure_spacing([], p_norm=p) == []
# --- A unique point
coord = np.random.randn(1, 2)
assert np.array_equal(coord, ensure_spacing(coord, p_norm=p, min_split_size=size))
# --- Verified spacing
coord = np.random.randn(100, 2)
# --- 0 spacing
assert np.array_equal(
coord, ensure_spacing(coord, spacing=0, p_norm=p, min_split_size=size)
)
# Spacing is chosen to be half the minimum distance
spacing = pdist(coord, metric=minkowski, p=p).min() * 0.5
out = ensure_spacing(coord, spacing=spacing, p_norm=p, min_split_size=size)
assert np.array_equal(coord, out)
@pytest.mark.parametrize("ndim", [1, 2, 3, 4, 5])
@pytest.mark.parametrize("size", [2, 10, None])
def test_ensure_spacing_nD(ndim, size):
coord = np.ones((5, ndim))
expected = np.ones((1, ndim))
assert np.array_equal(ensure_spacing(coord, min_split_size=size), expected)
@pytest.mark.parametrize("p", [1, 2, np.inf])
@pytest.mark.parametrize("size", [50, 100, None])
def test_ensure_spacing_batch_processing(p, size):
coord = np.random.randn(100, 2)
# --- Consider the average distance btween the point as spacing
spacing = np.median(pdist(coord, metric=minkowski, p=p))
expected = ensure_spacing(coord, spacing=spacing, p_norm=p)
assert np.array_equal(
ensure_spacing(coord, spacing=spacing, p_norm=p, min_split_size=size), expected
)
def test_max_batch_size():
"""Small batches are slow, large batches -> large allocations -> also slow.
https://github.com/scikit-image/scikit-image/pull/6035#discussion_r751518691
"""
coords = np.random.randint(low=0, high=1848, size=(40000, 2))
tstart = time.time()
ensure_spacing(coords, spacing=100, min_split_size=50, max_split_size=2000)
dur1 = time.time() - tstart
tstart = time.time()
ensure_spacing(coords, spacing=100, min_split_size=50, max_split_size=20000)
dur2 = time.time() - tstart
# Originally checked dur1 < dur2 to assert that the default batch size was
# faster than a much larger batch size. However, on rare occasion a CI test
# case would fail with dur1 ~5% larger than dur2. To be more robust to
# variable load or differences across architectures, we relax this here.
assert dur1 < 1.33 * dur2
@pytest.mark.parametrize("p", [1, 2, np.inf])
@pytest.mark.parametrize("size", [30, 50, None])
def test_ensure_spacing_p_norm(p, size):
coord = np.random.randn(100, 2)
# --- Consider the average distance btween the point as spacing
spacing = np.median(pdist(coord, metric=minkowski, p=p))
out = ensure_spacing(coord, spacing=spacing, p_norm=p, min_split_size=size)
assert pdist(out, metric=minkowski, p=p).min() > spacing

View File

@@ -0,0 +1,14 @@
import numpy as np
import pytest
from ..dtype import numeric_dtype_min_max, numeric_types
class Test_numeric_dtype_min_max:
@pytest.mark.parametrize("dtype", numeric_types)
def test_all_numeric_types(self, dtype):
min_, max_ = numeric_dtype_min_max(dtype)
assert np.isscalar(min_)
assert np.isscalar(max_)
assert min_ < max_

View File

@@ -0,0 +1,20 @@
from ..fast_exp import fast_exp
import numpy as np
def test_fast_exp():
X = np.linspace(-5, 0, 5000, endpoint=True)
# Ground truth
Y = np.exp(X)
# Approximation at double precision
_y_f64 = np.array([fast_exp['float64_t'](x) for x in X])
# Approximation at single precision
_y_f32 = np.array(
[fast_exp['float32_t'](x) for x in X.astype('float32')], dtype='float32'
)
for _y in [_y_f64, _y_f32]:
assert np.abs(Y - _y).mean() < 3e-3

View File

@@ -0,0 +1,81 @@
import pytest
from skimage._shared._geometry import polygon_clip, polygon_area
import numpy as np
from numpy.testing import assert_equal, assert_almost_equal
pytest.importorskip("matplotlib")
hand = np.array(
[
[1.64516129, 1.16145833],
[1.64516129, 1.59375],
[1.35080645, 1.921875],
[1.375, 2.18229167],
[1.68548387, 1.9375],
[1.60887097, 2.55208333],
[1.68548387, 2.69791667],
[1.76209677, 2.56770833],
[1.83064516, 1.97395833],
[1.89516129, 2.75],
[1.9516129, 2.84895833],
[2.01209677, 2.76041667],
[1.99193548, 1.99479167],
[2.11290323, 2.63020833],
[2.2016129, 2.734375],
[2.25403226, 2.60416667],
[2.14919355, 1.953125],
[2.30645161, 2.36979167],
[2.39112903, 2.36979167],
[2.41532258, 2.1875],
[2.1733871, 1.703125],
[2.07782258, 1.16666667],
]
)
def test_polygon_area():
x = [0, 0, 1, 1]
y = [0, 1, 1, 0]
assert_almost_equal(polygon_area(y, x), 1)
x = [0, 0, 1]
y = [0, 1, 1]
assert_almost_equal(polygon_area(y, x), 0.5)
x = [0, 0, 0.5, 1, 1, 0.5]
y = [0, 1, 0.5, 1, 0, 0.5]
assert_almost_equal(polygon_area(y, x), 0.5)
def test_poly_clip():
x = [0, 1, 2, 1]
y = [0, -1, 0, 1]
yc, xc = polygon_clip(y, x, 0, 0, 1, 1)
assert_equal(polygon_area(yc, xc), 0.5)
x = [-1, 1.5, 1.5, -1]
y = [0.5, 0.5, 1.5, 1.5]
yc, xc = polygon_clip(y, x, 0, 0, 1, 1)
assert_equal(polygon_area(yc, xc), 0.5)
def test_hand_clip():
(r0, c0, r1, c1) = (1.0, 1.5, 2.1, 2.5)
clip_r, clip_c = polygon_clip(hand[:, 1], hand[:, 0], r0, c0, r1, c1)
assert_equal(clip_r.size, 19)
assert_equal(clip_r[0], clip_r[-1])
assert_equal(clip_c[0], clip_c[-1])
(r0, c0, r1, c1) = (1.0, 1.5, 1.7, 2.5)
clip_r, clip_c = polygon_clip(hand[:, 1], hand[:, 0], r0, c0, r1, c1)
assert_equal(clip_r.size, 6)
(r0, c0, r1, c1) = (1.0, 1.5, 1.5, 2.5)
clip_r, clip_c = polygon_clip(hand[:, 1], hand[:, 0], r0, c0, r1, c1)
assert_equal(clip_r.size, 5)

View File

@@ -0,0 +1,28 @@
from skimage._shared.interpolation import coord_map_py
from skimage._shared.testing import assert_array_equal
def test_coord_map():
symmetric = [coord_map_py(4, n, 'S') for n in range(-6, 6)]
expected_symmetric = [2, 3, 3, 2, 1, 0, 0, 1, 2, 3, 3, 2]
assert_array_equal(symmetric, expected_symmetric)
wrap = [coord_map_py(4, n, 'W') for n in range(-6, 6)]
expected_wrap = [2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1]
assert_array_equal(wrap, expected_wrap)
edge = [coord_map_py(4, n, 'E') for n in range(-6, 6)]
expected_edge = [0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 3, 3]
assert_array_equal(edge, expected_edge)
reflect = [coord_map_py(4, n, 'R') for n in range(-6, 6)]
expected_reflect = [0, 1, 2, 3, 2, 1, 0, 1, 2, 3, 2, 1]
assert_array_equal(reflect, expected_reflect)
reflect = [coord_map_py(1, n, 'R') for n in range(-6, 6)]
expected_reflect = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
assert_array_equal(reflect, expected_reflect)
other = [coord_map_py(4, n, 'undefined') for n in range(-6, 6)]
expected_other = list(range(-6, 6))
assert_array_equal(other, expected_other)

View File

@@ -0,0 +1,41 @@
import numpy as np
from skimage._shared.utils import safe_as_int
from skimage._shared import testing
def test_int_cast_not_possible():
with testing.raises(ValueError):
safe_as_int(7.1)
with testing.raises(ValueError):
safe_as_int([7.1, 0.9])
with testing.raises(ValueError):
safe_as_int(np.r_[7.1, 0.9])
with testing.raises(ValueError):
safe_as_int((7.1, 0.9))
with testing.raises(ValueError):
safe_as_int(((3, 4, 1), (2, 7.6, 289)))
with testing.raises(ValueError):
safe_as_int(7.1, 0.09)
with testing.raises(ValueError):
safe_as_int([7.1, 0.9], 0.09)
with testing.raises(ValueError):
safe_as_int(np.r_[7.1, 0.9], 0.09)
with testing.raises(ValueError):
safe_as_int((7.1, 0.9), 0.09)
with testing.raises(ValueError):
safe_as_int(((3, 4, 1), (2, 7.6, 289)), 0.25)
def test_int_cast_possible():
testing.assert_equal(safe_as_int(7.1, atol=0.11), 7)
testing.assert_equal(safe_as_int(-7.1, atol=0.11), -7)
testing.assert_equal(safe_as_int(41.9, atol=0.11), 42)
testing.assert_array_equal(
safe_as_int([2, 42, 5789234.0, 87, 4]), np.r_[2, 42, 5789234, 87, 4]
)
testing.assert_array_equal(
safe_as_int(
np.r_[[[3, 4, 1.000000001], [7, 2, -8.999999999], [6, 9, -4234918347.0]]]
),
np.r_[[[3, 4, 1], [7, 2, -9], [6, 9, -4234918347]]],
)

View File

@@ -0,0 +1,154 @@
""" Testing decorators module
"""
import inspect
import re
import warnings
import pytest
from numpy.testing import assert_equal
from skimage._shared.testing import (
doctest_skip_parser,
run_in_parallel,
assert_stacklevel,
)
from skimage._shared import testing
from skimage._shared._dependency_checks import is_wasm
from skimage._shared._warnings import expected_warnings
from warnings import warn
def test_skipper():
def f():
pass
class c:
def __init__(self):
self.me = "I think, therefore..."
docstring = """ Header
>>> something # skip if not HAVE_AMODULE
>>> something + else
>>> a = 1 # skip if not HAVE_BMODULE
>>> something2 # skip if HAVE_AMODULE
"""
f.__doc__ = docstring
c.__doc__ = docstring
global HAVE_AMODULE, HAVE_BMODULE
HAVE_AMODULE = False
HAVE_BMODULE = True
f2 = doctest_skip_parser(f)
c2 = doctest_skip_parser(c)
assert f is f2
assert c is c2
expected = """ Header
>>> something # doctest: +SKIP
>>> something + else
>>> a = 1
>>> something2
"""
assert_equal(f2.__doc__, expected)
assert_equal(c2.__doc__, expected)
HAVE_AMODULE = True
HAVE_BMODULE = False
f.__doc__ = docstring
c.__doc__ = docstring
f2 = doctest_skip_parser(f)
c2 = doctest_skip_parser(c)
assert f is f2
expected = """ Header
>>> something
>>> something + else
>>> a = 1 # doctest: +SKIP
>>> something2 # doctest: +SKIP
"""
assert_equal(f2.__doc__, expected)
assert_equal(c2.__doc__, expected)
del HAVE_AMODULE
f.__doc__ = docstring
c.__doc__ = docstring
with testing.raises(NameError):
doctest_skip_parser(f)
with testing.raises(NameError):
doctest_skip_parser(c)
@pytest.mark.skipif(is_wasm, reason="Cannot start threads in WASM")
def test_run_in_parallel():
state = []
@run_in_parallel()
def change_state1():
state.append(None)
change_state1()
assert len(state) == 2
@run_in_parallel(num_threads=1)
def change_state2():
state.append(None)
change_state2()
assert len(state) == 3
@run_in_parallel(num_threads=3)
def change_state3():
state.append(None)
change_state3()
assert len(state) == 6
def test_parallel_warning():
@run_in_parallel()
def change_state_warns_fails():
warn("Test warning for test parallel", stacklevel=2)
with expected_warnings(['Test warning for test parallel']):
change_state_warns_fails()
@run_in_parallel(warnings_matching=['Test warning for test parallel'])
def change_state_warns_passes():
warn("Test warning for test parallel", stacklevel=2)
change_state_warns_passes()
def test_expected_warnings_noop():
# This will ensure the line beolow it behaves like a no-op
with expected_warnings(['Expected warnings test']):
# This should behave as a no-op
with expected_warnings(None):
warn('Expected warnings test')
class Test_assert_stacklevel:
def raise_warning(self, *args, **kwargs):
warnings.warn(*args, **kwargs)
def test_correct_stacklevel(self):
# Should pass if stacklevel is set correctly
with pytest.warns(UserWarning, match="passes") as record:
self.raise_warning("passes", UserWarning, stacklevel=2)
assert_stacklevel(record)
@pytest.mark.parametrize("level", [1, 3])
def test_wrong_stacklevel(self, level):
# AssertionError should be raised for wrong stacklevel
with pytest.warns(UserWarning, match="wrong") as record:
self.raise_warning("wrong", UserWarning, stacklevel=level)
# Check that message contains expected line on right side
line_number = inspect.currentframe().f_lineno - 2
regex = ".*" + re.escape(f"!= {__file__}:{line_number}")
with pytest.raises(AssertionError, match=regex):
assert_stacklevel(record, offset=-5)

View File

@@ -0,0 +1,516 @@
import sys
import warnings
import numpy as np
import pytest
from skimage._shared import testing
from skimage._shared.utils import (
_supported_float_type,
_validate_interpolation_order,
change_default_value,
channel_as_last_axis,
check_nD,
deprecate_func,
deprecate_parameter,
DEPRECATED,
)
complex_dtypes = [np.complex64, np.complex128]
if hasattr(np, 'complex256'):
complex_dtypes += [np.complex256]
have_numpydoc = False
try:
import numpydoc # noqa: F401
have_numpydoc = True
except ImportError:
pass
def test_change_default_value():
@change_default_value('arg1', new_value=-1, changed_version='0.12')
def foo(arg0, arg1=0, arg2=1):
"""Expected docstring"""
return arg0, arg1, arg2
@change_default_value(
'arg1',
new_value=-1,
changed_version='0.12',
warning_msg="Custom warning message",
)
def bar(arg0, arg1=0, arg2=1):
"""Expected docstring"""
return arg0, arg1, arg2
# Assert warning messages
with pytest.warns(FutureWarning) as record:
assert foo(0) == (0, 0, 1)
assert bar(0) == (0, 0, 1)
expected_msg = (
"The new recommended value for arg1 is -1. Until "
"version 0.12, the default arg1 value is 0. From "
"version 0.12, the arg1 default value will be -1. "
"To avoid this warning, please explicitly set arg1 value."
)
assert str(record[0].message) == expected_msg
assert str(record[1].message) == "Custom warning message"
# Assert that nothing happens if arg1 is set
with warnings.catch_warnings(record=True) as recorded:
# No kwargs
assert foo(0, 2) == (0, 2, 1)
assert foo(0, arg1=0) == (0, 0, 1)
# Function name and doc is preserved
assert foo.__name__ == 'foo'
if sys.flags.optimize < 2:
# if PYTHONOPTIMIZE is set to 2, docstrings are stripped
assert foo.__doc__ == 'Expected docstring'
# Assert no warnings were raised
assert len(recorded) == 0
def test_check_nD():
z = np.random.random(200**2).reshape((200, 200))
x = z[10:30, 30:10]
with testing.raises(ValueError):
check_nD(x, 2)
@pytest.mark.parametrize(
'dtype', [bool, int, np.uint8, np.uint16, float, np.float32, np.float64]
)
@pytest.mark.parametrize('order', [None, -1, 0, 1, 2, 3, 4, 5, 6])
def test_validate_interpolation_order(dtype, order):
if order is None:
# Default order
assert _validate_interpolation_order(dtype, None) == 0 if dtype == bool else 1
elif order < 0 or order > 5:
# Order not in valid range
with testing.raises(ValueError):
_validate_interpolation_order(dtype, order)
elif dtype == bool and order != 0:
# Deprecated order for bool array
with pytest.raises(ValueError):
_validate_interpolation_order(bool, order)
else:
# Valid use case
assert _validate_interpolation_order(dtype, order) == order
@pytest.mark.parametrize(
'dtype',
[
bool,
np.float16,
np.float32,
np.float64,
np.uint8,
np.uint16,
np.uint32,
np.uint64,
np.int8,
np.int16,
np.int32,
np.int64,
],
)
def test_supported_float_dtype_real(dtype):
float_dtype = _supported_float_type(dtype)
if dtype in [np.float16, np.float32]:
assert float_dtype == np.float32
else:
assert float_dtype == np.float64
@pytest.mark.parametrize('dtype', complex_dtypes)
@pytest.mark.parametrize('allow_complex', [False, True])
def test_supported_float_dtype_complex(dtype, allow_complex):
if allow_complex:
float_dtype = _supported_float_type(dtype, allow_complex=allow_complex)
if dtype == np.complex64:
assert float_dtype == np.complex64
else:
assert float_dtype == np.complex128
else:
with testing.raises(ValueError):
_supported_float_type(dtype, allow_complex=allow_complex)
@pytest.mark.parametrize('dtype', ['f', 'float32', np.float32, np.dtype(np.float32)])
def test_supported_float_dtype_input_kinds(dtype):
assert _supported_float_type(dtype) == np.float32
@pytest.mark.parametrize(
'dtypes, expected',
[
((np.float16, np.float64), np.float64),
((np.float32, np.uint16, np.int8), np.float64),
((np.float32, np.float16), np.float32),
],
)
def test_supported_float_dtype_sequence(dtypes, expected):
float_dtype = _supported_float_type(dtypes)
assert float_dtype == expected
@channel_as_last_axis(multichannel_output=False)
def _decorated_channel_axis_size(x, *, channel_axis=None):
if channel_axis is None:
return None
assert channel_axis == -1
return x.shape[-1]
@testing.parametrize('channel_axis', [None, 0, 1, 2, -1, -2, -3])
def test_decorated_channel_axis_shape(channel_axis):
# Verify that channel_as_last_axis modifies the channel_axis as expected
# need unique size per axis here
x = np.zeros((2, 3, 4))
size = _decorated_channel_axis_size(x, channel_axis=channel_axis)
if channel_axis is None:
assert size is None
else:
assert size == x.shape[channel_axis]
@deprecate_func(
deprecated_version="x", removed_version="y", hint="You are on your own."
)
def _deprecated_func():
"""Dummy function used in `test_deprecate_func`.
The decorated function must be outside the test function, otherwise it
seems that the warning does not point at the calling location.
"""
def test_deprecate_func():
with pytest.warns(FutureWarning) as record:
_deprecated_func()
testing.assert_stacklevel(record)
assert len(record) == 1
assert record[0].message.args[0] == (
"`_deprecated_func` is deprecated since version x and will be removed in "
"version y. You are on your own."
)
@deprecate_parameter("old1", start_version="0.10", stop_version="0.12")
@deprecate_parameter("old0", start_version="0.10", stop_version="0.12")
def _func_deprecated_params(arg0, old0=DEPRECATED, old1=DEPRECATED, arg1=None):
"""Expected docstring.
Parameters
----------
arg0 : int
First unchanged parameter.
arg1 : int, optional
Second unchanged parameter.
"""
return arg0, old0, old1, arg1
@deprecate_parameter("old1", new_name="new0", start_version="0.10", stop_version="0.12")
@deprecate_parameter("old0", new_name="new1", start_version="0.10", stop_version="0.12")
def _func_replace_params(
arg0, old0=DEPRECATED, old1=DEPRECATED, new0=None, new1=None, arg1=None
):
"""Expected docstring.
Parameters
----------
arg0 : int
First unchanged parameter.
new0 : int, optional
First new parameter.
.. versionadded:: 0.10
new1 : int, optional
Second new parameter.
.. versionadded:: 0.10
arg1 : int, optional
Second unchanged parameter.
"""
return arg0, old0, old1, new0, new1, arg1
class Test_deprecate_parameter:
@pytest.mark.skipif(not have_numpydoc, reason="requires numpydoc")
def test_docstring_removed_param(self):
# function name and doc are preserved
assert _func_deprecated_params.__name__ == "_func_deprecated_params"
if sys.flags.optimize < 2:
# if PYTHONOPTIMIZE is set to 2, docstrings are stripped
assert (
_func_deprecated_params.__doc__
== """Expected docstring.
Parameters
----------
arg0 : int
First unchanged parameter.
arg1 : int, optional
Second unchanged parameter.
Other Parameters
----------------
old0 : DEPRECATED
`old0` is deprecated.
.. deprecated:: 0.10
old1 : DEPRECATED
`old1` is deprecated.
.. deprecated:: 0.10
"""
)
@pytest.mark.skipif(not have_numpydoc, reason="requires numpydoc")
def test_docstring_replaced_param(self):
assert _func_replace_params.__name__ == "_func_replace_params"
if sys.flags.optimize < 2:
# if PYTHONOPTIMIZE is set to 2, docstrings are stripped
assert (
_func_replace_params.__doc__
== """Expected docstring.
Parameters
----------
arg0 : int
First unchanged parameter.
new0 : int, optional
First new parameter.
.. versionadded:: 0.10
new1 : int, optional
Second new parameter.
.. versionadded:: 0.10
arg1 : int, optional
Second unchanged parameter.
Other Parameters
----------------
old0 : DEPRECATED
Deprecated in favor of `new1`.
.. deprecated:: 0.10
old1 : DEPRECATED
Deprecated in favor of `new0`.
.. deprecated:: 0.10
"""
)
def test_warning_removed_param(self):
match = (
r".*`old[01]` is deprecated since version 0\.10 and will be removed "
r"in 0\.12.* see the documentation of .*_func_deprecated_params`."
)
with pytest.warns(FutureWarning, match=match):
assert _func_deprecated_params(1, 2) == (1, DEPRECATED, DEPRECATED, None)
with pytest.warns(FutureWarning, match=match):
assert _func_deprecated_params(1, 2, 3) == (1, DEPRECATED, DEPRECATED, None)
with pytest.warns(FutureWarning, match=match):
assert _func_deprecated_params(1, old0=2) == (
1,
DEPRECATED,
DEPRECATED,
None,
)
with pytest.warns(FutureWarning, match=match):
assert _func_deprecated_params(1, old1=2) == (
1,
DEPRECATED,
DEPRECATED,
None,
)
with warnings.catch_warnings(record=True) as record:
assert _func_deprecated_params(1, arg1=3) == (1, DEPRECATED, DEPRECATED, 3)
assert len(record) == 0
def test_warning_replaced_param(self):
match = (
r".*`old[0,1]` is deprecated since version 0\.10 and will be removed "
r"in 0\.12.* see the documentation of .*_func_replace_params`."
)
with pytest.warns(FutureWarning, match=match):
assert _func_replace_params(1, 2) == (
1,
DEPRECATED,
DEPRECATED,
None,
2,
None,
)
with pytest.warns(FutureWarning, match=match) as records:
assert _func_replace_params(1, 2, 3) == (
1,
DEPRECATED,
DEPRECATED,
3,
2,
None,
)
assert len(records) == 2
assert "`old1` is deprecated" in records[0].message.args[0]
assert "`old0` is deprecated" in records[1].message.args[0]
with pytest.warns(FutureWarning, match=match):
assert _func_replace_params(1, old0=2) == (
1,
DEPRECATED,
DEPRECATED,
None,
2,
None,
)
with pytest.warns(FutureWarning, match=match):
assert _func_replace_params(1, old1=3) == (
1,
DEPRECATED,
DEPRECATED,
3,
None,
None,
)
# Otherwise, no warnings are emitted!
with warnings.catch_warnings(record=True) as record:
assert _func_replace_params(1, new0=2, new1=3) == (
1,
DEPRECATED,
DEPRECATED,
2,
3,
None,
)
assert len(record) == 0
def test_missing_DEPRECATED(self):
decorate = deprecate_parameter(
"old", start_version="0.10", stop_version="0.12", stacklevel=2
)
def foo(arg0, old=None):
return arg0, old
with pytest.raises(RuntimeError, match="Expected .* <DEPRECATED>"):
decorate(foo)
def bar(arg0, old=DEPRECATED):
return arg0
assert decorate(bar)(1) == 1
def test_new_keyword_only(self):
@deprecate_parameter(
"old",
new_name="new",
start_version="0.19",
stop_version="0.21",
)
def foo(arg0, old=DEPRECATED, *, new=1, arg3=None):
"""Expected docstring"""
return arg0, new, arg3
# Assert that nothing happens when the function is called with the
# new API
with warnings.catch_warnings(record=True) as recorded:
# No kwargs
assert foo(0) == (0, 1, None)
# Kwargs without deprecated argument
assert foo(0, new=1, arg3=2) == (0, 1, 2)
assert foo(0, new=2) == (0, 2, None)
assert foo(0, arg3=2) == (0, 1, 2)
assert len(recorded) == 0
def test_conflicting_old_and_new(self):
match = r".*`old[0,1]` is deprecated"
with pytest.warns(FutureWarning, match=match):
with pytest.raises(ValueError, match=".* avoid conflicting values"):
_func_replace_params(1, old0=2, new1=2)
with pytest.warns(FutureWarning, match=match):
with pytest.raises(ValueError, match=".* avoid conflicting values"):
_func_replace_params(1, old1=2, new0=2)
with pytest.warns(FutureWarning, match=match):
with pytest.raises(ValueError, match=".* avoid conflicting values"):
_func_replace_params(1, old0=1, old1=1, new0=1, new1=1)
def test_wrong_call_signature(self):
"""Check that normal errors for faulty calls are unchanged."""
with pytest.raises(
TypeError, match=r".* required positional argument\: 'arg0'"
):
_func_replace_params()
with pytest.warns(FutureWarning, match=r".*`old[0,1]` is deprecated"):
with pytest.raises(
TypeError, match=".* multiple values for argument 'old0'"
):
_func_deprecated_params(1, 2, old0=2)
def test_wrong_param_name(self):
with pytest.raises(ValueError, match="'old' is not in list"):
@deprecate_parameter("old", start_version="0.10", stop_version="0.12")
def foo(arg0):
pass
with pytest.raises(ValueError, match="'new' is not in list"):
@deprecate_parameter(
"old", new_name="new", start_version="0.10", stop_version="0.12"
)
def bar(arg0, old, arg1):
pass
def test_warning_location(self):
with pytest.warns(FutureWarning) as records:
_func_deprecated_params(1, old0=2, old1=2)
testing.assert_stacklevel(records)
assert len(records) == 2
def test_stacklevel(self):
@deprecate_parameter(
"old",
start_version="0.19",
stop_version="0.21",
)
def foo(arg0, old=DEPRECATED):
pass
with pytest.raises(RuntimeError, match="Set stacklevel manually"):
foo(0, 1)
@deprecate_parameter(
"old",
start_version="0.19",
stop_version="0.21",
stacklevel=2,
)
def bar(arg0, old=DEPRECATED):
pass
with pytest.warns(FutureWarning, match="`old` is deprecated") as records:
bar(0, 1)
testing.assert_stacklevel(records)

View File

@@ -0,0 +1,42 @@
"""Tests for the version requirement functions.
"""
import numpy as np
from numpy.testing import assert_equal
from skimage._shared import version_requirements as version_req
from skimage._shared import testing
def test_get_module_version():
assert version_req.get_module_version('numpy')
assert version_req.get_module_version('scipy')
with testing.raises(ImportError):
version_req.get_module_version('fakenumpy')
def test_is_installed():
assert version_req.is_installed('python', '>=2.7')
assert not version_req.is_installed('numpy', '<1.0')
def test_require():
# A function that only runs on Python >2.7 and numpy > 1.5 (should pass)
@version_req.require('python', '>2.7')
@version_req.require('numpy', '>1.5')
def foo():
return 1
assert_equal(foo(), 1)
# function that requires scipy < 0.1 (should fail)
@version_req.require('scipy', '<0.1')
def bar():
return 0
with testing.raises(ImportError):
bar()
def test_get_module():
assert version_req.get_module("numpy") is np

View File

@@ -0,0 +1,37 @@
import os
from skimage._shared._warnings import expected_warnings
import pytest
@pytest.fixture(scope='function')
def setup():
# Remove any environment variable if it exists
old_strictness = os.environ.pop('SKIMAGE_TEST_STRICT_WARNINGS', None)
yield
# Add the user's desired strictness
if old_strictness is not None:
os.environ['SKIMAGE_TEST_STRICT_WARNINGS'] = old_strictness
def test_strict_warnigns_default(setup):
# By default we should fail on missing expected warnings
with pytest.raises(ValueError):
with expected_warnings(['some warnings']):
pass
@pytest.mark.parametrize('strictness', ['1', 'true', 'True', 'TRUE'])
def test_strict_warning_true(setup, strictness):
os.environ['SKIMAGE_TEST_STRICT_WARNINGS'] = strictness
with pytest.raises(ValueError):
with expected_warnings(['some warnings']):
pass
@pytest.mark.parametrize('strictness', ['0', 'false', 'False', 'FALSE'])
def test_strict_warning_false(setup, strictness):
# If the user doesn't wish to be strict about warnings
# the following shouldn't raise any error
os.environ['SKIMAGE_TEST_STRICT_WARNINGS'] = strictness
with expected_warnings(['some warnings']):
pass

View File

@@ -0,0 +1,884 @@
import functools
import inspect
import sys
import warnings
import numpy as np
from ._warnings import all_warnings, warn
__all__ = [
'deprecate_func',
'get_bound_method_class',
'all_warnings',
'safe_as_int',
'check_shape_equality',
'check_nD',
'warn',
'reshape_nd',
'identity',
'slice_at_axis',
"deprecate_parameter",
"DEPRECATED",
]
def _count_wrappers(func):
"""Count the number of wrappers around `func`."""
unwrapped = func
count = 0
while hasattr(unwrapped, "__wrapped__"):
unwrapped = unwrapped.__wrapped__
count += 1
return count
def _warning_stacklevel(func):
"""Find stacklevel for a warning raised from a wrapper around `func`.
Try to determine the number of
Parameters
----------
func : Callable
Returns
-------
stacklevel : int
The stacklevel. Minimum of 2.
"""
# Count number of wrappers around `func`
wrapped_count = _count_wrappers(func)
# Count number of total wrappers around global version of `func`
module = sys.modules.get(func.__module__)
try:
for name in func.__qualname__.split("."):
global_func = getattr(module, name)
except AttributeError as e:
raise RuntimeError(
f"Could not access `{func.__qualname__}` in {module!r}, "
f" may be a closure. Set stacklevel manually. ",
) from e
else:
global_wrapped_count = _count_wrappers(global_func)
stacklevel = global_wrapped_count - wrapped_count + 1
return max(stacklevel, 2)
def _get_stack_length(func):
"""Return function call stack length."""
_func = func.__globals__.get(func.__name__, func)
length = _count_wrappers(_func)
return length
class _DecoratorBaseClass:
"""Used to manage decorators' warnings stacklevel.
The `_stack_length` class variable is used to store the number of
times a function is wrapped by a decorator.
Let `stack_length` be the total number of times a decorated
function is wrapped, and `stack_rank` be the rank of the decorator
in the decorators stack. The stacklevel of a warning is then
`stacklevel = 1 + stack_length - stack_rank`.
"""
_stack_length = {}
def get_stack_length(self, func):
length = self._stack_length.get(func.__name__, _get_stack_length(func))
return length
class change_default_value(_DecoratorBaseClass):
"""Decorator for changing the default value of an argument.
Parameters
----------
arg_name: str
The name of the argument to be updated.
new_value: any
The argument new value.
changed_version : str
The package version in which the change will be introduced.
warning_msg: str
Optional warning message. If None, a generic warning message
is used.
"""
def __init__(self, arg_name, *, new_value, changed_version, warning_msg=None):
self.arg_name = arg_name
self.new_value = new_value
self.warning_msg = warning_msg
self.changed_version = changed_version
def __call__(self, func):
parameters = inspect.signature(func).parameters
arg_idx = list(parameters.keys()).index(self.arg_name)
old_value = parameters[self.arg_name].default
stack_rank = _count_wrappers(func)
if self.warning_msg is None:
self.warning_msg = (
f'The new recommended value for {self.arg_name} is '
f'{self.new_value}. Until version {self.changed_version}, '
f'the default {self.arg_name} value is {old_value}. '
f'From version {self.changed_version}, the {self.arg_name} '
f'default value will be {self.new_value}. To avoid '
f'this warning, please explicitly set {self.arg_name} value.'
)
@functools.wraps(func)
def fixed_func(*args, **kwargs):
stacklevel = 1 + self.get_stack_length(func) - stack_rank
if len(args) < arg_idx + 1 and self.arg_name not in kwargs.keys():
# warn that arg_name default value changed:
warnings.warn(self.warning_msg, FutureWarning, stacklevel=stacklevel)
return func(*args, **kwargs)
return fixed_func
class PatchClassRepr(type):
"""Control class representations in rendered signatures."""
def __repr__(cls):
return f"<{cls.__name__}>"
class DEPRECATED(metaclass=PatchClassRepr):
"""Signal value to help with deprecating parameters that use None.
This is a proxy object, used to signal that a parameter has not been set.
This is useful if ``None`` is already used for a different purpose or just
to highlight a deprecated parameter in the signature.
"""
class deprecate_parameter:
"""Deprecate a parameter of a function.
Parameters
----------
deprecated_name : str
The name of the deprecated parameter.
start_version : str
The package version in which the warning was introduced.
stop_version : str
The package version in which the warning will be replaced by
an error / the deprecation is completed.
template : str, optional
If given, this message template is used instead of the default one.
new_name : str, optional
If given, the default message will recommend the new parameter name and an
error will be raised if the user uses both old and new names for the
same parameter.
modify_docstring : bool, optional
If the wrapped function has a docstring, add the deprecated parameters
to the "Other Parameters" section.
stacklevel : int, optional
This decorator attempts to detect the appropriate stacklevel for the
deprecation warning automatically. If this fails, e.g., due to
decorating a closure, you can set the stacklevel manually. The
outermost decorator should have stacklevel 2, the next inner one
stacklevel 3, etc.
Notes
-----
Assign `DEPRECATED` as the new default value for the deprecated parameter.
This marks the status of the parameter also in the signature and rendered
HTML docs.
This decorator can be stacked to deprecate more than one parameter.
Examples
--------
>>> from skimage._shared.utils import deprecate_parameter, DEPRECATED
>>> @deprecate_parameter(
... "b", new_name="c", start_version="0.1", stop_version="0.3"
... )
... def foo(a, b=DEPRECATED, *, c=None):
... return a, c
Calling ``foo(1, b=2)`` will warn with::
FutureWarning: Parameter `b` is deprecated since version 0.1 and will
be removed in 0.3 (or later). To avoid this warning, please use the
parameter `c` instead. For more details, see the documentation of
`foo`.
"""
DEPRECATED = DEPRECATED # Make signal value accessible for convenience
remove_parameter_template = (
"Parameter `{deprecated_name}` is deprecated since version "
"{deprecated_version} and will be removed in {changed_version} (or "
"later). To avoid this warning, please do not use the parameter "
"`{deprecated_name}`. For more details, see the documentation of "
"`{func_name}`."
)
replace_parameter_template = (
"Parameter `{deprecated_name}` is deprecated since version "
"{deprecated_version} and will be removed in {changed_version} (or "
"later). To avoid this warning, please use the parameter `{new_name}` "
"instead. For more details, see the documentation of `{func_name}`."
)
def __init__(
self,
deprecated_name,
*,
start_version,
stop_version,
template=None,
new_name=None,
modify_docstring=True,
stacklevel=None,
):
self.deprecated_name = deprecated_name
self.new_name = new_name
self.template = template
self.start_version = start_version
self.stop_version = stop_version
self.modify_docstring = modify_docstring
self.stacklevel = stacklevel
def __call__(self, func):
parameters = inspect.signature(func).parameters
deprecated_idx = list(parameters.keys()).index(self.deprecated_name)
if self.new_name:
new_idx = list(parameters.keys()).index(self.new_name)
else:
new_idx = False
if parameters[self.deprecated_name].default is not DEPRECATED:
raise RuntimeError(
f"Expected `{self.deprecated_name}` to have the value {DEPRECATED!r} "
f"to indicate its status in the rendered signature."
)
if self.template is not None:
template = self.template
elif self.new_name is not None:
template = self.replace_parameter_template
else:
template = self.remove_parameter_template
warning_message = template.format(
deprecated_name=self.deprecated_name,
deprecated_version=self.start_version,
changed_version=self.stop_version,
func_name=func.__qualname__,
new_name=self.new_name,
)
@functools.wraps(func)
def fixed_func(*args, **kwargs):
deprecated_value = DEPRECATED
new_value = DEPRECATED
# Extract value of deprecated parameter
if len(args) > deprecated_idx:
deprecated_value = args[deprecated_idx]
args = (
args[:deprecated_idx] + (DEPRECATED,) + args[deprecated_idx + 1 :]
)
if self.deprecated_name in kwargs.keys():
deprecated_value = kwargs[self.deprecated_name]
kwargs[self.deprecated_name] = DEPRECATED
# Extract value of new parameter (if present)
if new_idx is not False and len(args) > new_idx:
new_value = args[new_idx]
if self.new_name and self.new_name in kwargs.keys():
new_value = kwargs[self.new_name]
if deprecated_value is not DEPRECATED:
stacklevel = (
self.stacklevel
if self.stacklevel is not None
else _warning_stacklevel(func)
)
warnings.warn(
warning_message, category=FutureWarning, stacklevel=stacklevel
)
if new_value is not DEPRECATED:
raise ValueError(
f"Both deprecated parameter `{self.deprecated_name}` "
f"and new parameter `{self.new_name}` are used. Use "
f"only the latter to avoid conflicting values."
)
elif self.new_name is not None:
# Assign old value to new one
kwargs[self.new_name] = deprecated_value
return func(*args, **kwargs)
if self.modify_docstring and func.__doc__ is not None:
newdoc = _docstring_add_deprecated(
func, {self.deprecated_name: self.new_name}, self.start_version
)
fixed_func.__doc__ = newdoc
return fixed_func
def _docstring_add_deprecated(func, kwarg_mapping, deprecated_version):
"""Add deprecated kwarg(s) to the "Other Params" section of a docstring.
Parameters
----------
func : function
The function whose docstring we wish to update.
kwarg_mapping : dict
A dict containing {old_arg: new_arg} key/value pairs, see
`deprecate_parameter`.
deprecated_version : str
A major.minor version string specifying when old_arg was
deprecated.
Returns
-------
new_doc : str
The updated docstring. Returns the original docstring if numpydoc is
not available.
"""
if func.__doc__ is None:
return None
try:
from numpydoc.docscrape import FunctionDoc, Parameter
except ImportError:
# Return an unmodified docstring if numpydoc is not available.
return func.__doc__
Doc = FunctionDoc(func)
for old_arg, new_arg in kwarg_mapping.items():
desc = []
if new_arg is None:
desc.append(f'`{old_arg}` is deprecated.')
else:
desc.append(f'Deprecated in favor of `{new_arg}`.')
desc += ['', f'.. deprecated:: {deprecated_version}']
Doc['Other Parameters'].append(
Parameter(name=old_arg, type='DEPRECATED', desc=desc)
)
new_docstring = str(Doc)
# new_docstring will have a header starting with:
#
# .. function:: func.__name__
#
# and some additional blank lines. We strip these off below.
split = new_docstring.split('\n')
no_header = split[1:]
while not no_header[0].strip():
no_header.pop(0)
# Store the initial description before any of the Parameters fields.
# Usually this is a single line, but the while loop covers any case
# where it is not.
descr = no_header.pop(0)
while no_header[0].strip():
descr += '\n ' + no_header.pop(0)
descr += '\n\n'
# '\n ' rather than '\n' here to restore the original indentation.
final_docstring = descr + '\n '.join(no_header)
# strip any extra spaces from ends of lines
final_docstring = '\n'.join([line.rstrip() for line in final_docstring.split('\n')])
return final_docstring
class channel_as_last_axis:
"""Decorator for automatically making channels axis last for all arrays.
This decorator reorders axes for compatibility with functions that only
support channels along the last axis. After the function call is complete
the channels axis is restored back to its original position.
Parameters
----------
channel_arg_positions : tuple of int, optional
Positional arguments at the positions specified in this tuple are
assumed to be multichannel arrays. The default is to assume only the
first argument to the function is a multichannel array.
channel_kwarg_names : tuple of str, optional
A tuple containing the names of any keyword arguments corresponding to
multichannel arrays.
multichannel_output : bool, optional
A boolean that should be True if the output of the function is not a
multichannel array and False otherwise. This decorator does not
currently support the general case of functions with multiple outputs
where some or all are multichannel.
"""
def __init__(
self,
channel_arg_positions=(0,),
channel_kwarg_names=(),
multichannel_output=True,
):
self.arg_positions = set(channel_arg_positions)
self.kwarg_names = set(channel_kwarg_names)
self.multichannel_output = multichannel_output
def __call__(self, func):
@functools.wraps(func)
def fixed_func(*args, **kwargs):
channel_axis = kwargs.get('channel_axis', None)
if channel_axis is None:
return func(*args, **kwargs)
# TODO: convert scalars to a tuple in anticipation of eventually
# supporting a tuple of channel axes. Right now, only an
# integer or a single-element tuple is supported, though.
if np.isscalar(channel_axis):
channel_axis = (channel_axis,)
if len(channel_axis) > 1:
raise ValueError("only a single channel axis is currently supported")
if channel_axis == (-1,) or channel_axis == -1:
return func(*args, **kwargs)
if self.arg_positions:
new_args = []
for pos, arg in enumerate(args):
if pos in self.arg_positions:
new_args.append(np.moveaxis(arg, channel_axis[0], -1))
else:
new_args.append(arg)
new_args = tuple(new_args)
else:
new_args = args
for name in self.kwarg_names:
kwargs[name] = np.moveaxis(kwargs[name], channel_axis[0], -1)
# now that we have moved the channels axis to the last position,
# change the channel_axis argument to -1
kwargs["channel_axis"] = -1
# Call the function with the fixed arguments
out = func(*new_args, **kwargs)
if self.multichannel_output:
out = np.moveaxis(out, -1, channel_axis[0])
return out
return fixed_func
class deprecate_func(_DecoratorBaseClass):
"""Decorate a deprecated function and warn when it is called.
Adapted from <http://wiki.python.org/moin/PythonDecoratorLibrary>.
Parameters
----------
deprecated_version : str
The package version when the deprecation was introduced.
removed_version : str
The package version in which the deprecated function will be removed.
hint : str, optional
A hint on how to address this deprecation,
e.g., "Use `skimage.submodule.alternative_func` instead."
Examples
--------
>>> @deprecate_func(
... deprecated_version="1.0.0",
... removed_version="1.2.0",
... hint="Use `bar` instead."
... )
... def foo():
... pass
Calling ``foo`` will warn with::
FutureWarning: `foo` is deprecated since version 1.0.0
and will be removed in version 1.2.0. Use `bar` instead.
"""
def __init__(self, *, deprecated_version, removed_version=None, hint=None):
self.deprecated_version = deprecated_version
self.removed_version = removed_version
self.hint = hint
def __call__(self, func):
message = (
f"`{func.__name__}` is deprecated since version "
f"{self.deprecated_version}"
)
if self.removed_version:
message += f" and will be removed in version {self.removed_version}."
if self.hint:
# Prepend space and make sure it closes with "."
message += f" {self.hint.rstrip('.')}."
stack_rank = _count_wrappers(func)
@functools.wraps(func)
def wrapped(*args, **kwargs):
stacklevel = 1 + self.get_stack_length(func) - stack_rank
warnings.warn(message, category=FutureWarning, stacklevel=stacklevel)
return func(*args, **kwargs)
# modify docstring to display deprecation warning
doc = f'**Deprecated:** {message}'
if wrapped.__doc__ is None:
wrapped.__doc__ = doc
else:
wrapped.__doc__ = doc + '\n\n ' + wrapped.__doc__
return wrapped
def get_bound_method_class(m):
"""Return the class for a bound method."""
return m.im_class if sys.version < '3' else m.__self__.__class__
def safe_as_int(val, atol=1e-3):
"""
Attempt to safely cast values to integer format.
Parameters
----------
val : scalar or iterable of scalars
Number or container of numbers which are intended to be interpreted as
integers, e.g., for indexing purposes, but which may not carry integer
type.
atol : float
Absolute tolerance away from nearest integer to consider values in
``val`` functionally integers.
Returns
-------
val_int : NumPy scalar or ndarray of dtype `np.int64`
Returns the input value(s) coerced to dtype `np.int64` assuming all
were within ``atol`` of the nearest integer.
Notes
-----
This operation calculates ``val`` modulo 1, which returns the mantissa of
all values. Then all mantissas greater than 0.5 are subtracted from one.
Finally, the absolute tolerance from zero is calculated. If it is less
than ``atol`` for all value(s) in ``val``, they are rounded and returned
in an integer array. Or, if ``val`` was a scalar, a NumPy scalar type is
returned.
If any value(s) are outside the specified tolerance, an informative error
is raised.
Examples
--------
>>> safe_as_int(7.0)
7
>>> safe_as_int([9, 4, 2.9999999999])
array([9, 4, 3])
>>> safe_as_int(53.1)
Traceback (most recent call last):
...
ValueError: Integer argument required but received 53.1, check inputs.
>>> safe_as_int(53.01, atol=0.01)
53
"""
mod = np.asarray(val) % 1 # Extract mantissa
# Check for and subtract any mod values > 0.5 from 1
if mod.ndim == 0: # Scalar input, cannot be indexed
if mod > 0.5:
mod = 1 - mod
else: # Iterable input, now ndarray
mod[mod > 0.5] = 1 - mod[mod > 0.5] # Test on each side of nearest int
try:
np.testing.assert_allclose(mod, 0, atol=atol)
except AssertionError:
raise ValueError(
f'Integer argument required but received ' f'{val}, check inputs.'
)
return np.round(val).astype(np.int64)
def check_shape_equality(*images):
"""Check that all images have the same shape"""
image0 = images[0]
if not all(image0.shape == image.shape for image in images[1:]):
raise ValueError('Input images must have the same dimensions.')
return
def slice_at_axis(sl, axis):
"""
Construct tuple of slices to slice an array in the given dimension.
Parameters
----------
sl : slice
The slice for the given dimension.
axis : int
The axis to which `sl` is applied. All other dimensions are left
"unsliced".
Returns
-------
sl : tuple of slices
A tuple with slices matching `shape` in length.
Examples
--------
>>> slice_at_axis(slice(None, 3, -1), 1)
(slice(None, None, None), slice(None, 3, -1), Ellipsis)
"""
return (slice(None),) * axis + (sl,) + (...,)
def reshape_nd(arr, ndim, dim):
"""Reshape a 1D array to have n dimensions, all singletons but one.
Parameters
----------
arr : array, shape (N,)
Input array
ndim : int
Number of desired dimensions of reshaped array.
dim : int
Which dimension/axis will not be singleton-sized.
Returns
-------
arr_reshaped : array, shape ([1, ...], N, [1,...])
View of `arr` reshaped to the desired shape.
Examples
--------
>>> rng = np.random.default_rng()
>>> arr = rng.random(7)
>>> reshape_nd(arr, 2, 0).shape
(7, 1)
>>> reshape_nd(arr, 3, 1).shape
(1, 7, 1)
>>> reshape_nd(arr, 4, -1).shape
(1, 1, 1, 7)
"""
if arr.ndim != 1:
raise ValueError("arr must be a 1D array")
new_shape = [1] * ndim
new_shape[dim] = -1
return np.reshape(arr, new_shape)
def check_nD(array, ndim, arg_name='image'):
"""
Verify an array meets the desired ndims and array isn't empty.
Parameters
----------
array : array-like
Input array to be validated
ndim : int or iterable of ints
Allowable ndim or ndims for the array.
arg_name : str, optional
The name of the array in the original function.
"""
array = np.asanyarray(array)
msg_incorrect_dim = "The parameter `%s` must be a %s-dimensional array"
msg_empty_array = "The parameter `%s` cannot be an empty array"
if isinstance(ndim, int):
ndim = [ndim]
if array.size == 0:
raise ValueError(msg_empty_array % (arg_name))
if array.ndim not in ndim:
raise ValueError(
msg_incorrect_dim % (arg_name, '-or-'.join([str(n) for n in ndim]))
)
def convert_to_float(image, preserve_range):
"""Convert input image to float image with the appropriate range.
Parameters
----------
image : ndarray
Input image.
preserve_range : bool
Determines if the range of the image should be kept or transformed
using img_as_float. Also see
https://scikit-image.org/docs/dev/user_guide/data_types.html
Notes
-----
* Input images with `float32` data type are not upcast.
Returns
-------
image : ndarray
Transformed version of the input.
"""
if image.dtype == np.float16:
return image.astype(np.float32)
if preserve_range:
# Convert image to double only if it is not single or double
# precision float
if image.dtype.char not in 'df':
image = image.astype(float)
else:
from ..util.dtype import img_as_float
image = img_as_float(image)
return image
def _validate_interpolation_order(image_dtype, order):
"""Validate and return spline interpolation's order.
Parameters
----------
image_dtype : dtype
Image dtype.
order : int, optional
The order of the spline interpolation. The order has to be in
the range 0-5. See `skimage.transform.warp` for detail.
Returns
-------
order : int
if input order is None, returns 0 if image_dtype is bool and 1
otherwise. Otherwise, image_dtype is checked and input order
is validated accordingly (order > 0 is not supported for bool
image dtype)
"""
if order is None:
return 0 if image_dtype == bool else 1
if order < 0 or order > 5:
raise ValueError("Spline interpolation order has to be in the " "range 0-5.")
if image_dtype == bool and order != 0:
raise ValueError(
"Input image dtype is bool. Interpolation is not defined "
"with bool data type. Please set order to 0 or explicitly "
"cast input image to another data type."
)
return order
def _to_np_mode(mode):
"""Convert padding modes from `ndi.correlate` to `np.pad`."""
mode_translation_dict = dict(nearest='edge', reflect='symmetric', mirror='reflect')
if mode in mode_translation_dict:
mode = mode_translation_dict[mode]
return mode
def _to_ndimage_mode(mode):
"""Convert from `numpy.pad` mode name to the corresponding ndimage mode."""
mode_translation_dict = dict(
constant='constant',
edge='nearest',
symmetric='reflect',
reflect='mirror',
wrap='wrap',
)
if mode not in mode_translation_dict:
raise ValueError(
f"Unknown mode: '{mode}', or cannot translate mode. The "
f"mode should be one of 'constant', 'edge', 'symmetric', "
f"'reflect', or 'wrap'. See the documentation of numpy.pad for "
f"more info."
)
return _fix_ndimage_mode(mode_translation_dict[mode])
def _fix_ndimage_mode(mode):
# SciPy 1.6.0 introduced grid variants of constant and wrap which
# have less surprising behavior for images. Use these when available
grid_modes = {'constant': 'grid-constant', 'wrap': 'grid-wrap'}
return grid_modes.get(mode, mode)
new_float_type = {
# preserved types
np.float32().dtype.char: np.float32,
np.float64().dtype.char: np.float64,
np.complex64().dtype.char: np.complex64,
np.complex128().dtype.char: np.complex128,
# altered types
np.float16().dtype.char: np.float32,
'g': np.float64, # np.float128 ; doesn't exist on windows
'G': np.complex128, # np.complex256 ; doesn't exist on windows
}
def _supported_float_type(input_dtype, allow_complex=False):
"""Return an appropriate floating-point dtype for a given dtype.
float32, float64, complex64, complex128 are preserved.
float16 is promoted to float32.
complex256 is demoted to complex128.
Other types are cast to float64.
Parameters
----------
input_dtype : np.dtype or tuple of np.dtype
The input dtype. If a tuple of multiple dtypes is provided, each
dtype is first converted to a supported floating point type and the
final dtype is then determined by applying `np.result_type` on the
sequence of supported floating point types.
allow_complex : bool, optional
If False, raise a ValueError on complex-valued inputs.
Returns
-------
float_type : dtype
Floating-point dtype for the image.
"""
if isinstance(input_dtype, tuple):
return np.result_type(*(_supported_float_type(d) for d in input_dtype))
input_dtype = np.dtype(input_dtype)
if not allow_complex and input_dtype.kind == 'c':
raise ValueError("complex valued input is not supported")
return new_float_type.get(input_dtype.char, np.float64)
def identity(image, *args, **kwargs):
"""Returns the first argument unmodified."""
return image
def as_binary_ndarray(array, *, variable_name):
"""Return `array` as a numpy.ndarray of dtype bool.
Raises
------
ValueError:
An error including the given `variable_name` if `array` can not be
safely cast to a boolean array.
"""
array = np.asarray(array)
if array.dtype != bool:
if np.any((array != 1) & (array != 0)):
raise ValueError(
f"{variable_name} array is not of dtype boolean or "
f"contains values other than 0 and 1 so cannot be "
f"safely cast to boolean array."
)
return np.asarray(array, dtype=bool)

View File

@@ -0,0 +1,138 @@
import sys
from packaging import version as _version
def _check_version(actver, version, cmp_op):
"""
Check version string of an active module against a required version.
If dev/prerelease tags result in TypeError for string-number comparison,
it is assumed that the dependency is satisfied.
Users on dev branches are responsible for keeping their own packages up to
date.
"""
try:
if cmp_op == '>':
return _version.parse(actver) > _version.parse(version)
elif cmp_op == '>=':
return _version.parse(actver) >= _version.parse(version)
elif cmp_op == '=':
return _version.parse(actver) == _version.parse(version)
elif cmp_op == '<':
return _version.parse(actver) < _version.parse(version)
else:
return False
except TypeError:
return True
def get_module_version(module_name):
"""Return module version or None if version can't be retrieved."""
mod = __import__(module_name, fromlist=[module_name.rpartition('.')[-1]])
return getattr(mod, '__version__', getattr(mod, 'VERSION', None))
def is_installed(name, version=None):
"""Test if *name* is installed.
Parameters
----------
name : str
Name of module or "python"
version : str, optional
Version string to test against.
If version is not None, checking version
(must have an attribute named '__version__' or 'VERSION')
Version may start with =, >=, > or < to specify the exact requirement
Returns
-------
out : bool
True if `name` is installed matching the optional version.
"""
if name.lower() == 'python':
actver = sys.version[:6]
else:
try:
actver = get_module_version(name)
except ImportError:
return False
if version is None:
return True
else:
# since version_requirements is in the critical import path,
# we lazy import re
import re
match = re.search('[0-9]', version)
assert match is not None, "Invalid version number"
symb = version[: match.start()]
if not symb:
symb = '='
assert symb in ('>=', '>', '=', '<'), f"Invalid version condition '{symb}'"
version = version[match.start() :]
return _check_version(actver, version, symb)
def require(name, version=None):
"""Return decorator that forces a requirement for a function or class.
Parameters
----------
name : str
Name of module or "python".
version : str, optional
Version string to test against.
If version is not None, checking version
(must have an attribute named '__version__' or 'VERSION')
Version may start with =, >=, > or < to specify the exact requirement
Returns
-------
func : function
A decorator that raises an ImportError if a function is run
in the absence of the input dependency.
"""
# since version_requirements is in the critical import path, we lazy import
# functools
import functools
def decorator(obj):
@functools.wraps(obj)
def func_wrapped(*args, **kwargs):
if is_installed(name, version):
return obj(*args, **kwargs)
else:
msg = f'"{obj}" in "{obj.__module__}" requires "{name}'
if version is not None:
msg += f" {version}"
raise ImportError(msg + '"')
return func_wrapped
return decorator
def get_module(module_name, version=None):
"""Return a module object of name *module_name* if installed.
Parameters
----------
module_name : str
Name of module.
version : str, optional
Version string to test against.
If version is not None, checking version
(must have an attribute named '__version__' or 'VERSION')
Version may start with =, >=, > or < to specify the exact requirement
Returns
-------
mod : module or None
Module if *module_name* is installed matching the optional version
or None otherwise.
"""
if not is_installed(module_name, version):
return None
return __import__(module_name, fromlist=[module_name.rpartition('.')[-1]])