update
This commit is contained in:
884
.CondaPkg/env/Lib/site-packages/skimage/_shared/utils.py
vendored
Normal file
884
.CondaPkg/env/Lib/site-packages/skimage/_shared/utils.py
vendored
Normal file
@@ -0,0 +1,884 @@
|
||||
import functools
|
||||
import inspect
|
||||
import sys
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ._warnings import all_warnings, warn
|
||||
|
||||
__all__ = [
|
||||
'deprecate_func',
|
||||
'get_bound_method_class',
|
||||
'all_warnings',
|
||||
'safe_as_int',
|
||||
'check_shape_equality',
|
||||
'check_nD',
|
||||
'warn',
|
||||
'reshape_nd',
|
||||
'identity',
|
||||
'slice_at_axis',
|
||||
"deprecate_parameter",
|
||||
"DEPRECATED",
|
||||
]
|
||||
|
||||
|
||||
def _count_wrappers(func):
|
||||
"""Count the number of wrappers around `func`."""
|
||||
unwrapped = func
|
||||
count = 0
|
||||
while hasattr(unwrapped, "__wrapped__"):
|
||||
unwrapped = unwrapped.__wrapped__
|
||||
count += 1
|
||||
return count
|
||||
|
||||
|
||||
def _warning_stacklevel(func):
|
||||
"""Find stacklevel for a warning raised from a wrapper around `func`.
|
||||
|
||||
Try to determine the number of
|
||||
|
||||
Parameters
|
||||
----------
|
||||
func : Callable
|
||||
|
||||
|
||||
Returns
|
||||
-------
|
||||
stacklevel : int
|
||||
The stacklevel. Minimum of 2.
|
||||
"""
|
||||
# Count number of wrappers around `func`
|
||||
wrapped_count = _count_wrappers(func)
|
||||
|
||||
# Count number of total wrappers around global version of `func`
|
||||
module = sys.modules.get(func.__module__)
|
||||
try:
|
||||
for name in func.__qualname__.split("."):
|
||||
global_func = getattr(module, name)
|
||||
except AttributeError as e:
|
||||
raise RuntimeError(
|
||||
f"Could not access `{func.__qualname__}` in {module!r}, "
|
||||
f" may be a closure. Set stacklevel manually. ",
|
||||
) from e
|
||||
else:
|
||||
global_wrapped_count = _count_wrappers(global_func)
|
||||
|
||||
stacklevel = global_wrapped_count - wrapped_count + 1
|
||||
return max(stacklevel, 2)
|
||||
|
||||
|
||||
def _get_stack_length(func):
|
||||
"""Return function call stack length."""
|
||||
_func = func.__globals__.get(func.__name__, func)
|
||||
length = _count_wrappers(_func)
|
||||
return length
|
||||
|
||||
|
||||
class _DecoratorBaseClass:
|
||||
"""Used to manage decorators' warnings stacklevel.
|
||||
|
||||
The `_stack_length` class variable is used to store the number of
|
||||
times a function is wrapped by a decorator.
|
||||
|
||||
Let `stack_length` be the total number of times a decorated
|
||||
function is wrapped, and `stack_rank` be the rank of the decorator
|
||||
in the decorators stack. The stacklevel of a warning is then
|
||||
`stacklevel = 1 + stack_length - stack_rank`.
|
||||
"""
|
||||
|
||||
_stack_length = {}
|
||||
|
||||
def get_stack_length(self, func):
|
||||
length = self._stack_length.get(func.__name__, _get_stack_length(func))
|
||||
return length
|
||||
|
||||
|
||||
class change_default_value(_DecoratorBaseClass):
|
||||
"""Decorator for changing the default value of an argument.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
arg_name: str
|
||||
The name of the argument to be updated.
|
||||
new_value: any
|
||||
The argument new value.
|
||||
changed_version : str
|
||||
The package version in which the change will be introduced.
|
||||
warning_msg: str
|
||||
Optional warning message. If None, a generic warning message
|
||||
is used.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, arg_name, *, new_value, changed_version, warning_msg=None):
|
||||
self.arg_name = arg_name
|
||||
self.new_value = new_value
|
||||
self.warning_msg = warning_msg
|
||||
self.changed_version = changed_version
|
||||
|
||||
def __call__(self, func):
|
||||
parameters = inspect.signature(func).parameters
|
||||
arg_idx = list(parameters.keys()).index(self.arg_name)
|
||||
old_value = parameters[self.arg_name].default
|
||||
|
||||
stack_rank = _count_wrappers(func)
|
||||
|
||||
if self.warning_msg is None:
|
||||
self.warning_msg = (
|
||||
f'The new recommended value for {self.arg_name} is '
|
||||
f'{self.new_value}. Until version {self.changed_version}, '
|
||||
f'the default {self.arg_name} value is {old_value}. '
|
||||
f'From version {self.changed_version}, the {self.arg_name} '
|
||||
f'default value will be {self.new_value}. To avoid '
|
||||
f'this warning, please explicitly set {self.arg_name} value.'
|
||||
)
|
||||
|
||||
@functools.wraps(func)
|
||||
def fixed_func(*args, **kwargs):
|
||||
stacklevel = 1 + self.get_stack_length(func) - stack_rank
|
||||
if len(args) < arg_idx + 1 and self.arg_name not in kwargs.keys():
|
||||
# warn that arg_name default value changed:
|
||||
warnings.warn(self.warning_msg, FutureWarning, stacklevel=stacklevel)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return fixed_func
|
||||
|
||||
|
||||
class PatchClassRepr(type):
|
||||
"""Control class representations in rendered signatures."""
|
||||
|
||||
def __repr__(cls):
|
||||
return f"<{cls.__name__}>"
|
||||
|
||||
|
||||
class DEPRECATED(metaclass=PatchClassRepr):
|
||||
"""Signal value to help with deprecating parameters that use None.
|
||||
|
||||
This is a proxy object, used to signal that a parameter has not been set.
|
||||
This is useful if ``None`` is already used for a different purpose or just
|
||||
to highlight a deprecated parameter in the signature.
|
||||
"""
|
||||
|
||||
|
||||
class deprecate_parameter:
|
||||
"""Deprecate a parameter of a function.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
deprecated_name : str
|
||||
The name of the deprecated parameter.
|
||||
start_version : str
|
||||
The package version in which the warning was introduced.
|
||||
stop_version : str
|
||||
The package version in which the warning will be replaced by
|
||||
an error / the deprecation is completed.
|
||||
template : str, optional
|
||||
If given, this message template is used instead of the default one.
|
||||
new_name : str, optional
|
||||
If given, the default message will recommend the new parameter name and an
|
||||
error will be raised if the user uses both old and new names for the
|
||||
same parameter.
|
||||
modify_docstring : bool, optional
|
||||
If the wrapped function has a docstring, add the deprecated parameters
|
||||
to the "Other Parameters" section.
|
||||
stacklevel : int, optional
|
||||
This decorator attempts to detect the appropriate stacklevel for the
|
||||
deprecation warning automatically. If this fails, e.g., due to
|
||||
decorating a closure, you can set the stacklevel manually. The
|
||||
outermost decorator should have stacklevel 2, the next inner one
|
||||
stacklevel 3, etc.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Assign `DEPRECATED` as the new default value for the deprecated parameter.
|
||||
This marks the status of the parameter also in the signature and rendered
|
||||
HTML docs.
|
||||
|
||||
This decorator can be stacked to deprecate more than one parameter.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> from skimage._shared.utils import deprecate_parameter, DEPRECATED
|
||||
>>> @deprecate_parameter(
|
||||
... "b", new_name="c", start_version="0.1", stop_version="0.3"
|
||||
... )
|
||||
... def foo(a, b=DEPRECATED, *, c=None):
|
||||
... return a, c
|
||||
|
||||
Calling ``foo(1, b=2)`` will warn with::
|
||||
|
||||
FutureWarning: Parameter `b` is deprecated since version 0.1 and will
|
||||
be removed in 0.3 (or later). To avoid this warning, please use the
|
||||
parameter `c` instead. For more details, see the documentation of
|
||||
`foo`.
|
||||
"""
|
||||
|
||||
DEPRECATED = DEPRECATED # Make signal value accessible for convenience
|
||||
|
||||
remove_parameter_template = (
|
||||
"Parameter `{deprecated_name}` is deprecated since version "
|
||||
"{deprecated_version} and will be removed in {changed_version} (or "
|
||||
"later). To avoid this warning, please do not use the parameter "
|
||||
"`{deprecated_name}`. For more details, see the documentation of "
|
||||
"`{func_name}`."
|
||||
)
|
||||
|
||||
replace_parameter_template = (
|
||||
"Parameter `{deprecated_name}` is deprecated since version "
|
||||
"{deprecated_version} and will be removed in {changed_version} (or "
|
||||
"later). To avoid this warning, please use the parameter `{new_name}` "
|
||||
"instead. For more details, see the documentation of `{func_name}`."
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
deprecated_name,
|
||||
*,
|
||||
start_version,
|
||||
stop_version,
|
||||
template=None,
|
||||
new_name=None,
|
||||
modify_docstring=True,
|
||||
stacklevel=None,
|
||||
):
|
||||
self.deprecated_name = deprecated_name
|
||||
self.new_name = new_name
|
||||
self.template = template
|
||||
self.start_version = start_version
|
||||
self.stop_version = stop_version
|
||||
self.modify_docstring = modify_docstring
|
||||
self.stacklevel = stacklevel
|
||||
|
||||
def __call__(self, func):
|
||||
parameters = inspect.signature(func).parameters
|
||||
deprecated_idx = list(parameters.keys()).index(self.deprecated_name)
|
||||
if self.new_name:
|
||||
new_idx = list(parameters.keys()).index(self.new_name)
|
||||
else:
|
||||
new_idx = False
|
||||
|
||||
if parameters[self.deprecated_name].default is not DEPRECATED:
|
||||
raise RuntimeError(
|
||||
f"Expected `{self.deprecated_name}` to have the value {DEPRECATED!r} "
|
||||
f"to indicate its status in the rendered signature."
|
||||
)
|
||||
|
||||
if self.template is not None:
|
||||
template = self.template
|
||||
elif self.new_name is not None:
|
||||
template = self.replace_parameter_template
|
||||
else:
|
||||
template = self.remove_parameter_template
|
||||
warning_message = template.format(
|
||||
deprecated_name=self.deprecated_name,
|
||||
deprecated_version=self.start_version,
|
||||
changed_version=self.stop_version,
|
||||
func_name=func.__qualname__,
|
||||
new_name=self.new_name,
|
||||
)
|
||||
|
||||
@functools.wraps(func)
|
||||
def fixed_func(*args, **kwargs):
|
||||
deprecated_value = DEPRECATED
|
||||
new_value = DEPRECATED
|
||||
|
||||
# Extract value of deprecated parameter
|
||||
if len(args) > deprecated_idx:
|
||||
deprecated_value = args[deprecated_idx]
|
||||
args = (
|
||||
args[:deprecated_idx] + (DEPRECATED,) + args[deprecated_idx + 1 :]
|
||||
)
|
||||
if self.deprecated_name in kwargs.keys():
|
||||
deprecated_value = kwargs[self.deprecated_name]
|
||||
kwargs[self.deprecated_name] = DEPRECATED
|
||||
# Extract value of new parameter (if present)
|
||||
if new_idx is not False and len(args) > new_idx:
|
||||
new_value = args[new_idx]
|
||||
if self.new_name and self.new_name in kwargs.keys():
|
||||
new_value = kwargs[self.new_name]
|
||||
|
||||
if deprecated_value is not DEPRECATED:
|
||||
stacklevel = (
|
||||
self.stacklevel
|
||||
if self.stacklevel is not None
|
||||
else _warning_stacklevel(func)
|
||||
)
|
||||
warnings.warn(
|
||||
warning_message, category=FutureWarning, stacklevel=stacklevel
|
||||
)
|
||||
|
||||
if new_value is not DEPRECATED:
|
||||
raise ValueError(
|
||||
f"Both deprecated parameter `{self.deprecated_name}` "
|
||||
f"and new parameter `{self.new_name}` are used. Use "
|
||||
f"only the latter to avoid conflicting values."
|
||||
)
|
||||
elif self.new_name is not None:
|
||||
# Assign old value to new one
|
||||
kwargs[self.new_name] = deprecated_value
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
||||
if self.modify_docstring and func.__doc__ is not None:
|
||||
newdoc = _docstring_add_deprecated(
|
||||
func, {self.deprecated_name: self.new_name}, self.start_version
|
||||
)
|
||||
fixed_func.__doc__ = newdoc
|
||||
|
||||
return fixed_func
|
||||
|
||||
|
||||
def _docstring_add_deprecated(func, kwarg_mapping, deprecated_version):
|
||||
"""Add deprecated kwarg(s) to the "Other Params" section of a docstring.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
func : function
|
||||
The function whose docstring we wish to update.
|
||||
kwarg_mapping : dict
|
||||
A dict containing {old_arg: new_arg} key/value pairs, see
|
||||
`deprecate_parameter`.
|
||||
deprecated_version : str
|
||||
A major.minor version string specifying when old_arg was
|
||||
deprecated.
|
||||
|
||||
Returns
|
||||
-------
|
||||
new_doc : str
|
||||
The updated docstring. Returns the original docstring if numpydoc is
|
||||
not available.
|
||||
"""
|
||||
if func.__doc__ is None:
|
||||
return None
|
||||
try:
|
||||
from numpydoc.docscrape import FunctionDoc, Parameter
|
||||
except ImportError:
|
||||
# Return an unmodified docstring if numpydoc is not available.
|
||||
return func.__doc__
|
||||
|
||||
Doc = FunctionDoc(func)
|
||||
for old_arg, new_arg in kwarg_mapping.items():
|
||||
desc = []
|
||||
if new_arg is None:
|
||||
desc.append(f'`{old_arg}` is deprecated.')
|
||||
else:
|
||||
desc.append(f'Deprecated in favor of `{new_arg}`.')
|
||||
|
||||
desc += ['', f'.. deprecated:: {deprecated_version}']
|
||||
Doc['Other Parameters'].append(
|
||||
Parameter(name=old_arg, type='DEPRECATED', desc=desc)
|
||||
)
|
||||
new_docstring = str(Doc)
|
||||
|
||||
# new_docstring will have a header starting with:
|
||||
#
|
||||
# .. function:: func.__name__
|
||||
#
|
||||
# and some additional blank lines. We strip these off below.
|
||||
split = new_docstring.split('\n')
|
||||
no_header = split[1:]
|
||||
while not no_header[0].strip():
|
||||
no_header.pop(0)
|
||||
|
||||
# Store the initial description before any of the Parameters fields.
|
||||
# Usually this is a single line, but the while loop covers any case
|
||||
# where it is not.
|
||||
descr = no_header.pop(0)
|
||||
while no_header[0].strip():
|
||||
descr += '\n ' + no_header.pop(0)
|
||||
descr += '\n\n'
|
||||
# '\n ' rather than '\n' here to restore the original indentation.
|
||||
final_docstring = descr + '\n '.join(no_header)
|
||||
# strip any extra spaces from ends of lines
|
||||
final_docstring = '\n'.join([line.rstrip() for line in final_docstring.split('\n')])
|
||||
return final_docstring
|
||||
|
||||
|
||||
class channel_as_last_axis:
|
||||
"""Decorator for automatically making channels axis last for all arrays.
|
||||
|
||||
This decorator reorders axes for compatibility with functions that only
|
||||
support channels along the last axis. After the function call is complete
|
||||
the channels axis is restored back to its original position.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
channel_arg_positions : tuple of int, optional
|
||||
Positional arguments at the positions specified in this tuple are
|
||||
assumed to be multichannel arrays. The default is to assume only the
|
||||
first argument to the function is a multichannel array.
|
||||
channel_kwarg_names : tuple of str, optional
|
||||
A tuple containing the names of any keyword arguments corresponding to
|
||||
multichannel arrays.
|
||||
multichannel_output : bool, optional
|
||||
A boolean that should be True if the output of the function is not a
|
||||
multichannel array and False otherwise. This decorator does not
|
||||
currently support the general case of functions with multiple outputs
|
||||
where some or all are multichannel.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channel_arg_positions=(0,),
|
||||
channel_kwarg_names=(),
|
||||
multichannel_output=True,
|
||||
):
|
||||
self.arg_positions = set(channel_arg_positions)
|
||||
self.kwarg_names = set(channel_kwarg_names)
|
||||
self.multichannel_output = multichannel_output
|
||||
|
||||
def __call__(self, func):
|
||||
@functools.wraps(func)
|
||||
def fixed_func(*args, **kwargs):
|
||||
channel_axis = kwargs.get('channel_axis', None)
|
||||
|
||||
if channel_axis is None:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
# TODO: convert scalars to a tuple in anticipation of eventually
|
||||
# supporting a tuple of channel axes. Right now, only an
|
||||
# integer or a single-element tuple is supported, though.
|
||||
if np.isscalar(channel_axis):
|
||||
channel_axis = (channel_axis,)
|
||||
if len(channel_axis) > 1:
|
||||
raise ValueError("only a single channel axis is currently supported")
|
||||
|
||||
if channel_axis == (-1,) or channel_axis == -1:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
if self.arg_positions:
|
||||
new_args = []
|
||||
for pos, arg in enumerate(args):
|
||||
if pos in self.arg_positions:
|
||||
new_args.append(np.moveaxis(arg, channel_axis[0], -1))
|
||||
else:
|
||||
new_args.append(arg)
|
||||
new_args = tuple(new_args)
|
||||
else:
|
||||
new_args = args
|
||||
|
||||
for name in self.kwarg_names:
|
||||
kwargs[name] = np.moveaxis(kwargs[name], channel_axis[0], -1)
|
||||
|
||||
# now that we have moved the channels axis to the last position,
|
||||
# change the channel_axis argument to -1
|
||||
kwargs["channel_axis"] = -1
|
||||
|
||||
# Call the function with the fixed arguments
|
||||
out = func(*new_args, **kwargs)
|
||||
if self.multichannel_output:
|
||||
out = np.moveaxis(out, -1, channel_axis[0])
|
||||
return out
|
||||
|
||||
return fixed_func
|
||||
|
||||
|
||||
class deprecate_func(_DecoratorBaseClass):
|
||||
"""Decorate a deprecated function and warn when it is called.
|
||||
|
||||
Adapted from <http://wiki.python.org/moin/PythonDecoratorLibrary>.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
deprecated_version : str
|
||||
The package version when the deprecation was introduced.
|
||||
removed_version : str
|
||||
The package version in which the deprecated function will be removed.
|
||||
hint : str, optional
|
||||
A hint on how to address this deprecation,
|
||||
e.g., "Use `skimage.submodule.alternative_func` instead."
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> @deprecate_func(
|
||||
... deprecated_version="1.0.0",
|
||||
... removed_version="1.2.0",
|
||||
... hint="Use `bar` instead."
|
||||
... )
|
||||
... def foo():
|
||||
... pass
|
||||
|
||||
Calling ``foo`` will warn with::
|
||||
|
||||
FutureWarning: `foo` is deprecated since version 1.0.0
|
||||
and will be removed in version 1.2.0. Use `bar` instead.
|
||||
"""
|
||||
|
||||
def __init__(self, *, deprecated_version, removed_version=None, hint=None):
|
||||
self.deprecated_version = deprecated_version
|
||||
self.removed_version = removed_version
|
||||
self.hint = hint
|
||||
|
||||
def __call__(self, func):
|
||||
message = (
|
||||
f"`{func.__name__}` is deprecated since version "
|
||||
f"{self.deprecated_version}"
|
||||
)
|
||||
if self.removed_version:
|
||||
message += f" and will be removed in version {self.removed_version}."
|
||||
if self.hint:
|
||||
# Prepend space and make sure it closes with "."
|
||||
message += f" {self.hint.rstrip('.')}."
|
||||
|
||||
stack_rank = _count_wrappers(func)
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapped(*args, **kwargs):
|
||||
stacklevel = 1 + self.get_stack_length(func) - stack_rank
|
||||
warnings.warn(message, category=FutureWarning, stacklevel=stacklevel)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
# modify docstring to display deprecation warning
|
||||
doc = f'**Deprecated:** {message}'
|
||||
if wrapped.__doc__ is None:
|
||||
wrapped.__doc__ = doc
|
||||
else:
|
||||
wrapped.__doc__ = doc + '\n\n ' + wrapped.__doc__
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
def get_bound_method_class(m):
|
||||
"""Return the class for a bound method."""
|
||||
return m.im_class if sys.version < '3' else m.__self__.__class__
|
||||
|
||||
|
||||
def safe_as_int(val, atol=1e-3):
|
||||
"""
|
||||
Attempt to safely cast values to integer format.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
val : scalar or iterable of scalars
|
||||
Number or container of numbers which are intended to be interpreted as
|
||||
integers, e.g., for indexing purposes, but which may not carry integer
|
||||
type.
|
||||
atol : float
|
||||
Absolute tolerance away from nearest integer to consider values in
|
||||
``val`` functionally integers.
|
||||
|
||||
Returns
|
||||
-------
|
||||
val_int : NumPy scalar or ndarray of dtype `np.int64`
|
||||
Returns the input value(s) coerced to dtype `np.int64` assuming all
|
||||
were within ``atol`` of the nearest integer.
|
||||
|
||||
Notes
|
||||
-----
|
||||
This operation calculates ``val`` modulo 1, which returns the mantissa of
|
||||
all values. Then all mantissas greater than 0.5 are subtracted from one.
|
||||
Finally, the absolute tolerance from zero is calculated. If it is less
|
||||
than ``atol`` for all value(s) in ``val``, they are rounded and returned
|
||||
in an integer array. Or, if ``val`` was a scalar, a NumPy scalar type is
|
||||
returned.
|
||||
|
||||
If any value(s) are outside the specified tolerance, an informative error
|
||||
is raised.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> safe_as_int(7.0)
|
||||
7
|
||||
|
||||
>>> safe_as_int([9, 4, 2.9999999999])
|
||||
array([9, 4, 3])
|
||||
|
||||
>>> safe_as_int(53.1)
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
ValueError: Integer argument required but received 53.1, check inputs.
|
||||
|
||||
>>> safe_as_int(53.01, atol=0.01)
|
||||
53
|
||||
|
||||
"""
|
||||
mod = np.asarray(val) % 1 # Extract mantissa
|
||||
|
||||
# Check for and subtract any mod values > 0.5 from 1
|
||||
if mod.ndim == 0: # Scalar input, cannot be indexed
|
||||
if mod > 0.5:
|
||||
mod = 1 - mod
|
||||
else: # Iterable input, now ndarray
|
||||
mod[mod > 0.5] = 1 - mod[mod > 0.5] # Test on each side of nearest int
|
||||
|
||||
try:
|
||||
np.testing.assert_allclose(mod, 0, atol=atol)
|
||||
except AssertionError:
|
||||
raise ValueError(
|
||||
f'Integer argument required but received ' f'{val}, check inputs.'
|
||||
)
|
||||
|
||||
return np.round(val).astype(np.int64)
|
||||
|
||||
|
||||
def check_shape_equality(*images):
|
||||
"""Check that all images have the same shape"""
|
||||
image0 = images[0]
|
||||
if not all(image0.shape == image.shape for image in images[1:]):
|
||||
raise ValueError('Input images must have the same dimensions.')
|
||||
return
|
||||
|
||||
|
||||
def slice_at_axis(sl, axis):
|
||||
"""
|
||||
Construct tuple of slices to slice an array in the given dimension.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sl : slice
|
||||
The slice for the given dimension.
|
||||
axis : int
|
||||
The axis to which `sl` is applied. All other dimensions are left
|
||||
"unsliced".
|
||||
|
||||
Returns
|
||||
-------
|
||||
sl : tuple of slices
|
||||
A tuple with slices matching `shape` in length.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> slice_at_axis(slice(None, 3, -1), 1)
|
||||
(slice(None, None, None), slice(None, 3, -1), Ellipsis)
|
||||
"""
|
||||
return (slice(None),) * axis + (sl,) + (...,)
|
||||
|
||||
|
||||
def reshape_nd(arr, ndim, dim):
|
||||
"""Reshape a 1D array to have n dimensions, all singletons but one.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
arr : array, shape (N,)
|
||||
Input array
|
||||
ndim : int
|
||||
Number of desired dimensions of reshaped array.
|
||||
dim : int
|
||||
Which dimension/axis will not be singleton-sized.
|
||||
|
||||
Returns
|
||||
-------
|
||||
arr_reshaped : array, shape ([1, ...], N, [1,...])
|
||||
View of `arr` reshaped to the desired shape.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> rng = np.random.default_rng()
|
||||
>>> arr = rng.random(7)
|
||||
>>> reshape_nd(arr, 2, 0).shape
|
||||
(7, 1)
|
||||
>>> reshape_nd(arr, 3, 1).shape
|
||||
(1, 7, 1)
|
||||
>>> reshape_nd(arr, 4, -1).shape
|
||||
(1, 1, 1, 7)
|
||||
"""
|
||||
if arr.ndim != 1:
|
||||
raise ValueError("arr must be a 1D array")
|
||||
new_shape = [1] * ndim
|
||||
new_shape[dim] = -1
|
||||
return np.reshape(arr, new_shape)
|
||||
|
||||
|
||||
def check_nD(array, ndim, arg_name='image'):
|
||||
"""
|
||||
Verify an array meets the desired ndims and array isn't empty.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
array : array-like
|
||||
Input array to be validated
|
||||
ndim : int or iterable of ints
|
||||
Allowable ndim or ndims for the array.
|
||||
arg_name : str, optional
|
||||
The name of the array in the original function.
|
||||
|
||||
"""
|
||||
array = np.asanyarray(array)
|
||||
msg_incorrect_dim = "The parameter `%s` must be a %s-dimensional array"
|
||||
msg_empty_array = "The parameter `%s` cannot be an empty array"
|
||||
if isinstance(ndim, int):
|
||||
ndim = [ndim]
|
||||
if array.size == 0:
|
||||
raise ValueError(msg_empty_array % (arg_name))
|
||||
if array.ndim not in ndim:
|
||||
raise ValueError(
|
||||
msg_incorrect_dim % (arg_name, '-or-'.join([str(n) for n in ndim]))
|
||||
)
|
||||
|
||||
|
||||
def convert_to_float(image, preserve_range):
|
||||
"""Convert input image to float image with the appropriate range.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
image : ndarray
|
||||
Input image.
|
||||
preserve_range : bool
|
||||
Determines if the range of the image should be kept or transformed
|
||||
using img_as_float. Also see
|
||||
https://scikit-image.org/docs/dev/user_guide/data_types.html
|
||||
|
||||
Notes
|
||||
-----
|
||||
* Input images with `float32` data type are not upcast.
|
||||
|
||||
Returns
|
||||
-------
|
||||
image : ndarray
|
||||
Transformed version of the input.
|
||||
|
||||
"""
|
||||
if image.dtype == np.float16:
|
||||
return image.astype(np.float32)
|
||||
if preserve_range:
|
||||
# Convert image to double only if it is not single or double
|
||||
# precision float
|
||||
if image.dtype.char not in 'df':
|
||||
image = image.astype(float)
|
||||
else:
|
||||
from ..util.dtype import img_as_float
|
||||
|
||||
image = img_as_float(image)
|
||||
return image
|
||||
|
||||
|
||||
def _validate_interpolation_order(image_dtype, order):
|
||||
"""Validate and return spline interpolation's order.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
image_dtype : dtype
|
||||
Image dtype.
|
||||
order : int, optional
|
||||
The order of the spline interpolation. The order has to be in
|
||||
the range 0-5. See `skimage.transform.warp` for detail.
|
||||
|
||||
Returns
|
||||
-------
|
||||
order : int
|
||||
if input order is None, returns 0 if image_dtype is bool and 1
|
||||
otherwise. Otherwise, image_dtype is checked and input order
|
||||
is validated accordingly (order > 0 is not supported for bool
|
||||
image dtype)
|
||||
|
||||
"""
|
||||
|
||||
if order is None:
|
||||
return 0 if image_dtype == bool else 1
|
||||
|
||||
if order < 0 or order > 5:
|
||||
raise ValueError("Spline interpolation order has to be in the " "range 0-5.")
|
||||
|
||||
if image_dtype == bool and order != 0:
|
||||
raise ValueError(
|
||||
"Input image dtype is bool. Interpolation is not defined "
|
||||
"with bool data type. Please set order to 0 or explicitly "
|
||||
"cast input image to another data type."
|
||||
)
|
||||
|
||||
return order
|
||||
|
||||
|
||||
def _to_np_mode(mode):
|
||||
"""Convert padding modes from `ndi.correlate` to `np.pad`."""
|
||||
mode_translation_dict = dict(nearest='edge', reflect='symmetric', mirror='reflect')
|
||||
if mode in mode_translation_dict:
|
||||
mode = mode_translation_dict[mode]
|
||||
return mode
|
||||
|
||||
|
||||
def _to_ndimage_mode(mode):
|
||||
"""Convert from `numpy.pad` mode name to the corresponding ndimage mode."""
|
||||
mode_translation_dict = dict(
|
||||
constant='constant',
|
||||
edge='nearest',
|
||||
symmetric='reflect',
|
||||
reflect='mirror',
|
||||
wrap='wrap',
|
||||
)
|
||||
if mode not in mode_translation_dict:
|
||||
raise ValueError(
|
||||
f"Unknown mode: '{mode}', or cannot translate mode. The "
|
||||
f"mode should be one of 'constant', 'edge', 'symmetric', "
|
||||
f"'reflect', or 'wrap'. See the documentation of numpy.pad for "
|
||||
f"more info."
|
||||
)
|
||||
return _fix_ndimage_mode(mode_translation_dict[mode])
|
||||
|
||||
|
||||
def _fix_ndimage_mode(mode):
|
||||
# SciPy 1.6.0 introduced grid variants of constant and wrap which
|
||||
# have less surprising behavior for images. Use these when available
|
||||
grid_modes = {'constant': 'grid-constant', 'wrap': 'grid-wrap'}
|
||||
return grid_modes.get(mode, mode)
|
||||
|
||||
|
||||
new_float_type = {
|
||||
# preserved types
|
||||
np.float32().dtype.char: np.float32,
|
||||
np.float64().dtype.char: np.float64,
|
||||
np.complex64().dtype.char: np.complex64,
|
||||
np.complex128().dtype.char: np.complex128,
|
||||
# altered types
|
||||
np.float16().dtype.char: np.float32,
|
||||
'g': np.float64, # np.float128 ; doesn't exist on windows
|
||||
'G': np.complex128, # np.complex256 ; doesn't exist on windows
|
||||
}
|
||||
|
||||
|
||||
def _supported_float_type(input_dtype, allow_complex=False):
|
||||
"""Return an appropriate floating-point dtype for a given dtype.
|
||||
|
||||
float32, float64, complex64, complex128 are preserved.
|
||||
float16 is promoted to float32.
|
||||
complex256 is demoted to complex128.
|
||||
Other types are cast to float64.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input_dtype : np.dtype or tuple of np.dtype
|
||||
The input dtype. If a tuple of multiple dtypes is provided, each
|
||||
dtype is first converted to a supported floating point type and the
|
||||
final dtype is then determined by applying `np.result_type` on the
|
||||
sequence of supported floating point types.
|
||||
allow_complex : bool, optional
|
||||
If False, raise a ValueError on complex-valued inputs.
|
||||
|
||||
Returns
|
||||
-------
|
||||
float_type : dtype
|
||||
Floating-point dtype for the image.
|
||||
"""
|
||||
if isinstance(input_dtype, tuple):
|
||||
return np.result_type(*(_supported_float_type(d) for d in input_dtype))
|
||||
input_dtype = np.dtype(input_dtype)
|
||||
if not allow_complex and input_dtype.kind == 'c':
|
||||
raise ValueError("complex valued input is not supported")
|
||||
return new_float_type.get(input_dtype.char, np.float64)
|
||||
|
||||
|
||||
def identity(image, *args, **kwargs):
|
||||
"""Returns the first argument unmodified."""
|
||||
return image
|
||||
|
||||
|
||||
def as_binary_ndarray(array, *, variable_name):
|
||||
"""Return `array` as a numpy.ndarray of dtype bool.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError:
|
||||
An error including the given `variable_name` if `array` can not be
|
||||
safely cast to a boolean array.
|
||||
"""
|
||||
array = np.asarray(array)
|
||||
if array.dtype != bool:
|
||||
if np.any((array != 1) & (array != 0)):
|
||||
raise ValueError(
|
||||
f"{variable_name} array is not of dtype boolean or "
|
||||
f"contains values other than 0 and 1 so cannot be "
|
||||
f"safely cast to boolean array."
|
||||
)
|
||||
return np.asarray(array, dtype=bool)
|
||||
Reference in New Issue
Block a user