885 lines
29 KiB
Python
885 lines
29 KiB
Python
import functools
|
|
import inspect
|
|
import sys
|
|
import warnings
|
|
|
|
import numpy as np
|
|
|
|
from ._warnings import all_warnings, warn
|
|
|
|
__all__ = [
|
|
'deprecate_func',
|
|
'get_bound_method_class',
|
|
'all_warnings',
|
|
'safe_as_int',
|
|
'check_shape_equality',
|
|
'check_nD',
|
|
'warn',
|
|
'reshape_nd',
|
|
'identity',
|
|
'slice_at_axis',
|
|
"deprecate_parameter",
|
|
"DEPRECATED",
|
|
]
|
|
|
|
|
|
def _count_wrappers(func):
|
|
"""Count the number of wrappers around `func`."""
|
|
unwrapped = func
|
|
count = 0
|
|
while hasattr(unwrapped, "__wrapped__"):
|
|
unwrapped = unwrapped.__wrapped__
|
|
count += 1
|
|
return count
|
|
|
|
|
|
def _warning_stacklevel(func):
|
|
"""Find stacklevel for a warning raised from a wrapper around `func`.
|
|
|
|
Try to determine the number of
|
|
|
|
Parameters
|
|
----------
|
|
func : Callable
|
|
|
|
|
|
Returns
|
|
-------
|
|
stacklevel : int
|
|
The stacklevel. Minimum of 2.
|
|
"""
|
|
# Count number of wrappers around `func`
|
|
wrapped_count = _count_wrappers(func)
|
|
|
|
# Count number of total wrappers around global version of `func`
|
|
module = sys.modules.get(func.__module__)
|
|
try:
|
|
for name in func.__qualname__.split("."):
|
|
global_func = getattr(module, name)
|
|
except AttributeError as e:
|
|
raise RuntimeError(
|
|
f"Could not access `{func.__qualname__}` in {module!r}, "
|
|
f" may be a closure. Set stacklevel manually. ",
|
|
) from e
|
|
else:
|
|
global_wrapped_count = _count_wrappers(global_func)
|
|
|
|
stacklevel = global_wrapped_count - wrapped_count + 1
|
|
return max(stacklevel, 2)
|
|
|
|
|
|
def _get_stack_length(func):
|
|
"""Return function call stack length."""
|
|
_func = func.__globals__.get(func.__name__, func)
|
|
length = _count_wrappers(_func)
|
|
return length
|
|
|
|
|
|
class _DecoratorBaseClass:
|
|
"""Used to manage decorators' warnings stacklevel.
|
|
|
|
The `_stack_length` class variable is used to store the number of
|
|
times a function is wrapped by a decorator.
|
|
|
|
Let `stack_length` be the total number of times a decorated
|
|
function is wrapped, and `stack_rank` be the rank of the decorator
|
|
in the decorators stack. The stacklevel of a warning is then
|
|
`stacklevel = 1 + stack_length - stack_rank`.
|
|
"""
|
|
|
|
_stack_length = {}
|
|
|
|
def get_stack_length(self, func):
|
|
length = self._stack_length.get(func.__name__, _get_stack_length(func))
|
|
return length
|
|
|
|
|
|
class change_default_value(_DecoratorBaseClass):
|
|
"""Decorator for changing the default value of an argument.
|
|
|
|
Parameters
|
|
----------
|
|
arg_name: str
|
|
The name of the argument to be updated.
|
|
new_value: any
|
|
The argument new value.
|
|
changed_version : str
|
|
The package version in which the change will be introduced.
|
|
warning_msg: str
|
|
Optional warning message. If None, a generic warning message
|
|
is used.
|
|
|
|
"""
|
|
|
|
def __init__(self, arg_name, *, new_value, changed_version, warning_msg=None):
|
|
self.arg_name = arg_name
|
|
self.new_value = new_value
|
|
self.warning_msg = warning_msg
|
|
self.changed_version = changed_version
|
|
|
|
def __call__(self, func):
|
|
parameters = inspect.signature(func).parameters
|
|
arg_idx = list(parameters.keys()).index(self.arg_name)
|
|
old_value = parameters[self.arg_name].default
|
|
|
|
stack_rank = _count_wrappers(func)
|
|
|
|
if self.warning_msg is None:
|
|
self.warning_msg = (
|
|
f'The new recommended value for {self.arg_name} is '
|
|
f'{self.new_value}. Until version {self.changed_version}, '
|
|
f'the default {self.arg_name} value is {old_value}. '
|
|
f'From version {self.changed_version}, the {self.arg_name} '
|
|
f'default value will be {self.new_value}. To avoid '
|
|
f'this warning, please explicitly set {self.arg_name} value.'
|
|
)
|
|
|
|
@functools.wraps(func)
|
|
def fixed_func(*args, **kwargs):
|
|
stacklevel = 1 + self.get_stack_length(func) - stack_rank
|
|
if len(args) < arg_idx + 1 and self.arg_name not in kwargs.keys():
|
|
# warn that arg_name default value changed:
|
|
warnings.warn(self.warning_msg, FutureWarning, stacklevel=stacklevel)
|
|
return func(*args, **kwargs)
|
|
|
|
return fixed_func
|
|
|
|
|
|
class PatchClassRepr(type):
|
|
"""Control class representations in rendered signatures."""
|
|
|
|
def __repr__(cls):
|
|
return f"<{cls.__name__}>"
|
|
|
|
|
|
class DEPRECATED(metaclass=PatchClassRepr):
|
|
"""Signal value to help with deprecating parameters that use None.
|
|
|
|
This is a proxy object, used to signal that a parameter has not been set.
|
|
This is useful if ``None`` is already used for a different purpose or just
|
|
to highlight a deprecated parameter in the signature.
|
|
"""
|
|
|
|
|
|
class deprecate_parameter:
|
|
"""Deprecate a parameter of a function.
|
|
|
|
Parameters
|
|
----------
|
|
deprecated_name : str
|
|
The name of the deprecated parameter.
|
|
start_version : str
|
|
The package version in which the warning was introduced.
|
|
stop_version : str
|
|
The package version in which the warning will be replaced by
|
|
an error / the deprecation is completed.
|
|
template : str, optional
|
|
If given, this message template is used instead of the default one.
|
|
new_name : str, optional
|
|
If given, the default message will recommend the new parameter name and an
|
|
error will be raised if the user uses both old and new names for the
|
|
same parameter.
|
|
modify_docstring : bool, optional
|
|
If the wrapped function has a docstring, add the deprecated parameters
|
|
to the "Other Parameters" section.
|
|
stacklevel : int, optional
|
|
This decorator attempts to detect the appropriate stacklevel for the
|
|
deprecation warning automatically. If this fails, e.g., due to
|
|
decorating a closure, you can set the stacklevel manually. The
|
|
outermost decorator should have stacklevel 2, the next inner one
|
|
stacklevel 3, etc.
|
|
|
|
Notes
|
|
-----
|
|
Assign `DEPRECATED` as the new default value for the deprecated parameter.
|
|
This marks the status of the parameter also in the signature and rendered
|
|
HTML docs.
|
|
|
|
This decorator can be stacked to deprecate more than one parameter.
|
|
|
|
Examples
|
|
--------
|
|
>>> from skimage._shared.utils import deprecate_parameter, DEPRECATED
|
|
>>> @deprecate_parameter(
|
|
... "b", new_name="c", start_version="0.1", stop_version="0.3"
|
|
... )
|
|
... def foo(a, b=DEPRECATED, *, c=None):
|
|
... return a, c
|
|
|
|
Calling ``foo(1, b=2)`` will warn with::
|
|
|
|
FutureWarning: Parameter `b` is deprecated since version 0.1 and will
|
|
be removed in 0.3 (or later). To avoid this warning, please use the
|
|
parameter `c` instead. For more details, see the documentation of
|
|
`foo`.
|
|
"""
|
|
|
|
DEPRECATED = DEPRECATED # Make signal value accessible for convenience
|
|
|
|
remove_parameter_template = (
|
|
"Parameter `{deprecated_name}` is deprecated since version "
|
|
"{deprecated_version} and will be removed in {changed_version} (or "
|
|
"later). To avoid this warning, please do not use the parameter "
|
|
"`{deprecated_name}`. For more details, see the documentation of "
|
|
"`{func_name}`."
|
|
)
|
|
|
|
replace_parameter_template = (
|
|
"Parameter `{deprecated_name}` is deprecated since version "
|
|
"{deprecated_version} and will be removed in {changed_version} (or "
|
|
"later). To avoid this warning, please use the parameter `{new_name}` "
|
|
"instead. For more details, see the documentation of `{func_name}`."
|
|
)
|
|
|
|
def __init__(
|
|
self,
|
|
deprecated_name,
|
|
*,
|
|
start_version,
|
|
stop_version,
|
|
template=None,
|
|
new_name=None,
|
|
modify_docstring=True,
|
|
stacklevel=None,
|
|
):
|
|
self.deprecated_name = deprecated_name
|
|
self.new_name = new_name
|
|
self.template = template
|
|
self.start_version = start_version
|
|
self.stop_version = stop_version
|
|
self.modify_docstring = modify_docstring
|
|
self.stacklevel = stacklevel
|
|
|
|
def __call__(self, func):
|
|
parameters = inspect.signature(func).parameters
|
|
deprecated_idx = list(parameters.keys()).index(self.deprecated_name)
|
|
if self.new_name:
|
|
new_idx = list(parameters.keys()).index(self.new_name)
|
|
else:
|
|
new_idx = False
|
|
|
|
if parameters[self.deprecated_name].default is not DEPRECATED:
|
|
raise RuntimeError(
|
|
f"Expected `{self.deprecated_name}` to have the value {DEPRECATED!r} "
|
|
f"to indicate its status in the rendered signature."
|
|
)
|
|
|
|
if self.template is not None:
|
|
template = self.template
|
|
elif self.new_name is not None:
|
|
template = self.replace_parameter_template
|
|
else:
|
|
template = self.remove_parameter_template
|
|
warning_message = template.format(
|
|
deprecated_name=self.deprecated_name,
|
|
deprecated_version=self.start_version,
|
|
changed_version=self.stop_version,
|
|
func_name=func.__qualname__,
|
|
new_name=self.new_name,
|
|
)
|
|
|
|
@functools.wraps(func)
|
|
def fixed_func(*args, **kwargs):
|
|
deprecated_value = DEPRECATED
|
|
new_value = DEPRECATED
|
|
|
|
# Extract value of deprecated parameter
|
|
if len(args) > deprecated_idx:
|
|
deprecated_value = args[deprecated_idx]
|
|
args = (
|
|
args[:deprecated_idx] + (DEPRECATED,) + args[deprecated_idx + 1 :]
|
|
)
|
|
if self.deprecated_name in kwargs.keys():
|
|
deprecated_value = kwargs[self.deprecated_name]
|
|
kwargs[self.deprecated_name] = DEPRECATED
|
|
# Extract value of new parameter (if present)
|
|
if new_idx is not False and len(args) > new_idx:
|
|
new_value = args[new_idx]
|
|
if self.new_name and self.new_name in kwargs.keys():
|
|
new_value = kwargs[self.new_name]
|
|
|
|
if deprecated_value is not DEPRECATED:
|
|
stacklevel = (
|
|
self.stacklevel
|
|
if self.stacklevel is not None
|
|
else _warning_stacklevel(func)
|
|
)
|
|
warnings.warn(
|
|
warning_message, category=FutureWarning, stacklevel=stacklevel
|
|
)
|
|
|
|
if new_value is not DEPRECATED:
|
|
raise ValueError(
|
|
f"Both deprecated parameter `{self.deprecated_name}` "
|
|
f"and new parameter `{self.new_name}` are used. Use "
|
|
f"only the latter to avoid conflicting values."
|
|
)
|
|
elif self.new_name is not None:
|
|
# Assign old value to new one
|
|
kwargs[self.new_name] = deprecated_value
|
|
|
|
return func(*args, **kwargs)
|
|
|
|
if self.modify_docstring and func.__doc__ is not None:
|
|
newdoc = _docstring_add_deprecated(
|
|
func, {self.deprecated_name: self.new_name}, self.start_version
|
|
)
|
|
fixed_func.__doc__ = newdoc
|
|
|
|
return fixed_func
|
|
|
|
|
|
def _docstring_add_deprecated(func, kwarg_mapping, deprecated_version):
|
|
"""Add deprecated kwarg(s) to the "Other Params" section of a docstring.
|
|
|
|
Parameters
|
|
----------
|
|
func : function
|
|
The function whose docstring we wish to update.
|
|
kwarg_mapping : dict
|
|
A dict containing {old_arg: new_arg} key/value pairs, see
|
|
`deprecate_parameter`.
|
|
deprecated_version : str
|
|
A major.minor version string specifying when old_arg was
|
|
deprecated.
|
|
|
|
Returns
|
|
-------
|
|
new_doc : str
|
|
The updated docstring. Returns the original docstring if numpydoc is
|
|
not available.
|
|
"""
|
|
if func.__doc__ is None:
|
|
return None
|
|
try:
|
|
from numpydoc.docscrape import FunctionDoc, Parameter
|
|
except ImportError:
|
|
# Return an unmodified docstring if numpydoc is not available.
|
|
return func.__doc__
|
|
|
|
Doc = FunctionDoc(func)
|
|
for old_arg, new_arg in kwarg_mapping.items():
|
|
desc = []
|
|
if new_arg is None:
|
|
desc.append(f'`{old_arg}` is deprecated.')
|
|
else:
|
|
desc.append(f'Deprecated in favor of `{new_arg}`.')
|
|
|
|
desc += ['', f'.. deprecated:: {deprecated_version}']
|
|
Doc['Other Parameters'].append(
|
|
Parameter(name=old_arg, type='DEPRECATED', desc=desc)
|
|
)
|
|
new_docstring = str(Doc)
|
|
|
|
# new_docstring will have a header starting with:
|
|
#
|
|
# .. function:: func.__name__
|
|
#
|
|
# and some additional blank lines. We strip these off below.
|
|
split = new_docstring.split('\n')
|
|
no_header = split[1:]
|
|
while not no_header[0].strip():
|
|
no_header.pop(0)
|
|
|
|
# Store the initial description before any of the Parameters fields.
|
|
# Usually this is a single line, but the while loop covers any case
|
|
# where it is not.
|
|
descr = no_header.pop(0)
|
|
while no_header[0].strip():
|
|
descr += '\n ' + no_header.pop(0)
|
|
descr += '\n\n'
|
|
# '\n ' rather than '\n' here to restore the original indentation.
|
|
final_docstring = descr + '\n '.join(no_header)
|
|
# strip any extra spaces from ends of lines
|
|
final_docstring = '\n'.join([line.rstrip() for line in final_docstring.split('\n')])
|
|
return final_docstring
|
|
|
|
|
|
class channel_as_last_axis:
|
|
"""Decorator for automatically making channels axis last for all arrays.
|
|
|
|
This decorator reorders axes for compatibility with functions that only
|
|
support channels along the last axis. After the function call is complete
|
|
the channels axis is restored back to its original position.
|
|
|
|
Parameters
|
|
----------
|
|
channel_arg_positions : tuple of int, optional
|
|
Positional arguments at the positions specified in this tuple are
|
|
assumed to be multichannel arrays. The default is to assume only the
|
|
first argument to the function is a multichannel array.
|
|
channel_kwarg_names : tuple of str, optional
|
|
A tuple containing the names of any keyword arguments corresponding to
|
|
multichannel arrays.
|
|
multichannel_output : bool, optional
|
|
A boolean that should be True if the output of the function is not a
|
|
multichannel array and False otherwise. This decorator does not
|
|
currently support the general case of functions with multiple outputs
|
|
where some or all are multichannel.
|
|
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
channel_arg_positions=(0,),
|
|
channel_kwarg_names=(),
|
|
multichannel_output=True,
|
|
):
|
|
self.arg_positions = set(channel_arg_positions)
|
|
self.kwarg_names = set(channel_kwarg_names)
|
|
self.multichannel_output = multichannel_output
|
|
|
|
def __call__(self, func):
|
|
@functools.wraps(func)
|
|
def fixed_func(*args, **kwargs):
|
|
channel_axis = kwargs.get('channel_axis', None)
|
|
|
|
if channel_axis is None:
|
|
return func(*args, **kwargs)
|
|
|
|
# TODO: convert scalars to a tuple in anticipation of eventually
|
|
# supporting a tuple of channel axes. Right now, only an
|
|
# integer or a single-element tuple is supported, though.
|
|
if np.isscalar(channel_axis):
|
|
channel_axis = (channel_axis,)
|
|
if len(channel_axis) > 1:
|
|
raise ValueError("only a single channel axis is currently supported")
|
|
|
|
if channel_axis == (-1,) or channel_axis == -1:
|
|
return func(*args, **kwargs)
|
|
|
|
if self.arg_positions:
|
|
new_args = []
|
|
for pos, arg in enumerate(args):
|
|
if pos in self.arg_positions:
|
|
new_args.append(np.moveaxis(arg, channel_axis[0], -1))
|
|
else:
|
|
new_args.append(arg)
|
|
new_args = tuple(new_args)
|
|
else:
|
|
new_args = args
|
|
|
|
for name in self.kwarg_names:
|
|
kwargs[name] = np.moveaxis(kwargs[name], channel_axis[0], -1)
|
|
|
|
# now that we have moved the channels axis to the last position,
|
|
# change the channel_axis argument to -1
|
|
kwargs["channel_axis"] = -1
|
|
|
|
# Call the function with the fixed arguments
|
|
out = func(*new_args, **kwargs)
|
|
if self.multichannel_output:
|
|
out = np.moveaxis(out, -1, channel_axis[0])
|
|
return out
|
|
|
|
return fixed_func
|
|
|
|
|
|
class deprecate_func(_DecoratorBaseClass):
|
|
"""Decorate a deprecated function and warn when it is called.
|
|
|
|
Adapted from <http://wiki.python.org/moin/PythonDecoratorLibrary>.
|
|
|
|
Parameters
|
|
----------
|
|
deprecated_version : str
|
|
The package version when the deprecation was introduced.
|
|
removed_version : str
|
|
The package version in which the deprecated function will be removed.
|
|
hint : str, optional
|
|
A hint on how to address this deprecation,
|
|
e.g., "Use `skimage.submodule.alternative_func` instead."
|
|
|
|
Examples
|
|
--------
|
|
>>> @deprecate_func(
|
|
... deprecated_version="1.0.0",
|
|
... removed_version="1.2.0",
|
|
... hint="Use `bar` instead."
|
|
... )
|
|
... def foo():
|
|
... pass
|
|
|
|
Calling ``foo`` will warn with::
|
|
|
|
FutureWarning: `foo` is deprecated since version 1.0.0
|
|
and will be removed in version 1.2.0. Use `bar` instead.
|
|
"""
|
|
|
|
def __init__(self, *, deprecated_version, removed_version=None, hint=None):
|
|
self.deprecated_version = deprecated_version
|
|
self.removed_version = removed_version
|
|
self.hint = hint
|
|
|
|
def __call__(self, func):
|
|
message = (
|
|
f"`{func.__name__}` is deprecated since version "
|
|
f"{self.deprecated_version}"
|
|
)
|
|
if self.removed_version:
|
|
message += f" and will be removed in version {self.removed_version}."
|
|
if self.hint:
|
|
# Prepend space and make sure it closes with "."
|
|
message += f" {self.hint.rstrip('.')}."
|
|
|
|
stack_rank = _count_wrappers(func)
|
|
|
|
@functools.wraps(func)
|
|
def wrapped(*args, **kwargs):
|
|
stacklevel = 1 + self.get_stack_length(func) - stack_rank
|
|
warnings.warn(message, category=FutureWarning, stacklevel=stacklevel)
|
|
return func(*args, **kwargs)
|
|
|
|
# modify docstring to display deprecation warning
|
|
doc = f'**Deprecated:** {message}'
|
|
if wrapped.__doc__ is None:
|
|
wrapped.__doc__ = doc
|
|
else:
|
|
wrapped.__doc__ = doc + '\n\n ' + wrapped.__doc__
|
|
|
|
return wrapped
|
|
|
|
|
|
def get_bound_method_class(m):
|
|
"""Return the class for a bound method."""
|
|
return m.im_class if sys.version < '3' else m.__self__.__class__
|
|
|
|
|
|
def safe_as_int(val, atol=1e-3):
|
|
"""
|
|
Attempt to safely cast values to integer format.
|
|
|
|
Parameters
|
|
----------
|
|
val : scalar or iterable of scalars
|
|
Number or container of numbers which are intended to be interpreted as
|
|
integers, e.g., for indexing purposes, but which may not carry integer
|
|
type.
|
|
atol : float
|
|
Absolute tolerance away from nearest integer to consider values in
|
|
``val`` functionally integers.
|
|
|
|
Returns
|
|
-------
|
|
val_int : NumPy scalar or ndarray of dtype `np.int64`
|
|
Returns the input value(s) coerced to dtype `np.int64` assuming all
|
|
were within ``atol`` of the nearest integer.
|
|
|
|
Notes
|
|
-----
|
|
This operation calculates ``val`` modulo 1, which returns the mantissa of
|
|
all values. Then all mantissas greater than 0.5 are subtracted from one.
|
|
Finally, the absolute tolerance from zero is calculated. If it is less
|
|
than ``atol`` for all value(s) in ``val``, they are rounded and returned
|
|
in an integer array. Or, if ``val`` was a scalar, a NumPy scalar type is
|
|
returned.
|
|
|
|
If any value(s) are outside the specified tolerance, an informative error
|
|
is raised.
|
|
|
|
Examples
|
|
--------
|
|
>>> safe_as_int(7.0)
|
|
7
|
|
|
|
>>> safe_as_int([9, 4, 2.9999999999])
|
|
array([9, 4, 3])
|
|
|
|
>>> safe_as_int(53.1)
|
|
Traceback (most recent call last):
|
|
...
|
|
ValueError: Integer argument required but received 53.1, check inputs.
|
|
|
|
>>> safe_as_int(53.01, atol=0.01)
|
|
53
|
|
|
|
"""
|
|
mod = np.asarray(val) % 1 # Extract mantissa
|
|
|
|
# Check for and subtract any mod values > 0.5 from 1
|
|
if mod.ndim == 0: # Scalar input, cannot be indexed
|
|
if mod > 0.5:
|
|
mod = 1 - mod
|
|
else: # Iterable input, now ndarray
|
|
mod[mod > 0.5] = 1 - mod[mod > 0.5] # Test on each side of nearest int
|
|
|
|
try:
|
|
np.testing.assert_allclose(mod, 0, atol=atol)
|
|
except AssertionError:
|
|
raise ValueError(
|
|
f'Integer argument required but received ' f'{val}, check inputs.'
|
|
)
|
|
|
|
return np.round(val).astype(np.int64)
|
|
|
|
|
|
def check_shape_equality(*images):
|
|
"""Check that all images have the same shape"""
|
|
image0 = images[0]
|
|
if not all(image0.shape == image.shape for image in images[1:]):
|
|
raise ValueError('Input images must have the same dimensions.')
|
|
return
|
|
|
|
|
|
def slice_at_axis(sl, axis):
|
|
"""
|
|
Construct tuple of slices to slice an array in the given dimension.
|
|
|
|
Parameters
|
|
----------
|
|
sl : slice
|
|
The slice for the given dimension.
|
|
axis : int
|
|
The axis to which `sl` is applied. All other dimensions are left
|
|
"unsliced".
|
|
|
|
Returns
|
|
-------
|
|
sl : tuple of slices
|
|
A tuple with slices matching `shape` in length.
|
|
|
|
Examples
|
|
--------
|
|
>>> slice_at_axis(slice(None, 3, -1), 1)
|
|
(slice(None, None, None), slice(None, 3, -1), Ellipsis)
|
|
"""
|
|
return (slice(None),) * axis + (sl,) + (...,)
|
|
|
|
|
|
def reshape_nd(arr, ndim, dim):
|
|
"""Reshape a 1D array to have n dimensions, all singletons but one.
|
|
|
|
Parameters
|
|
----------
|
|
arr : array, shape (N,)
|
|
Input array
|
|
ndim : int
|
|
Number of desired dimensions of reshaped array.
|
|
dim : int
|
|
Which dimension/axis will not be singleton-sized.
|
|
|
|
Returns
|
|
-------
|
|
arr_reshaped : array, shape ([1, ...], N, [1,...])
|
|
View of `arr` reshaped to the desired shape.
|
|
|
|
Examples
|
|
--------
|
|
>>> rng = np.random.default_rng()
|
|
>>> arr = rng.random(7)
|
|
>>> reshape_nd(arr, 2, 0).shape
|
|
(7, 1)
|
|
>>> reshape_nd(arr, 3, 1).shape
|
|
(1, 7, 1)
|
|
>>> reshape_nd(arr, 4, -1).shape
|
|
(1, 1, 1, 7)
|
|
"""
|
|
if arr.ndim != 1:
|
|
raise ValueError("arr must be a 1D array")
|
|
new_shape = [1] * ndim
|
|
new_shape[dim] = -1
|
|
return np.reshape(arr, new_shape)
|
|
|
|
|
|
def check_nD(array, ndim, arg_name='image'):
|
|
"""
|
|
Verify an array meets the desired ndims and array isn't empty.
|
|
|
|
Parameters
|
|
----------
|
|
array : array-like
|
|
Input array to be validated
|
|
ndim : int or iterable of ints
|
|
Allowable ndim or ndims for the array.
|
|
arg_name : str, optional
|
|
The name of the array in the original function.
|
|
|
|
"""
|
|
array = np.asanyarray(array)
|
|
msg_incorrect_dim = "The parameter `%s` must be a %s-dimensional array"
|
|
msg_empty_array = "The parameter `%s` cannot be an empty array"
|
|
if isinstance(ndim, int):
|
|
ndim = [ndim]
|
|
if array.size == 0:
|
|
raise ValueError(msg_empty_array % (arg_name))
|
|
if array.ndim not in ndim:
|
|
raise ValueError(
|
|
msg_incorrect_dim % (arg_name, '-or-'.join([str(n) for n in ndim]))
|
|
)
|
|
|
|
|
|
def convert_to_float(image, preserve_range):
|
|
"""Convert input image to float image with the appropriate range.
|
|
|
|
Parameters
|
|
----------
|
|
image : ndarray
|
|
Input image.
|
|
preserve_range : bool
|
|
Determines if the range of the image should be kept or transformed
|
|
using img_as_float. Also see
|
|
https://scikit-image.org/docs/dev/user_guide/data_types.html
|
|
|
|
Notes
|
|
-----
|
|
* Input images with `float32` data type are not upcast.
|
|
|
|
Returns
|
|
-------
|
|
image : ndarray
|
|
Transformed version of the input.
|
|
|
|
"""
|
|
if image.dtype == np.float16:
|
|
return image.astype(np.float32)
|
|
if preserve_range:
|
|
# Convert image to double only if it is not single or double
|
|
# precision float
|
|
if image.dtype.char not in 'df':
|
|
image = image.astype(float)
|
|
else:
|
|
from ..util.dtype import img_as_float
|
|
|
|
image = img_as_float(image)
|
|
return image
|
|
|
|
|
|
def _validate_interpolation_order(image_dtype, order):
|
|
"""Validate and return spline interpolation's order.
|
|
|
|
Parameters
|
|
----------
|
|
image_dtype : dtype
|
|
Image dtype.
|
|
order : int, optional
|
|
The order of the spline interpolation. The order has to be in
|
|
the range 0-5. See `skimage.transform.warp` for detail.
|
|
|
|
Returns
|
|
-------
|
|
order : int
|
|
if input order is None, returns 0 if image_dtype is bool and 1
|
|
otherwise. Otherwise, image_dtype is checked and input order
|
|
is validated accordingly (order > 0 is not supported for bool
|
|
image dtype)
|
|
|
|
"""
|
|
|
|
if order is None:
|
|
return 0 if image_dtype == bool else 1
|
|
|
|
if order < 0 or order > 5:
|
|
raise ValueError("Spline interpolation order has to be in the " "range 0-5.")
|
|
|
|
if image_dtype == bool and order != 0:
|
|
raise ValueError(
|
|
"Input image dtype is bool. Interpolation is not defined "
|
|
"with bool data type. Please set order to 0 or explicitly "
|
|
"cast input image to another data type."
|
|
)
|
|
|
|
return order
|
|
|
|
|
|
def _to_np_mode(mode):
|
|
"""Convert padding modes from `ndi.correlate` to `np.pad`."""
|
|
mode_translation_dict = dict(nearest='edge', reflect='symmetric', mirror='reflect')
|
|
if mode in mode_translation_dict:
|
|
mode = mode_translation_dict[mode]
|
|
return mode
|
|
|
|
|
|
def _to_ndimage_mode(mode):
|
|
"""Convert from `numpy.pad` mode name to the corresponding ndimage mode."""
|
|
mode_translation_dict = dict(
|
|
constant='constant',
|
|
edge='nearest',
|
|
symmetric='reflect',
|
|
reflect='mirror',
|
|
wrap='wrap',
|
|
)
|
|
if mode not in mode_translation_dict:
|
|
raise ValueError(
|
|
f"Unknown mode: '{mode}', or cannot translate mode. The "
|
|
f"mode should be one of 'constant', 'edge', 'symmetric', "
|
|
f"'reflect', or 'wrap'. See the documentation of numpy.pad for "
|
|
f"more info."
|
|
)
|
|
return _fix_ndimage_mode(mode_translation_dict[mode])
|
|
|
|
|
|
def _fix_ndimage_mode(mode):
|
|
# SciPy 1.6.0 introduced grid variants of constant and wrap which
|
|
# have less surprising behavior for images. Use these when available
|
|
grid_modes = {'constant': 'grid-constant', 'wrap': 'grid-wrap'}
|
|
return grid_modes.get(mode, mode)
|
|
|
|
|
|
new_float_type = {
|
|
# preserved types
|
|
np.float32().dtype.char: np.float32,
|
|
np.float64().dtype.char: np.float64,
|
|
np.complex64().dtype.char: np.complex64,
|
|
np.complex128().dtype.char: np.complex128,
|
|
# altered types
|
|
np.float16().dtype.char: np.float32,
|
|
'g': np.float64, # np.float128 ; doesn't exist on windows
|
|
'G': np.complex128, # np.complex256 ; doesn't exist on windows
|
|
}
|
|
|
|
|
|
def _supported_float_type(input_dtype, allow_complex=False):
|
|
"""Return an appropriate floating-point dtype for a given dtype.
|
|
|
|
float32, float64, complex64, complex128 are preserved.
|
|
float16 is promoted to float32.
|
|
complex256 is demoted to complex128.
|
|
Other types are cast to float64.
|
|
|
|
Parameters
|
|
----------
|
|
input_dtype : np.dtype or tuple of np.dtype
|
|
The input dtype. If a tuple of multiple dtypes is provided, each
|
|
dtype is first converted to a supported floating point type and the
|
|
final dtype is then determined by applying `np.result_type` on the
|
|
sequence of supported floating point types.
|
|
allow_complex : bool, optional
|
|
If False, raise a ValueError on complex-valued inputs.
|
|
|
|
Returns
|
|
-------
|
|
float_type : dtype
|
|
Floating-point dtype for the image.
|
|
"""
|
|
if isinstance(input_dtype, tuple):
|
|
return np.result_type(*(_supported_float_type(d) for d in input_dtype))
|
|
input_dtype = np.dtype(input_dtype)
|
|
if not allow_complex and input_dtype.kind == 'c':
|
|
raise ValueError("complex valued input is not supported")
|
|
return new_float_type.get(input_dtype.char, np.float64)
|
|
|
|
|
|
def identity(image, *args, **kwargs):
|
|
"""Returns the first argument unmodified."""
|
|
return image
|
|
|
|
|
|
def as_binary_ndarray(array, *, variable_name):
|
|
"""Return `array` as a numpy.ndarray of dtype bool.
|
|
|
|
Raises
|
|
------
|
|
ValueError:
|
|
An error including the given `variable_name` if `array` can not be
|
|
safely cast to a boolean array.
|
|
"""
|
|
array = np.asarray(array)
|
|
if array.dtype != bool:
|
|
if np.any((array != 1) & (array != 0)):
|
|
raise ValueError(
|
|
f"{variable_name} array is not of dtype boolean or "
|
|
f"contains values other than 0 and 1 so cannot be "
|
|
f"safely cast to boolean array."
|
|
)
|
|
return np.asarray(array, dtype=bool)
|