using for loop to install conda package
This commit is contained in:
40
.CondaPkg/env/Lib/site-packages/pywt/__init__.py
vendored
Normal file
40
.CondaPkg/env/Lib/site-packages/pywt/__init__.py
vendored
Normal file
@@ -0,0 +1,40 @@
|
||||
# flake8: noqa
|
||||
|
||||
# Copyright (c) 2006-2012 Filip Wasilewski <http://en.ig.ma/>
|
||||
# Copyright (c) 2012-2020 The PyWavelets Developers
|
||||
# <https://github.com/PyWavelets/pywt>
|
||||
# See LICENSE for more details.
|
||||
|
||||
"""
|
||||
Discrete forward and inverse wavelet transform, stationary wavelet transform,
|
||||
wavelet packets signal decomposition and reconstruction module.
|
||||
"""
|
||||
|
||||
from __future__ import division, print_function, absolute_import
|
||||
|
||||
from ._extensions._pywt import *
|
||||
from ._functions import *
|
||||
from ._multilevel import *
|
||||
from ._multidim import *
|
||||
from ._thresholding import *
|
||||
from ._wavelet_packets import *
|
||||
from ._dwt import *
|
||||
from ._swt import *
|
||||
from ._cwt import *
|
||||
from ._mra import *
|
||||
|
||||
from . import data
|
||||
|
||||
__all__ = [s for s in dir() if not s.startswith('_')]
|
||||
try:
|
||||
# In Python 2.x the name of the tempvar leaks out of the list
|
||||
# comprehension. Delete it to not make it show up in the main namespace.
|
||||
del s
|
||||
except NameError:
|
||||
pass
|
||||
|
||||
from pywt.version import version as __version__
|
||||
|
||||
from ._pytesttester import PytestTester
|
||||
test = PytestTester(__name__)
|
||||
del PytestTester
|
||||
BIN
.CondaPkg/env/Lib/site-packages/pywt/__pycache__/__init__.cpython-311.pyc
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/__pycache__/__init__.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/__pycache__/_c99_config.cpython-311.pyc
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/__pycache__/_c99_config.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/__pycache__/_cwt.cpython-311.pyc
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/__pycache__/_cwt.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/__pycache__/_doc_utils.cpython-311.pyc
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/__pycache__/_doc_utils.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/__pycache__/_dwt.cpython-311.pyc
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/__pycache__/_dwt.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/__pycache__/_functions.cpython-311.pyc
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/__pycache__/_functions.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/__pycache__/_mra.cpython-311.pyc
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/__pycache__/_mra.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/__pycache__/_multidim.cpython-311.pyc
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/__pycache__/_multidim.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/__pycache__/_multilevel.cpython-311.pyc
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/__pycache__/_multilevel.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/__pycache__/_pytest.cpython-311.pyc
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/__pycache__/_pytest.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/__pycache__/_pytesttester.cpython-311.pyc
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/__pycache__/_pytesttester.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/__pycache__/_swt.cpython-311.pyc
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/__pycache__/_swt.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/__pycache__/_thresholding.cpython-311.pyc
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/__pycache__/_thresholding.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/__pycache__/_utils.cpython-311.pyc
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/__pycache__/_utils.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/__pycache__/_wavelet_packets.cpython-311.pyc
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/__pycache__/_wavelet_packets.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/__pycache__/conftest.cpython-311.pyc
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/__pycache__/conftest.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/__pycache__/version.cpython-311.pyc
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/__pycache__/version.cpython-311.pyc
vendored
Normal file
Binary file not shown.
3
.CondaPkg/env/Lib/site-packages/pywt/_c99_config.py
vendored
Normal file
3
.CondaPkg/env/Lib/site-packages/pywt/_c99_config.py
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
# Autogenerated file containing compile-time definitions
|
||||
|
||||
_have_c99_complex = 0
|
||||
203
.CondaPkg/env/Lib/site-packages/pywt/_cwt.py
vendored
Normal file
203
.CondaPkg/env/Lib/site-packages/pywt/_cwt.py
vendored
Normal file
@@ -0,0 +1,203 @@
|
||||
from math import floor, ceil
|
||||
|
||||
from ._extensions._pywt import (DiscreteContinuousWavelet, ContinuousWavelet,
|
||||
Wavelet, _check_dtype)
|
||||
from ._functions import integrate_wavelet, scale2frequency
|
||||
|
||||
|
||||
__all__ = ["cwt"]
|
||||
|
||||
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
# Prefer scipy.fft (new in SciPy 1.4)
|
||||
import scipy.fft
|
||||
fftmodule = scipy.fft
|
||||
next_fast_len = fftmodule.next_fast_len
|
||||
except ImportError:
|
||||
try:
|
||||
import scipy.fftpack
|
||||
fftmodule = scipy.fftpack
|
||||
next_fast_len = fftmodule.next_fast_len
|
||||
except ImportError:
|
||||
fftmodule = np.fft
|
||||
|
||||
# provide a fallback so scipy is an optional requirement
|
||||
def next_fast_len(n):
|
||||
"""Round up size to the nearest power of two.
|
||||
|
||||
Given a number of samples `n`, returns the next power of two
|
||||
following this number to take advantage of FFT speedup.
|
||||
This fallback is less efficient than `scipy.fftpack.next_fast_len`
|
||||
"""
|
||||
return 2**ceil(np.log2(n))
|
||||
|
||||
|
||||
def cwt(data, scales, wavelet, sampling_period=1., method='conv', axis=-1):
|
||||
"""
|
||||
cwt(data, scales, wavelet)
|
||||
|
||||
One dimensional Continuous Wavelet Transform.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : array_like
|
||||
Input signal
|
||||
scales : array_like
|
||||
The wavelet scales to use. One can use
|
||||
``f = scale2frequency(wavelet, scale)/sampling_period`` to determine
|
||||
what physical frequency, ``f``. Here, ``f`` is in hertz when the
|
||||
``sampling_period`` is given in seconds.
|
||||
wavelet : Wavelet object or name
|
||||
Wavelet to use
|
||||
sampling_period : float
|
||||
Sampling period for the frequencies output (optional).
|
||||
The values computed for ``coefs`` are independent of the choice of
|
||||
``sampling_period`` (i.e. ``scales`` is not scaled by the sampling
|
||||
period).
|
||||
method : {'conv', 'fft'}, optional
|
||||
The method used to compute the CWT. Can be any of:
|
||||
- ``conv`` uses ``numpy.convolve``.
|
||||
- ``fft`` uses frequency domain convolution.
|
||||
- ``auto`` uses automatic selection based on an estimate of the
|
||||
computational complexity at each scale.
|
||||
|
||||
The ``conv`` method complexity is ``O(len(scale) * len(data))``.
|
||||
The ``fft`` method is ``O(N * log2(N))`` with
|
||||
``N = len(scale) + len(data) - 1``. It is well suited for large size
|
||||
signals but slightly slower than ``conv`` on small ones.
|
||||
axis: int, optional
|
||||
Axis over which to compute the CWT. If not given, the last axis is
|
||||
used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
coefs : array_like
|
||||
Continuous wavelet transform of the input signal for the given scales
|
||||
and wavelet. The first axis of ``coefs`` corresponds to the scales.
|
||||
The remaining axes match the shape of ``data``.
|
||||
frequencies : array_like
|
||||
If the unit of sampling period are seconds and given, than frequencies
|
||||
are in hertz. Otherwise, a sampling period of 1 is assumed.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Size of coefficients arrays depends on the length of the input array and
|
||||
the length of given scales.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import pywt
|
||||
>>> import numpy as np
|
||||
>>> import matplotlib.pyplot as plt
|
||||
>>> x = np.arange(512)
|
||||
>>> y = np.sin(2*np.pi*x/32)
|
||||
>>> coef, freqs=pywt.cwt(y,np.arange(1,129),'gaus1')
|
||||
>>> plt.matshow(coef) # doctest: +SKIP
|
||||
>>> plt.show() # doctest: +SKIP
|
||||
----------
|
||||
>>> import pywt
|
||||
>>> import numpy as np
|
||||
>>> import matplotlib.pyplot as plt
|
||||
>>> t = np.linspace(-1, 1, 200, endpoint=False)
|
||||
>>> sig = np.cos(2 * np.pi * 7 * t) + np.real(np.exp(-7*(t-0.4)**2)*np.exp(1j*2*np.pi*2*(t-0.4)))
|
||||
>>> widths = np.arange(1, 31)
|
||||
>>> cwtmatr, freqs = pywt.cwt(sig, widths, 'mexh')
|
||||
>>> plt.imshow(cwtmatr, extent=[-1, 1, 1, 31], cmap='PRGn', aspect='auto',
|
||||
... vmax=abs(cwtmatr).max(), vmin=-abs(cwtmatr).max()) # doctest: +SKIP
|
||||
>>> plt.show() # doctest: +SKIP
|
||||
"""
|
||||
|
||||
# accept array_like input; make a copy to ensure a contiguous array
|
||||
dt = _check_dtype(data)
|
||||
data = np.asarray(data, dtype=dt)
|
||||
dt_cplx = np.result_type(dt, np.complex64)
|
||||
if not isinstance(wavelet, (ContinuousWavelet, Wavelet)):
|
||||
wavelet = DiscreteContinuousWavelet(wavelet)
|
||||
if np.isscalar(scales):
|
||||
scales = np.array([scales])
|
||||
if not np.isscalar(axis):
|
||||
raise np.AxisError("axis must be a scalar.")
|
||||
|
||||
dt_out = dt_cplx if wavelet.complex_cwt else dt
|
||||
out = np.empty((np.size(scales),) + data.shape, dtype=dt_out)
|
||||
precision = 10
|
||||
int_psi, x = integrate_wavelet(wavelet, precision=precision)
|
||||
int_psi = np.conj(int_psi) if wavelet.complex_cwt else int_psi
|
||||
|
||||
# convert int_psi, x to the same precision as the data
|
||||
dt_psi = dt_cplx if int_psi.dtype.kind == 'c' else dt
|
||||
int_psi = np.asarray(int_psi, dtype=dt_psi)
|
||||
x = np.asarray(x, dtype=data.real.dtype)
|
||||
|
||||
if method == 'fft':
|
||||
size_scale0 = -1
|
||||
fft_data = None
|
||||
elif not method == 'conv':
|
||||
raise ValueError("method must be 'conv' or 'fft'")
|
||||
|
||||
if data.ndim > 1:
|
||||
# move axis to be transformed last (so it is contiguous)
|
||||
data = data.swapaxes(-1, axis)
|
||||
|
||||
# reshape to (n_batch, data.shape[-1])
|
||||
data_shape_pre = data.shape
|
||||
data = data.reshape((-1, data.shape[-1]))
|
||||
|
||||
for i, scale in enumerate(scales):
|
||||
step = x[1] - x[0]
|
||||
j = np.arange(scale * (x[-1] - x[0]) + 1) / (scale * step)
|
||||
j = j.astype(int) # floor
|
||||
if j[-1] >= int_psi.size:
|
||||
j = np.extract(j < int_psi.size, j)
|
||||
int_psi_scale = int_psi[j][::-1]
|
||||
|
||||
if method == 'conv':
|
||||
if data.ndim == 1:
|
||||
conv = np.convolve(data, int_psi_scale)
|
||||
else:
|
||||
# batch convolution via loop
|
||||
conv_shape = list(data.shape)
|
||||
conv_shape[-1] += int_psi_scale.size - 1
|
||||
conv_shape = tuple(conv_shape)
|
||||
conv = np.empty(conv_shape, dtype=dt_out)
|
||||
for n in range(data.shape[0]):
|
||||
conv[n, :] = np.convolve(data[n], int_psi_scale)
|
||||
else:
|
||||
# The padding is selected for:
|
||||
# - optimal FFT complexity
|
||||
# - to be larger than the two signals length to avoid circular
|
||||
# convolution
|
||||
size_scale = next_fast_len(
|
||||
data.shape[-1] + int_psi_scale.size - 1
|
||||
)
|
||||
if size_scale != size_scale0:
|
||||
# Must recompute fft_data when the padding size changes.
|
||||
fft_data = fftmodule.fft(data, size_scale, axis=-1)
|
||||
size_scale0 = size_scale
|
||||
fft_wav = fftmodule.fft(int_psi_scale, size_scale, axis=-1)
|
||||
conv = fftmodule.ifft(fft_wav * fft_data, axis=-1)
|
||||
conv = conv[..., :data.shape[-1] + int_psi_scale.size - 1]
|
||||
|
||||
coef = - np.sqrt(scale) * np.diff(conv, axis=-1)
|
||||
if out.dtype.kind != 'c':
|
||||
coef = coef.real
|
||||
# transform axis is always -1 due to the data reshape above
|
||||
d = (coef.shape[-1] - data.shape[-1]) / 2.
|
||||
if d > 0:
|
||||
coef = coef[..., floor(d):-ceil(d)]
|
||||
elif d < 0:
|
||||
raise ValueError(
|
||||
"Selected scale of {} too small.".format(scale))
|
||||
if data.ndim > 1:
|
||||
# restore original data shape and axis position
|
||||
coef = coef.reshape(data_shape_pre)
|
||||
coef = coef.swapaxes(axis, -1)
|
||||
out[i, ...] = coef
|
||||
|
||||
frequencies = scale2frequency(wavelet, scales, precision)
|
||||
if np.isscalar(frequencies):
|
||||
frequencies = np.array([frequencies])
|
||||
frequencies /= sampling_period
|
||||
return out, frequencies
|
||||
187
.CondaPkg/env/Lib/site-packages/pywt/_doc_utils.py
vendored
Normal file
187
.CondaPkg/env/Lib/site-packages/pywt/_doc_utils.py
vendored
Normal file
@@ -0,0 +1,187 @@
|
||||
"""Utilities used to generate various figures in the documentation."""
|
||||
from itertools import product
|
||||
|
||||
import numpy as np
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
from ._dwt import pad
|
||||
|
||||
__all__ = ['wavedec_keys', 'wavedec2_keys', 'draw_2d_wp_basis',
|
||||
'draw_2d_fswavedecn_basis', 'boundary_mode_subplot']
|
||||
|
||||
|
||||
def wavedec_keys(level):
|
||||
"""Subband keys corresponding to a wavedec decomposition."""
|
||||
approx = ''
|
||||
coeffs = {}
|
||||
for lev in range(level):
|
||||
for k in ['a', 'd']:
|
||||
coeffs[approx + k] = None
|
||||
approx = 'a' * (lev + 1)
|
||||
if lev < level - 1:
|
||||
coeffs.pop(approx)
|
||||
return list(coeffs.keys())
|
||||
|
||||
|
||||
def wavedec2_keys(level):
|
||||
"""Subband keys corresponding to a wavedec2 decomposition."""
|
||||
approx = ''
|
||||
coeffs = {}
|
||||
for lev in range(level):
|
||||
for k in ['a', 'h', 'v', 'd']:
|
||||
coeffs[approx + k] = None
|
||||
approx = 'a' * (lev + 1)
|
||||
if lev < level - 1:
|
||||
coeffs.pop(approx)
|
||||
return list(coeffs.keys())
|
||||
|
||||
|
||||
def _box(bl, ur):
|
||||
"""(x, y) coordinates for the 4 lines making up a rectangular box.
|
||||
|
||||
Parameters
|
||||
==========
|
||||
bl : float
|
||||
The bottom left corner of the box
|
||||
ur : float
|
||||
The upper right corner of the box
|
||||
|
||||
Returns
|
||||
=======
|
||||
coords : 2-tuple
|
||||
The first and second elements of the tuple are the x and y coordinates
|
||||
of the box.
|
||||
"""
|
||||
xl, xr = bl[0], ur[0]
|
||||
yb, yt = bl[1], ur[1]
|
||||
box_x = [xl, xr,
|
||||
xr, xr,
|
||||
xr, xl,
|
||||
xl, xl]
|
||||
box_y = [yb, yb,
|
||||
yb, yt,
|
||||
yt, yt,
|
||||
yt, yb]
|
||||
return (box_x, box_y)
|
||||
|
||||
|
||||
def _2d_wp_basis_coords(shape, keys):
|
||||
# Coordinates of the lines to be drawn by draw_2d_wp_basis
|
||||
coords = []
|
||||
centers = {} # retain center of boxes for use in labeling
|
||||
for key in keys:
|
||||
offset_x = offset_y = 0
|
||||
for n, char in enumerate(key):
|
||||
if char in ['h', 'd']:
|
||||
offset_x += shape[0] // 2**(n + 1)
|
||||
if char in ['v', 'd']:
|
||||
offset_y += shape[1] // 2**(n + 1)
|
||||
sx = shape[0] // 2**(n + 1)
|
||||
sy = shape[1] // 2**(n + 1)
|
||||
xc, yc = _box((offset_x, -offset_y),
|
||||
(offset_x + sx, -offset_y - sy))
|
||||
coords.append((xc, yc))
|
||||
centers[key] = (offset_x + sx // 2, -offset_y - sy // 2)
|
||||
return coords, centers
|
||||
|
||||
|
||||
def draw_2d_wp_basis(shape, keys, fmt='k', plot_kwargs={}, ax=None,
|
||||
label_levels=0):
|
||||
"""Plot a 2D representation of a WaveletPacket2D basis."""
|
||||
coords, centers = _2d_wp_basis_coords(shape, keys)
|
||||
if ax is None:
|
||||
fig, ax = plt.subplots(1, 1)
|
||||
else:
|
||||
fig = ax.get_figure()
|
||||
for coord in coords:
|
||||
ax.plot(coord[0], coord[1], fmt)
|
||||
ax.set_axis_off()
|
||||
ax.axis('square')
|
||||
if label_levels > 0:
|
||||
for key, c in centers.items():
|
||||
if len(key) <= label_levels:
|
||||
ax.text(c[0], c[1], key,
|
||||
horizontalalignment='center',
|
||||
verticalalignment='center')
|
||||
return fig, ax
|
||||
|
||||
|
||||
def _2d_fswavedecn_coords(shape, levels):
|
||||
coords = []
|
||||
centers = {} # retain center of boxes for use in labeling
|
||||
for key in product(wavedec_keys(levels), repeat=2):
|
||||
(key0, key1) = key
|
||||
offsets = [0, 0]
|
||||
widths = list(shape)
|
||||
for n0, char in enumerate(key0):
|
||||
if char in ['d']:
|
||||
offsets[0] += shape[0] // 2**(n0 + 1)
|
||||
for n1, char in enumerate(key1):
|
||||
if char in ['d']:
|
||||
offsets[1] += shape[1] // 2**(n1 + 1)
|
||||
widths[0] = shape[0] // 2**(n0 + 1)
|
||||
widths[1] = shape[1] // 2**(n1 + 1)
|
||||
xc, yc = _box((offsets[0], -offsets[1]),
|
||||
(offsets[0] + widths[0], -offsets[1] - widths[1]))
|
||||
coords.append((xc, yc))
|
||||
centers[(key0, key1)] = (offsets[0] + widths[0] / 2,
|
||||
-offsets[1] - widths[1] / 2)
|
||||
return coords, centers
|
||||
|
||||
|
||||
def draw_2d_fswavedecn_basis(shape, levels, fmt='k', plot_kwargs={}, ax=None,
|
||||
label_levels=0):
|
||||
"""Plot a 2D representation of a WaveletPacket2D basis."""
|
||||
coords, centers = _2d_fswavedecn_coords(shape, levels)
|
||||
if ax is None:
|
||||
fig, ax = plt.subplots(1, 1)
|
||||
else:
|
||||
fig = ax.get_figure()
|
||||
for coord in coords:
|
||||
ax.plot(coord[0], coord[1], fmt)
|
||||
ax.set_axis_off()
|
||||
ax.axis('square')
|
||||
if label_levels > 0:
|
||||
for key, c in centers.items():
|
||||
lev = np.max([len(k) for k in key])
|
||||
if lev <= label_levels:
|
||||
ax.text(c[0], c[1], key,
|
||||
horizontalalignment='center',
|
||||
verticalalignment='center')
|
||||
return fig, ax
|
||||
|
||||
|
||||
def boundary_mode_subplot(x, mode, ax, symw=True):
|
||||
"""Plot an illustration of the boundary mode in a subplot axis."""
|
||||
|
||||
# if odd-length, periodization replicates the last sample to make it even
|
||||
if mode == 'periodization' and len(x) % 2 == 1:
|
||||
x = np.concatenate((x, (x[-1], )))
|
||||
|
||||
npad = 2 * len(x)
|
||||
t = np.arange(len(x) + 2 * npad)
|
||||
xp = pad(x, (npad, npad), mode=mode)
|
||||
|
||||
ax.plot(t, xp, 'k.')
|
||||
ax.set_title(mode)
|
||||
|
||||
# plot the original signal in red
|
||||
if mode == 'periodization':
|
||||
ax.plot(t[npad:npad + len(x) - 1], x[:-1], 'r.')
|
||||
else:
|
||||
ax.plot(t[npad:npad + len(x)], x, 'r.')
|
||||
|
||||
# add vertical bars indicating points of symmetry or boundary extension
|
||||
o2 = np.ones(2)
|
||||
left = npad
|
||||
if symw:
|
||||
step = len(x) - 1
|
||||
rng = range(-2, 4)
|
||||
else:
|
||||
left -= 0.5
|
||||
step = len(x)
|
||||
rng = range(-2, 4)
|
||||
if mode in ['smooth', 'constant', 'zero']:
|
||||
rng = range(0, 2)
|
||||
for rep in rng:
|
||||
ax.plot((left + rep * step) * o2, [xp.min() - .5, xp.max() + .5], 'k-')
|
||||
517
.CondaPkg/env/Lib/site-packages/pywt/_dwt.py
vendored
Normal file
517
.CondaPkg/env/Lib/site-packages/pywt/_dwt.py
vendored
Normal file
@@ -0,0 +1,517 @@
|
||||
from numbers import Number
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ._c99_config import _have_c99_complex
|
||||
from ._extensions._pywt import Wavelet, Modes, _check_dtype, wavelist
|
||||
from ._extensions._dwt import (dwt_single, dwt_axis, idwt_single, idwt_axis,
|
||||
upcoef as _upcoef, downcoef as _downcoef,
|
||||
dwt_max_level as _dwt_max_level,
|
||||
dwt_coeff_len as _dwt_coeff_len)
|
||||
from ._utils import _as_wavelet
|
||||
|
||||
|
||||
__all__ = ["dwt", "idwt", "downcoef", "upcoef", "dwt_max_level",
|
||||
"dwt_coeff_len", "pad"]
|
||||
|
||||
|
||||
def dwt_max_level(data_len, filter_len):
|
||||
r"""
|
||||
dwt_max_level(data_len, filter_len)
|
||||
|
||||
Compute the maximum useful level of decomposition.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data_len : int
|
||||
Input data length.
|
||||
filter_len : int, str or Wavelet
|
||||
The wavelet filter length. Alternatively, the name of a discrete
|
||||
wavelet or a Wavelet object can be specified.
|
||||
|
||||
Returns
|
||||
-------
|
||||
max_level : int
|
||||
Maximum level.
|
||||
|
||||
Notes
|
||||
-----
|
||||
The rational for the choice of levels is the maximum level where at least
|
||||
one coefficient in the output is uncorrupted by edge effects caused by
|
||||
signal extension. Put another way, decomposition stops when the signal
|
||||
becomes shorter than the FIR filter length for a given wavelet. This
|
||||
corresponds to:
|
||||
|
||||
.. max_level = floor(log2(data_len/(filter_len - 1)))
|
||||
|
||||
.. math::
|
||||
\mathtt{max\_level} = \left\lfloor\log_2\left(\mathtt{
|
||||
\frac{data\_len}{filter\_len - 1}}\right)\right\rfloor
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import pywt
|
||||
>>> w = pywt.Wavelet('sym5')
|
||||
>>> pywt.dwt_max_level(data_len=1000, filter_len=w.dec_len)
|
||||
6
|
||||
>>> pywt.dwt_max_level(1000, w)
|
||||
6
|
||||
>>> pywt.dwt_max_level(1000, 'sym5')
|
||||
6
|
||||
"""
|
||||
if isinstance(filter_len, Wavelet):
|
||||
filter_len = filter_len.dec_len
|
||||
elif isinstance(filter_len, str):
|
||||
if filter_len in wavelist(kind='discrete'):
|
||||
filter_len = Wavelet(filter_len).dec_len
|
||||
else:
|
||||
raise ValueError(
|
||||
("'{}', is not a recognized discrete wavelet. A list of "
|
||||
"supported wavelet names can be obtained via "
|
||||
"pywt.wavelist(kind='discrete')").format(filter_len))
|
||||
elif not (isinstance(filter_len, Number) and filter_len % 1 == 0):
|
||||
raise ValueError(
|
||||
"filter_len must be an integer, discrete Wavelet object, or the "
|
||||
"name of a discrete wavelet.")
|
||||
|
||||
if filter_len < 2:
|
||||
raise ValueError("invalid wavelet filter length")
|
||||
|
||||
return _dwt_max_level(data_len, filter_len)
|
||||
|
||||
|
||||
def dwt_coeff_len(data_len, filter_len, mode):
|
||||
"""
|
||||
dwt_coeff_len(data_len, filter_len, mode='symmetric')
|
||||
|
||||
Returns length of dwt output for given data length, filter length and mode
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data_len : int
|
||||
Data length.
|
||||
filter_len : int
|
||||
Filter length.
|
||||
mode : str, optional
|
||||
Signal extension mode, see :ref:`Modes <ref-modes>`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
len : int
|
||||
Length of dwt output.
|
||||
|
||||
Notes
|
||||
-----
|
||||
For all modes except periodization::
|
||||
|
||||
len(cA) == len(cD) == floor((len(data) + wavelet.dec_len - 1) / 2)
|
||||
|
||||
for periodization mode ("per")::
|
||||
|
||||
len(cA) == len(cD) == ceil(len(data) / 2)
|
||||
|
||||
"""
|
||||
if isinstance(filter_len, Wavelet):
|
||||
filter_len = filter_len.dec_len
|
||||
|
||||
return _dwt_coeff_len(data_len, filter_len, Modes.from_object(mode))
|
||||
|
||||
|
||||
def dwt(data, wavelet, mode='symmetric', axis=-1):
|
||||
"""
|
||||
dwt(data, wavelet, mode='symmetric', axis=-1)
|
||||
|
||||
Single level Discrete Wavelet Transform.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : array_like
|
||||
Input signal
|
||||
wavelet : Wavelet object or name
|
||||
Wavelet to use
|
||||
mode : str, optional
|
||||
Signal extension mode, see :ref:`Modes <ref-modes>`.
|
||||
axis: int, optional
|
||||
Axis over which to compute the DWT. If not given, the
|
||||
last axis is used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(cA, cD) : tuple
|
||||
Approximation and detail coefficients.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Length of coefficients arrays depends on the selected mode.
|
||||
For all modes except periodization:
|
||||
|
||||
``len(cA) == len(cD) == floor((len(data) + wavelet.dec_len - 1) / 2)``
|
||||
|
||||
For periodization mode ("per"):
|
||||
|
||||
``len(cA) == len(cD) == ceil(len(data) / 2)``
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import pywt
|
||||
>>> (cA, cD) = pywt.dwt([1, 2, 3, 4, 5, 6], 'db1')
|
||||
>>> cA
|
||||
array([ 2.12132034, 4.94974747, 7.77817459])
|
||||
>>> cD
|
||||
array([-0.70710678, -0.70710678, -0.70710678])
|
||||
|
||||
"""
|
||||
if not _have_c99_complex and np.iscomplexobj(data):
|
||||
data = np.asarray(data)
|
||||
cA_r, cD_r = dwt(data.real, wavelet, mode, axis)
|
||||
cA_i, cD_i = dwt(data.imag, wavelet, mode, axis)
|
||||
return (cA_r + 1j*cA_i, cD_r + 1j*cD_i)
|
||||
|
||||
# accept array_like input; make a copy to ensure a contiguous array
|
||||
dt = _check_dtype(data)
|
||||
data = np.asarray(data, dtype=dt, order='C')
|
||||
mode = Modes.from_object(mode)
|
||||
wavelet = _as_wavelet(wavelet)
|
||||
|
||||
if axis < 0:
|
||||
axis = axis + data.ndim
|
||||
if not 0 <= axis < data.ndim:
|
||||
raise np.AxisError("Axis greater than data dimensions")
|
||||
|
||||
if data.ndim == 1:
|
||||
cA, cD = dwt_single(data, wavelet, mode)
|
||||
# TODO: Check whether this makes a copy
|
||||
cA, cD = np.asarray(cA, dt), np.asarray(cD, dt)
|
||||
else:
|
||||
cA, cD = dwt_axis(data, wavelet, mode, axis=axis)
|
||||
|
||||
return (cA, cD)
|
||||
|
||||
|
||||
def idwt(cA, cD, wavelet, mode='symmetric', axis=-1):
|
||||
"""
|
||||
idwt(cA, cD, wavelet, mode='symmetric', axis=-1)
|
||||
|
||||
Single level Inverse Discrete Wavelet Transform.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
cA : array_like or None
|
||||
Approximation coefficients. If None, will be set to array of zeros
|
||||
with same shape as ``cD``.
|
||||
cD : array_like or None
|
||||
Detail coefficients. If None, will be set to array of zeros
|
||||
with same shape as ``cA``.
|
||||
wavelet : Wavelet object or name
|
||||
Wavelet to use
|
||||
mode : str, optional (default: 'symmetric')
|
||||
Signal extension mode, see :ref:`Modes <ref-modes>`.
|
||||
axis: int, optional
|
||||
Axis over which to compute the inverse DWT. If not given, the
|
||||
last axis is used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
rec: array_like
|
||||
Single level reconstruction of signal from given coefficients.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import pywt
|
||||
>>> (cA, cD) = pywt.dwt([1,2,3,4,5,6], 'db2', 'smooth')
|
||||
>>> pywt.idwt(cA, cD, 'db2', 'smooth')
|
||||
array([ 1., 2., 3., 4., 5., 6.])
|
||||
|
||||
One of the neat features of ``idwt`` is that one of the ``cA`` and ``cD``
|
||||
arguments can be set to None. In that situation the reconstruction will be
|
||||
performed using only the other one. Mathematically speaking, this is
|
||||
equivalent to passing a zero-filled array as one of the arguments.
|
||||
|
||||
>>> (cA, cD) = pywt.dwt([1,2,3,4,5,6], 'db2', 'smooth')
|
||||
>>> A = pywt.idwt(cA, None, 'db2', 'smooth')
|
||||
>>> D = pywt.idwt(None, cD, 'db2', 'smooth')
|
||||
>>> A + D
|
||||
array([ 1., 2., 3., 4., 5., 6.])
|
||||
|
||||
"""
|
||||
# TODO: Lots of possible allocations to eliminate (zeros_like, asarray(rec))
|
||||
# accept array_like input; make a copy to ensure a contiguous array
|
||||
|
||||
if cA is None and cD is None:
|
||||
raise ValueError("At least one coefficient parameter must be "
|
||||
"specified.")
|
||||
|
||||
# for complex inputs: compute real and imaginary separately then combine
|
||||
if not _have_c99_complex and (np.iscomplexobj(cA) or np.iscomplexobj(cD)):
|
||||
if cA is None:
|
||||
cD = np.asarray(cD)
|
||||
cA = np.zeros_like(cD)
|
||||
elif cD is None:
|
||||
cA = np.asarray(cA)
|
||||
cD = np.zeros_like(cA)
|
||||
return (idwt(cA.real, cD.real, wavelet, mode, axis) +
|
||||
1j*idwt(cA.imag, cD.imag, wavelet, mode, axis))
|
||||
|
||||
if cA is not None:
|
||||
dt = _check_dtype(cA)
|
||||
cA = np.asarray(cA, dtype=dt, order='C')
|
||||
if cD is not None:
|
||||
dt = _check_dtype(cD)
|
||||
cD = np.asarray(cD, dtype=dt, order='C')
|
||||
|
||||
if cA is not None and cD is not None:
|
||||
if cA.dtype != cD.dtype:
|
||||
# need to upcast to common type
|
||||
if cA.dtype.kind == 'c' or cD.dtype.kind == 'c':
|
||||
dtype = np.complex128
|
||||
else:
|
||||
dtype = np.float64
|
||||
cA = cA.astype(dtype)
|
||||
cD = cD.astype(dtype)
|
||||
elif cA is None:
|
||||
cA = np.zeros_like(cD)
|
||||
elif cD is None:
|
||||
cD = np.zeros_like(cA)
|
||||
|
||||
# cA and cD should be same dimension by here
|
||||
ndim = cA.ndim
|
||||
|
||||
mode = Modes.from_object(mode)
|
||||
wavelet = _as_wavelet(wavelet)
|
||||
|
||||
if axis < 0:
|
||||
axis = axis + ndim
|
||||
if not 0 <= axis < ndim:
|
||||
raise np.AxisError("Axis greater than coefficient dimensions")
|
||||
|
||||
if ndim == 1:
|
||||
rec = idwt_single(cA, cD, wavelet, mode)
|
||||
else:
|
||||
rec = idwt_axis(cA, cD, wavelet, mode, axis=axis)
|
||||
|
||||
return rec
|
||||
|
||||
|
||||
def downcoef(part, data, wavelet, mode='symmetric', level=1):
|
||||
"""
|
||||
downcoef(part, data, wavelet, mode='symmetric', level=1)
|
||||
|
||||
Partial Discrete Wavelet Transform data decomposition.
|
||||
|
||||
Similar to ``pywt.dwt``, but computes only one set of coefficients.
|
||||
Useful when you need only approximation or only details at the given level.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
part : str
|
||||
Coefficients type:
|
||||
|
||||
* 'a' - approximations reconstruction is performed
|
||||
* 'd' - details reconstruction is performed
|
||||
|
||||
data : array_like
|
||||
Input signal.
|
||||
wavelet : Wavelet object or name
|
||||
Wavelet to use
|
||||
mode : str, optional
|
||||
Signal extension mode, see :ref:`Modes <ref-modes>`.
|
||||
level : int, optional
|
||||
Decomposition level. Default is 1.
|
||||
|
||||
Returns
|
||||
-------
|
||||
coeffs : ndarray
|
||||
1-D array of coefficients.
|
||||
|
||||
See Also
|
||||
--------
|
||||
upcoef
|
||||
|
||||
"""
|
||||
if not _have_c99_complex and np.iscomplexobj(data):
|
||||
return (downcoef(part, data.real, wavelet, mode, level) +
|
||||
1j*downcoef(part, data.imag, wavelet, mode, level))
|
||||
# accept array_like input; make a copy to ensure a contiguous array
|
||||
dt = _check_dtype(data)
|
||||
data = np.asarray(data, dtype=dt, order='C')
|
||||
if data.ndim > 1:
|
||||
raise ValueError("downcoef only supports 1d data.")
|
||||
if part not in 'ad':
|
||||
raise ValueError("Argument 1 must be 'a' or 'd', not '%s'." % part)
|
||||
mode = Modes.from_object(mode)
|
||||
wavelet = _as_wavelet(wavelet)
|
||||
return np.asarray(_downcoef(part == 'a', data, wavelet, mode, level))
|
||||
|
||||
|
||||
def upcoef(part, coeffs, wavelet, level=1, take=0):
|
||||
"""
|
||||
upcoef(part, coeffs, wavelet, level=1, take=0)
|
||||
|
||||
Direct reconstruction from coefficients.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
part : str
|
||||
Coefficients type:
|
||||
* 'a' - approximations reconstruction is performed
|
||||
* 'd' - details reconstruction is performed
|
||||
coeffs : array_like
|
||||
Coefficients array to reconstruct
|
||||
wavelet : Wavelet object or name
|
||||
Wavelet to use
|
||||
level : int, optional
|
||||
Multilevel reconstruction level. Default is 1.
|
||||
take : int, optional
|
||||
Take central part of length equal to 'take' from the result.
|
||||
Default is 0.
|
||||
|
||||
Returns
|
||||
-------
|
||||
rec : ndarray
|
||||
1-D array with reconstructed data from coefficients.
|
||||
|
||||
See Also
|
||||
--------
|
||||
downcoef
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import pywt
|
||||
>>> data = [1,2,3,4,5,6]
|
||||
>>> (cA, cD) = pywt.dwt(data, 'db2', 'smooth')
|
||||
>>> pywt.upcoef('a', cA, 'db2') + pywt.upcoef('d', cD, 'db2')
|
||||
array([-0.25 , -0.4330127 , 1. , 2. , 3. ,
|
||||
4. , 5. , 6. , 1.78589838, -1.03108891])
|
||||
>>> n = len(data)
|
||||
>>> pywt.upcoef('a', cA, 'db2', take=n) + pywt.upcoef('d', cD, 'db2', take=n)
|
||||
array([ 1., 2., 3., 4., 5., 6.])
|
||||
|
||||
"""
|
||||
if not _have_c99_complex and np.iscomplexobj(coeffs):
|
||||
return (upcoef(part, coeffs.real, wavelet, level, take) +
|
||||
1j*upcoef(part, coeffs.imag, wavelet, level, take))
|
||||
# accept array_like input; make a copy to ensure a contiguous array
|
||||
dt = _check_dtype(coeffs)
|
||||
coeffs = np.asarray(coeffs, dtype=dt, order='C')
|
||||
if coeffs.ndim > 1:
|
||||
raise ValueError("upcoef only supports 1d coeffs.")
|
||||
wavelet = _as_wavelet(wavelet)
|
||||
if part not in 'ad':
|
||||
raise ValueError("Argument 1 must be 'a' or 'd', not '%s'." % part)
|
||||
return np.asarray(_upcoef(part == 'a', coeffs, wavelet, level, take))
|
||||
|
||||
|
||||
def pad(x, pad_widths, mode):
|
||||
"""Extend a 1D signal using a given boundary mode.
|
||||
|
||||
This function operates like :func:`numpy.pad` but supports all signal
|
||||
extension modes that can be used by PyWavelets discrete wavelet transforms.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : ndarray
|
||||
The array to pad
|
||||
pad_widths : {sequence, array_like, int}
|
||||
Number of values padded to the edges of each axis.
|
||||
``((before_1, after_1), … (before_N, after_N))`` unique pad widths for
|
||||
each axis. ``((before, after),)`` yields same before and after pad for
|
||||
each axis. ``(pad,)`` or int is a shortcut for
|
||||
``before = after = pad width`` for all axes.
|
||||
mode : str, optional
|
||||
Signal extension mode, see :ref:`Modes <ref-modes>`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
pad : ndarray
|
||||
Padded array of rank equal to array with shape increased according to
|
||||
``pad_widths``.
|
||||
|
||||
Notes
|
||||
-----
|
||||
The performance of padding in dimensions > 1 may be substantially slower
|
||||
for modes ``'smooth'`` and ``'antisymmetric'`` as these modes are not
|
||||
supported efficiently by the underlying :func:`numpy.pad` function.
|
||||
|
||||
Note that the behavior of the ``'constant'`` mode here follows the
|
||||
PyWavelets convention which is different from NumPy (it is equivalent to
|
||||
``mode='edge'`` in :func:`numpy.pad`).
|
||||
"""
|
||||
x = np.asanyarray(x)
|
||||
|
||||
# process pad_widths exactly as in numpy.pad
|
||||
pad_widths = np.array(pad_widths)
|
||||
pad_widths = np.round(pad_widths).astype(np.intp, copy=False)
|
||||
if pad_widths.min() < 0:
|
||||
raise ValueError("pad_widths must be > 0")
|
||||
pad_widths = np.broadcast_to(pad_widths, (x.ndim, 2)).tolist()
|
||||
|
||||
if mode in ['symmetric', 'reflect']:
|
||||
xp = np.pad(x, pad_widths, mode=mode)
|
||||
elif mode in ['periodic', 'periodization']:
|
||||
if mode == 'periodization':
|
||||
# Promote odd-sized dimensions to even length by duplicating the
|
||||
# last value.
|
||||
edge_pad_widths = [(0, x.shape[ax] % 2)
|
||||
for ax in range(x.ndim)]
|
||||
x = np.pad(x, edge_pad_widths, mode='edge')
|
||||
xp = np.pad(x, pad_widths, mode='wrap')
|
||||
elif mode == 'zero':
|
||||
xp = np.pad(x, pad_widths, mode='constant', constant_values=0)
|
||||
elif mode == 'constant':
|
||||
xp = np.pad(x, pad_widths, mode='edge')
|
||||
elif mode == 'smooth':
|
||||
def pad_smooth(vector, pad_width, iaxis, kwargs):
|
||||
# smooth extension to left
|
||||
left = vector[pad_width[0]]
|
||||
slope_left = (left - vector[pad_width[0] + 1])
|
||||
vector[:pad_width[0]] = \
|
||||
left + np.arange(pad_width[0], 0, -1) * slope_left
|
||||
|
||||
# smooth extension to right
|
||||
right = vector[-pad_width[1] - 1]
|
||||
slope_right = (right - vector[-pad_width[1] - 2])
|
||||
vector[-pad_width[1]:] = \
|
||||
right + np.arange(1, pad_width[1] + 1) * slope_right
|
||||
return vector
|
||||
xp = np.pad(x, pad_widths, pad_smooth)
|
||||
elif mode == 'antisymmetric':
|
||||
def pad_antisymmetric(vector, pad_width, iaxis, kwargs):
|
||||
# smooth extension to left
|
||||
# implement by flipping portions symmetric padding
|
||||
npad_l, npad_r = pad_width
|
||||
vsize_nonpad = vector.size - npad_l - npad_r
|
||||
# Note: must modify vector in-place
|
||||
vector[:] = np.pad(vector[pad_width[0]:-pad_width[-1]],
|
||||
pad_width, mode='symmetric')
|
||||
vp = vector
|
||||
r_edge = npad_l + vsize_nonpad - 1
|
||||
l_edge = npad_l
|
||||
# width of each reflected segment
|
||||
seg_width = vsize_nonpad
|
||||
# flip reflected segments on the right of the original signal
|
||||
n = 1
|
||||
while r_edge <= vp.size:
|
||||
segment_slice = slice(r_edge + 1,
|
||||
min(r_edge + 1 + seg_width, vp.size))
|
||||
if n % 2:
|
||||
vp[segment_slice] *= -1
|
||||
r_edge += seg_width
|
||||
n += 1
|
||||
|
||||
# flip reflected segments on the left of the original signal
|
||||
n = 1
|
||||
while l_edge >= 0:
|
||||
segment_slice = slice(max(0, l_edge - seg_width), l_edge)
|
||||
if n % 2:
|
||||
vp[segment_slice] *= -1
|
||||
l_edge -= seg_width
|
||||
n += 1
|
||||
return vector
|
||||
xp = np.pad(x, pad_widths, pad_antisymmetric)
|
||||
elif mode == 'antireflect':
|
||||
xp = np.pad(x, pad_widths, mode='reflect', reflect_type='odd')
|
||||
else:
|
||||
raise ValueError(
|
||||
("unsupported mode: {}. The supported modes are {}").format(
|
||||
mode, Modes.modes))
|
||||
return xp
|
||||
0
.CondaPkg/env/Lib/site-packages/pywt/_extensions/__init__.py
vendored
Normal file
0
.CondaPkg/env/Lib/site-packages/pywt/_extensions/__init__.py
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/_extensions/__pycache__/__init__.cpython-311.pyc
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/_extensions/__pycache__/__init__.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/_extensions/_cwt.cp311-win_amd64.pyd
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/_extensions/_cwt.cp311-win_amd64.pyd
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/_extensions/_dwt.cp311-win_amd64.pyd
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/_extensions/_dwt.cp311-win_amd64.pyd
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/_extensions/_pywt.cp311-win_amd64.pyd
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/_extensions/_pywt.cp311-win_amd64.pyd
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/_extensions/_swt.cp311-win_amd64.pyd
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/_extensions/_swt.cp311-win_amd64.pyd
vendored
Normal file
Binary file not shown.
265
.CondaPkg/env/Lib/site-packages/pywt/_functions.py
vendored
Normal file
265
.CondaPkg/env/Lib/site-packages/pywt/_functions.py
vendored
Normal file
@@ -0,0 +1,265 @@
|
||||
# Copyright (c) 2006-2012 Filip Wasilewski <http://en.ig.ma/>
|
||||
# Copyright (c) 2012-2016 The PyWavelets Developers
|
||||
# <https://github.com/PyWavelets/pywt>
|
||||
# See COPYING for license details.
|
||||
|
||||
"""
|
||||
Other wavelet related functions.
|
||||
"""
|
||||
|
||||
from __future__ import division, print_function, absolute_import
|
||||
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
from numpy.fft import fft
|
||||
|
||||
from ._extensions._pywt import DiscreteContinuousWavelet, Wavelet, ContinuousWavelet
|
||||
|
||||
|
||||
__all__ = ["integrate_wavelet", "central_frequency",
|
||||
"scale2frequency", "frequency2scale", "qmf",
|
||||
"orthogonal_filter_bank",
|
||||
"intwave", "centrfrq", "scal2frq", "orthfilt"]
|
||||
|
||||
|
||||
_DEPRECATION_MSG = ("`{old}` has been renamed to `{new}` and will "
|
||||
"be removed in a future version of pywt.")
|
||||
|
||||
|
||||
def _integrate(arr, step):
|
||||
integral = np.cumsum(arr)
|
||||
integral *= step
|
||||
return integral
|
||||
|
||||
|
||||
def intwave(*args, **kwargs):
|
||||
msg = _DEPRECATION_MSG.format(old='intwave', new='integrate_wavelet')
|
||||
warnings.warn(msg, DeprecationWarning)
|
||||
return integrate_wavelet(*args, **kwargs)
|
||||
|
||||
|
||||
def centrfrq(*args, **kwargs):
|
||||
msg = _DEPRECATION_MSG.format(old='centrfrq', new='central_frequency')
|
||||
warnings.warn(msg, DeprecationWarning)
|
||||
return central_frequency(*args, **kwargs)
|
||||
|
||||
|
||||
def scal2frq(*args, **kwargs):
|
||||
msg = _DEPRECATION_MSG.format(old='scal2frq', new='scale2frequency')
|
||||
warnings.warn(msg, DeprecationWarning)
|
||||
return scale2frequency(*args, **kwargs)
|
||||
|
||||
|
||||
def orthfilt(*args, **kwargs):
|
||||
msg = _DEPRECATION_MSG.format(old='orthfilt', new='orthogonal_filter_bank')
|
||||
warnings.warn(msg, DeprecationWarning)
|
||||
return orthogonal_filter_bank(*args, **kwargs)
|
||||
|
||||
|
||||
def integrate_wavelet(wavelet, precision=8):
|
||||
"""
|
||||
Integrate `psi` wavelet function from -Inf to x using the rectangle
|
||||
integration method.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
wavelet : Wavelet instance or str
|
||||
Wavelet to integrate. If a string, should be the name of a wavelet.
|
||||
precision : int, optional
|
||||
Precision that will be used for wavelet function
|
||||
approximation computed with the wavefun(level=precision)
|
||||
Wavelet's method (default: 8).
|
||||
|
||||
Returns
|
||||
-------
|
||||
[int_psi, x] :
|
||||
for orthogonal wavelets
|
||||
[int_psi_d, int_psi_r, x] :
|
||||
for other wavelets
|
||||
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> from pywt import Wavelet, integrate_wavelet
|
||||
>>> wavelet1 = Wavelet('db2')
|
||||
>>> [int_psi, x] = integrate_wavelet(wavelet1, precision=5)
|
||||
>>> wavelet2 = Wavelet('bior1.3')
|
||||
>>> [int_psi_d, int_psi_r, x] = integrate_wavelet(wavelet2, precision=5)
|
||||
|
||||
"""
|
||||
# FIXME: this function should really use scipy.integrate.quad
|
||||
|
||||
if type(wavelet) in (tuple, list):
|
||||
msg = ("Integration of a general signal is deprecated "
|
||||
"and will be removed in a future version of pywt.")
|
||||
warnings.warn(msg, DeprecationWarning)
|
||||
elif not isinstance(wavelet, (Wavelet, ContinuousWavelet)):
|
||||
wavelet = DiscreteContinuousWavelet(wavelet)
|
||||
|
||||
if type(wavelet) in (tuple, list):
|
||||
psi, x = np.asarray(wavelet[0]), np.asarray(wavelet[1])
|
||||
step = x[1] - x[0]
|
||||
return _integrate(psi, step), x
|
||||
|
||||
functions_approximations = wavelet.wavefun(precision)
|
||||
|
||||
if len(functions_approximations) == 2: # continuous wavelet
|
||||
psi, x = functions_approximations
|
||||
step = x[1] - x[0]
|
||||
return _integrate(psi, step), x
|
||||
|
||||
elif len(functions_approximations) == 3: # orthogonal wavelet
|
||||
phi, psi, x = functions_approximations
|
||||
step = x[1] - x[0]
|
||||
return _integrate(psi, step), x
|
||||
|
||||
else: # biorthogonal wavelet
|
||||
phi_d, psi_d, phi_r, psi_r, x = functions_approximations
|
||||
step = x[1] - x[0]
|
||||
return _integrate(psi_d, step), _integrate(psi_r, step), x
|
||||
|
||||
|
||||
def central_frequency(wavelet, precision=8):
|
||||
"""
|
||||
Computes the central frequency of the `psi` wavelet function.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
wavelet : Wavelet instance, str or tuple
|
||||
Wavelet to integrate. If a string, should be the name of a wavelet.
|
||||
precision : int, optional
|
||||
Precision that will be used for wavelet function
|
||||
approximation computed with the wavefun(level=precision)
|
||||
Wavelet's method (default: 8).
|
||||
|
||||
Returns
|
||||
-------
|
||||
scalar
|
||||
|
||||
"""
|
||||
|
||||
if not isinstance(wavelet, (Wavelet, ContinuousWavelet)):
|
||||
wavelet = DiscreteContinuousWavelet(wavelet)
|
||||
|
||||
functions_approximations = wavelet.wavefun(precision)
|
||||
|
||||
if len(functions_approximations) == 2:
|
||||
psi, x = functions_approximations
|
||||
else:
|
||||
# (psi, x) for (phi, psi, x)
|
||||
# (psi_d, x) for (phi_d, psi_d, phi_r, psi_r, x)
|
||||
psi, x = functions_approximations[1], functions_approximations[-1]
|
||||
|
||||
domain = float(x[-1] - x[0])
|
||||
assert domain > 0
|
||||
|
||||
index = np.argmax(abs(fft(psi)[1:])) + 2
|
||||
if index > len(psi) / 2:
|
||||
index = len(psi) - index + 2
|
||||
|
||||
return 1.0 / (domain / (index - 1))
|
||||
|
||||
|
||||
def scale2frequency(wavelet, scale, precision=8):
|
||||
"""Convert from CWT "scale" to normalized frequency.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
wavelet : Wavelet instance or str
|
||||
Wavelet to integrate. If a string, should be the name of a wavelet.
|
||||
scale : scalar
|
||||
The scale of the CWT.
|
||||
precision : int, optional
|
||||
Precision that will be used for wavelet function approximation computed
|
||||
with ``wavelet.wavefun(level=precision)``. Default is 8.
|
||||
|
||||
Returns
|
||||
-------
|
||||
freq : scalar
|
||||
Frequency normalized to the sampling frequency. In other words, for a
|
||||
sampling interval of `dt` seconds, the normalized frequency of 1.0
|
||||
corresponds to (`1/dt` Hz).
|
||||
|
||||
"""
|
||||
return central_frequency(wavelet, precision=precision) / scale
|
||||
|
||||
def frequency2scale(wavelet, freq, precision=8):
|
||||
"""Convert from to normalized frequency to CWT "scale".
|
||||
|
||||
Parameters
|
||||
----------
|
||||
wavelet : Wavelet instance or str
|
||||
Wavelet to integrate. If a string, should be the name of a wavelet.
|
||||
freq : scalar
|
||||
Frequency, normalized so that the sampling frequency corresponds to a
|
||||
value of 1.0.
|
||||
precision : int, optional
|
||||
Precision that will be used for wavelet function approximation computed
|
||||
with ``wavelet.wavefun(level=precision)``. Default is 8.
|
||||
|
||||
Returns
|
||||
-------
|
||||
scale : scalar
|
||||
|
||||
"""
|
||||
return central_frequency(wavelet, precision=precision) / freq
|
||||
|
||||
def qmf(filt):
|
||||
"""
|
||||
Returns the Quadrature Mirror Filter(QMF).
|
||||
|
||||
The magnitude response of QMF is mirror image about `pi/2` of that of the
|
||||
input filter.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
filt : array_like
|
||||
Input filter for which QMF needs to be computed.
|
||||
|
||||
Returns
|
||||
-------
|
||||
qm_filter : ndarray
|
||||
Quadrature mirror of the input filter.
|
||||
|
||||
"""
|
||||
qm_filter = np.array(filt)[::-1]
|
||||
qm_filter[1::2] = -qm_filter[1::2]
|
||||
return qm_filter
|
||||
|
||||
|
||||
def orthogonal_filter_bank(scaling_filter):
|
||||
"""
|
||||
Returns the orthogonal filter bank.
|
||||
|
||||
The orthogonal filter bank consists of the HPFs and LPFs at
|
||||
decomposition and reconstruction stage for the input scaling filter.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
scaling_filter : array_like
|
||||
Input scaling filter (father wavelet).
|
||||
|
||||
Returns
|
||||
-------
|
||||
orth_filt_bank : tuple of 4 ndarrays
|
||||
The orthogonal filter bank of the input scaling filter in the order :
|
||||
1] Decomposition LPF
|
||||
2] Decomposition HPF
|
||||
3] Reconstruction LPF
|
||||
4] Reconstruction HPF
|
||||
|
||||
"""
|
||||
if not (len(scaling_filter) % 2 == 0):
|
||||
raise ValueError("`scaling_filter` length has to be even.")
|
||||
|
||||
scaling_filter = np.asarray(scaling_filter, dtype=np.float64)
|
||||
|
||||
rec_lo = np.sqrt(2) * scaling_filter / np.sum(scaling_filter)
|
||||
dec_lo = rec_lo[::-1]
|
||||
|
||||
rec_hi = qmf(rec_lo)
|
||||
dec_hi = rec_hi[::-1]
|
||||
|
||||
orth_filt_bank = (dec_lo, dec_hi, rec_lo, rec_hi)
|
||||
return orth_filt_bank
|
||||
427
.CondaPkg/env/Lib/site-packages/pywt/_mra.py
vendored
Normal file
427
.CondaPkg/env/Lib/site-packages/pywt/_mra.py
vendored
Normal file
@@ -0,0 +1,427 @@
|
||||
from functools import partial, reduce
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ._multilevel import (_prep_axes_wavedecn, wavedec, wavedec2, wavedecn,
|
||||
waverec, waverec2, waverecn)
|
||||
from ._swt import iswt, iswt2, iswtn, swt, swt2, swt_max_level, swtn
|
||||
from ._utils import _modes_per_axis, _wavelets_per_axis
|
||||
|
||||
__all__ = ["mra", "mra2", "mran", "imra", "imra2", "imran"]
|
||||
|
||||
|
||||
def mra(data, wavelet, level=None, axis=-1, transform='swt',
|
||||
mode='periodization'):
|
||||
"""Forward 1D multiresolution analysis.
|
||||
|
||||
It is a projection onto the wavelet subspaces.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data: array_like
|
||||
Input data
|
||||
wavelet : Wavelet object or name string
|
||||
Wavelet to use
|
||||
level : int, optional
|
||||
Decomposition level (must be >= 0). If level is None (default) then it
|
||||
will be calculated using the `dwt_max_level` function.
|
||||
axis: int, optional
|
||||
Axis over which to compute the DWT. If not given, the last axis is
|
||||
used. Currently only available when ``transform='dwt'``.
|
||||
transform : {'dwt', 'swt'}
|
||||
Whether to use the DWT or SWT for the transforms.
|
||||
mode : str, optional
|
||||
Signal extension mode, see `Modes` (default: 'symmetric'). This option
|
||||
is only used when transform='dwt'.
|
||||
|
||||
Returns
|
||||
-------
|
||||
[cAn, {details_level_n}, ... {details_level_1}] : list
|
||||
For more information, see the detailed description in `wavedec`
|
||||
|
||||
See Also
|
||||
--------
|
||||
imra, swt
|
||||
|
||||
Notes
|
||||
-----
|
||||
This is sometimes referred to as an additive decomposition because the
|
||||
inverse transform (``imra``) is just the sum of the coefficient arrays
|
||||
[1]_. The decomposition using ``transform='dwt'`` corresponds to section
|
||||
2.2 while that using an undecimated transform (``transform='swt'``) is
|
||||
described in section 3.2 and appendix A.
|
||||
|
||||
This transform does not share the variance partition property of ``swt``
|
||||
with `norm=True`. It does however, result in coefficients that are
|
||||
temporally aligned regardless of the symmetry of the wavelet used.
|
||||
|
||||
The redundancy of this transform is ``(level + 1)``.
|
||||
|
||||
References
|
||||
----------
|
||||
.. [1] Donald B. Percival and Harold O. Mofjeld. Analysis of Subtidal
|
||||
Coastal Sea Level Fluctuations Using Wavelets. Journal of the American
|
||||
Statistical Association Vol. 92, No. 439 (Sep., 1997), pp. 868-880.
|
||||
https://doi.org/10.2307/2965551
|
||||
|
||||
"""
|
||||
if transform == 'swt':
|
||||
if mode != 'periodization':
|
||||
raise ValueError(
|
||||
"transform swt only supports mode='periodization'")
|
||||
kwargs = dict(wavelet=wavelet, axis=axis, norm=True)
|
||||
forward = partial(swt, level=level, trim_approx=True, **kwargs)
|
||||
inverse = partial(iswt, **kwargs)
|
||||
is_swt = True
|
||||
elif transform == 'dwt':
|
||||
kwargs = dict(wavelet=wavelet, mode=mode, axis=axis)
|
||||
forward = partial(wavedec, level=level, **kwargs)
|
||||
inverse = partial(waverec, **kwargs)
|
||||
is_swt = False
|
||||
else:
|
||||
raise ValueError("unrecognized transform: {}".format(transform))
|
||||
|
||||
wav_coeffs = forward(data)
|
||||
|
||||
mra_coeffs = []
|
||||
nc = len(wav_coeffs)
|
||||
|
||||
if is_swt:
|
||||
# replicate same zeros array to save memory
|
||||
z = np.zeros_like(wav_coeffs[0])
|
||||
tmp = [z, ] * nc
|
||||
else:
|
||||
# zero arrays have variable size in DWT case
|
||||
tmp = [np.zeros_like(c) for c in wav_coeffs]
|
||||
|
||||
for j in range(nc):
|
||||
# tmp has arrays of zeros except for the jth entry
|
||||
tmp[j] = wav_coeffs[j]
|
||||
|
||||
# reconstruct
|
||||
rec = inverse(tmp)
|
||||
if rec.shape != data.shape:
|
||||
# trim any excess coefficients
|
||||
rec = rec[tuple([slice(sz) for sz in data.shape])]
|
||||
mra_coeffs.append(rec)
|
||||
|
||||
# restore zeros
|
||||
if is_swt:
|
||||
tmp[j] = z
|
||||
else:
|
||||
tmp[j] = np.zeros_like(tmp[j])
|
||||
return mra_coeffs
|
||||
|
||||
|
||||
def imra(mra_coeffs):
|
||||
"""Inverse 1D multiresolution analysis via summation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mra_coeffs : list of ndarray
|
||||
Multiresolution analysis coefficients as returned by `mra`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
rec : ndarray
|
||||
The reconstructed signal.
|
||||
|
||||
See Also
|
||||
--------
|
||||
mra
|
||||
|
||||
References
|
||||
----------
|
||||
.. [1] Donald B. Percival and Harold O. Mofjeld. Analysis of Subtidal
|
||||
Coastal Sea Level Fluctuations Using Wavelets. Journal of the American
|
||||
Statistical Association Vol. 92, No. 439 (Sep., 1997), pp. 868-880.
|
||||
https://doi.org/10.2307/2965551
|
||||
"""
|
||||
return reduce(lambda x, y: x + y, mra_coeffs)
|
||||
|
||||
|
||||
def mra2(data, wavelet, level=None, axes=(-2, -1), transform='swt2',
|
||||
mode='periodization'):
|
||||
"""Forward 2D multiresolution analysis.
|
||||
|
||||
It is a projection onto wavelet subspaces.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data: array_like
|
||||
Input data
|
||||
wavelet : Wavelet object or name string, or 2-tuple of wavelets
|
||||
Wavelet to use. This can also be a tuple containing a wavelet to
|
||||
apply along each axis in `axes`.
|
||||
level : int, optional
|
||||
Decomposition level (must be >= 0). If level is None (default) then it
|
||||
will be calculated using the `dwt_max_level` function.
|
||||
axes : 2-tuple of ints, optional
|
||||
Axes over which to compute the DWT. Repeated elements are not allowed.
|
||||
Currently only available when ``transform='dwt2'``.
|
||||
transform : {'dwt2', 'swt2'}
|
||||
Whether to use the DWT or SWT for the transforms.
|
||||
mode : str or 2-tuple of str, optional
|
||||
Signal extension mode, see `Modes` (default: 'symmetric'). This option
|
||||
is only used when transform='dwt2'.
|
||||
|
||||
Returns
|
||||
-------
|
||||
coeffs : list
|
||||
For more information, see the detailed description in `wavedec2`
|
||||
|
||||
Notes
|
||||
-----
|
||||
This is sometimes referred to as an additive decomposition because the
|
||||
inverse transform (``imra2``) is just the sum of the coefficient arrays
|
||||
[1]_. The decomposition using ``transform='dwt'`` corresponds to section
|
||||
2.2 while that using an undecimated transform (``transform='swt'``) is
|
||||
described in section 3.2 and appendix A.
|
||||
|
||||
This transform does not share the variance partition property of ``swt2``
|
||||
with `norm=True`. It does however, result in coefficients that are
|
||||
temporally aligned regardless of the symmetry of the wavelet used.
|
||||
|
||||
The redundancy of this transform is ``3 * level + 1``.
|
||||
|
||||
See Also
|
||||
--------
|
||||
imra2, swt2
|
||||
|
||||
References
|
||||
----------
|
||||
.. [1] Donald B. Percival and Harold O. Mofjeld. Analysis of Subtidal
|
||||
Coastal Sea Level Fluctuations Using Wavelets. Journal of the American
|
||||
Statistical Association Vol. 92, No. 439 (Sep., 1997), pp. 868-880.
|
||||
https://doi.org/10.2307/2965551
|
||||
"""
|
||||
if transform == 'swt2':
|
||||
if mode != 'periodization':
|
||||
raise ValueError(
|
||||
"transform swt only supports mode='periodization'")
|
||||
if level is None:
|
||||
level = min(swt_max_level(s) for s in data.shape)
|
||||
kwargs = dict(wavelet=wavelet, axes=axes, norm=True)
|
||||
forward = partial(swt2, level=level, trim_approx=True, **kwargs)
|
||||
inverse = partial(iswt2, **kwargs)
|
||||
elif transform == 'dwt2':
|
||||
kwargs = dict(wavelet=wavelet, mode=mode, axes=axes)
|
||||
forward = partial(wavedec2, level=level, **kwargs)
|
||||
inverse = partial(waverec2, **kwargs)
|
||||
else:
|
||||
raise ValueError("unrecognized transform: {}".format(transform))
|
||||
|
||||
wav_coeffs = forward(data)
|
||||
|
||||
mra_coeffs = []
|
||||
nc = len(wav_coeffs)
|
||||
z = np.zeros_like(wav_coeffs[0])
|
||||
tmp = [z]
|
||||
for j in range(1, nc):
|
||||
tmp.append([np.zeros_like(c) for c in wav_coeffs[j]])
|
||||
|
||||
# tmp has arrays of zeros except for the jth entry
|
||||
tmp[0] = wav_coeffs[0]
|
||||
# reconstruct
|
||||
rec = inverse(tmp)
|
||||
if rec.shape != data.shape:
|
||||
# trim any excess coefficients
|
||||
rec = rec[tuple([slice(sz) for sz in data.shape])]
|
||||
mra_coeffs.append(rec)
|
||||
# restore zeros
|
||||
tmp[0] = z
|
||||
|
||||
for j in range(1, nc):
|
||||
dcoeffs = []
|
||||
for n in range(3):
|
||||
# tmp has arrays of zeros except for the jth entry
|
||||
z = tmp[j][n]
|
||||
tmp[j][n] = wav_coeffs[j][n]
|
||||
# reconstruct
|
||||
rec = inverse(tmp)
|
||||
if rec.shape != data.shape:
|
||||
# trim any excess coefficients
|
||||
rec = rec[tuple([slice(sz) for sz in data.shape])]
|
||||
dcoeffs.append(rec)
|
||||
# restore zeros
|
||||
tmp[j][n] = z
|
||||
mra_coeffs.append(tuple(dcoeffs))
|
||||
return mra_coeffs
|
||||
|
||||
|
||||
def imra2(mra_coeffs):
|
||||
"""Inverse 2D multiresolution analysis via summation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mra_coeffs : list
|
||||
Multiresolution analysis coefficients as returned by `mra2`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
rec : ndarray
|
||||
The reconstructed signal.
|
||||
|
||||
See Also
|
||||
--------
|
||||
mra2
|
||||
|
||||
References
|
||||
----------
|
||||
.. [1] Donald B. Percival and Harold O. Mofjeld. Analysis of Subtidal
|
||||
Coastal Sea Level Fluctuations Using Wavelets. Journal of the American
|
||||
Statistical Association Vol. 92, No. 439 (Sep., 1997), pp. 868-880.
|
||||
https://doi.org/10.2307/2965551
|
||||
"""
|
||||
rec = mra_coeffs[0]
|
||||
for j in range(1, len(mra_coeffs)):
|
||||
for n in range(3):
|
||||
rec += mra_coeffs[j][n]
|
||||
return rec
|
||||
|
||||
|
||||
def mran(data, wavelet, level=None, axes=None, transform='swtn',
|
||||
mode='periodization'):
|
||||
"""Forward nD multiresolution analysis.
|
||||
|
||||
It is a projection onto the wavelet subspaces.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data: array_like
|
||||
Input data
|
||||
wavelet : Wavelet object or name string, or tuple of wavelets
|
||||
Wavelet to use. This can also be a tuple containing a wavelet to
|
||||
apply along each axis in `axes`.
|
||||
level : int, optional
|
||||
Decomposition level (must be >= 0). If level is None (default) then it
|
||||
will be calculated using the `dwt_max_level` function.
|
||||
axes : tuple of ints, optional
|
||||
Axes over which to compute the DWT. Repeated elements are not allowed.
|
||||
transform : {'dwtn', 'swtn'}
|
||||
Whether to use the DWT or SWT for the transforms.
|
||||
mode : str or tuple of str, optional
|
||||
Signal extension mode, see `Modes` (default: 'symmetric'). This option
|
||||
is only used when transform='dwtn'.
|
||||
|
||||
Returns
|
||||
-------
|
||||
coeffs : list
|
||||
For more information, see the detailed description in `wavedecn`.
|
||||
|
||||
See Also
|
||||
--------
|
||||
imran, swtn
|
||||
|
||||
Notes
|
||||
-----
|
||||
This is sometimes referred to as an additive decomposition because the
|
||||
inverse transform (``imran``) is just the sum of the coefficient arrays
|
||||
[1]_. The decomposition using ``transform='dwt'`` corresponds to section
|
||||
2.2 while that using an undecimated transform (``transform='swt'``) is
|
||||
described in section 3.2 and appendix A.
|
||||
|
||||
This transform does not share the variance partition property of ``swtn``
|
||||
with `norm=True`. It does however, result in coefficients that are
|
||||
temporally aligned regardless of the symmetry of the wavelet used.
|
||||
|
||||
The redundancy of this transform is ``(2**n - 1) * level + 1`` where ``n``
|
||||
corresponds to the number of axes transformed.
|
||||
|
||||
References
|
||||
----------
|
||||
.. [1] Donald B. Percival and Harold O. Mofjeld. Analysis of Subtidal
|
||||
Coastal Sea Level Fluctuations Using Wavelets. Journal of the American
|
||||
Statistical Association Vol. 92, No. 439 (Sep., 1997), pp. 868-880.
|
||||
https://doi.org/10.2307/2965551
|
||||
"""
|
||||
axes, axes_shapes, ndim_transform = _prep_axes_wavedecn(data.shape, axes)
|
||||
wavelets = _wavelets_per_axis(wavelet, axes)
|
||||
|
||||
if transform == 'swtn':
|
||||
if mode != 'periodization':
|
||||
raise ValueError(
|
||||
"transform swt only supports mode='periodization'")
|
||||
if level is None:
|
||||
level = min(swt_max_level(s) for s in data.shape)
|
||||
kwargs = dict(wavelet=wavelets, axes=axes, norm=True)
|
||||
forward = partial(swtn, level=level, trim_approx=True, **kwargs)
|
||||
inverse = partial(iswtn, **kwargs)
|
||||
elif transform == 'dwtn':
|
||||
modes = _modes_per_axis(mode, axes)
|
||||
kwargs = dict(wavelet=wavelets, mode=modes, axes=axes)
|
||||
forward = partial(wavedecn, level=level, **kwargs)
|
||||
inverse = partial(waverecn, **kwargs)
|
||||
else:
|
||||
raise ValueError("unrecognized transform: {}".format(transform))
|
||||
|
||||
wav_coeffs = forward(data)
|
||||
|
||||
mra_coeffs = []
|
||||
nc = len(wav_coeffs)
|
||||
z = np.zeros_like(wav_coeffs[0])
|
||||
tmp = [z]
|
||||
for j in range(1, nc):
|
||||
tmp.append({k: np.zeros_like(v) for k, v in wav_coeffs[j].items()})
|
||||
|
||||
# tmp has arrays of zeros except for the jth entry
|
||||
tmp[0] = wav_coeffs[0]
|
||||
# reconstruct
|
||||
rec = inverse(tmp)
|
||||
if rec.shape != data.shape:
|
||||
# trim any excess coefficients
|
||||
rec = rec[tuple([slice(sz) for sz in data.shape])]
|
||||
mra_coeffs.append(rec)
|
||||
# restore zeros
|
||||
tmp[0] = z
|
||||
|
||||
for j in range(1, nc):
|
||||
dcoeffs = {}
|
||||
dkeys = list(wav_coeffs[j].keys())
|
||||
for k in dkeys:
|
||||
# tmp has arrays of zeros except for the jth entry
|
||||
z = tmp[j][k]
|
||||
tmp[j][k] = wav_coeffs[j][k]
|
||||
# tmp[j]['a' * len(k)] = z
|
||||
# reconstruct
|
||||
rec = inverse(tmp)
|
||||
if rec.shape != data.shape:
|
||||
# trim any excess coefficients
|
||||
rec = rec[tuple([slice(sz) for sz in data.shape])]
|
||||
dcoeffs[k] = rec
|
||||
# restore zeros
|
||||
tmp[j][k] = z
|
||||
# tmp[j].pop('a' * len(k))
|
||||
mra_coeffs.append(dcoeffs)
|
||||
return mra_coeffs
|
||||
|
||||
|
||||
def imran(mra_coeffs):
|
||||
"""Inverse nD multiresolution analysis via summation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mra_coeffs : list
|
||||
Multiresolution analysis coefficients as returned by `mra2`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
rec : ndarray
|
||||
The reconstructed signal.
|
||||
|
||||
See Also
|
||||
--------
|
||||
mran
|
||||
|
||||
References
|
||||
----------
|
||||
.. [1] Donald B. Percival and Harold O. Mofjeld. Analysis of Subtidal
|
||||
Coastal Sea Level Fluctuations Using Wavelets. Journal of the American
|
||||
Statistical Association Vol. 92, No. 439 (Sep., 1997), pp. 868-880.
|
||||
https://doi.org/10.2307/2965551
|
||||
"""
|
||||
rec = mra_coeffs[0]
|
||||
for j in range(1, len(mra_coeffs)):
|
||||
for k, v in mra_coeffs[j].items():
|
||||
rec += v
|
||||
return rec
|
||||
314
.CondaPkg/env/Lib/site-packages/pywt/_multidim.py
vendored
Normal file
314
.CondaPkg/env/Lib/site-packages/pywt/_multidim.py
vendored
Normal file
@@ -0,0 +1,314 @@
|
||||
# Copyright (c) 2006-2012 Filip Wasilewski <http://en.ig.ma/>
|
||||
# Copyright (c) 2012-2016 The PyWavelets Developers
|
||||
# <https://github.com/PyWavelets/pywt>
|
||||
# See COPYING for license details.
|
||||
|
||||
"""
|
||||
2D and nD Discrete Wavelet Transforms and Inverse Discrete Wavelet Transforms.
|
||||
"""
|
||||
|
||||
from __future__ import division, print_function, absolute_import
|
||||
|
||||
from itertools import product
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ._c99_config import _have_c99_complex
|
||||
from ._extensions._dwt import dwt_axis, idwt_axis
|
||||
from ._utils import _wavelets_per_axis, _modes_per_axis
|
||||
|
||||
|
||||
__all__ = ['dwt2', 'idwt2', 'dwtn', 'idwtn']
|
||||
|
||||
|
||||
def dwt2(data, wavelet, mode='symmetric', axes=(-2, -1)):
|
||||
"""
|
||||
2D Discrete Wavelet Transform.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : array_like
|
||||
2D array with input data
|
||||
wavelet : Wavelet object or name string, or 2-tuple of wavelets
|
||||
Wavelet to use. This can also be a tuple containing a wavelet to
|
||||
apply along each axis in ``axes``.
|
||||
mode : str or 2-tuple of strings, optional
|
||||
Signal extension mode, see :ref:`Modes <ref-modes>`. This can
|
||||
also be a tuple of modes specifying the mode to use on each axis in
|
||||
``axes``.
|
||||
axes : 2-tuple of ints, optional
|
||||
Axes over which to compute the DWT. Repeated elements mean the DWT will
|
||||
be performed multiple times along these axes.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(cA, (cH, cV, cD)) : tuple
|
||||
Approximation, horizontal detail, vertical detail and diagonal
|
||||
detail coefficients respectively. Horizontal refers to array axis 0
|
||||
(or ``axes[0]`` for user-specified ``axes``).
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import numpy as np
|
||||
>>> import pywt
|
||||
>>> data = np.ones((4,4), dtype=np.float64)
|
||||
>>> coeffs = pywt.dwt2(data, 'haar')
|
||||
>>> cA, (cH, cV, cD) = coeffs
|
||||
>>> cA
|
||||
array([[ 2., 2.],
|
||||
[ 2., 2.]])
|
||||
>>> cV
|
||||
array([[ 0., 0.],
|
||||
[ 0., 0.]])
|
||||
|
||||
"""
|
||||
axes = tuple(axes)
|
||||
data = np.asarray(data)
|
||||
if len(axes) != 2:
|
||||
raise ValueError("Expected 2 axes")
|
||||
if data.ndim < len(np.unique(axes)):
|
||||
raise ValueError("Input array has fewer dimensions than the specified "
|
||||
"axes")
|
||||
|
||||
coefs = dwtn(data, wavelet, mode, axes)
|
||||
return coefs['aa'], (coefs['da'], coefs['ad'], coefs['dd'])
|
||||
|
||||
|
||||
def idwt2(coeffs, wavelet, mode='symmetric', axes=(-2, -1)):
|
||||
"""
|
||||
2-D Inverse Discrete Wavelet Transform.
|
||||
|
||||
Reconstructs data from coefficient arrays.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
coeffs : tuple
|
||||
(cA, (cH, cV, cD)) A tuple with approximation coefficients and three
|
||||
details coefficients 2D arrays like from ``dwt2``. If any of these
|
||||
components are set to ``None``, it will be treated as zeros.
|
||||
wavelet : Wavelet object or name string, or 2-tuple of wavelets
|
||||
Wavelet to use. This can also be a tuple containing a wavelet to
|
||||
apply along each axis in ``axes``.
|
||||
mode : str or 2-tuple of strings, optional
|
||||
Signal extension mode, see :ref:`Modes <ref-modes>`. This can
|
||||
also be a tuple of modes specifying the mode to use on each axis in
|
||||
``axes``.
|
||||
axes : 2-tuple of ints, optional
|
||||
Axes over which to compute the IDWT. Repeated elements mean the IDWT
|
||||
will be performed multiple times along these axes.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import numpy as np
|
||||
>>> import pywt
|
||||
>>> data = np.array([[1,2], [3,4]], dtype=np.float64)
|
||||
>>> coeffs = pywt.dwt2(data, 'haar')
|
||||
>>> pywt.idwt2(coeffs, 'haar')
|
||||
array([[ 1., 2.],
|
||||
[ 3., 4.]])
|
||||
|
||||
"""
|
||||
# L -low-pass data, H - high-pass data
|
||||
LL, (HL, LH, HH) = coeffs
|
||||
axes = tuple(axes)
|
||||
if len(axes) != 2:
|
||||
raise ValueError("Expected 2 axes")
|
||||
|
||||
coeffs = {'aa': LL, 'da': HL, 'ad': LH, 'dd': HH}
|
||||
return idwtn(coeffs, wavelet, mode, axes)
|
||||
|
||||
|
||||
def dwtn(data, wavelet, mode='symmetric', axes=None):
|
||||
"""
|
||||
Single-level n-dimensional Discrete Wavelet Transform.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : array_like
|
||||
n-dimensional array with input data.
|
||||
wavelet : Wavelet object or name string, or tuple of wavelets
|
||||
Wavelet to use. This can also be a tuple containing a wavelet to
|
||||
apply along each axis in ``axes``.
|
||||
mode : str or tuple of string, optional
|
||||
Signal extension mode used in the decomposition,
|
||||
see :ref:`Modes <ref-modes>`. This can also be a tuple of modes
|
||||
specifying the mode to use on each axis in ``axes``.
|
||||
axes : sequence of ints, optional
|
||||
Axes over which to compute the DWT. Repeated elements mean the DWT will
|
||||
be performed multiple times along these axes. A value of ``None`` (the
|
||||
default) selects all axes.
|
||||
|
||||
Axes may be repeated, but information about the original size may be
|
||||
lost if it is not divisible by ``2 ** nrepeats``. The reconstruction
|
||||
will be larger, with additional values derived according to the
|
||||
``mode`` parameter. ``pywt.wavedecn`` should be used for multilevel
|
||||
decomposition.
|
||||
|
||||
Returns
|
||||
-------
|
||||
coeffs : dict
|
||||
Results are arranged in a dictionary, where key specifies
|
||||
the transform type on each dimension and value is a n-dimensional
|
||||
coefficients array.
|
||||
|
||||
For example, for a 2D case the result will look something like this::
|
||||
|
||||
{'aa': <coeffs> # A(LL) - approx. on 1st dim, approx. on 2nd dim
|
||||
'ad': <coeffs> # V(LH) - approx. on 1st dim, det. on 2nd dim
|
||||
'da': <coeffs> # H(HL) - det. on 1st dim, approx. on 2nd dim
|
||||
'dd': <coeffs> # D(HH) - det. on 1st dim, det. on 2nd dim
|
||||
}
|
||||
|
||||
For user-specified ``axes``, the order of the characters in the
|
||||
dictionary keys map to the specified ``axes``.
|
||||
|
||||
"""
|
||||
data = np.asarray(data)
|
||||
if not _have_c99_complex and np.iscomplexobj(data):
|
||||
real = dwtn(data.real, wavelet, mode, axes)
|
||||
imag = dwtn(data.imag, wavelet, mode, axes)
|
||||
return dict((k, real[k] + 1j * imag[k]) for k in real.keys())
|
||||
|
||||
if data.dtype == np.dtype('object'):
|
||||
raise TypeError("Input must be a numeric array-like")
|
||||
if data.ndim < 1:
|
||||
raise ValueError("Input data must be at least 1D")
|
||||
|
||||
if axes is None:
|
||||
axes = range(data.ndim)
|
||||
axes = [a + data.ndim if a < 0 else a for a in axes]
|
||||
|
||||
modes = _modes_per_axis(mode, axes)
|
||||
wavelets = _wavelets_per_axis(wavelet, axes)
|
||||
|
||||
coeffs = [('', data)]
|
||||
for axis, wav, mode in zip(axes, wavelets, modes):
|
||||
new_coeffs = []
|
||||
for subband, x in coeffs:
|
||||
cA, cD = dwt_axis(x, wav, mode, axis)
|
||||
new_coeffs.extend([(subband + 'a', cA),
|
||||
(subband + 'd', cD)])
|
||||
coeffs = new_coeffs
|
||||
return dict(coeffs)
|
||||
|
||||
|
||||
def _fix_coeffs(coeffs):
|
||||
missing_keys = [k for k, v in coeffs.items() if v is None]
|
||||
if missing_keys:
|
||||
raise ValueError(
|
||||
"The following detail coefficients were set to None:\n"
|
||||
"{0}\n"
|
||||
"For multilevel transforms, rather than setting\n"
|
||||
"\tcoeffs[key] = None\n"
|
||||
"use\n"
|
||||
"\tcoeffs[key] = np.zeros_like(coeffs[key])\n".format(
|
||||
missing_keys))
|
||||
|
||||
invalid_keys = [k for k, v in coeffs.items() if
|
||||
not set(k) <= set('ad')]
|
||||
if invalid_keys:
|
||||
raise ValueError(
|
||||
"The following invalid keys were found in the detail "
|
||||
"coefficient dictionary: {}.".format(invalid_keys))
|
||||
|
||||
key_lengths = [len(k) for k in coeffs.keys()]
|
||||
if len(np.unique(key_lengths)) > 1:
|
||||
raise ValueError(
|
||||
"All detail coefficient names must have equal length.")
|
||||
|
||||
return dict((k, np.asarray(v)) for k, v in coeffs.items())
|
||||
|
||||
|
||||
def idwtn(coeffs, wavelet, mode='symmetric', axes=None):
|
||||
"""
|
||||
Single-level n-dimensional Inverse Discrete Wavelet Transform.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
coeffs: dict
|
||||
Dictionary as in output of ``dwtn``. Missing or ``None`` items
|
||||
will be treated as zeros.
|
||||
wavelet : Wavelet object or name string, or tuple of wavelets
|
||||
Wavelet to use. This can also be a tuple containing a wavelet to
|
||||
apply along each axis in ``axes``.
|
||||
mode : str or list of string, optional
|
||||
Signal extension mode used in the decomposition,
|
||||
see :ref:`Modes <ref-modes>`. This can also be a tuple of modes
|
||||
specifying the mode to use on each axis in ``axes``.
|
||||
axes : sequence of ints, optional
|
||||
Axes over which to compute the IDWT. Repeated elements mean the IDWT
|
||||
will be performed multiple times along these axes. A value of ``None``
|
||||
(the default) selects all axes.
|
||||
|
||||
For the most accurate reconstruction, the axes should be provided in
|
||||
the same order as they were provided to ``dwtn``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
data: ndarray
|
||||
Original signal reconstructed from input data.
|
||||
|
||||
"""
|
||||
|
||||
# drop the keys corresponding to value = None
|
||||
coeffs = dict((k, v) for k, v in coeffs.items() if v is not None)
|
||||
|
||||
# drop the keys corresponding to value = None
|
||||
coeffs = dict((k, v) for k, v in coeffs.items() if v is not None)
|
||||
|
||||
# Raise error for invalid key combinations
|
||||
coeffs = _fix_coeffs(coeffs)
|
||||
|
||||
if (not _have_c99_complex and
|
||||
any(np.iscomplexobj(v) for v in coeffs.values())):
|
||||
real_coeffs = dict((k, v.real) for k, v in coeffs.items())
|
||||
imag_coeffs = dict((k, v.imag) for k, v in coeffs.items())
|
||||
return (idwtn(real_coeffs, wavelet, mode, axes) +
|
||||
1j * idwtn(imag_coeffs, wavelet, mode, axes))
|
||||
|
||||
# key length matches the number of axes transformed
|
||||
ndim_transform = max(len(key) for key in coeffs.keys())
|
||||
|
||||
try:
|
||||
coeff_shapes = (v.shape for k, v in coeffs.items()
|
||||
if v is not None and len(k) == ndim_transform)
|
||||
coeff_shape = next(coeff_shapes)
|
||||
except StopIteration:
|
||||
raise ValueError("`coeffs` must contain at least one non-null wavelet "
|
||||
"band")
|
||||
if any(s != coeff_shape for s in coeff_shapes):
|
||||
raise ValueError("`coeffs` must all be of equal size (or None)")
|
||||
|
||||
if axes is None:
|
||||
axes = range(ndim_transform)
|
||||
ndim = ndim_transform
|
||||
else:
|
||||
ndim = len(coeff_shape)
|
||||
axes = [a + ndim if a < 0 else a for a in axes]
|
||||
|
||||
modes = _modes_per_axis(mode, axes)
|
||||
wavelets = _wavelets_per_axis(wavelet, axes)
|
||||
for key_length, (axis, wav, mode) in reversed(
|
||||
list(enumerate(zip(axes, wavelets, modes)))):
|
||||
if axis < 0 or axis >= ndim:
|
||||
raise np.AxisError("Axis greater than data dimensions")
|
||||
|
||||
new_coeffs = {}
|
||||
new_keys = [''.join(coef) for coef in product('ad', repeat=key_length)]
|
||||
|
||||
for key in new_keys:
|
||||
L = coeffs.get(key + 'a', None)
|
||||
H = coeffs.get(key + 'd', None)
|
||||
if L is not None and H is not None:
|
||||
if L.dtype != H.dtype:
|
||||
# upcast to a common dtype (float64 or complex128)
|
||||
if L.dtype.kind == 'c' or H.dtype.kind == 'c':
|
||||
dtype = np.complex128
|
||||
else:
|
||||
dtype = np.float64
|
||||
L = np.asarray(L, dtype=dtype)
|
||||
H = np.asarray(H, dtype=dtype)
|
||||
new_coeffs[key] = idwt_axis(L, H, wav, mode, axis)
|
||||
coeffs = new_coeffs
|
||||
|
||||
return coeffs['']
|
||||
1561
.CondaPkg/env/Lib/site-packages/pywt/_multilevel.py
vendored
Normal file
1561
.CondaPkg/env/Lib/site-packages/pywt/_multilevel.py
vendored
Normal file
File diff suppressed because it is too large
Load Diff
68
.CondaPkg/env/Lib/site-packages/pywt/_pytest.py
vendored
Normal file
68
.CondaPkg/env/Lib/site-packages/pywt/_pytest.py
vendored
Normal file
@@ -0,0 +1,68 @@
|
||||
"""common test-related code."""
|
||||
import os
|
||||
import sys
|
||||
import multiprocessing
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
|
||||
__all__ = ['uses_matlab', # skip if pymatbridge and Matlab unavailable
|
||||
'uses_futures', # skip if futures unavailable
|
||||
'uses_pymatbridge', # skip if no PYWT_XSLOW environment variable
|
||||
'uses_precomputed', # skip if PYWT_XSLOW environment variable found
|
||||
'matlab_result_dict_cwt', # dict with precomputed Matlab dwt data
|
||||
'matlab_result_dict_dwt', # dict with precomputed Matlab cwt data
|
||||
'futures', # the futures module or None
|
||||
'max_workers', # the number of workers available to futures
|
||||
'size_set', # the set of Matlab tests to run
|
||||
]
|
||||
|
||||
try:
|
||||
if sys.version_info[0] == 2:
|
||||
import futures
|
||||
else:
|
||||
from concurrent import futures
|
||||
max_workers = multiprocessing.cpu_count()
|
||||
futures_available = True
|
||||
except ImportError:
|
||||
futures_available = False
|
||||
futures = None
|
||||
|
||||
# check if pymatbridge + MATLAB tests should be run
|
||||
matlab_result_dict_dwt = None
|
||||
matlab_result_dict_cwt = None
|
||||
matlab_missing = True
|
||||
use_precomputed = True
|
||||
size_set = 'reduced'
|
||||
if 'PYWT_XSLOW' in os.environ:
|
||||
try:
|
||||
from pymatbridge import Matlab
|
||||
mlab = Matlab()
|
||||
matlab_missing = False
|
||||
use_precomputed = False
|
||||
size_set = 'full'
|
||||
except ImportError:
|
||||
print("To run Matlab compatibility tests you need to have MathWorks "
|
||||
"MATLAB, MathWorks Wavelet Toolbox and the pymatbridge Python "
|
||||
"package installed.")
|
||||
if use_precomputed:
|
||||
# load dictionaries of precomputed results
|
||||
data_dir = os.path.join(os.path.dirname(__file__), 'tests', 'data')
|
||||
matlab_data_file_cwt = os.path.join(
|
||||
data_dir, 'cwt_matlabR2015b_result.npz')
|
||||
matlab_result_dict_cwt = np.load(matlab_data_file_cwt)
|
||||
|
||||
matlab_data_file_dwt = os.path.join(
|
||||
data_dir, 'dwt_matlabR2012a_result.npz')
|
||||
matlab_result_dict_dwt = np.load(matlab_data_file_dwt)
|
||||
|
||||
uses_futures = pytest.mark.skipif(
|
||||
not futures_available, reason='futures not available')
|
||||
uses_matlab = pytest.mark.skipif(
|
||||
matlab_missing, reason='pymatbridge and/or Matlab not available')
|
||||
uses_pymatbridge = pytest.mark.skipif(
|
||||
use_precomputed,
|
||||
reason='PYWT_XSLOW set: skipping tests against precomputed Matlab results')
|
||||
uses_precomputed = pytest.mark.skipif(
|
||||
not use_precomputed,
|
||||
reason='PYWT_XSLOW not set: test against precomputed matlab tests')
|
||||
164
.CondaPkg/env/Lib/site-packages/pywt/_pytesttester.py
vendored
Normal file
164
.CondaPkg/env/Lib/site-packages/pywt/_pytesttester.py
vendored
Normal file
@@ -0,0 +1,164 @@
|
||||
"""
|
||||
Pytest test running.
|
||||
|
||||
This module implements the ``test()`` function for NumPy modules. The usual
|
||||
boiler plate for doing that is to put the following in the module
|
||||
``__init__.py`` file::
|
||||
|
||||
from pywt._pytesttester import PytestTester
|
||||
test = PytestTester(__name__).test
|
||||
del PytestTester
|
||||
|
||||
|
||||
Warnings filtering and other runtime settings should be dealt with in the
|
||||
``pytest.ini`` file in the pywt repo root. The behavior of the test depends on
|
||||
whether or not that file is found as follows:
|
||||
|
||||
* ``pytest.ini`` is present (develop mode)
|
||||
All warnings except those explicily filtered out are raised as error.
|
||||
* ``pytest.ini`` is absent (release mode)
|
||||
DeprecationWarnings and PendingDeprecationWarnings are ignored, other
|
||||
warnings are passed through.
|
||||
|
||||
In practice, tests run from the PyWavelets repo are run in develop mode. That
|
||||
includes the standard ``python runtests.py`` invocation.
|
||||
|
||||
"""
|
||||
from __future__ import division, absolute_import, print_function
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
__all__ = ['PytestTester']
|
||||
|
||||
|
||||
def _show_pywt_info():
|
||||
import pywt
|
||||
from pywt._c99_config import _have_c99_complex
|
||||
print("PyWavelets version %s" % pywt.__version__)
|
||||
if _have_c99_complex:
|
||||
print("Compiled with C99 complex support.")
|
||||
else:
|
||||
print("Compiled without C99 complex support.")
|
||||
|
||||
|
||||
class PytestTester(object):
|
||||
"""
|
||||
Pytest test runner.
|
||||
|
||||
This class is made available in ``pywt.testing``, and a test function
|
||||
is typically added to a package's __init__.py like so::
|
||||
|
||||
from pywt.testing import PytestTester
|
||||
test = PytestTester(__name__).test
|
||||
del PytestTester
|
||||
|
||||
Calling this test function finds and runs all tests associated with the
|
||||
module and all its sub-modules.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
module_name : str
|
||||
Full path to the package to test.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
module_name : module name
|
||||
The name of the module to test.
|
||||
|
||||
"""
|
||||
def __init__(self, module_name):
|
||||
self.module_name = module_name
|
||||
|
||||
def __call__(self, label='fast', verbose=1, extra_argv=None,
|
||||
doctests=False, coverage=False, durations=-1, tests=None):
|
||||
"""
|
||||
Run tests for module using pytest.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
label : {'fast', 'full'}, optional
|
||||
Identifies the tests to run. When set to 'fast', tests decorated
|
||||
with `pytest.mark.slow` are skipped, when 'full', the slow marker
|
||||
is ignored.
|
||||
verbose : int, optional
|
||||
Verbosity value for test outputs, in the range 1-3. Default is 1.
|
||||
extra_argv : list, optional
|
||||
List with any extra arguments to pass to pytests.
|
||||
doctests : bool, optional
|
||||
.. note:: Not supported
|
||||
coverage : bool, optional
|
||||
If True, report coverage of NumPy code. Default is False.
|
||||
Requires installation of (pip) pytest-cov.
|
||||
durations : int, optional
|
||||
If < 0, do nothing, If 0, report time of all tests, if > 0,
|
||||
report the time of the slowest `timer` tests. Default is -1.
|
||||
tests : test or list of tests
|
||||
Tests to be executed with pytest '--pyargs'
|
||||
|
||||
Returns
|
||||
-------
|
||||
result : bool
|
||||
Return True on success, false otherwise.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> result = np.lib.test() #doctest: +SKIP
|
||||
...
|
||||
1023 passed, 2 skipped, 6 deselected, 1 xfailed in 10.39 seconds
|
||||
>>> result
|
||||
True
|
||||
|
||||
"""
|
||||
import pytest
|
||||
|
||||
module = sys.modules[self.module_name]
|
||||
module_path = os.path.abspath(module.__path__[0])
|
||||
|
||||
# setup the pytest arguments
|
||||
pytest_args = ["-l"]
|
||||
|
||||
# offset verbosity. The "-q" cancels a "-v".
|
||||
pytest_args += ["-q"]
|
||||
|
||||
# Filter out annoying import messages. Want these in both develop and
|
||||
# release mode.
|
||||
pytest_args += [
|
||||
"-W ignore:Not importing directory",
|
||||
"-W ignore:numpy.dtype size changed",
|
||||
"-W ignore:numpy.ufunc size changed", ]
|
||||
|
||||
if doctests:
|
||||
raise ValueError("Doctests not supported")
|
||||
|
||||
if extra_argv:
|
||||
pytest_args += list(extra_argv)
|
||||
|
||||
if verbose > 1:
|
||||
pytest_args += ["-" + "v"*(verbose - 1)]
|
||||
|
||||
if coverage:
|
||||
pytest_args += ["--cov=" + module_path]
|
||||
|
||||
if label == "fast":
|
||||
pytest_args += ["-m", "not slow"]
|
||||
elif label != "full":
|
||||
pytest_args += ["-m", label]
|
||||
|
||||
if durations >= 0:
|
||||
pytest_args += ["--durations=%s" % durations]
|
||||
|
||||
if tests is None:
|
||||
tests = [self.module_name]
|
||||
|
||||
pytest_args += ["--pyargs"] + list(tests)
|
||||
|
||||
# run tests.
|
||||
_show_pywt_info()
|
||||
|
||||
try:
|
||||
code = pytest.main(pytest_args)
|
||||
except SystemExit as exc:
|
||||
code = exc.code
|
||||
|
||||
return code == 0
|
||||
824
.CondaPkg/env/Lib/site-packages/pywt/_swt.py
vendored
Normal file
824
.CondaPkg/env/Lib/site-packages/pywt/_swt.py
vendored
Normal file
@@ -0,0 +1,824 @@
|
||||
import warnings
|
||||
from itertools import product
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ._c99_config import _have_c99_complex
|
||||
from ._extensions._dwt import idwt_single
|
||||
from ._extensions._swt import swt_max_level, swt as _swt, swt_axis as _swt_axis
|
||||
from ._extensions._pywt import Wavelet, Modes, _check_dtype
|
||||
from ._multidim import idwt2, idwtn
|
||||
from ._utils import _as_wavelet, _wavelets_per_axis
|
||||
|
||||
|
||||
__all__ = ["swt", "swt_max_level", 'iswt', 'swt2', 'iswt2', 'swtn', 'iswtn']
|
||||
|
||||
|
||||
def _rescale_wavelet_filterbank(wavelet, sf):
|
||||
wav = Wavelet(wavelet.name + 'r',
|
||||
[np.asarray(f) * sf for f in wavelet.filter_bank])
|
||||
|
||||
# copy attributes from the original wavelet
|
||||
wav.orthogonal = wavelet.orthogonal
|
||||
wav.biorthogonal = wavelet.biorthogonal
|
||||
return wav
|
||||
|
||||
|
||||
def swt(data, wavelet, level=None, start_level=0, axis=-1,
|
||||
trim_approx=False, norm=False):
|
||||
"""
|
||||
Multilevel 1D stationary wavelet transform.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data :
|
||||
Input signal
|
||||
wavelet :
|
||||
Wavelet to use (Wavelet object or name)
|
||||
level : int, optional
|
||||
The number of decomposition steps to perform.
|
||||
start_level : int, optional
|
||||
The level at which the decomposition will begin (it allows one to
|
||||
skip a given number of transform steps and compute
|
||||
coefficients starting from start_level) (default: 0)
|
||||
axis: int, optional
|
||||
Axis over which to compute the SWT. If not given, the
|
||||
last axis is used.
|
||||
trim_approx : bool, optional
|
||||
If True, approximation coefficients at the final level are retained.
|
||||
norm : bool, optional
|
||||
If True, transform is normalized so that the energy of the coefficients
|
||||
will be equal to the energy of ``data``. In other words,
|
||||
``np.linalg.norm(data.ravel())`` will equal the norm of the
|
||||
concatenated transform coefficients when ``trim_approx`` is True.
|
||||
|
||||
Returns
|
||||
-------
|
||||
coeffs : list
|
||||
List of approximation and details coefficients pairs in order
|
||||
similar to wavedec function::
|
||||
|
||||
[(cAn, cDn), ..., (cA2, cD2), (cA1, cD1)]
|
||||
|
||||
where n equals input parameter ``level``.
|
||||
|
||||
If ``start_level = m`` is given, then the beginning m steps are
|
||||
skipped::
|
||||
|
||||
[(cAm+n, cDm+n), ..., (cAm+1, cDm+1), (cAm, cDm)]
|
||||
|
||||
If ``trim_approx`` is ``True``, then the output list is exactly as in
|
||||
``pywt.wavedec``, where the first coefficient in the list is the
|
||||
approximation coefficient at the final level and the rest are the
|
||||
detail coefficients::
|
||||
|
||||
[cAn, cDn, ..., cD2, cD1]
|
||||
|
||||
Notes
|
||||
-----
|
||||
The implementation here follows the "algorithm a-trous" and requires that
|
||||
the signal length along the transformed axis be a multiple of ``2**level``.
|
||||
If this is not the case, the user should pad up to an appropriate size
|
||||
using a function such as ``numpy.pad``.
|
||||
|
||||
A primary benefit of this transform in comparison to its decimated
|
||||
counterpart (``pywt.wavedecn``), is that it is shift-invariant. This comes
|
||||
at cost of redundancy in the transform (the size of the output coefficients
|
||||
is larger than the input).
|
||||
|
||||
When the following three conditions are true:
|
||||
|
||||
1. The wavelet is orthogonal
|
||||
2. ``swt`` is called with ``norm=True``
|
||||
3. ``swt`` is called with ``trim_approx=True``
|
||||
|
||||
the transform has the following additional properties that may be
|
||||
desirable in applications:
|
||||
|
||||
1. energy is conserved
|
||||
2. variance is partitioned across scales
|
||||
|
||||
When used with ``norm=True``, this transform is closely related to the
|
||||
multiple-overlap DWT (MODWT) as popularized for time-series analysis,
|
||||
although the underlying implementation is slightly different from the one
|
||||
published in [1]_. Specifically, the implementation used here requires a
|
||||
signal that is a multiple of ``2**level`` in length.
|
||||
|
||||
References
|
||||
----------
|
||||
.. [1] DB Percival and AT Walden. Wavelet Methods for Time Series Analysis.
|
||||
Cambridge University Press, 2000.
|
||||
"""
|
||||
|
||||
if not _have_c99_complex and np.iscomplexobj(data):
|
||||
data = np.asarray(data)
|
||||
kwargs = dict(wavelet=wavelet, level=level, start_level=start_level,
|
||||
trim_approx=trim_approx, axis=axis, norm=norm)
|
||||
coeffs_real = swt(data.real, **kwargs)
|
||||
coeffs_imag = swt(data.imag, **kwargs)
|
||||
if not trim_approx:
|
||||
coeffs_cplx = []
|
||||
for (cA_r, cD_r), (cA_i, cD_i) in zip(coeffs_real, coeffs_imag):
|
||||
coeffs_cplx.append((cA_r + 1j*cA_i, cD_r + 1j*cD_i))
|
||||
else:
|
||||
coeffs_cplx = [cr + 1j*ci
|
||||
for (cr, ci) in zip(coeffs_real, coeffs_imag)]
|
||||
return coeffs_cplx
|
||||
|
||||
# accept array_like input; make a copy to ensure a contiguous array
|
||||
dt = _check_dtype(data)
|
||||
data = np.array(data, dtype=dt)
|
||||
|
||||
wavelet = _as_wavelet(wavelet)
|
||||
if norm:
|
||||
if not wavelet.orthogonal:
|
||||
warnings.warn(
|
||||
"norm=True, but the wavelet is not orthogonal: \n"
|
||||
"\tThe conditions for energy preservation are not satisfied.")
|
||||
wavelet = _rescale_wavelet_filterbank(wavelet, 1/np.sqrt(2))
|
||||
|
||||
if axis < 0:
|
||||
axis = axis + data.ndim
|
||||
if not 0 <= axis < data.ndim:
|
||||
raise np.AxisError("Axis greater than data dimensions")
|
||||
|
||||
if level is None:
|
||||
level = swt_max_level(data.shape[axis])
|
||||
|
||||
if data.ndim == 1:
|
||||
ret = _swt(data, wavelet, level, start_level, trim_approx)
|
||||
else:
|
||||
ret = _swt_axis(data, wavelet, level, start_level, axis, trim_approx)
|
||||
return ret
|
||||
|
||||
|
||||
def iswt(coeffs, wavelet, norm=False, axis=-1):
|
||||
"""
|
||||
Multilevel 1D inverse discrete stationary wavelet transform.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
coeffs : array_like
|
||||
Coefficients list of tuples::
|
||||
|
||||
[(cAn, cDn), ..., (cA2, cD2), (cA1, cD1)]
|
||||
|
||||
where cA is approximation, cD is details. Index 1 corresponds to
|
||||
``start_level`` from ``pywt.swt``.
|
||||
wavelet : Wavelet object or name string
|
||||
Wavelet to use
|
||||
norm : bool, optional
|
||||
Controls the normalization used by the inverse transform. This must
|
||||
be set equal to the value that was used by ``pywt.swt`` to preserve the
|
||||
energy of a round-trip transform.
|
||||
|
||||
Returns
|
||||
-------
|
||||
1D array of reconstructed data.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import pywt
|
||||
>>> coeffs = pywt.swt([1,2,3,4,5,6,7,8], 'db2', level=2)
|
||||
>>> pywt.iswt(coeffs, 'db2')
|
||||
array([ 1., 2., 3., 4., 5., 6., 7., 8.])
|
||||
"""
|
||||
# copy to avoid modification of input data
|
||||
# If swt was called with trim_approx=False, first element is a tuple
|
||||
trim_approx = not isinstance(coeffs[0], (tuple, list))
|
||||
cA = coeffs[0] if trim_approx else coeffs[0][0]
|
||||
if cA.ndim > 1:
|
||||
# convert to swtn coefficient format and call iswtn
|
||||
if trim_approx:
|
||||
coeffs_nd = [cA] + [{'d': d} for d in coeffs[1:]]
|
||||
else:
|
||||
coeffs_nd = [{'a': a, 'd': d} for a, d in coeffs]
|
||||
return iswtn(coeffs_nd, wavelet, axes=(axis,), norm=norm)
|
||||
elif axis != 0 and axis != -1:
|
||||
raise np.AxisError("Axis greater than data dimensions")
|
||||
if not _have_c99_complex and np.iscomplexobj(cA):
|
||||
if trim_approx:
|
||||
coeffs_real = [c.real for c in coeffs]
|
||||
coeffs_imag = [c.imag for c in coeffs]
|
||||
else:
|
||||
coeffs_real = [(ca.real, cd.real) for ca, cd in coeffs]
|
||||
coeffs_imag = [(ca.imag, cd.imag) for ca, cd in coeffs]
|
||||
kwargs = dict(wavelet=wavelet, norm=norm)
|
||||
y = iswt(coeffs_real, **kwargs)
|
||||
return y + 1j * iswt(coeffs_imag, **kwargs)
|
||||
|
||||
if trim_approx:
|
||||
coeffs = coeffs[1:]
|
||||
|
||||
if cA.ndim != 1:
|
||||
raise ValueError("iswt only supports 1D data")
|
||||
|
||||
dt = _check_dtype(cA)
|
||||
output = np.array(cA, dtype=dt, copy=True)
|
||||
|
||||
# num_levels, equivalent to the decomposition level, n
|
||||
num_levels = len(coeffs)
|
||||
wavelet = _as_wavelet(wavelet)
|
||||
if norm:
|
||||
wavelet = _rescale_wavelet_filterbank(wavelet, np.sqrt(2))
|
||||
mode = Modes.from_object('periodization')
|
||||
for j in range(num_levels, 0, -1):
|
||||
step_size = int(pow(2, j-1))
|
||||
last_index = step_size
|
||||
if trim_approx:
|
||||
cD = coeffs[-j]
|
||||
else:
|
||||
_, cD = coeffs[-j]
|
||||
cD = np.asarray(cD, dtype=_check_dtype(cD))
|
||||
if cD.dtype != output.dtype:
|
||||
# upcast to a common dtype (float64 or complex128)
|
||||
if output.dtype.kind == 'c' or cD.dtype.kind == 'c':
|
||||
dtype = np.complex128
|
||||
else:
|
||||
dtype = np.float64
|
||||
output = np.asarray(output, dtype=dtype)
|
||||
cD = np.asarray(cD, dtype=dtype)
|
||||
for first in range(last_index): # 0 to last_index - 1
|
||||
|
||||
# Getting the indices that we will transform
|
||||
indices = np.arange(first, len(cD), step_size)
|
||||
|
||||
# select the even indices
|
||||
even_indices = indices[0::2]
|
||||
# select the odd indices
|
||||
odd_indices = indices[1::2]
|
||||
|
||||
# perform the inverse dwt on the selected indices,
|
||||
# making sure to use periodic boundary conditions
|
||||
# Note: indexing with an array of ints returns a contiguous
|
||||
# copy as required by idwt_single.
|
||||
x1 = idwt_single(output[even_indices],
|
||||
cD[even_indices],
|
||||
wavelet, mode)
|
||||
x2 = idwt_single(output[odd_indices],
|
||||
cD[odd_indices],
|
||||
wavelet, mode)
|
||||
|
||||
# perform a circular shift right
|
||||
x2 = np.roll(x2, 1)
|
||||
|
||||
# average and insert into the correct indices
|
||||
output[indices] = (x1 + x2)/2.
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def swt2(data, wavelet, level, start_level=0, axes=(-2, -1),
|
||||
trim_approx=False, norm=False):
|
||||
"""
|
||||
Multilevel 2D stationary wavelet transform.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : array_like
|
||||
2D array with input data
|
||||
wavelet : Wavelet object or name string, or 2-tuple of wavelets
|
||||
Wavelet to use. This can also be a tuple of wavelets to apply per
|
||||
axis in ``axes``.
|
||||
level : int
|
||||
The number of decomposition steps to perform.
|
||||
start_level : int, optional
|
||||
The level at which the decomposition will start (default: 0)
|
||||
axes : 2-tuple of ints, optional
|
||||
Axes over which to compute the SWT. Repeated elements are not allowed.
|
||||
trim_approx : bool, optional
|
||||
If True, approximation coefficients at the final level are retained.
|
||||
norm : bool, optional
|
||||
If True, transform is normalized so that the energy of the coefficients
|
||||
will be equal to the energy of ``data``. In other words,
|
||||
``np.linalg.norm(data.ravel())`` will equal the norm of the
|
||||
concatenated transform coefficients when ``trim_approx`` is True.
|
||||
|
||||
Returns
|
||||
-------
|
||||
coeffs : list
|
||||
Approximation and details coefficients (for ``start_level = m``).
|
||||
If ``trim_approx`` is ``False``, approximation coefficients are
|
||||
retained for all levels::
|
||||
|
||||
[
|
||||
(cA_m+level,
|
||||
(cH_m+level, cV_m+level, cD_m+level)
|
||||
),
|
||||
...,
|
||||
(cA_m+1,
|
||||
(cH_m+1, cV_m+1, cD_m+1)
|
||||
),
|
||||
(cA_m,
|
||||
(cH_m, cV_m, cD_m)
|
||||
)
|
||||
]
|
||||
|
||||
where cA is approximation, cH is horizontal details, cV is
|
||||
vertical details, cD is diagonal details and m is ``start_level``.
|
||||
|
||||
If ``trim_approx`` is ``True``, approximation coefficients are only
|
||||
retained at the final level of decomposition. This matches the format
|
||||
used by ``pywt.wavedec2``::
|
||||
|
||||
[
|
||||
cA_m+level,
|
||||
(cH_m+level, cV_m+level, cD_m+level),
|
||||
...,
|
||||
(cH_m+1, cV_m+1, cD_m+1),
|
||||
(cH_m, cV_m, cD_m),
|
||||
]
|
||||
|
||||
Notes
|
||||
-----
|
||||
The implementation here follows the "algorithm a-trous" and requires that
|
||||
the signal length along the transformed axes be a multiple of ``2**level``.
|
||||
If this is not the case, the user should pad up to an appropriate size
|
||||
using a function such as ``numpy.pad``.
|
||||
|
||||
A primary benefit of this transform in comparison to its decimated
|
||||
counterpart (``pywt.wavedecn``), is that it is shift-invariant. This comes
|
||||
at cost of redundancy in the transform (the size of the output coefficients
|
||||
is larger than the input).
|
||||
|
||||
When the following three conditions are true:
|
||||
|
||||
1. The wavelet is orthogonal
|
||||
2. ``swt2`` is called with ``norm=True``
|
||||
3. ``swt2`` is called with ``trim_approx=True``
|
||||
|
||||
the transform has the following additional properties that may be
|
||||
desirable in applications:
|
||||
|
||||
1. energy is conserved
|
||||
2. variance is partitioned across scales
|
||||
|
||||
"""
|
||||
axes = tuple(axes)
|
||||
data = np.asarray(data)
|
||||
if len(axes) != 2:
|
||||
raise ValueError("Expected 2 axes")
|
||||
if len(axes) != len(set(axes)):
|
||||
raise ValueError("The axes passed to swt2 must be unique.")
|
||||
if data.ndim < len(np.unique(axes)):
|
||||
raise ValueError("Input array has fewer dimensions than the specified "
|
||||
"axes")
|
||||
|
||||
coefs = swtn(data, wavelet, level, start_level, axes, trim_approx, norm)
|
||||
ret = []
|
||||
if trim_approx:
|
||||
ret.append(coefs[0])
|
||||
coefs = coefs[1:]
|
||||
for c in coefs:
|
||||
if trim_approx:
|
||||
ret.append((c['da'], c['ad'], c['dd']))
|
||||
else:
|
||||
ret.append((c['aa'], (c['da'], c['ad'], c['dd'])))
|
||||
return ret
|
||||
|
||||
|
||||
def iswt2(coeffs, wavelet, norm=False, axes=(-2, -1)):
|
||||
"""
|
||||
Multilevel 2D inverse discrete stationary wavelet transform.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
coeffs : list
|
||||
Approximation and details coefficients::
|
||||
|
||||
[
|
||||
(cA_n,
|
||||
(cH_n, cV_n, cD_n)
|
||||
),
|
||||
...,
|
||||
(cA_2,
|
||||
(cH_2, cV_2, cD_2)
|
||||
),
|
||||
(cA_1,
|
||||
(cH_1, cV_1, cD_1)
|
||||
)
|
||||
]
|
||||
|
||||
where cA is approximation, cH is horizontal details, cV is
|
||||
vertical details, cD is diagonal details and n is the number of
|
||||
levels. Index 1 corresponds to ``start_level`` from ``pywt.swt2``.
|
||||
wavelet : Wavelet object or name string, or 2-tuple of wavelets
|
||||
Wavelet to use. This can also be a 2-tuple of wavelets to apply per
|
||||
axis.
|
||||
norm : bool, optional
|
||||
Controls the normalization used by the inverse transform. This must
|
||||
be set equal to the value that was used by ``pywt.swt2`` to preserve
|
||||
the energy of a round-trip transform.
|
||||
|
||||
Returns
|
||||
-------
|
||||
2D array of reconstructed data.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import pywt
|
||||
>>> coeffs = pywt.swt2([[1,2,3,4],[5,6,7,8],
|
||||
... [9,10,11,12],[13,14,15,16]],
|
||||
... 'db1', level=2)
|
||||
>>> pywt.iswt2(coeffs, 'db1')
|
||||
array([[ 1., 2., 3., 4.],
|
||||
[ 5., 6., 7., 8.],
|
||||
[ 9., 10., 11., 12.],
|
||||
[ 13., 14., 15., 16.]])
|
||||
|
||||
"""
|
||||
|
||||
# If swt was called with trim_approx=False, first element is a tuple
|
||||
trim_approx = not isinstance(coeffs[0], (tuple, list))
|
||||
cA = coeffs[0] if trim_approx else coeffs[0][0]
|
||||
if cA.ndim != 2 or axes != (-2, -1):
|
||||
# convert to swtn coefficient format and call iswtn instead
|
||||
if trim_approx:
|
||||
coeffs_nd = [cA] + [{'da': h, 'ad': v, 'dd': d}
|
||||
for h, v, d in coeffs[1:]]
|
||||
else:
|
||||
coeffs_nd = [{'aa': a, 'da': h, 'ad': v, 'dd': d}
|
||||
for a, (h, v, d) in coeffs]
|
||||
return iswtn(coeffs_nd, wavelet, axes=axes, norm=norm)
|
||||
if not _have_c99_complex and np.iscomplexobj(cA):
|
||||
if trim_approx:
|
||||
coeffs_real = [cA.real]
|
||||
coeffs_real += [(h.real, v.real, d.real) for h, v, d in coeffs[1:]]
|
||||
coeffs_imag = [cA.imag]
|
||||
coeffs_imag += [(h.imag, v.imag, d.imag) for h, v, d in coeffs[1:]]
|
||||
else:
|
||||
coeffs_real = [(a.real, (h.real, v.real, d.real))
|
||||
for a, (h, v, d) in coeffs]
|
||||
coeffs_imag = [(a.imag, (h.imag, v.imag, d.imag))
|
||||
for a, (h, v, d) in coeffs]
|
||||
kwargs = dict(wavelet=wavelet, norm=norm)
|
||||
y = iswt2(coeffs_real, **kwargs)
|
||||
return y + 1j * iswt2(coeffs_imag, **kwargs)
|
||||
|
||||
if trim_approx:
|
||||
coeffs = coeffs[1:]
|
||||
|
||||
# copy to avoid modification of input data
|
||||
dt = _check_dtype(cA)
|
||||
output = np.array(cA, dtype=dt, copy=True)
|
||||
|
||||
if output.ndim != 2:
|
||||
raise ValueError(
|
||||
"iswt2 only supports 2D arrays. see iswtn for a general "
|
||||
"n-dimensionsal ISWT")
|
||||
# num_levels, equivalent to the decomposition level, n
|
||||
num_levels = len(coeffs)
|
||||
wavelets = _wavelets_per_axis(wavelet, axes=(0, 1))
|
||||
if norm:
|
||||
wavelets = [_rescale_wavelet_filterbank(wav, np.sqrt(2))
|
||||
for wav in wavelets]
|
||||
|
||||
for j in range(num_levels):
|
||||
step_size = int(pow(2, num_levels-j-1))
|
||||
last_index = step_size
|
||||
if trim_approx:
|
||||
(cH, cV, cD) = coeffs[j]
|
||||
else:
|
||||
_, (cH, cV, cD) = coeffs[j]
|
||||
# We are going to assume cH, cV, and cD are of equal size
|
||||
if (cH.shape != cV.shape) or (cH.shape != cD.shape):
|
||||
raise RuntimeError(
|
||||
"Mismatch in shape of intermediate coefficient arrays")
|
||||
|
||||
# make sure output shares the common dtype
|
||||
# (conversion of dtype for individual coeffs is handled within idwt2 )
|
||||
common_dtype = np.result_type(*(
|
||||
[dt, ] + [_check_dtype(c) for c in [cH, cV, cD]]))
|
||||
if output.dtype != common_dtype:
|
||||
output = output.astype(common_dtype)
|
||||
|
||||
for first_h in range(last_index): # 0 to last_index - 1
|
||||
for first_w in range(last_index): # 0 to last_index - 1
|
||||
# Getting the indices that we will transform
|
||||
indices_h = slice(first_h, cH.shape[0], step_size)
|
||||
indices_w = slice(first_w, cH.shape[1], step_size)
|
||||
|
||||
even_idx_h = slice(first_h, cH.shape[0], 2*step_size)
|
||||
even_idx_w = slice(first_w, cH.shape[1], 2*step_size)
|
||||
odd_idx_h = slice(first_h + step_size, cH.shape[0], 2*step_size)
|
||||
odd_idx_w = slice(first_w + step_size, cH.shape[1], 2*step_size)
|
||||
|
||||
# perform the inverse dwt on the selected indices,
|
||||
# making sure to use periodic boundary conditions
|
||||
x1 = idwt2((output[even_idx_h, even_idx_w],
|
||||
(cH[even_idx_h, even_idx_w],
|
||||
cV[even_idx_h, even_idx_w],
|
||||
cD[even_idx_h, even_idx_w])),
|
||||
wavelets, 'periodization')
|
||||
x2 = idwt2((output[even_idx_h, odd_idx_w],
|
||||
(cH[even_idx_h, odd_idx_w],
|
||||
cV[even_idx_h, odd_idx_w],
|
||||
cD[even_idx_h, odd_idx_w])),
|
||||
wavelets, 'periodization')
|
||||
x3 = idwt2((output[odd_idx_h, even_idx_w],
|
||||
(cH[odd_idx_h, even_idx_w],
|
||||
cV[odd_idx_h, even_idx_w],
|
||||
cD[odd_idx_h, even_idx_w])),
|
||||
wavelets, 'periodization')
|
||||
x4 = idwt2((output[odd_idx_h, odd_idx_w],
|
||||
(cH[odd_idx_h, odd_idx_w],
|
||||
cV[odd_idx_h, odd_idx_w],
|
||||
cD[odd_idx_h, odd_idx_w])),
|
||||
wavelets, 'periodization')
|
||||
|
||||
# perform a circular shifts
|
||||
x2 = np.roll(x2, 1, axis=1)
|
||||
x3 = np.roll(x3, 1, axis=0)
|
||||
x4 = np.roll(x4, 1, axis=0)
|
||||
x4 = np.roll(x4, 1, axis=1)
|
||||
output[indices_h, indices_w] = (x1 + x2 + x3 + x4) / 4
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def swtn(data, wavelet, level, start_level=0, axes=None, trim_approx=False,
|
||||
norm=False):
|
||||
"""
|
||||
n-dimensional stationary wavelet transform.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : array_like
|
||||
n-dimensional array with input data.
|
||||
wavelet : Wavelet object or name string, or tuple of wavelets
|
||||
Wavelet to use. This can also be a tuple of wavelets to apply per
|
||||
axis in ``axes``.
|
||||
level : int
|
||||
The number of decomposition steps to perform.
|
||||
start_level : int, optional
|
||||
The level at which the decomposition will start (default: 0)
|
||||
axes : sequence of ints, optional
|
||||
Axes over which to compute the SWT. A value of ``None`` (the
|
||||
default) selects all axes. Axes may not be repeated.
|
||||
trim_approx : bool, optional
|
||||
If True, approximation coefficients at the final level are retained.
|
||||
norm : bool, optional
|
||||
If True, transform is normalized so that the energy of the coefficients
|
||||
will be equal to the energy of ``data``. In other words,
|
||||
``np.linalg.norm(data.ravel())`` will equal the norm of the
|
||||
concatenated transform coefficients when ``trim_approx`` is True.
|
||||
|
||||
Returns
|
||||
-------
|
||||
[{coeffs_level_n}, ..., {coeffs_level_1}]: list of dict
|
||||
Results for each level are arranged in a dictionary, where the key
|
||||
specifies the transform type on each dimension and value is a
|
||||
n-dimensional coefficients array.
|
||||
|
||||
For example, for a 2D case the result at a given level will look
|
||||
something like this::
|
||||
|
||||
{'aa': <coeffs> # A(LL) - approx. on 1st dim, approx. on 2nd dim
|
||||
'ad': <coeffs> # V(LH) - approx. on 1st dim, det. on 2nd dim
|
||||
'da': <coeffs> # H(HL) - det. on 1st dim, approx. on 2nd dim
|
||||
'dd': <coeffs> # D(HH) - det. on 1st dim, det. on 2nd dim
|
||||
}
|
||||
|
||||
For user-specified ``axes``, the order of the characters in the
|
||||
dictionary keys map to the specified ``axes``.
|
||||
|
||||
If ``trim_approx`` is ``True``, the first element of the list contains
|
||||
the array of approximation coefficients from the final level of
|
||||
decomposition, while the remaining coefficient dictionaries contain
|
||||
only detail coefficients. This matches the behavior of `pywt.wavedecn`.
|
||||
|
||||
Notes
|
||||
-----
|
||||
The implementation here follows the "algorithm a-trous" and requires that
|
||||
the signal length along the transformed axes be a multiple of ``2**level``.
|
||||
If this is not the case, the user should pad up to an appropriate size
|
||||
using a function such as ``numpy.pad``.
|
||||
|
||||
A primary benefit of this transform in comparison to its decimated
|
||||
counterpart (``pywt.wavedecn``), is that it is shift-invariant. This comes
|
||||
at cost of redundancy in the transform (the size of the output coefficients
|
||||
is larger than the input).
|
||||
|
||||
When the following three conditions are true:
|
||||
|
||||
1. The wavelet is orthogonal
|
||||
2. ``swtn`` is called with ``norm=True``
|
||||
3. ``swtn`` is called with ``trim_approx=True``
|
||||
|
||||
the transform has the following additional properties that may be
|
||||
desirable in applications:
|
||||
|
||||
1. energy is conserved
|
||||
2. variance is partitioned across scales
|
||||
|
||||
"""
|
||||
data = np.asarray(data)
|
||||
if not _have_c99_complex and np.iscomplexobj(data):
|
||||
kwargs = dict(wavelet=wavelet, level=level, start_level=start_level,
|
||||
trim_approx=trim_approx, axes=axes, norm=norm)
|
||||
real = swtn(data.real, **kwargs)
|
||||
imag = swtn(data.imag, **kwargs)
|
||||
if trim_approx:
|
||||
cplx = [real[0] + 1j * imag[0]]
|
||||
offset = 1
|
||||
else:
|
||||
cplx = []
|
||||
offset = 0
|
||||
for rdict, idict in zip(real[offset:], imag[offset:]):
|
||||
cplx.append(
|
||||
dict((k, rdict[k] + 1j * idict[k]) for k in rdict.keys()))
|
||||
return cplx
|
||||
|
||||
if data.dtype == np.dtype('object'):
|
||||
raise TypeError("Input must be a numeric array-like")
|
||||
if data.ndim < 1:
|
||||
raise ValueError("Input data must be at least 1D")
|
||||
|
||||
if axes is None:
|
||||
axes = range(data.ndim)
|
||||
axes = [a + data.ndim if a < 0 else a for a in axes]
|
||||
if any(a < 0 or a >= data.ndim for a in axes):
|
||||
raise np.AxisError("Axis greater than data dimensions")
|
||||
if len(axes) != len(set(axes)):
|
||||
raise ValueError("The axes passed to swtn must be unique.")
|
||||
num_axes = len(axes)
|
||||
|
||||
wavelets = _wavelets_per_axis(wavelet, axes)
|
||||
if norm:
|
||||
if not np.all([wav.orthogonal for wav in wavelets]):
|
||||
warnings.warn(
|
||||
"norm=True, but the wavelets used are not orthogonal: \n"
|
||||
"\tThe conditions for energy preservation are not satisfied.")
|
||||
wavelets = [_rescale_wavelet_filterbank(wav, 1/np.sqrt(2))
|
||||
for wav in wavelets]
|
||||
ret = []
|
||||
for i in range(start_level, start_level + level):
|
||||
coeffs = [('', data)]
|
||||
for axis, wavelet in zip(axes, wavelets):
|
||||
new_coeffs = []
|
||||
for subband, x in coeffs:
|
||||
cA, cD = _swt_axis(x, wavelet, level=1, start_level=i,
|
||||
axis=axis)[0]
|
||||
new_coeffs.extend([(subband + 'a', cA),
|
||||
(subband + 'd', cD)])
|
||||
coeffs = new_coeffs
|
||||
|
||||
coeffs = dict(coeffs)
|
||||
ret.append(coeffs)
|
||||
|
||||
# data for the next level is the approximation coeffs from this level
|
||||
data = coeffs['a' * num_axes]
|
||||
if trim_approx:
|
||||
coeffs.pop('a' * num_axes)
|
||||
if trim_approx:
|
||||
ret.append(data)
|
||||
ret.reverse()
|
||||
return ret
|
||||
|
||||
|
||||
def iswtn(coeffs, wavelet, axes=None, norm=False):
|
||||
"""
|
||||
Multilevel nD inverse discrete stationary wavelet transform.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
coeffs : list
|
||||
[{coeffs_level_n}, ..., {coeffs_level_1}]: list of dict
|
||||
wavelet : Wavelet object or name string, or tuple of wavelets
|
||||
Wavelet to use. This can also be a tuple of wavelets to apply per
|
||||
axis in ``axes``.
|
||||
axes : sequence of ints, optional
|
||||
Axes over which to compute the inverse SWT. Axes may not be repeated.
|
||||
The default is ``None``, which means transform all axes
|
||||
(``axes = range(data.ndim)``).
|
||||
norm : bool, optional
|
||||
Controls the normalization used by the inverse transform. This must
|
||||
be set equal to the value that was used by ``pywt.swtn`` to preserve
|
||||
the energy of a round-trip transform.
|
||||
|
||||
Returns
|
||||
-------
|
||||
nD array of reconstructed data.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import pywt
|
||||
>>> coeffs = pywt.swtn([[1,2,3,4],[5,6,7,8],
|
||||
... [9,10,11,12],[13,14,15,16]],
|
||||
... 'db1', level=2)
|
||||
>>> pywt.iswtn(coeffs, 'db1')
|
||||
array([[ 1., 2., 3., 4.],
|
||||
[ 5., 6., 7., 8.],
|
||||
[ 9., 10., 11., 12.],
|
||||
[ 13., 14., 15., 16.]])
|
||||
|
||||
"""
|
||||
|
||||
# key length matches the number of axes transformed
|
||||
ndim_transform = max(len(key) for key in coeffs[-1].keys())
|
||||
trim_approx = not isinstance(coeffs[0], dict)
|
||||
cA = coeffs[0] if trim_approx else coeffs[0]['a'*ndim_transform]
|
||||
|
||||
if not _have_c99_complex and np.iscomplexobj(cA):
|
||||
if trim_approx:
|
||||
coeffs_real = [coeffs[0].real]
|
||||
coeffs_imag = [coeffs[0].imag]
|
||||
coeffs = coeffs[1:]
|
||||
else:
|
||||
coeffs_real = []
|
||||
coeffs_imag = []
|
||||
coeffs_real += [{k: v.real for k, v in c.items()} for c in coeffs]
|
||||
coeffs_imag += [{k: v.imag for k, v in c.items()} for c in coeffs]
|
||||
kwargs = dict(wavelet=wavelet, axes=axes, norm=norm)
|
||||
y = iswtn(coeffs_real, **kwargs)
|
||||
return y + 1j * iswtn(coeffs_imag, **kwargs)
|
||||
|
||||
if trim_approx:
|
||||
coeffs = coeffs[1:]
|
||||
|
||||
# copy to avoid modification of input data
|
||||
dt = _check_dtype(cA)
|
||||
output = np.array(cA, dtype=dt, copy=True)
|
||||
ndim = output.ndim
|
||||
|
||||
if axes is None:
|
||||
axes = range(output.ndim)
|
||||
axes = [a + ndim if a < 0 else a for a in axes]
|
||||
if len(axes) != len(set(axes)):
|
||||
raise ValueError("The axes passed to swtn must be unique.")
|
||||
if ndim_transform != len(axes):
|
||||
raise ValueError("The number of axes used in iswtn must match the "
|
||||
"number of dimensions transformed in swtn.")
|
||||
|
||||
# num_levels, equivalent to the decomposition level, n
|
||||
num_levels = len(coeffs)
|
||||
wavelets = _wavelets_per_axis(wavelet, axes)
|
||||
if norm:
|
||||
wavelets = [_rescale_wavelet_filterbank(wav, np.sqrt(2))
|
||||
for wav in wavelets]
|
||||
|
||||
# initialize various slice objects used in the loops below
|
||||
# these will remain slice(None) only on axes that aren't transformed
|
||||
indices = [slice(None), ]*ndim
|
||||
even_indices = [slice(None), ]*ndim
|
||||
odd_indices = [slice(None), ]*ndim
|
||||
odd_even_slices = [slice(None), ]*ndim
|
||||
|
||||
for j in range(num_levels):
|
||||
step_size = int(pow(2, num_levels-j-1))
|
||||
last_index = step_size
|
||||
if not trim_approx:
|
||||
a = coeffs[j].pop('a'*ndim_transform) # will restore later
|
||||
details = coeffs[j]
|
||||
# make sure dtype matches the coarsest level approximation coefficients
|
||||
common_dtype = np.result_type(*(
|
||||
[dt, ] + [v.dtype for v in details.values()]))
|
||||
if output.dtype != common_dtype:
|
||||
output = output.astype(common_dtype)
|
||||
|
||||
# We assume all coefficient arrays are of equal size
|
||||
shapes = [v.shape for k, v in details.items()]
|
||||
if len(set(shapes)) != 1:
|
||||
raise RuntimeError(
|
||||
"Mismatch in shape of intermediate coefficient arrays")
|
||||
|
||||
# shape of a single coefficient array, excluding non-transformed axes
|
||||
coeff_trans_shape = tuple([shapes[0][ax] for ax in axes])
|
||||
|
||||
# nested loop over all combinations of axis offsets at this level
|
||||
for firsts in product(*([range(last_index), ]*ndim_transform)):
|
||||
for first, sh, ax in zip(firsts, coeff_trans_shape, axes):
|
||||
indices[ax] = slice(first, sh, step_size)
|
||||
even_indices[ax] = slice(first, sh, 2*step_size)
|
||||
odd_indices[ax] = slice(first+step_size, sh, 2*step_size)
|
||||
|
||||
# nested loop over all combinations of odd/even inidices
|
||||
approx = output.copy()
|
||||
output[tuple(indices)] = 0
|
||||
ntransforms = 0
|
||||
for odds in product(*([(0, 1), ]*ndim_transform)):
|
||||
for o, ax in zip(odds, axes):
|
||||
if o:
|
||||
odd_even_slices[ax] = odd_indices[ax]
|
||||
else:
|
||||
odd_even_slices[ax] = even_indices[ax]
|
||||
# extract the odd/even indices for all detail coefficients
|
||||
details_slice = {}
|
||||
for key, value in details.items():
|
||||
details_slice[key] = value[tuple(odd_even_slices)]
|
||||
details_slice['a'*ndim_transform] = approx[
|
||||
tuple(odd_even_slices)]
|
||||
|
||||
# perform the inverse dwt on the selected indices,
|
||||
# making sure to use periodic boundary conditions
|
||||
x = idwtn(details_slice, wavelets, 'periodization', axes=axes)
|
||||
for o, ax in zip(odds, axes):
|
||||
# circular shift along any odd indexed axis
|
||||
if o:
|
||||
x = np.roll(x, 1, axis=ax)
|
||||
output[tuple(indices)] += x
|
||||
ntransforms += 1
|
||||
output[tuple(indices)] /= ntransforms # normalize
|
||||
if not trim_approx:
|
||||
coeffs[j]['a'*ndim_transform] = a # restore approx coeffs to dict
|
||||
return output
|
||||
250
.CondaPkg/env/Lib/site-packages/pywt/_thresholding.py
vendored
Normal file
250
.CondaPkg/env/Lib/site-packages/pywt/_thresholding.py
vendored
Normal file
@@ -0,0 +1,250 @@
|
||||
# Copyright (c) 2006-2012 Filip Wasilewski <http://en.ig.ma/>
|
||||
# Copyright (c) 2012-2016 The PyWavelets Developers
|
||||
# <https://github.com/PyWavelets/pywt>
|
||||
# See COPYING for license details.
|
||||
|
||||
"""
|
||||
The thresholding helper module implements the most popular signal thresholding
|
||||
functions.
|
||||
"""
|
||||
|
||||
from __future__ import division, print_function, absolute_import
|
||||
import numpy as np
|
||||
|
||||
__all__ = ['threshold', 'threshold_firm']
|
||||
|
||||
|
||||
def soft(data, value, substitute=0):
|
||||
data = np.asarray(data)
|
||||
magnitude = np.absolute(data)
|
||||
|
||||
with np.errstate(divide='ignore'):
|
||||
# divide by zero okay as np.inf values get clipped, so ignore warning.
|
||||
thresholded = (1 - value/magnitude)
|
||||
thresholded.clip(min=0, max=None, out=thresholded)
|
||||
thresholded = data * thresholded
|
||||
|
||||
if substitute == 0:
|
||||
return thresholded
|
||||
else:
|
||||
cond = np.less(magnitude, value)
|
||||
return np.where(cond, substitute, thresholded)
|
||||
|
||||
|
||||
def nn_garrote(data, value, substitute=0):
|
||||
"""Non-negative Garrote."""
|
||||
data = np.asarray(data)
|
||||
magnitude = np.absolute(data)
|
||||
|
||||
with np.errstate(divide='ignore'):
|
||||
# divide by zero okay as np.inf values get clipped, so ignore warning.
|
||||
thresholded = (1 - value**2/magnitude**2)
|
||||
thresholded.clip(min=0, max=None, out=thresholded)
|
||||
thresholded = data * thresholded
|
||||
|
||||
if substitute == 0:
|
||||
return thresholded
|
||||
else:
|
||||
cond = np.less(magnitude, value)
|
||||
return np.where(cond, substitute, thresholded)
|
||||
|
||||
|
||||
def hard(data, value, substitute=0):
|
||||
data = np.asarray(data)
|
||||
cond = np.less(np.absolute(data), value)
|
||||
return np.where(cond, substitute, data)
|
||||
|
||||
|
||||
def greater(data, value, substitute=0):
|
||||
data = np.asarray(data)
|
||||
if np.iscomplexobj(data):
|
||||
raise ValueError("greater thresholding only supports real data")
|
||||
return np.where(np.less(data, value), substitute, data)
|
||||
|
||||
|
||||
def less(data, value, substitute=0):
|
||||
data = np.asarray(data)
|
||||
if np.iscomplexobj(data):
|
||||
raise ValueError("less thresholding only supports real data")
|
||||
return np.where(np.greater(data, value), substitute, data)
|
||||
|
||||
|
||||
thresholding_options = {'soft': soft,
|
||||
'hard': hard,
|
||||
'greater': greater,
|
||||
'less': less,
|
||||
'garrote': nn_garrote,
|
||||
# misspelled garrote for backwards compatibility
|
||||
'garotte': nn_garrote,
|
||||
}
|
||||
|
||||
|
||||
def threshold(data, value, mode='soft', substitute=0):
|
||||
"""
|
||||
Thresholds the input data depending on the mode argument.
|
||||
|
||||
In ``soft`` thresholding [1]_, data values with absolute value less than
|
||||
`param` are replaced with `substitute`. Data values with absolute value
|
||||
greater or equal to the thresholding value are shrunk toward zero
|
||||
by `value`. In other words, the new value is
|
||||
``data/np.abs(data) * np.maximum(np.abs(data) - value, 0)``.
|
||||
|
||||
In ``hard`` thresholding, the data values where their absolute value is
|
||||
less than the value param are replaced with `substitute`. Data values with
|
||||
absolute value greater or equal to the thresholding value stay untouched.
|
||||
|
||||
``garrote`` corresponds to the Non-negative garrote threshold [2]_, [3]_.
|
||||
It is intermediate between ``hard`` and ``soft`` thresholding. It behaves
|
||||
like soft thresholding for small data values and approaches hard
|
||||
thresholding for large data values.
|
||||
|
||||
In ``greater`` thresholding, the data is replaced with `substitute` where
|
||||
data is below the thresholding value. Greater data values pass untouched.
|
||||
|
||||
In ``less`` thresholding, the data is replaced with `substitute` where data
|
||||
is above the thresholding value. Lesser data values pass untouched.
|
||||
|
||||
Both ``hard`` and ``soft`` thresholding also support complex-valued data.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : array_like
|
||||
Numeric data.
|
||||
value : scalar
|
||||
Thresholding value.
|
||||
mode : {'soft', 'hard', 'garrote', 'greater', 'less'}
|
||||
Decides the type of thresholding to be applied on input data. Default
|
||||
is 'soft'.
|
||||
substitute : float, optional
|
||||
Substitute value (default: 0).
|
||||
|
||||
Returns
|
||||
-------
|
||||
output : array
|
||||
Thresholded array.
|
||||
|
||||
See Also
|
||||
--------
|
||||
threshold_firm
|
||||
|
||||
References
|
||||
----------
|
||||
.. [1] D.L. Donoho and I.M. Johnstone. Ideal Spatial Adaptation via
|
||||
Wavelet Shrinkage. Biometrika. Vol. 81, No. 3, pp.425-455, 1994.
|
||||
DOI:10.1093/biomet/81.3.425
|
||||
.. [2] L. Breiman. Better Subset Regression Using the Nonnegative Garrote.
|
||||
Technometrics, Vol. 37, pp. 373-384, 1995.
|
||||
DOI:10.2307/1269730
|
||||
.. [3] H-Y. Gao. Wavelet Shrinkage Denoising Using the Non-Negative
|
||||
Garrote. Journal of Computational and Graphical Statistics Vol. 7,
|
||||
No. 4, pp.469-488. 1998.
|
||||
DOI:10.1080/10618600.1998.10474789
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import numpy as np
|
||||
>>> import pywt
|
||||
>>> data = np.linspace(1, 4, 7)
|
||||
>>> data
|
||||
array([ 1. , 1.5, 2. , 2.5, 3. , 3.5, 4. ])
|
||||
>>> pywt.threshold(data, 2, 'soft')
|
||||
array([ 0. , 0. , 0. , 0.5, 1. , 1.5, 2. ])
|
||||
>>> pywt.threshold(data, 2, 'hard')
|
||||
array([ 0. , 0. , 2. , 2.5, 3. , 3.5, 4. ])
|
||||
>>> pywt.threshold(data, 2, 'garrote')
|
||||
array([ 0. , 0. , 0. , 0.9 , 1.66666667,
|
||||
2.35714286, 3. ])
|
||||
>>> pywt.threshold(data, 2, 'greater')
|
||||
array([ 0. , 0. , 2. , 2.5, 3. , 3.5, 4. ])
|
||||
>>> pywt.threshold(data, 2, 'less')
|
||||
array([ 1. , 1.5, 2. , 0. , 0. , 0. , 0. ])
|
||||
|
||||
"""
|
||||
|
||||
try:
|
||||
return thresholding_options[mode](data, value, substitute)
|
||||
except KeyError:
|
||||
# Make sure error is always identical by sorting keys
|
||||
keys = ("'{0}'".format(key) for key in
|
||||
sorted(thresholding_options.keys()))
|
||||
raise ValueError("The mode parameter only takes values from: {0}."
|
||||
.format(', '.join(keys)))
|
||||
|
||||
|
||||
def threshold_firm(data, value_low, value_high):
|
||||
"""Firm threshold.
|
||||
|
||||
The approach is intermediate between soft and hard thresholding [1]_. It
|
||||
behaves the same as soft-thresholding for values below `value_low` and
|
||||
the same as hard-thresholding for values above `thresh_high`. For
|
||||
intermediate values, the thresholded value is in between that corresponding
|
||||
to soft or hard thresholding.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : array-like
|
||||
The data to threshold. This can be either real or complex-valued.
|
||||
value_low : float
|
||||
Any values smaller then `value_low` will be set to zero.
|
||||
value_high : float
|
||||
Any values larger than `value_high` will not be modified.
|
||||
|
||||
Notes
|
||||
-----
|
||||
This thresholding technique is also known as semi-soft thresholding [2]_.
|
||||
|
||||
For each value, `x`, in `data`. This function computes::
|
||||
|
||||
if np.abs(x) <= value_low:
|
||||
return 0
|
||||
elif np.abs(x) > value_high:
|
||||
return x
|
||||
elif value_low < np.abs(x) and np.abs(x) <= value_high:
|
||||
return x * value_high * (1 - value_low/x)/(value_high - value_low)
|
||||
|
||||
``firm`` is a continuous function (like soft thresholding), but is
|
||||
unbiased for large values (like hard thresholding).
|
||||
|
||||
If ``value_high == value_low`` this function becomes hard-thresholding.
|
||||
If ``value_high`` is infinity, this function becomes soft-thresholding.
|
||||
|
||||
Returns
|
||||
-------
|
||||
val_new : array-like
|
||||
The values after firm thresholding at the specified thresholds.
|
||||
|
||||
See Also
|
||||
--------
|
||||
threshold
|
||||
|
||||
References
|
||||
----------
|
||||
.. [1] H.-Y. Gao and A.G. Bruce. Waveshrink with firm shrinkage.
|
||||
Statistica Sinica, Vol. 7, pp. 855-874, 1997.
|
||||
.. [2] A. Bruce and H-Y. Gao. WaveShrink: Shrinkage Functions and
|
||||
Thresholds. Proc. SPIE 2569, Wavelet Applications in Signal and
|
||||
Image Processing III, 1995.
|
||||
DOI:10.1117/12.217582
|
||||
"""
|
||||
|
||||
if value_low < 0:
|
||||
raise ValueError("value_low must be non-negative.")
|
||||
|
||||
if value_high < value_low:
|
||||
raise ValueError(
|
||||
"value_high must be greater than or equal to value_low.")
|
||||
|
||||
data = np.asarray(data)
|
||||
magnitude = np.absolute(data)
|
||||
with np.errstate(divide='ignore'):
|
||||
# divide by zero okay as np.inf values get clipped, so ignore warning.
|
||||
vdiff = value_high - value_low
|
||||
thresholded = value_high * (1 - value_low/magnitude) / vdiff
|
||||
thresholded.clip(min=0, max=None, out=thresholded)
|
||||
thresholded = data * thresholded
|
||||
|
||||
# restore hard-thresholding behavior for values > value_high
|
||||
large_vals = np.where(magnitude > value_high)
|
||||
if np.any(large_vals[0]):
|
||||
thresholded[large_vals] = data[large_vals]
|
||||
return thresholded
|
||||
93
.CondaPkg/env/Lib/site-packages/pywt/_utils.py
vendored
Normal file
93
.CondaPkg/env/Lib/site-packages/pywt/_utils.py
vendored
Normal file
@@ -0,0 +1,93 @@
|
||||
# Copyright (c) 2017 The PyWavelets Developers
|
||||
# <https://github.com/PyWavelets/pywt>
|
||||
# See COPYING for license details.
|
||||
import inspect
|
||||
import numpy as np
|
||||
from collections.abc import Iterable
|
||||
|
||||
from ._extensions._pywt import (Wavelet, ContinuousWavelet,
|
||||
DiscreteContinuousWavelet, Modes)
|
||||
|
||||
|
||||
def _as_wavelet(wavelet):
|
||||
"""Convert wavelet name to a Wavelet object."""
|
||||
if not isinstance(wavelet, (ContinuousWavelet, Wavelet)):
|
||||
wavelet = DiscreteContinuousWavelet(wavelet)
|
||||
if isinstance(wavelet, ContinuousWavelet):
|
||||
raise ValueError(
|
||||
"A ContinuousWavelet object was provided, but only discrete "
|
||||
"Wavelet objects are supported by this function. A list of all "
|
||||
"supported discrete wavelets can be obtained by running:\n"
|
||||
"print(pywt.wavelist(kind='discrete'))")
|
||||
return wavelet
|
||||
|
||||
|
||||
def _wavelets_per_axis(wavelet, axes):
|
||||
"""Initialize Wavelets for each axis to be transformed.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
wavelet : Wavelet or tuple of Wavelets
|
||||
If a single Wavelet is provided, it will used for all axes. Otherwise
|
||||
one Wavelet per axis must be provided.
|
||||
axes : list
|
||||
The tuple of axes to be transformed.
|
||||
|
||||
Returns
|
||||
-------
|
||||
wavelets : list of Wavelet objects
|
||||
A tuple of Wavelets equal in length to ``axes``.
|
||||
|
||||
"""
|
||||
axes = tuple(axes)
|
||||
if isinstance(wavelet, (str, Wavelet)):
|
||||
# same wavelet on all axes
|
||||
wavelets = [_as_wavelet(wavelet), ] * len(axes)
|
||||
elif isinstance(wavelet, Iterable):
|
||||
# (potentially) unique wavelet per axis (e.g. for dual-tree DWT)
|
||||
if len(wavelet) == 1:
|
||||
wavelets = [_as_wavelet(wavelet[0]), ] * len(axes)
|
||||
else:
|
||||
if len(wavelet) != len(axes):
|
||||
raise ValueError((
|
||||
"The number of wavelets must match the number of axes "
|
||||
"to be transformed."))
|
||||
wavelets = [_as_wavelet(w) for w in wavelet]
|
||||
else:
|
||||
raise ValueError("wavelet must be a str, Wavelet or iterable")
|
||||
return wavelets
|
||||
|
||||
|
||||
def _modes_per_axis(modes, axes):
|
||||
"""Initialize mode for each axis to be transformed.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
modes : str or tuple of strings
|
||||
If a single mode is provided, it will used for all axes. Otherwise
|
||||
one mode per axis must be provided.
|
||||
axes : tuple
|
||||
The tuple of axes to be transformed.
|
||||
|
||||
Returns
|
||||
-------
|
||||
modes : tuple of int
|
||||
A tuple of Modes equal in length to ``axes``.
|
||||
|
||||
"""
|
||||
axes = tuple(axes)
|
||||
if isinstance(modes, (int, str)):
|
||||
# same wavelet on all axes
|
||||
modes = [Modes.from_object(modes), ] * len(axes)
|
||||
elif isinstance(modes, Iterable):
|
||||
if len(modes) == 1:
|
||||
modes = [Modes.from_object(modes[0]), ] * len(axes)
|
||||
else:
|
||||
# (potentially) unique wavelet per axis (e.g. for dual-tree DWT)
|
||||
if len(modes) != len(axes):
|
||||
raise ValueError(("The number of modes must match the number "
|
||||
"of axes to be transformed."))
|
||||
modes = [Modes.from_object(mode) for mode in modes]
|
||||
else:
|
||||
raise ValueError("modes must be a str, Mode enum or iterable")
|
||||
return modes
|
||||
1056
.CondaPkg/env/Lib/site-packages/pywt/_wavelet_packets.py
vendored
Normal file
1056
.CondaPkg/env/Lib/site-packages/pywt/_wavelet_packets.py
vendored
Normal file
File diff suppressed because it is too large
Load Diff
6
.CondaPkg/env/Lib/site-packages/pywt/conftest.py
vendored
Normal file
6
.CondaPkg/env/Lib/site-packages/pywt/conftest.py
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
import pytest
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
config.addinivalue_line("markers",
|
||||
"slow: Tests that are slow.")
|
||||
2
.CondaPkg/env/Lib/site-packages/pywt/data/__init__.py
vendored
Normal file
2
.CondaPkg/env/Lib/site-packages/pywt/data/__init__.py
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
from ._readers import ascent, aero, ecg, camera, nino
|
||||
from ._wavelab_signals import demo_signal
|
||||
BIN
.CondaPkg/env/Lib/site-packages/pywt/data/__pycache__/__init__.cpython-311.pyc
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/data/__pycache__/__init__.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/data/__pycache__/_readers.cpython-311.pyc
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/data/__pycache__/_readers.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/data/__pycache__/_wavelab_signals.cpython-311.pyc
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/data/__pycache__/_wavelab_signals.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/data/__pycache__/create_dat.cpython-311.pyc
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/data/__pycache__/create_dat.cpython-311.pyc
vendored
Normal file
Binary file not shown.
198
.CondaPkg/env/Lib/site-packages/pywt/data/_readers.py
vendored
Normal file
198
.CondaPkg/env/Lib/site-packages/pywt/data/_readers.py
vendored
Normal file
@@ -0,0 +1,198 @@
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def ascent():
|
||||
"""
|
||||
Get an 8-bit grayscale bit-depth, 512 x 512 derived image for
|
||||
easy use in demos
|
||||
|
||||
The image is derived from accent-to-the-top.jpg at
|
||||
http://www.public-domain-image.com/people-public-domain-images-pictures/
|
||||
|
||||
Parameters
|
||||
----------
|
||||
None
|
||||
|
||||
Returns
|
||||
-------
|
||||
ascent : ndarray
|
||||
convenient image to use for testing and demonstration
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import pywt.data
|
||||
>>> ascent = pywt.data.ascent()
|
||||
>>> ascent.shape == (512, 512)
|
||||
True
|
||||
>>> ascent.max()
|
||||
255
|
||||
|
||||
>>> import matplotlib.pyplot as plt
|
||||
>>> plt.gray()
|
||||
>>> plt.imshow(ascent) # doctest: +ELLIPSIS
|
||||
<matplotlib.image.AxesImage object at ...>
|
||||
>>> plt.show() # doctest: +SKIP
|
||||
|
||||
"""
|
||||
fname = os.path.join(os.path.dirname(__file__), 'ascent.npz')
|
||||
ascent = np.load(fname)['data']
|
||||
return ascent
|
||||
|
||||
|
||||
def aero():
|
||||
"""
|
||||
Get an 8-bit grayscale bit-depth, 512 x 512 derived image for
|
||||
easy use in demos
|
||||
|
||||
Parameters
|
||||
----------
|
||||
None
|
||||
|
||||
Returns
|
||||
-------
|
||||
aero : ndarray
|
||||
convenient image to use for testing and demonstration
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import pywt.data
|
||||
>>> aero = pywt.data.ascent()
|
||||
>>> aero.shape == (512, 512)
|
||||
True
|
||||
>>> aero.max()
|
||||
255
|
||||
|
||||
>>> import matplotlib.pyplot as plt
|
||||
>>> plt.gray()
|
||||
>>> plt.imshow(aero) # doctest: +ELLIPSIS
|
||||
<matplotlib.image.AxesImage object at ...>
|
||||
>>> plt.show() # doctest: +SKIP
|
||||
|
||||
"""
|
||||
fname = os.path.join(os.path.dirname(__file__), 'aero.npz')
|
||||
aero = np.load(fname)['data']
|
||||
return aero
|
||||
|
||||
|
||||
def camera():
|
||||
"""
|
||||
Get an 8-bit grayscale bit-depth, 512 x 512 derived image for
|
||||
easy use in demos
|
||||
|
||||
Parameters
|
||||
----------
|
||||
None
|
||||
|
||||
Returns
|
||||
-------
|
||||
camera : ndarray
|
||||
convenient image to use for testing and demonstration
|
||||
|
||||
Notes
|
||||
-----
|
||||
No copyright restrictions. CC0 by the photographer (Lav Varshney).
|
||||
|
||||
.. versionchanged:: 0.18
|
||||
This image was replaced due to copyright restrictions. For more
|
||||
information, please see [1]_, where the same change was made in
|
||||
scikit-image.
|
||||
|
||||
References
|
||||
----------
|
||||
.. [1] https://github.com/scikit-image/scikit-image/issues/3927
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import pywt.data
|
||||
>>> camera = pywt.data.ascent()
|
||||
>>> camera.shape == (512, 512)
|
||||
True
|
||||
|
||||
>>> import matplotlib.pyplot as plt
|
||||
>>> plt.gray()
|
||||
>>> plt.imshow(camera) # doctest: +ELLIPSIS
|
||||
<matplotlib.image.AxesImage object at ...>
|
||||
>>> plt.show() # doctest: +SKIP
|
||||
|
||||
"""
|
||||
fname = os.path.join(os.path.dirname(__file__), 'camera.npz')
|
||||
camera = np.load(fname)['data']
|
||||
return camera
|
||||
|
||||
|
||||
def ecg():
|
||||
"""
|
||||
Get 1024 points of an ECG timeseries.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
None
|
||||
|
||||
Returns
|
||||
-------
|
||||
ecg : ndarray
|
||||
convenient timeseries to use for testing and demonstration
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import pywt.data
|
||||
>>> ecg = pywt.data.ecg()
|
||||
>>> ecg.shape == (1024,)
|
||||
True
|
||||
|
||||
>>> import matplotlib.pyplot as plt
|
||||
>>> plt.plot(ecg) # doctest: +ELLIPSIS
|
||||
[<matplotlib.lines.Line2D object at ...>]
|
||||
>>> plt.show() # doctest: +SKIP
|
||||
"""
|
||||
fname = os.path.join(os.path.dirname(__file__), 'ecg.npy')
|
||||
ecg = np.load(fname)
|
||||
return ecg
|
||||
|
||||
|
||||
def nino():
|
||||
"""
|
||||
This data contains the averaged monthly sea surface temperature in degrees
|
||||
Celsius of the Pacific Ocean, between 0-10 degrees South and 90-80 degrees West, from 1950 to 2016.
|
||||
This dataset is in the public domain and was obtained from NOAA.
|
||||
National Oceanic and Atmospheric Administration's National Weather Service
|
||||
ERSSTv4 dataset, nino 3, http://www.cpc.ncep.noaa.gov/data/indices/
|
||||
|
||||
Parameters
|
||||
----------
|
||||
None
|
||||
|
||||
Returns
|
||||
-------
|
||||
time : ndarray
|
||||
convenient timeseries to use for testing and demonstration
|
||||
sst : ndarray
|
||||
convenient timeseries to use for testing and demonstration
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import pywt.data
|
||||
>>> time, sst = pywt.data.nino()
|
||||
>>> sst.shape == (264,)
|
||||
True
|
||||
|
||||
>>> import matplotlib.pyplot as plt
|
||||
>>> plt.plot(time,sst) # doctest: +ELLIPSIS
|
||||
[<matplotlib.lines.Line2D object at ...>]
|
||||
>>> plt.show() # doctest: +SKIP
|
||||
"""
|
||||
fname = os.path.join(os.path.dirname(__file__), 'sst_nino3.npz')
|
||||
sst_csv = np.load(fname)['sst_csv']
|
||||
# sst_csv = pd.read_csv("http://www.cpc.ncep.noaa.gov/data/indices/ersst4.nino.mth.81-10.ascii", sep=' ', skipinitialspace=True)
|
||||
# take only full years
|
||||
n = int(np.floor(sst_csv.shape[0]/12.)*12.)
|
||||
# Building the mean of three months
|
||||
# the 4. column is nino 3
|
||||
sst = np.mean(np.reshape(np.array(sst_csv)[:n, 4], (n//3, -1)), axis=1)
|
||||
sst = (sst - np.mean(sst)) / np.std(sst, ddof=1)
|
||||
|
||||
dt = 0.25
|
||||
time = np.arange(len(sst)) * dt + 1950.0 # construct time array
|
||||
return time, sst
|
||||
259
.CondaPkg/env/Lib/site-packages/pywt/data/_wavelab_signals.py
vendored
Normal file
259
.CondaPkg/env/Lib/site-packages/pywt/data/_wavelab_signals.py
vendored
Normal file
@@ -0,0 +1,259 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from __future__ import division
|
||||
|
||||
import numpy as np
|
||||
|
||||
__all__ = ['demo_signal']
|
||||
|
||||
_implemented_signals = [
|
||||
'Blocks',
|
||||
'Bumps',
|
||||
'HeaviSine',
|
||||
'Doppler',
|
||||
'Ramp',
|
||||
'HiSine',
|
||||
'LoSine',
|
||||
'LinChirp',
|
||||
'TwoChirp',
|
||||
'QuadChirp',
|
||||
'MishMash',
|
||||
'WernerSorrows',
|
||||
'HypChirps',
|
||||
'LinChirps',
|
||||
'Chirps',
|
||||
'Gabor',
|
||||
'sineoneoverx',
|
||||
'Piece-Regular',
|
||||
'Piece-Polynomial',
|
||||
'Riemann']
|
||||
|
||||
|
||||
def demo_signal(name='Bumps', n=None):
|
||||
"""Simple 1D wavelet test functions.
|
||||
|
||||
This function can generate a number of common 1D test signals used in
|
||||
papers by David Donoho and colleagues (e.g. [1]_) as well as the wavelet
|
||||
book by Stéphane Mallat [2]_.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : {'Blocks', 'Bumps', 'HeaviSine', 'Doppler', ...}
|
||||
The type of test signal to generate (`name` is case-insensitive). If
|
||||
`name` is set to `'list'`, a list of the available test functions is
|
||||
returned.
|
||||
n : int or None
|
||||
The length of the test signal. This should be provided for all test
|
||||
signals except `'Gabor'` and `'sineoneoverx'` which have a fixed
|
||||
length.
|
||||
|
||||
Returns
|
||||
-------
|
||||
f : np.ndarray
|
||||
Array of length ``n`` corresponding to the specified test signal type.
|
||||
|
||||
References
|
||||
----------
|
||||
.. [1] D.L. Donoho and I.M. Johnstone. Ideal spatial adaptation by
|
||||
wavelet shrinkage. Biometrika, vol. 81, pp. 425–455, 1994.
|
||||
.. [2] S. Mallat. A Wavelet Tour of Signal Processing: The Sparse Way.
|
||||
Academic Press. 2009.
|
||||
|
||||
Notes
|
||||
-----
|
||||
This function is a partial reimplementation of the `MakeSignal` function
|
||||
from the [Wavelab](https://statweb.stanford.edu/~wavelab/) toolbox. These
|
||||
test signals are provided with permission of Dr. Donoho to encourage
|
||||
reproducible research.
|
||||
|
||||
"""
|
||||
if name.lower() == 'list':
|
||||
return _implemented_signals
|
||||
|
||||
if n is not None:
|
||||
if n < 1 or (n % 1) != 0:
|
||||
raise ValueError("n must be an integer >= 1")
|
||||
t = np.arange(1/n, 1 + 1/n, 1/n)
|
||||
|
||||
# The following function types don't allow user-specified `n`.
|
||||
n_hard_coded = ['gabor', 'sineoneoverx']
|
||||
|
||||
name = name.lower()
|
||||
if name in n_hard_coded and n is not None:
|
||||
raise ValueError(
|
||||
"Parameter n must be set to None when name is {}".format(name))
|
||||
elif n is None and name not in n_hard_coded:
|
||||
raise ValueError(
|
||||
"Parameter n must be provided when name is {}".format(name))
|
||||
|
||||
if name == 'blocks':
|
||||
t0s = [.1, .13, .15, .23, .25, .4, .44, .65, .76, .78, .81]
|
||||
hs = [4, -5, 3, -4, 5, -4.2, 2.1, 4.3, -3.1, 2.1, -4.2]
|
||||
f = 0
|
||||
for (t0, h) in zip(t0s, hs):
|
||||
f += h * (1 + np.sign(t - t0)) / 2
|
||||
elif name == 'bumps':
|
||||
t0s = [.1, .13, .15, .23, .25, .4, .44, .65, .76, .78, .81]
|
||||
hs = [4, 5, 3, 4, 5, 4.2, 2.1, 4.3, 3.1, 5.1, 4.2]
|
||||
ws = [.005, .005, .006, .01, .01, .03, .01, .01, .005, .008, .005]
|
||||
f = 0
|
||||
for (t0, h, w) in zip(t0s, hs, ws):
|
||||
f += h / (1 + np.abs((t - t0) / w))**4
|
||||
elif name == 'heavisine':
|
||||
f = 4 * np.sin(4 * np.pi * t) - np.sign(t - 0.3) - np.sign(0.72 - t)
|
||||
elif name == 'doppler':
|
||||
f = np.sqrt(t * (1 - t)) * np.sin(2 * np.pi * 1.05 / (t + 0.05))
|
||||
elif name == 'ramp':
|
||||
f = t - (t >= .37)
|
||||
elif name == 'hisine':
|
||||
f = np.sin(np.pi * (n * .6902) * t)
|
||||
elif name == 'losine':
|
||||
f = np.sin(np.pi * (n * .3333) * t)
|
||||
elif name == 'linchirp':
|
||||
f = np.sin(np.pi * t * ((n * .500) * t))
|
||||
elif name == 'twochirp':
|
||||
f = np.sin(np.pi * t * (n * t)) + np.sin((np.pi / 3) * t * (n * t))
|
||||
elif name == 'quadchirp':
|
||||
f = np.sin((np.pi / 3) * t * (n * t**2))
|
||||
elif name == 'mishmash': # QuadChirp + LinChirp + HiSine
|
||||
f = np.sin((np.pi / 3) * t * (n * t**2))
|
||||
f += np.sin(np.pi * (n * .6902) * t)
|
||||
f += np.sin(np.pi * t * (n * .125 * t))
|
||||
elif name == 'wernersorrows':
|
||||
f = np.sin(np.pi * t * (n / 2 * t**2))
|
||||
f = f + np.sin(np.pi * (n * .6902) * t)
|
||||
f = f + np.sin(np.pi * t * (n * t))
|
||||
pos = [.1, .13, .15, .23, .25, .40, .44, .65, .76, .78, .81]
|
||||
hgt = [4, 5, 3, 4, 5, 4.2, 2.1, 4.3, 3.1, 5.1, 4.2]
|
||||
wth = [.005, .005, .006, .01, .01, .03, .01, .01, .005, .008, .005]
|
||||
for p, h, w in zip(pos, hgt, wth):
|
||||
f += h / (1 + np.abs((t - p) / w))**4
|
||||
elif name == 'hypchirps': # Hyperbolic Chirps of Mallat's book
|
||||
alpha = 15 * n * np.pi / 1024
|
||||
beta = 5 * n * np.pi / 1024
|
||||
t = np.arange(1.001, n + .001 + 1) / n
|
||||
f1 = np.zeros(n)
|
||||
f2 = np.zeros(n)
|
||||
f1 = np.sin(alpha / (.8 - t)) * (0.1 < t) * (t < 0.68)
|
||||
f2 = np.sin(beta / (.8 - t)) * (0.1 < t) * (t < 0.75)
|
||||
m = int(np.round(0.65 * n))
|
||||
p = m // 4
|
||||
envelope = np.ones(m) # the rinp.sing cutoff function
|
||||
tmp = np.arange(1, p + 1)-np.ones(p)
|
||||
envelope[:p] = (1 + np.sin(-np.pi / 2 + tmp / (p - 1) * np.pi)) / 2
|
||||
envelope[m-p:m] = envelope[:p][::-1]
|
||||
env = np.zeros(n)
|
||||
env[int(np.ceil(n / 10)) - 1:m + int(np.ceil(n / 10)) - 1] = \
|
||||
envelope[:m]
|
||||
f = (f1 + f2) * env
|
||||
elif name == 'linchirps': # Linear Chirps of Mallat's book
|
||||
b = 100 * n * np.pi / 1024
|
||||
a = 250 * n * np.pi / 1024
|
||||
t = np.arange(1, n + 1) / n
|
||||
A1 = np.sqrt((t - 1 / n) * (1 - t))
|
||||
f = A1 * (np.cos(a * t**2) + np.cos(b * t + a * t**2))
|
||||
elif name == 'chirps': # Mixture of Chirps of Mallat's book
|
||||
t = np.arange(1, n + 1)/n * 10 * np.pi
|
||||
f1 = np.cos(t**2 * n / 1024)
|
||||
a = 30 * n / 1024
|
||||
t = np.arange(1, n + 1)/n * np.pi
|
||||
f2 = np.cos(a * (t**3))
|
||||
f2 = f2[::-1]
|
||||
ix = np.arange(-n, n + 1) / n * 20
|
||||
g = np.exp(-ix**2 * 4 * n / 1024)
|
||||
i1 = slice(n // 2, n // 2 + n)
|
||||
i2 = slice(n // 8, n // 8 + n)
|
||||
j = np.arange(1, n + 1) / n
|
||||
f3 = g[i1] * np.cos(50 * np.pi * j * n / 1024)
|
||||
f4 = g[i2] * np.cos(350 * np.pi * j * n / 1024)
|
||||
f = f1 + f2 + f3 + f4
|
||||
envelope = np.ones(n) # the rinp.sing cutoff function
|
||||
tmp = np.arange(1, n // 8 + 1) - np.ones(n // 8)
|
||||
envelope[:n // 8] = (
|
||||
1 + np.sin(-np.pi / 2 + tmp / (n / 8 - 1) * np.pi)) / 2
|
||||
envelope[7 * n // 8:n] = envelope[:n // 8][::-1]
|
||||
f = f*envelope
|
||||
elif name == 'gabor': # two modulated Gabor functions in Mallat's book
|
||||
n = 512
|
||||
t = np.arange(-n, n + 1)*5 / n
|
||||
j = np.arange(1, n + 1) / n
|
||||
g = np.exp(-t**2 * 20)
|
||||
i1 = slice(2*n // 4, 2 * n // 4 + n)
|
||||
i2 = slice(n // 4, n // 4 + n)
|
||||
f1 = 3 * g[i1] * np.exp(1j * (n // 16) * np.pi * j)
|
||||
f2 = 3 * g[i2] * np.exp(1j * (n // 4) * np.pi * j)
|
||||
f = f1 + f2
|
||||
elif name == 'sineoneoverx': # np.sin(1/x) in Mallat's book
|
||||
n = 1024
|
||||
i1 = np.arange(-n + 1, n + 1, dtype=float)
|
||||
i1[i1 == 0] = 1 / 100
|
||||
i1 = i1 / (n - 1)
|
||||
f = np.sin(1.5 / i1)
|
||||
f = f[512:1536]
|
||||
elif name == 'piece-regular':
|
||||
f = np.zeros(n)
|
||||
n_12 = int(np.fix(n / 12))
|
||||
n_7 = int(np.fix(n / 7))
|
||||
n_5 = int(np.fix(n / 5))
|
||||
n_3 = int(np.fix(n / 3))
|
||||
n_2 = int(np.fix(n / 2))
|
||||
n_20 = int(np.fix(n / 20))
|
||||
f1 = -15 * demo_signal('bumps', n)
|
||||
t = np.arange(1, n_12 + 1) / n_12
|
||||
f2 = -np.exp(4 * t)
|
||||
t = np.arange(1, n_7 + 1) / n_7
|
||||
f5 = np.exp(4 * t)-np.exp(4)
|
||||
t = np.arange(1, n_3 + 1) / n_3
|
||||
fma = 6 / 40
|
||||
f6 = -70 * np.exp(-((t - 0.5) * (t - 0.5)) / (2 * fma**2))
|
||||
f[:n_7] = f6[:n_7]
|
||||
f[n_7:n_5] = 0.5 * f6[n_7:n_5]
|
||||
f[n_5:n_3] = f6[n_5:n_3]
|
||||
f[n_3:n_2] = f1[n_3:n_2]
|
||||
f[n_2:n_2 + n_12] = f2
|
||||
f[n_2 + 2 * n_12 - 1:n_2 + n_12 - 1:-1] = f2
|
||||
f[n_2 + 2 * n_12 + n_20:n_2 + 2 * n_12 + 3 * n_20] = -np.ones(
|
||||
n_2 + 2*n_12 + 3*n_20 - n_2 - 2*n_12 - n_20) * 25
|
||||
k = n_2 + 2 * n_12 + 3 * n_20
|
||||
f[k:k + n_7] = f5
|
||||
diff = n - 5 * n_5
|
||||
f[5 * n_5:n] = f[diff - 1::-1]
|
||||
# zero-mean
|
||||
bias = np.sum(f) / n
|
||||
f = bias - f
|
||||
elif name == 'piece-polynomial':
|
||||
f = np.zeros(n)
|
||||
n_5 = int(np.fix(n / 5))
|
||||
n_10 = int(np.fix(n / 10))
|
||||
n_20 = int(np.fix(n / 20))
|
||||
t = np.arange(1, n_5 + 1) / n_5
|
||||
f1 = 20 * (t**3 + t**2 + 4)
|
||||
f3 = 40 * (2 * t**3 + t) + 100
|
||||
f2 = 10 * t**3 + 45
|
||||
f4 = 16 * t**2 + 8 * t + 16
|
||||
f5 = 20 * (t + 4)
|
||||
f6 = np.ones(n_10) * 20
|
||||
f[:n_5] = f1
|
||||
f[2 * n_5 - 1:n_5 - 1:-1] = f2
|
||||
f[2 * n_5:3 * n_5] = f3
|
||||
f[3 * n_5:4 * n_5] = f4
|
||||
f[4 * n_5:5 * n_5] = f5[n_5::-1]
|
||||
diff = n - 5*n_5
|
||||
f[5 * n_5:n] = f[diff - 1::-1]
|
||||
f[n_20:n_20 + n_10] = np.ones(n_10) * 10
|
||||
f[n - n_10:n + n_20 - n_10] = np.ones(n_20) * 150
|
||||
# zero-mean
|
||||
bias = np.sum(f) / n
|
||||
f = f - bias
|
||||
elif name == 'riemann':
|
||||
# Riemann's Non-differentiable Function
|
||||
sqn = int(np.round(np.sqrt(n)))
|
||||
idx = np.arange(1, sqn + 1)
|
||||
idx *= idx
|
||||
f = np.zeros_like(t)
|
||||
f[idx - 1] = 1. / np.arange(1, sqn + 1)
|
||||
f = np.real(np.fft.ifft(f))
|
||||
else:
|
||||
raise ValueError(
|
||||
"unknown name: {}. name must be one of: {}".format(
|
||||
name, _implemented_signals))
|
||||
return f
|
||||
BIN
.CondaPkg/env/Lib/site-packages/pywt/data/aero.npz
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/data/aero.npz
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/data/ascent.npz
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/data/ascent.npz
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/data/camera.npz
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/data/camera.npz
vendored
Normal file
Binary file not shown.
39
.CondaPkg/env/Lib/site-packages/pywt/data/create_dat.py
vendored
Normal file
39
.CondaPkg/env/Lib/site-packages/pywt/data/create_dat.py
vendored
Normal file
@@ -0,0 +1,39 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
"""Helper script for creating image .dat files by numpy.save
|
||||
|
||||
Usage:
|
||||
|
||||
python create_dat.py <name of image file> <name of dat file>
|
||||
|
||||
Example (to create aero.dat):
|
||||
|
||||
python create_dat.py aero.png aero.dat
|
||||
|
||||
Requires Scipy and PIL.
|
||||
"""
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def main():
|
||||
from scipy.misc import imread
|
||||
|
||||
if len(sys.argv) != 3:
|
||||
print(__doc__)
|
||||
exit()
|
||||
|
||||
image_fname = sys.argv[1]
|
||||
dat_fname = sys.argv[2]
|
||||
|
||||
data = imread(image_fname)
|
||||
|
||||
np.savez_compressed(dat_fname, data=data)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
BIN
.CondaPkg/env/Lib/site-packages/pywt/data/ecg.npy
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/data/ecg.npy
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/data/sst_nino3.npz
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/data/sst_nino3.npz
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test__pywt.cpython-311.pyc
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test__pywt.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_concurrent.cpython-311.pyc
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_concurrent.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_cwt_wavelets.cpython-311.pyc
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_cwt_wavelets.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_data.cpython-311.pyc
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_data.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_deprecations.cpython-311.pyc
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_deprecations.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_doc.cpython-311.pyc
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_doc.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_dwt_idwt.cpython-311.pyc
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_dwt_idwt.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_functions.cpython-311.pyc
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_functions.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_matlab_compatibility.cpython-311.pyc
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_matlab_compatibility.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_matlab_compatibility_cwt.cpython-311.pyc
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_matlab_compatibility_cwt.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_modes.cpython-311.pyc
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_modes.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_mra.cpython-311.pyc
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_mra.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_multidim.cpython-311.pyc
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_multidim.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_multilevel.cpython-311.pyc
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_multilevel.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_perfect_reconstruction.cpython-311.pyc
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_perfect_reconstruction.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_swt.cpython-311.pyc
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_swt.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_thresholding.cpython-311.pyc
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_thresholding.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_wavelet.cpython-311.pyc
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_wavelet.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_wp.cpython-311.pyc
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_wp.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_wp2d.cpython-311.pyc
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_wp2d.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_wpnd.cpython-311.pyc
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_wpnd.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/data/__pycache__/generate_matlab_data.cpython-311.pyc
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/data/__pycache__/generate_matlab_data.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/data/__pycache__/generate_matlab_data_cwt.cpython-311.pyc
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/data/__pycache__/generate_matlab_data_cwt.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/data/cwt_matlabR2015b_result.npz
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/data/cwt_matlabR2015b_result.npz
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/data/dwt_matlabR2012a_result.npz
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/data/dwt_matlabR2012a_result.npz
vendored
Normal file
Binary file not shown.
96
.CondaPkg/env/Lib/site-packages/pywt/tests/data/generate_matlab_data.py
vendored
Normal file
96
.CondaPkg/env/Lib/site-packages/pywt/tests/data/generate_matlab_data.py
vendored
Normal file
@@ -0,0 +1,96 @@
|
||||
""" This script was used to generate dwt_matlabR2012a_result.npz by storing
|
||||
the outputs from Matlab R2012a. """
|
||||
|
||||
from __future__ import division, print_function, absolute_import
|
||||
|
||||
import numpy as np
|
||||
import pywt
|
||||
|
||||
try:
|
||||
from pymatbridge import Matlab
|
||||
mlab = Matlab()
|
||||
_matlab_missing = False
|
||||
except ImportError:
|
||||
print("To run Matlab compatibility tests you need to have MathWorks "
|
||||
"MATLAB, MathWorks Wavelet Toolbox and the pymatbridge Python "
|
||||
"package installed.")
|
||||
_matlab_missing = True
|
||||
|
||||
if _matlab_missing:
|
||||
raise EnvironmentError("Can't generate matlab data files without MATLAB")
|
||||
|
||||
size_set = 'reduced'
|
||||
|
||||
# list of mode names in pywt and matlab
|
||||
modes = [('zero', 'zpd'),
|
||||
('constant', 'sp0'),
|
||||
('symmetric', 'sym'),
|
||||
('reflect', 'symw'),
|
||||
('periodic', 'ppd'),
|
||||
('smooth', 'sp1'),
|
||||
('periodization', 'per'),
|
||||
('antisymmetric', 'asym'),
|
||||
('antireflect', 'asymw')]
|
||||
|
||||
families = ('db', 'sym', 'coif', 'bior', 'rbio')
|
||||
wavelets = sum([pywt.wavelist(name) for name in families], [])
|
||||
|
||||
rstate = np.random.RandomState(1234)
|
||||
mlab.start()
|
||||
try:
|
||||
all_matlab_results = {}
|
||||
for wavelet in wavelets:
|
||||
w = pywt.Wavelet(wavelet)
|
||||
mlab.set_variable('wavelet', wavelet)
|
||||
if size_set == 'full':
|
||||
data_sizes = list(range(w.dec_len, 40)) + \
|
||||
[100, 200, 500, 1000, 50000]
|
||||
else:
|
||||
data_sizes = (w.dec_len, w.dec_len + 1)
|
||||
for N in data_sizes:
|
||||
data = rstate.randn(N)
|
||||
mlab.set_variable('data', data)
|
||||
for pmode, mmode in modes:
|
||||
# Matlab result
|
||||
if np.any((wavelet == np.array(['coif6', 'coif7', 'coif8', 'coif9', 'coif10', 'coif11', 'coif12', 'coif13', 'coif14', 'coif15', 'coif16', 'coif17'])),axis=0):
|
||||
mlab.set_variable('Lo_D', w.dec_lo)
|
||||
mlab.set_variable('Hi_D', w.dec_hi)
|
||||
mlab_code = ("[ma, md] = dwt(data, Lo_D, Hi_D, "
|
||||
"'mode', '%s');" % mmode)
|
||||
else:
|
||||
mlab_code = ("[ma, md] = dwt(data, wavelet, "
|
||||
"'mode', '%s');" % mmode)
|
||||
res = mlab.run_code(mlab_code)
|
||||
if not res['success']:
|
||||
raise RuntimeError(
|
||||
"Matlab failed to execute the provided code. "
|
||||
"Check that the wavelet toolbox is installed.")
|
||||
# need np.asarray because sometimes the output is type float
|
||||
ma = np.asarray(mlab.get_variable('ma'))
|
||||
md = np.asarray(mlab.get_variable('md'))
|
||||
ma_key = '_'.join([mmode, wavelet, str(N), 'ma'])
|
||||
md_key = '_'.join([mmode, wavelet, str(N), 'md'])
|
||||
all_matlab_results[ma_key] = ma
|
||||
all_matlab_results[md_key] = md
|
||||
|
||||
# Matlab result
|
||||
mlab.set_variable('Lo_D', w.dec_lo)
|
||||
mlab.set_variable('Hi_D', w.dec_hi)
|
||||
mlab_code = ("[ma, md] = dwt(data, Lo_D, Hi_D, "
|
||||
"'mode', '%s');" % mmode)
|
||||
res = mlab.run_code(mlab_code)
|
||||
if not res['success']:
|
||||
raise RuntimeError(
|
||||
"Matlab failed to execute the provided code. "
|
||||
"Check that the wavelet toolbox is installed.")
|
||||
# need np.asarray because sometimes the output is type float
|
||||
ma = np.asarray(mlab.get_variable('ma'))
|
||||
md = np.asarray(mlab.get_variable('md'))
|
||||
ma_key = '_'.join([mmode, wavelet, str(N), 'ma_pywtCoeffs'])
|
||||
md_key = '_'.join([mmode, wavelet, str(N), 'md_pywtCoeffs'])
|
||||
all_matlab_results[ma_key] = ma
|
||||
all_matlab_results[md_key] = md
|
||||
finally:
|
||||
mlab.stop()
|
||||
|
||||
np.savez('dwt_matlabR2012a_result.npz', **all_matlab_results)
|
||||
86
.CondaPkg/env/Lib/site-packages/pywt/tests/data/generate_matlab_data_cwt.py
vendored
Normal file
86
.CondaPkg/env/Lib/site-packages/pywt/tests/data/generate_matlab_data_cwt.py
vendored
Normal file
@@ -0,0 +1,86 @@
|
||||
""" This script was used to generate dwt_matlabR2012a_result.npz by storing
|
||||
the outputs from Matlab R2012a. """
|
||||
|
||||
from __future__ import division, print_function, absolute_import
|
||||
|
||||
import numpy as np
|
||||
import pywt
|
||||
|
||||
try:
|
||||
from pymatbridge import Matlab
|
||||
mlab = Matlab()
|
||||
_matlab_missing = False
|
||||
except ImportError:
|
||||
print("To run Matlab compatibility tests you need to have MathWorks "
|
||||
"MATLAB, MathWorks Wavelet Toolbox and the pymatbridge Python "
|
||||
"package installed.")
|
||||
_matlab_missing = True
|
||||
|
||||
if _matlab_missing:
|
||||
raise EnvironmentError("Can't generate matlab data files without MATLAB")
|
||||
|
||||
size_set = 'reduced'
|
||||
|
||||
# list of mode names in pywt and matlab
|
||||
modes = [('zero', 'zpd'),
|
||||
('constant', 'sp0'),
|
||||
('symmetric', 'sym'),
|
||||
('periodic', 'ppd'),
|
||||
('smooth', 'sp1'),
|
||||
('periodization', 'per')]
|
||||
|
||||
families = ('gaus', 'mexh', 'morl', 'cgau', 'shan', 'fbsp', 'cmor')
|
||||
wavelets = sum([pywt.wavelist(name) for name in families], [])
|
||||
|
||||
rstate = np.random.RandomState(1234)
|
||||
mlab.start()
|
||||
try:
|
||||
all_matlab_results = {}
|
||||
for wavelet in wavelets:
|
||||
w = pywt.ContinuousWavelet(wavelet)
|
||||
if np.any((wavelet == np.array(['shan', 'cmor'])),axis=0):
|
||||
mlab.set_variable('wavelet', wavelet+str(w.bandwidth_frequency)+'-'+str(w.center_frequency))
|
||||
elif wavelet == 'fbsp':
|
||||
mlab.set_variable('wavelet', wavelet+str(w.fbsp_order)+'-'+str(w.bandwidth_frequency)+'-'+str(w.center_frequency))
|
||||
else:
|
||||
mlab.set_variable('wavelet', wavelet)
|
||||
if size_set == 'full':
|
||||
data_sizes = list(range(100, 101)) + \
|
||||
[100, 200, 500, 1000, 50000]
|
||||
Scales = (1,np.arange(1,3),np.arange(1,4),np.arange(1,5))
|
||||
else:
|
||||
data_sizes = (1000, 1000 + 1)
|
||||
Scales = (1,np.arange(1,3))
|
||||
mlab_code = ("psi = wavefun(wavelet,10)")
|
||||
res = mlab.run_code(mlab_code)
|
||||
if not res['success']:
|
||||
raise RuntimeError(
|
||||
"Matlab failed to execute the provided code. "
|
||||
"Check that the wavelet toolbox is installed.")
|
||||
psi = np.asarray(mlab.get_variable('psi'))
|
||||
psi_key = '_'.join([wavelet, 'psi'])
|
||||
all_matlab_results[psi_key] = psi
|
||||
for N in data_sizes:
|
||||
data = rstate.randn(N)
|
||||
mlab.set_variable('data', data)
|
||||
|
||||
# Matlab result
|
||||
scale_count = 0
|
||||
for scales in Scales:
|
||||
scale_count += 1
|
||||
mlab.set_variable('scales', scales)
|
||||
mlab_code = ("coefs = cwt(data, scales, wavelet)")
|
||||
res = mlab.run_code(mlab_code)
|
||||
if not res['success']:
|
||||
raise RuntimeError(
|
||||
"Matlab failed to execute the provided code. "
|
||||
"Check that the wavelet toolbox is installed.")
|
||||
# need np.asarray because sometimes the output is type float
|
||||
coefs = np.asarray(mlab.get_variable('coefs'))
|
||||
coefs_key = '_'.join([str(scale_count), wavelet, str(N), 'coefs'])
|
||||
all_matlab_results[coefs_key] = coefs
|
||||
|
||||
finally:
|
||||
mlab.stop()
|
||||
|
||||
np.savez('cwt_matlabR2015b_result.npz', **all_matlab_results)
|
||||
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/data/wavelab_test_signals.npz
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/data/wavelab_test_signals.npz
vendored
Normal file
Binary file not shown.
170
.CondaPkg/env/Lib/site-packages/pywt/tests/test__pywt.py
vendored
Normal file
170
.CondaPkg/env/Lib/site-packages/pywt/tests/test__pywt.py
vendored
Normal file
@@ -0,0 +1,170 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
from __future__ import division, print_function, absolute_import
|
||||
|
||||
import numpy as np
|
||||
from numpy.testing import assert_allclose, assert_, assert_raises
|
||||
|
||||
import pywt
|
||||
|
||||
|
||||
def test_upcoef_reconstruct():
|
||||
data = np.arange(3)
|
||||
a = pywt.downcoef('a', data, 'haar')
|
||||
d = pywt.downcoef('d', data, 'haar')
|
||||
|
||||
rec = (pywt.upcoef('a', a, 'haar', take=3) +
|
||||
pywt.upcoef('d', d, 'haar', take=3))
|
||||
assert_allclose(rec, data)
|
||||
|
||||
|
||||
def test_downcoef_multilevel():
|
||||
rstate = np.random.RandomState(1234)
|
||||
r = rstate.randn(16)
|
||||
nlevels = 3
|
||||
# calling with level=1 nlevels times
|
||||
a1 = r.copy()
|
||||
for i in range(nlevels):
|
||||
a1 = pywt.downcoef('a', a1, 'haar', level=1)
|
||||
# call with level=nlevels once
|
||||
a3 = pywt.downcoef('a', r, 'haar', level=nlevels)
|
||||
assert_allclose(a1, a3)
|
||||
|
||||
|
||||
def test_downcoef_complex():
|
||||
rstate = np.random.RandomState(1234)
|
||||
r = rstate.randn(16) + 1j * rstate.randn(16)
|
||||
nlevels = 3
|
||||
a = pywt.downcoef('a', r, 'haar', level=nlevels)
|
||||
a_ref = pywt.downcoef('a', r.real, 'haar', level=nlevels)
|
||||
a_ref = a_ref + 1j * pywt.downcoef('a', r.imag, 'haar', level=nlevels)
|
||||
assert_allclose(a, a_ref)
|
||||
|
||||
|
||||
def test_downcoef_errs():
|
||||
# invalid part string (not 'a' or 'd')
|
||||
assert_raises(ValueError, pywt.downcoef, 'f', np.ones(16), 'haar')
|
||||
|
||||
|
||||
def test_compare_downcoef_coeffs():
|
||||
rstate = np.random.RandomState(1234)
|
||||
r = rstate.randn(16)
|
||||
# compare downcoef against wavedec outputs
|
||||
for nlevels in [1, 2, 3]:
|
||||
for wavelet in pywt.wavelist():
|
||||
if wavelet in ['cmor', 'shan', 'fbsp']:
|
||||
# skip these CWT families to avoid warnings
|
||||
continue
|
||||
wavelet = pywt.DiscreteContinuousWavelet(wavelet)
|
||||
if isinstance(wavelet, pywt.Wavelet):
|
||||
max_level = pywt.dwt_max_level(r.size, wavelet.dec_len)
|
||||
if nlevels <= max_level:
|
||||
a = pywt.downcoef('a', r, wavelet, level=nlevels)
|
||||
d = pywt.downcoef('d', r, wavelet, level=nlevels)
|
||||
coeffs = pywt.wavedec(r, wavelet, level=nlevels)
|
||||
assert_allclose(a, coeffs[0])
|
||||
assert_allclose(d, coeffs[1])
|
||||
|
||||
|
||||
def test_upcoef_multilevel():
|
||||
rstate = np.random.RandomState(1234)
|
||||
r = rstate.randn(4)
|
||||
nlevels = 3
|
||||
# calling with level=1 nlevels times
|
||||
a1 = r.copy()
|
||||
for i in range(nlevels):
|
||||
a1 = pywt.upcoef('a', a1, 'haar', level=1)
|
||||
# call with level=nlevels once
|
||||
a3 = pywt.upcoef('a', r, 'haar', level=nlevels)
|
||||
assert_allclose(a1, a3)
|
||||
|
||||
|
||||
def test_upcoef_complex():
|
||||
rstate = np.random.RandomState(1234)
|
||||
r = rstate.randn(4) + 1j*rstate.randn(4)
|
||||
nlevels = 3
|
||||
a = pywt.upcoef('a', r, 'haar', level=nlevels)
|
||||
a_ref = pywt.upcoef('a', r.real, 'haar', level=nlevels)
|
||||
a_ref = a_ref + 1j*pywt.upcoef('a', r.imag, 'haar', level=nlevels)
|
||||
assert_allclose(a, a_ref)
|
||||
|
||||
|
||||
def test_upcoef_errs():
|
||||
# invalid part string (not 'a' or 'd')
|
||||
assert_raises(ValueError, pywt.upcoef, 'f', np.ones(4), 'haar')
|
||||
|
||||
|
||||
def test_upcoef_and_downcoef_1d_only():
|
||||
# upcoef and downcoef raise a ValueError if data.ndim > 1d
|
||||
for ndim in [2, 3]:
|
||||
data = np.ones((8, )*ndim)
|
||||
assert_raises(ValueError, pywt.downcoef, 'a', data, 'haar')
|
||||
assert_raises(ValueError, pywt.upcoef, 'a', data, 'haar')
|
||||
|
||||
|
||||
def test_wavelet_repr():
|
||||
from pywt._extensions import _pywt
|
||||
wavelet = _pywt.Wavelet('sym8')
|
||||
|
||||
repr_wavelet = eval(wavelet.__repr__())
|
||||
|
||||
assert_(wavelet.__repr__() == repr_wavelet.__repr__())
|
||||
|
||||
|
||||
def test_dwt_max_level():
|
||||
assert_(pywt.dwt_max_level(16, 2) == 4)
|
||||
assert_(pywt.dwt_max_level(16, 8) == 1)
|
||||
assert_(pywt.dwt_max_level(16, 9) == 1)
|
||||
assert_(pywt.dwt_max_level(16, 10) == 0)
|
||||
assert_(pywt.dwt_max_level(16, np.int8(10)) == 0)
|
||||
assert_(pywt.dwt_max_level(16, 10.) == 0)
|
||||
assert_(pywt.dwt_max_level(16, 18) == 0)
|
||||
|
||||
# accepts discrete Wavelet object or string as well
|
||||
assert_(pywt.dwt_max_level(32, pywt.Wavelet('sym5')) == 1)
|
||||
assert_(pywt.dwt_max_level(32, 'sym5') == 1)
|
||||
|
||||
# string input that is not a discrete wavelet
|
||||
assert_raises(ValueError, pywt.dwt_max_level, 16, 'mexh')
|
||||
|
||||
# filter_len must be an integer >= 2
|
||||
assert_raises(ValueError, pywt.dwt_max_level, 16, 1)
|
||||
assert_raises(ValueError, pywt.dwt_max_level, 16, -1)
|
||||
assert_raises(ValueError, pywt.dwt_max_level, 16, 3.3)
|
||||
|
||||
|
||||
def test_ContinuousWavelet_errs():
|
||||
assert_raises(ValueError, pywt.ContinuousWavelet, 'qwertz')
|
||||
|
||||
|
||||
def test_ContinuousWavelet_repr():
|
||||
from pywt._extensions import _pywt
|
||||
wavelet = _pywt.ContinuousWavelet('gaus2')
|
||||
|
||||
repr_wavelet = eval(wavelet.__repr__())
|
||||
|
||||
assert_(wavelet.__repr__() == repr_wavelet.__repr__())
|
||||
|
||||
|
||||
def test_wavelist():
|
||||
for name in pywt.wavelist(family='coif'):
|
||||
assert_(name.startswith('coif'))
|
||||
|
||||
assert_('cgau7' in pywt.wavelist(kind='continuous'))
|
||||
assert_('sym20' in pywt.wavelist(kind='discrete'))
|
||||
assert_(len(pywt.wavelist(kind='continuous')) +
|
||||
len(pywt.wavelist(kind='discrete')) ==
|
||||
len(pywt.wavelist(kind='all')))
|
||||
|
||||
assert_raises(ValueError, pywt.wavelist, kind='foobar')
|
||||
|
||||
|
||||
def test_wavelet_errormsgs():
|
||||
try:
|
||||
pywt.Wavelet('gaus1')
|
||||
except ValueError as e:
|
||||
assert_(e.args[0].startswith('The `Wavelet` class'))
|
||||
try:
|
||||
pywt.Wavelet('cmord')
|
||||
except ValueError as e:
|
||||
assert_(e.args[0] == "Invalid wavelet name 'cmord'.")
|
||||
105
.CondaPkg/env/Lib/site-packages/pywt/tests/test_concurrent.py
vendored
Normal file
105
.CondaPkg/env/Lib/site-packages/pywt/tests/test_concurrent.py
vendored
Normal file
@@ -0,0 +1,105 @@
|
||||
"""
|
||||
Tests used to verify running PyWavelets transforms in parallel via
|
||||
concurrent.futures.ThreadPoolExecutor does not raise errors.
|
||||
"""
|
||||
|
||||
from __future__ import division, print_function, absolute_import
|
||||
|
||||
import warnings
|
||||
import numpy as np
|
||||
from functools import partial
|
||||
from numpy.testing import assert_array_equal, assert_allclose
|
||||
from pywt._pytest import uses_futures, futures, max_workers
|
||||
|
||||
import pywt
|
||||
|
||||
|
||||
def _assert_all_coeffs_equal(coefs1, coefs2):
|
||||
# return True only if all coefficients of SWT or DWT match over all levels
|
||||
if len(coefs1) != len(coefs2):
|
||||
return False
|
||||
for (c1, c2) in zip(coefs1, coefs2):
|
||||
if isinstance(c1, tuple):
|
||||
# for swt, swt2, dwt, dwt2, wavedec, wavedec2
|
||||
for a1, a2 in zip(c1, c2):
|
||||
assert_array_equal(a1, a2)
|
||||
elif isinstance(c1, dict):
|
||||
# for swtn, dwtn, wavedecn
|
||||
for k, v in c1.items():
|
||||
assert_array_equal(v, c2[k])
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@uses_futures
|
||||
def test_concurrent_swt():
|
||||
# tests error-free concurrent operation (see gh-288)
|
||||
# swt on 1D data calls the Cython swt
|
||||
# other cases call swt_axes
|
||||
with warnings.catch_warnings():
|
||||
# can remove catch_warnings once the swt2 FutureWarning is removed
|
||||
warnings.simplefilter('ignore', FutureWarning)
|
||||
for swt_func, x in zip([pywt.swt, pywt.swt2, pywt.swtn],
|
||||
[np.ones(8), np.eye(16), np.eye(16)]):
|
||||
transform = partial(swt_func, wavelet='haar', level=3)
|
||||
for _ in range(10):
|
||||
arrs = [x.copy() for _ in range(100)]
|
||||
with futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
|
||||
results = list(ex.map(transform, arrs))
|
||||
|
||||
# validate result from one of the concurrent runs
|
||||
expected_result = transform(x)
|
||||
_assert_all_coeffs_equal(expected_result, results[-1])
|
||||
|
||||
|
||||
@uses_futures
|
||||
def test_concurrent_wavedec():
|
||||
# wavedec on 1D data calls the Cython dwt_single
|
||||
# other cases call dwt_axis
|
||||
for wavedec_func, x in zip([pywt.wavedec, pywt.wavedec2, pywt.wavedecn],
|
||||
[np.ones(8), np.eye(16), np.eye(16)]):
|
||||
transform = partial(wavedec_func, wavelet='haar', level=1)
|
||||
for _ in range(10):
|
||||
arrs = [x.copy() for _ in range(100)]
|
||||
with futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
|
||||
results = list(ex.map(transform, arrs))
|
||||
|
||||
# validate result from one of the concurrent runs
|
||||
expected_result = transform(x)
|
||||
_assert_all_coeffs_equal(expected_result, results[-1])
|
||||
|
||||
|
||||
@uses_futures
|
||||
def test_concurrent_dwt():
|
||||
# dwt on 1D data calls the Cython dwt_single
|
||||
# other cases call dwt_axis
|
||||
for dwt_func, x in zip([pywt.dwt, pywt.dwt2, pywt.dwtn],
|
||||
[np.ones(8), np.eye(16), np.eye(16)]):
|
||||
transform = partial(dwt_func, wavelet='haar')
|
||||
for _ in range(10):
|
||||
arrs = [x.copy() for _ in range(100)]
|
||||
with futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
|
||||
results = list(ex.map(transform, arrs))
|
||||
|
||||
# validate result from one of the concurrent runs
|
||||
expected_result = transform(x)
|
||||
_assert_all_coeffs_equal([expected_result, ], [results[-1], ])
|
||||
|
||||
|
||||
@uses_futures
|
||||
def test_concurrent_cwt():
|
||||
atol = rtol = 1e-14
|
||||
time, sst = pywt.data.nino()
|
||||
dt = time[1]-time[0]
|
||||
transform = partial(pywt.cwt, scales=np.arange(1, 4), wavelet='cmor1.5-1',
|
||||
sampling_period=dt)
|
||||
for _ in range(10):
|
||||
arrs = [sst.copy() for _ in range(50)]
|
||||
with futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
|
||||
results = list(ex.map(transform, arrs))
|
||||
|
||||
# validate result from one of the concurrent runs
|
||||
expected_result = transform(sst)
|
||||
for a1, a2 in zip(expected_result, results[-1]):
|
||||
assert_allclose(a1, a2, atol=atol, rtol=rtol)
|
||||
462
.CondaPkg/env/Lib/site-packages/pywt/tests/test_cwt_wavelets.py
vendored
Normal file
462
.CondaPkg/env/Lib/site-packages/pywt/tests/test_cwt_wavelets.py
vendored
Normal file
@@ -0,0 +1,462 @@
|
||||
#!/usr/bin/env python
|
||||
import os
|
||||
from itertools import product
|
||||
import pickle
|
||||
|
||||
from numpy.testing import (assert_allclose, assert_warns, assert_almost_equal,
|
||||
assert_raises, assert_equal)
|
||||
import pytest
|
||||
import numpy as np
|
||||
import pywt
|
||||
|
||||
|
||||
def ref_gaus(LB, UB, N, num):
|
||||
X = np.linspace(LB, UB, N)
|
||||
F0 = (2./np.pi)**(1./4.)*np.exp(-(X**2))
|
||||
if (num == 1):
|
||||
psi = -2.*X*F0
|
||||
elif (num == 2):
|
||||
psi = -2/(3**(1/2))*(-1 + 2*X**2)*F0
|
||||
elif (num == 3):
|
||||
psi = -4/(15**(1/2))*X*(3 - 2*X**2)*F0
|
||||
elif (num == 4):
|
||||
psi = 4/(105**(1/2))*(3 - 12*X**2 + 4*X**4)*F0
|
||||
elif (num == 5):
|
||||
psi = 8/(3*(105**(1/2)))*X*(-15 + 20*X**2 - 4*X**4)*F0
|
||||
elif (num == 6):
|
||||
psi = -8/(3*(1155**(1/2)))*(-15 + 90*X**2 - 60*X**4 + 8*X**6)*F0
|
||||
elif (num == 7):
|
||||
psi = -16/(3*(15015**(1/2)))*X*(105 - 210*X**2 + 84*X**4 - 8*X**6)*F0
|
||||
elif (num == 8):
|
||||
psi = 16/(45*(1001**(1/2)))*(105 - 840*X**2 + 840*X**4 -
|
||||
224*X**6 + 16*X**8)*F0
|
||||
return (psi, X)
|
||||
|
||||
|
||||
def ref_cgau(LB, UB, N, num):
|
||||
X = np.linspace(LB, UB, N)
|
||||
F0 = np.exp(-X**2)
|
||||
F1 = np.exp(-1j*X)
|
||||
F2 = (F1*F0)/(np.exp(-1/2)*2**(1/2)*np.pi**(1/2))**(1/2)
|
||||
if (num == 1):
|
||||
psi = F2*(-1j - 2*X)*2**(1/2)
|
||||
elif (num == 2):
|
||||
psi = 1/3*F2*(-3 + 4j*X + 4*X**2)*6**(1/2)
|
||||
elif (num == 3):
|
||||
psi = 1/15*F2*(7j + 18*X - 12j*X**2 - 8*X**3)*30**(1/2)
|
||||
elif (num == 4):
|
||||
psi = 1/105*F2*(25 - 56j*X - 72*X**2 + 32j*X**3 + 16*X**4)*210**(1/2)
|
||||
elif (num == 5):
|
||||
psi = 1/315*F2*(-81j - 250*X + 280j*X**2 + 240*X**3 -
|
||||
80j*X**4 - 32*X**5)*210**(1/2)
|
||||
elif (num == 6):
|
||||
psi = 1/3465*F2*(-331 + 972j*X + 1500*X**2 - 1120j*X**3 - 720*X**4 +
|
||||
192j*X**5 + 64*X**6)*2310**(1/2)
|
||||
elif (num == 7):
|
||||
psi = 1/45045*F2*(
|
||||
1303j + 4634*X - 6804j*X**2 - 7000*X**3 + 3920j*X**4 + 2016*X**5 -
|
||||
448j*X**6 - 128*X**7)*30030**(1/2)
|
||||
elif (num == 8):
|
||||
psi = 1/45045*F2*(
|
||||
5937 - 20848j*X - 37072*X**2 + 36288j*X**3 + 28000*X**4 -
|
||||
12544j*X**5 - 5376*X**6 + 1024j*X**7 + 256*X**8)*2002**(1/2)
|
||||
|
||||
psi = psi/np.real(np.sqrt(np.real(np.sum(psi*np.conj(psi)))*(X[1] - X[0])))
|
||||
return (psi, X)
|
||||
|
||||
|
||||
def sinc2(x):
|
||||
y = np.ones_like(x)
|
||||
k = np.where(x)[0]
|
||||
y[k] = np.sin(np.pi*x[k])/(np.pi*x[k])
|
||||
return y
|
||||
|
||||
|
||||
def ref_shan(LB, UB, N, Fb, Fc):
|
||||
x = np.linspace(LB, UB, N)
|
||||
psi = np.sqrt(Fb)*(sinc2(Fb*x)*np.exp(2j*np.pi*Fc*x))
|
||||
return (psi, x)
|
||||
|
||||
|
||||
def ref_fbsp(LB, UB, N, m, Fb, Fc):
|
||||
x = np.linspace(LB, UB, N)
|
||||
psi = np.sqrt(Fb)*((sinc2(Fb*x/m)**m)*np.exp(2j*np.pi*Fc*x))
|
||||
return (psi, x)
|
||||
|
||||
|
||||
def ref_cmor(LB, UB, N, Fb, Fc):
|
||||
x = np.linspace(LB, UB, N)
|
||||
psi = ((np.pi*Fb)**(-0.5))*np.exp(2j*np.pi*Fc*x)*np.exp(-(x**2)/Fb)
|
||||
return (psi, x)
|
||||
|
||||
|
||||
def ref_morl(LB, UB, N):
|
||||
x = np.linspace(LB, UB, N)
|
||||
psi = np.exp(-(x**2)/2)*np.cos(5*x)
|
||||
return (psi, x)
|
||||
|
||||
|
||||
def ref_mexh(LB, UB, N):
|
||||
x = np.linspace(LB, UB, N)
|
||||
psi = (2/(np.sqrt(3)*np.pi**0.25))*np.exp(-(x**2)/2)*(1 - (x**2))
|
||||
return (psi, x)
|
||||
|
||||
|
||||
def test_gaus():
|
||||
LB = -5
|
||||
UB = 5
|
||||
N = 1000
|
||||
for num in np.arange(1, 9):
|
||||
[psi, x] = ref_gaus(LB, UB, N, num)
|
||||
w = pywt.ContinuousWavelet("gaus" + str(num))
|
||||
PSI, X = w.wavefun(length=N)
|
||||
|
||||
assert_allclose(np.real(PSI), np.real(psi))
|
||||
assert_allclose(np.imag(PSI), np.imag(psi))
|
||||
assert_allclose(X, x)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
|
||||
def test_continuous_wavelet_dtype(dtype):
|
||||
wavelet = pywt.ContinuousWavelet('cmor1.5-1.0', dtype)
|
||||
int_psi, x = pywt.integrate_wavelet(wavelet)
|
||||
assert int_psi.real.dtype == dtype
|
||||
assert x.dtype == dtype
|
||||
|
||||
|
||||
def test_continuous_wavelet_invalid_dtype():
|
||||
with pytest.raises(ValueError):
|
||||
pywt.ContinuousWavelet('gaus5', np.complex64)
|
||||
with pytest.raises(ValueError):
|
||||
pywt.ContinuousWavelet('gaus5', np.int_)
|
||||
|
||||
|
||||
def test_cgau():
|
||||
LB = -5
|
||||
UB = 5
|
||||
N = 1000
|
||||
for num in np.arange(1, 9):
|
||||
[psi, x] = ref_cgau(LB, UB, N, num)
|
||||
w = pywt.ContinuousWavelet("cgau" + str(num))
|
||||
PSI, X = w.wavefun(length=N)
|
||||
|
||||
assert_allclose(np.real(PSI), np.real(psi))
|
||||
assert_allclose(np.imag(PSI), np.imag(psi))
|
||||
assert_allclose(X, x)
|
||||
|
||||
|
||||
def test_shan():
|
||||
LB = -20
|
||||
UB = 20
|
||||
N = 1000
|
||||
Fb = 1
|
||||
Fc = 1.5
|
||||
|
||||
[psi, x] = ref_shan(LB, UB, N, Fb, Fc)
|
||||
w = pywt.ContinuousWavelet("shan{}-{}".format(Fb, Fc))
|
||||
assert_almost_equal(w.center_frequency, Fc)
|
||||
assert_almost_equal(w.bandwidth_frequency, Fb)
|
||||
w.upper_bound = UB
|
||||
w.lower_bound = LB
|
||||
PSI, X = w.wavefun(length=N)
|
||||
|
||||
assert_allclose(np.real(PSI), np.real(psi), atol=1e-15)
|
||||
assert_allclose(np.imag(PSI), np.imag(psi), atol=1e-15)
|
||||
assert_allclose(X, x, atol=1e-15)
|
||||
|
||||
LB = -20
|
||||
UB = 20
|
||||
N = 1000
|
||||
Fb = 1.5
|
||||
Fc = 1
|
||||
|
||||
[psi, x] = ref_shan(LB, UB, N, Fb, Fc)
|
||||
w = pywt.ContinuousWavelet("shan{}-{}".format(Fb, Fc))
|
||||
assert_almost_equal(w.center_frequency, Fc)
|
||||
assert_almost_equal(w.bandwidth_frequency, Fb)
|
||||
w.upper_bound = UB
|
||||
w.lower_bound = LB
|
||||
PSI, X = w.wavefun(length=N)
|
||||
|
||||
assert_allclose(np.real(PSI), np.real(psi), atol=1e-15)
|
||||
assert_allclose(np.imag(PSI), np.imag(psi), atol=1e-15)
|
||||
assert_allclose(X, x, atol=1e-15)
|
||||
|
||||
|
||||
def test_cmor():
|
||||
LB = -20
|
||||
UB = 20
|
||||
N = 1000
|
||||
Fb = 1
|
||||
Fc = 1.5
|
||||
|
||||
[psi, x] = ref_cmor(LB, UB, N, Fb, Fc)
|
||||
w = pywt.ContinuousWavelet("cmor{}-{}".format(Fb, Fc))
|
||||
assert_almost_equal(w.center_frequency, Fc)
|
||||
assert_almost_equal(w.bandwidth_frequency, Fb)
|
||||
w.upper_bound = UB
|
||||
w.lower_bound = LB
|
||||
PSI, X = w.wavefun(length=N)
|
||||
|
||||
assert_allclose(np.real(PSI), np.real(psi), atol=1e-15)
|
||||
assert_allclose(np.imag(PSI), np.imag(psi), atol=1e-15)
|
||||
assert_allclose(X, x, atol=1e-15)
|
||||
|
||||
LB = -20
|
||||
UB = 20
|
||||
N = 1000
|
||||
Fb = 1.5
|
||||
Fc = 1
|
||||
|
||||
[psi, x] = ref_cmor(LB, UB, N, Fb, Fc)
|
||||
w = pywt.ContinuousWavelet("cmor{}-{}".format(Fb, Fc))
|
||||
assert_almost_equal(w.center_frequency, Fc)
|
||||
assert_almost_equal(w.bandwidth_frequency, Fb)
|
||||
w.upper_bound = UB
|
||||
w.lower_bound = LB
|
||||
PSI, X = w.wavefun(length=N)
|
||||
|
||||
assert_allclose(np.real(PSI), np.real(psi), atol=1e-15)
|
||||
assert_allclose(np.imag(PSI), np.imag(psi), atol=1e-15)
|
||||
assert_allclose(X, x, atol=1e-15)
|
||||
|
||||
|
||||
def test_fbsp():
|
||||
LB = -20
|
||||
UB = 20
|
||||
N = 1000
|
||||
M = 2
|
||||
Fb = 1
|
||||
Fc = 1.5
|
||||
|
||||
[psi, x] = ref_fbsp(LB, UB, N, M, Fb, Fc)
|
||||
|
||||
w = pywt.ContinuousWavelet("fbsp{}-{}-{}".format(M, Fb, Fc))
|
||||
assert_almost_equal(w.center_frequency, Fc)
|
||||
assert_almost_equal(w.bandwidth_frequency, Fb)
|
||||
w.fbsp_order = M
|
||||
w.upper_bound = UB
|
||||
w.lower_bound = LB
|
||||
PSI, X = w.wavefun(length=N)
|
||||
|
||||
assert_allclose(np.real(PSI), np.real(psi), atol=1e-15)
|
||||
assert_allclose(np.imag(PSI), np.imag(psi), atol=1e-15)
|
||||
assert_allclose(X, x, atol=1e-15)
|
||||
|
||||
LB = -20
|
||||
UB = 20
|
||||
N = 1000
|
||||
M = 2
|
||||
Fb = 1.5
|
||||
Fc = 1
|
||||
|
||||
[psi, x] = ref_fbsp(LB, UB, N, M, Fb, Fc)
|
||||
w = pywt.ContinuousWavelet("fbsp{}-{}-{}".format(M, Fb, Fc))
|
||||
assert_almost_equal(w.center_frequency, Fc)
|
||||
assert_almost_equal(w.bandwidth_frequency, Fb)
|
||||
w.fbsp_order = M
|
||||
w.upper_bound = UB
|
||||
w.lower_bound = LB
|
||||
PSI, X = w.wavefun(length=N)
|
||||
|
||||
assert_allclose(np.real(PSI), np.real(psi), atol=1e-15)
|
||||
assert_allclose(np.imag(PSI), np.imag(psi), atol=1e-15)
|
||||
assert_allclose(X, x, atol=1e-15)
|
||||
|
||||
LB = -20
|
||||
UB = 20
|
||||
N = 1000
|
||||
M = 3
|
||||
Fb = 1.5
|
||||
Fc = 1.2
|
||||
|
||||
[psi, x] = ref_fbsp(LB, UB, N, M, Fb, Fc)
|
||||
w = pywt.ContinuousWavelet("fbsp{}-{}-{}".format(M, Fb, Fc))
|
||||
assert_almost_equal(w.center_frequency, Fc)
|
||||
assert_almost_equal(w.bandwidth_frequency, Fb)
|
||||
w.fbsp_order = M
|
||||
w.upper_bound = UB
|
||||
w.lower_bound = LB
|
||||
PSI, X = w.wavefun(length=N)
|
||||
# TODO: investigate why atol = 1e-5 is necessary
|
||||
assert_allclose(np.real(PSI), np.real(psi), atol=1e-5)
|
||||
assert_allclose(np.imag(PSI), np.imag(psi), atol=1e-5)
|
||||
assert_allclose(X, x, atol=1e-15)
|
||||
|
||||
|
||||
def test_morl():
|
||||
LB = -5
|
||||
UB = 5
|
||||
N = 1000
|
||||
|
||||
[psi, x] = ref_morl(LB, UB, N)
|
||||
w = pywt.ContinuousWavelet("morl")
|
||||
w.upper_bound = UB
|
||||
w.lower_bound = LB
|
||||
PSI, X = w.wavefun(length=N)
|
||||
|
||||
assert_allclose(np.real(PSI), np.real(psi))
|
||||
assert_allclose(np.imag(PSI), np.imag(psi))
|
||||
assert_allclose(X, x)
|
||||
|
||||
|
||||
def test_mexh():
|
||||
LB = -5
|
||||
UB = 5
|
||||
N = 1000
|
||||
|
||||
[psi, x] = ref_mexh(LB, UB, N)
|
||||
w = pywt.ContinuousWavelet("mexh")
|
||||
w.upper_bound = UB
|
||||
w.lower_bound = LB
|
||||
PSI, X = w.wavefun(length=N)
|
||||
|
||||
assert_allclose(np.real(PSI), np.real(psi))
|
||||
assert_allclose(np.imag(PSI), np.imag(psi))
|
||||
assert_allclose(X, x)
|
||||
|
||||
LB = -5
|
||||
UB = 5
|
||||
N = 1001
|
||||
|
||||
[psi, x] = ref_mexh(LB, UB, N)
|
||||
w = pywt.ContinuousWavelet("mexh")
|
||||
w.upper_bound = UB
|
||||
w.lower_bound = LB
|
||||
PSI, X = w.wavefun(length=N)
|
||||
|
||||
assert_allclose(np.real(PSI), np.real(psi))
|
||||
assert_allclose(np.imag(PSI), np.imag(psi))
|
||||
assert_allclose(X, x)
|
||||
|
||||
|
||||
def test_cwt_parameters_in_names():
|
||||
|
||||
for func in [pywt.ContinuousWavelet, pywt.DiscreteContinuousWavelet]:
|
||||
for name in ['fbsp', 'cmor', 'shan']:
|
||||
# additional parameters should be specified within the name
|
||||
assert_warns(FutureWarning, func, name)
|
||||
|
||||
for name in ['cmor', 'shan']:
|
||||
# valid names
|
||||
func(name + '1.5-1.0')
|
||||
func(name + '1-4')
|
||||
|
||||
# invalid names
|
||||
assert_raises(ValueError, func, name + '1.0')
|
||||
assert_raises(ValueError, func, name + 'B-C')
|
||||
assert_raises(ValueError, func, name + '1.0-1.0-1.0')
|
||||
|
||||
# valid names
|
||||
func('fbsp1-1.5-1.0')
|
||||
func('fbsp1.0-1.5-1')
|
||||
func('fbsp2-5-1')
|
||||
|
||||
# invalid name (non-integer order)
|
||||
assert_raises(ValueError, func, 'fbsp1.5-1-1')
|
||||
assert_raises(ValueError, func, 'fbspM-B-C')
|
||||
|
||||
# invalid name (too few or too many params)
|
||||
assert_raises(ValueError, func, 'fbsp1.0')
|
||||
assert_raises(ValueError, func, 'fbsp1.0-0.4')
|
||||
assert_raises(ValueError, func, 'fbsp1-1-1-1')
|
||||
|
||||
|
||||
@pytest.mark.parametrize('dtype, tol, method',
|
||||
[(np.float32, 1e-5, 'conv'),
|
||||
(np.float32, 1e-5, 'fft'),
|
||||
(np.float64, 1e-13, 'conv'),
|
||||
(np.float64, 1e-13, 'fft')])
|
||||
def test_cwt_complex(dtype, tol, method):
|
||||
time, sst = pywt.data.nino()
|
||||
sst = np.asarray(sst, dtype=dtype)
|
||||
dt = time[1] - time[0]
|
||||
wavelet = 'cmor1.5-1.0'
|
||||
scales = np.arange(1, 32)
|
||||
|
||||
# real-valued tranfsorm as a reference
|
||||
[cfs, f] = pywt.cwt(sst, scales, wavelet, dt, method=method)
|
||||
|
||||
# verify same precision
|
||||
assert_equal(cfs.real.dtype, sst.dtype)
|
||||
|
||||
# complex-valued transform equals sum of the transforms of the real
|
||||
# and imaginary components
|
||||
sst_complex = sst + 1j*sst
|
||||
[cfs_complex, f] = pywt.cwt(sst_complex, scales, wavelet, dt,
|
||||
method=method)
|
||||
assert_allclose(cfs + 1j*cfs, cfs_complex, atol=tol, rtol=tol)
|
||||
# verify dtype is preserved
|
||||
assert_equal(cfs_complex.dtype, sst_complex.dtype)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('axis, method', product([0, 1], ['conv', 'fft']))
|
||||
def test_cwt_batch(axis, method):
|
||||
dtype = np.float64
|
||||
time, sst = pywt.data.nino()
|
||||
n_batch = 8
|
||||
batch_axis = 1 - axis
|
||||
sst1 = np.asarray(sst, dtype=dtype)
|
||||
sst = np.stack((sst1, ) * n_batch, axis=batch_axis)
|
||||
dt = time[1] - time[0]
|
||||
wavelet = 'cmor1.5-1.0'
|
||||
scales = np.arange(1, 32)
|
||||
|
||||
# non-batch transform as reference
|
||||
[cfs1, f] = pywt.cwt(sst1, scales, wavelet, dt, method=method, axis=axis)
|
||||
|
||||
shape_in = sst.shape
|
||||
[cfs, f] = pywt.cwt(sst, scales, wavelet, dt, method=method, axis=axis)
|
||||
|
||||
# shape of input is not modified
|
||||
assert_equal(shape_in, sst.shape)
|
||||
|
||||
# verify same precision
|
||||
assert_equal(cfs.real.dtype, sst.dtype)
|
||||
|
||||
# verify expected shape
|
||||
assert_equal(cfs.shape[0], len(scales))
|
||||
assert_equal(cfs.shape[1 + batch_axis], n_batch)
|
||||
assert_equal(cfs.shape[1 + axis], sst.shape[axis])
|
||||
|
||||
# batch result on stacked input is the same as stacked 1d result
|
||||
assert_almost_equal(cfs, np.stack((cfs1,) * n_batch, axis=batch_axis + 1),
|
||||
decimal=12)
|
||||
|
||||
|
||||
def test_cwt_small_scales():
|
||||
data = np.zeros(32)
|
||||
|
||||
# A scale of 0.1 was chosen specifically to give a filter of length 2 for
|
||||
# mexh. This corner case should not raise an error.
|
||||
cfs, f = pywt.cwt(data, scales=0.1, wavelet='mexh')
|
||||
assert_allclose(cfs, np.zeros_like(cfs))
|
||||
|
||||
# extremely short scale factors raise a ValueError
|
||||
assert_raises(ValueError, pywt.cwt, data, scales=0.01, wavelet='mexh')
|
||||
|
||||
|
||||
def test_cwt_method_fft():
|
||||
rstate = np.random.RandomState(1)
|
||||
data = rstate.randn(50)
|
||||
data[15] = 1.
|
||||
scales = np.arange(1, 64)
|
||||
wavelet = 'cmor1.5-1.0'
|
||||
|
||||
# build a reference cwt with the legacy np.conv() method
|
||||
cfs_conv, _ = pywt.cwt(data, scales, wavelet, method='conv')
|
||||
|
||||
# compare with the fft based convolution
|
||||
cfs_fft, _ = pywt.cwt(data, scales, wavelet, method='fft')
|
||||
assert_allclose(cfs_conv, cfs_fft, rtol=0, atol=1e-13)
|
||||
|
||||
|
||||
def test_continuous_wavelet_pickle(tmpdir):
|
||||
wavelet = pywt.ContinuousWavelet('cmor1.5-1.0')
|
||||
filename = os.path.join(tmpdir, 'cwav.pickle')
|
||||
with open(filename, 'wb') as f:
|
||||
pickle.dump(wavelet, f)
|
||||
with open(filename, 'rb') as f:
|
||||
wavelet2 = pickle.load(f)
|
||||
assert isinstance(wavelet2, pywt.ContinuousWavelet)
|
||||
assert wavelet2.name == wavelet.name
|
||||
77
.CondaPkg/env/Lib/site-packages/pywt/tests/test_data.py
vendored
Normal file
77
.CondaPkg/env/Lib/site-packages/pywt/tests/test_data.py
vendored
Normal file
@@ -0,0 +1,77 @@
|
||||
import os
|
||||
import numpy as np
|
||||
from numpy.testing import assert_allclose, assert_raises, assert_
|
||||
|
||||
import pywt.data
|
||||
|
||||
data_dir = os.path.join(os.path.dirname(__file__), 'data')
|
||||
wavelab_data_file = os.path.join(data_dir, 'wavelab_test_signals.npz')
|
||||
wavelab_result_dict = np.load(wavelab_data_file)
|
||||
|
||||
|
||||
def test_data_aero():
|
||||
aero = pywt.data.aero()
|
||||
|
||||
ref = np.array([[178, 178, 179],
|
||||
[170, 173, 171],
|
||||
[185, 174, 171]])
|
||||
|
||||
assert_allclose(aero[:3, :3], ref)
|
||||
|
||||
|
||||
def test_data_ascent():
|
||||
ascent = pywt.data.ascent()
|
||||
|
||||
ref = np.array([[83, 83, 83],
|
||||
[82, 82, 83],
|
||||
[80, 81, 83]])
|
||||
|
||||
assert_allclose(ascent[:3, :3], ref)
|
||||
|
||||
|
||||
def test_data_camera():
|
||||
camera = pywt.data.camera()
|
||||
|
||||
ref = np.array([[200, 200, 200],
|
||||
[200, 199, 199],
|
||||
[199, 199, 199]])
|
||||
|
||||
assert_allclose(camera[:3, :3], ref)
|
||||
|
||||
|
||||
def test_data_ecg():
|
||||
ecg = pywt.data.ecg()
|
||||
|
||||
ref = np.array([-86, -87, -87])
|
||||
|
||||
assert_allclose(ecg[:3], ref)
|
||||
|
||||
|
||||
def test_wavelab_signals():
|
||||
"""Comparison with results generated using WaveLab"""
|
||||
rtol = atol = 1e-12
|
||||
|
||||
# get a list of the available signals
|
||||
available_signals = pywt.data.demo_signal('list')
|
||||
assert_('Doppler' in available_signals)
|
||||
|
||||
for signal in available_signals:
|
||||
# reference dictionary has lowercase names for the keys
|
||||
key = signal.replace('-', '_').lower()
|
||||
val = wavelab_result_dict[key]
|
||||
if key in ['gabor', 'sineoneoverx']:
|
||||
# these functions do not allow a size to be provided
|
||||
assert_allclose(val, pywt.data.demo_signal(signal),
|
||||
rtol=rtol, atol=atol)
|
||||
assert_raises(ValueError, pywt.data.demo_signal, key, val.size)
|
||||
else:
|
||||
assert_allclose(val, pywt.data.demo_signal(signal, val.size),
|
||||
rtol=rtol, atol=atol)
|
||||
# these functions require a size to be provided
|
||||
assert_raises(ValueError, pywt.data.demo_signal, key)
|
||||
|
||||
# ValueError on unrecognized signal type
|
||||
assert_raises(ValueError, pywt.data.demo_signal, 'unknown_signal', 512)
|
||||
|
||||
# ValueError on invalid length
|
||||
assert_raises(ValueError, pywt.data.demo_signal, 'Doppler', 0)
|
||||
89
.CondaPkg/env/Lib/site-packages/pywt/tests/test_deprecations.py
vendored
Normal file
89
.CondaPkg/env/Lib/site-packages/pywt/tests/test_deprecations.py
vendored
Normal file
@@ -0,0 +1,89 @@
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
from numpy.testing import assert_warns, assert_array_equal
|
||||
|
||||
import pywt
|
||||
|
||||
|
||||
def test_intwave_deprecation():
|
||||
wavelet = pywt.Wavelet('db3')
|
||||
assert_warns(DeprecationWarning, pywt.intwave, wavelet)
|
||||
|
||||
|
||||
def test_centrfrq_deprecation():
|
||||
wavelet = pywt.Wavelet('db3')
|
||||
assert_warns(DeprecationWarning, pywt.centrfrq, wavelet)
|
||||
|
||||
|
||||
def test_scal2frq_deprecation():
|
||||
wavelet = pywt.Wavelet('db3')
|
||||
assert_warns(DeprecationWarning, pywt.scal2frq, wavelet, 1)
|
||||
|
||||
|
||||
def test_orthfilt_deprecation():
|
||||
assert_warns(DeprecationWarning, pywt.orthfilt, range(6))
|
||||
|
||||
|
||||
def test_integrate_wave_tuple():
|
||||
sig = [0, 1, 2, 3]
|
||||
xgrid = [0, 1, 2, 3]
|
||||
assert_warns(DeprecationWarning, pywt.integrate_wavelet, (sig, xgrid))
|
||||
|
||||
|
||||
old_modes = ['zpd',
|
||||
'cpd',
|
||||
'sym',
|
||||
'ppd',
|
||||
'sp1',
|
||||
'per',
|
||||
]
|
||||
|
||||
|
||||
def test_MODES_from_object_deprecation():
|
||||
for mode in old_modes:
|
||||
assert_warns(DeprecationWarning, pywt.Modes.from_object, mode)
|
||||
|
||||
|
||||
def test_MODES_attributes_deprecation():
|
||||
def get_mode(Modes, name):
|
||||
return getattr(Modes, name)
|
||||
|
||||
for mode in old_modes:
|
||||
assert_warns(DeprecationWarning, get_mode, pywt.Modes, mode)
|
||||
|
||||
|
||||
def test_MODES_deprecation_new():
|
||||
def use_MODES_new():
|
||||
return pywt.MODES.symmetric
|
||||
|
||||
assert_warns(DeprecationWarning, use_MODES_new)
|
||||
|
||||
|
||||
def test_MODES_deprecation_old():
|
||||
def use_MODES_old():
|
||||
return pywt.MODES.sym
|
||||
|
||||
assert_warns(DeprecationWarning, use_MODES_old)
|
||||
|
||||
|
||||
def test_MODES_deprecation_getattr():
|
||||
def use_MODES_new():
|
||||
return getattr(pywt.MODES, 'symmetric')
|
||||
|
||||
assert_warns(DeprecationWarning, use_MODES_new)
|
||||
|
||||
|
||||
def test_mode_equivalence():
|
||||
old_new = [('zpd', 'zero'),
|
||||
('cpd', 'constant'),
|
||||
('sym', 'symmetric'),
|
||||
('ppd', 'periodic'),
|
||||
('sp1', 'smooth'),
|
||||
('per', 'periodization')]
|
||||
x = np.arange(8.)
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore', DeprecationWarning)
|
||||
for old, new in old_new:
|
||||
assert_array_equal(pywt.dwt(x, 'db2', mode=old),
|
||||
pywt.dwt(x, 'db2', mode=new))
|
||||
25
.CondaPkg/env/Lib/site-packages/pywt/tests/test_doc.py
vendored
Normal file
25
.CondaPkg/env/Lib/site-packages/pywt/tests/test_doc.py
vendored
Normal file
@@ -0,0 +1,25 @@
|
||||
from __future__ import division, print_function, absolute_import
|
||||
|
||||
import doctest
|
||||
import glob
|
||||
import os
|
||||
import unittest
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
np.set_printoptions(legacy='1.13')
|
||||
except TypeError:
|
||||
pass
|
||||
|
||||
pdir = os.path.pardir
|
||||
docs_base = os.path.abspath(os.path.join(os.path.dirname(__file__),
|
||||
pdir, pdir, "doc", "source"))
|
||||
|
||||
files = glob.glob(os.path.join(docs_base, "*.rst")) + \
|
||||
glob.glob(os.path.join(docs_base, "*", "*.rst"))
|
||||
|
||||
suite = doctest.DocFileSuite(*files, module_relative=False, encoding="utf-8")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.TextTestRunner().run(suite)
|
||||
307
.CondaPkg/env/Lib/site-packages/pywt/tests/test_dwt_idwt.py
vendored
Normal file
307
.CondaPkg/env/Lib/site-packages/pywt/tests/test_dwt_idwt.py
vendored
Normal file
@@ -0,0 +1,307 @@
|
||||
#!/usr/bin/env python
|
||||
from __future__ import division, print_function, absolute_import
|
||||
|
||||
import numpy as np
|
||||
from numpy.testing import (assert_allclose, assert_, assert_raises,
|
||||
assert_array_equal)
|
||||
import pywt
|
||||
|
||||
# Check that float32, float64, complex64, complex128 are preserved.
|
||||
# Other real types get converted to float64.
|
||||
# complex256 gets converted to complex128
|
||||
dtypes_in = [np.int8, np.float16, np.float32, np.float64, np.complex64,
|
||||
np.complex128]
|
||||
dtypes_out = [np.float64, np.float32, np.float32, np.float64, np.complex64,
|
||||
np.complex128]
|
||||
|
||||
# test complex256 as well if it is available
|
||||
try:
|
||||
dtypes_in += [np.complex256, ]
|
||||
dtypes_out += [np.complex128, ]
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
|
||||
def test_dwt_idwt_basic():
|
||||
x = [3, 7, 1, 1, -2, 5, 4, 6]
|
||||
cA, cD = pywt.dwt(x, 'db2')
|
||||
cA_expect = [5.65685425, 7.39923721, 0.22414387, 3.33677403, 7.77817459]
|
||||
cD_expect = [-2.44948974, -1.60368225, -4.44140056, -0.41361256,
|
||||
1.22474487]
|
||||
assert_allclose(cA, cA_expect)
|
||||
assert_allclose(cD, cD_expect)
|
||||
|
||||
x_roundtrip = pywt.idwt(cA, cD, 'db2')
|
||||
assert_allclose(x_roundtrip, x, rtol=1e-10)
|
||||
|
||||
# mismatched dtypes OK
|
||||
x_roundtrip2 = pywt.idwt(cA.astype(np.float64), cD.astype(np.float32),
|
||||
'db2')
|
||||
assert_allclose(x_roundtrip2, x, rtol=1e-7, atol=1e-7)
|
||||
assert_(x_roundtrip2.dtype == np.float64)
|
||||
|
||||
|
||||
def test_idwt_mixed_complex_dtype():
|
||||
x = np.arange(8).astype(float)
|
||||
x = x + 1j*x[::-1]
|
||||
cA, cD = pywt.dwt(x, 'db2')
|
||||
|
||||
x_roundtrip = pywt.idwt(cA, cD, 'db2')
|
||||
assert_allclose(x_roundtrip, x, rtol=1e-10)
|
||||
|
||||
# mismatched dtypes OK
|
||||
x_roundtrip2 = pywt.idwt(cA.astype(np.complex128), cD.astype(np.complex64),
|
||||
'db2')
|
||||
assert_allclose(x_roundtrip2, x, rtol=1e-7, atol=1e-7)
|
||||
assert_(x_roundtrip2.dtype == np.complex128)
|
||||
|
||||
|
||||
def test_dwt_idwt_dtypes():
|
||||
wavelet = pywt.Wavelet('haar')
|
||||
for dt_in, dt_out in zip(dtypes_in, dtypes_out):
|
||||
x = np.ones(4, dtype=dt_in)
|
||||
errmsg = "wrong dtype returned for {0} input".format(dt_in)
|
||||
|
||||
cA, cD = pywt.dwt(x, wavelet)
|
||||
assert_(cA.dtype == cD.dtype == dt_out, "dwt: " + errmsg)
|
||||
|
||||
x_roundtrip = pywt.idwt(cA, cD, wavelet)
|
||||
assert_(x_roundtrip.dtype == dt_out, "idwt: " + errmsg)
|
||||
|
||||
|
||||
def test_dwt_idwt_basic_complex():
|
||||
x = np.asarray([3, 7, 1, 1, -2, 5, 4, 6])
|
||||
x = x + 0.5j*x
|
||||
cA, cD = pywt.dwt(x, 'db2')
|
||||
cA_expect = np.asarray([5.65685425, 7.39923721, 0.22414387, 3.33677403,
|
||||
7.77817459])
|
||||
cA_expect = cA_expect + 0.5j*cA_expect
|
||||
cD_expect = np.asarray([-2.44948974, -1.60368225, -4.44140056, -0.41361256,
|
||||
1.22474487])
|
||||
cD_expect = cD_expect + 0.5j*cD_expect
|
||||
assert_allclose(cA, cA_expect)
|
||||
assert_allclose(cD, cD_expect)
|
||||
|
||||
x_roundtrip = pywt.idwt(cA, cD, 'db2')
|
||||
assert_allclose(x_roundtrip, x, rtol=1e-10)
|
||||
|
||||
|
||||
def test_dwt_idwt_partial_complex():
|
||||
x = np.asarray([3, 7, 1, 1, -2, 5, 4, 6])
|
||||
x = x + 0.5j*x
|
||||
|
||||
cA, cD = pywt.dwt(x, 'haar')
|
||||
cA_rec_expect = np.array([5.0+2.5j, 5.0+2.5j, 1.0+0.5j, 1.0+0.5j,
|
||||
1.5+0.75j, 1.5+0.75j, 5.0+2.5j, 5.0+2.5j])
|
||||
cA_rec = pywt.idwt(cA, None, 'haar')
|
||||
assert_allclose(cA_rec, cA_rec_expect)
|
||||
|
||||
cD_rec_expect = np.array([-2.0-1.0j, 2.0+1.0j, 0.0+0.0j, 0.0+0.0j,
|
||||
-3.5-1.75j, 3.5+1.75j, -1.0-0.5j, 1.0+0.5j])
|
||||
cD_rec = pywt.idwt(None, cD, 'haar')
|
||||
assert_allclose(cD_rec, cD_rec_expect)
|
||||
|
||||
assert_allclose(cA_rec + cD_rec, x)
|
||||
|
||||
|
||||
def test_dwt_wavelet_kwd():
|
||||
x = np.array([3, 7, 1, 1, -2, 5, 4, 6])
|
||||
w = pywt.Wavelet('sym3')
|
||||
cA, cD = pywt.dwt(x, wavelet=w, mode='constant')
|
||||
cA_expect = [4.38354585, 3.80302657, 7.31813271, -0.58565539, 4.09727044,
|
||||
7.81994027]
|
||||
cD_expect = [-1.33068221, -2.78795192, -3.16825651, -0.67715519,
|
||||
-0.09722957, -0.07045258]
|
||||
assert_allclose(cA, cA_expect)
|
||||
assert_allclose(cD, cD_expect)
|
||||
|
||||
|
||||
def test_dwt_coeff_len():
|
||||
x = np.array([3, 7, 1, 1, -2, 5, 4, 6])
|
||||
w = pywt.Wavelet('sym3')
|
||||
ln_modes = [pywt.dwt_coeff_len(len(x), w.dec_len, mode) for mode in
|
||||
pywt.Modes.modes]
|
||||
|
||||
expected_result = [6, ] * len(pywt.Modes.modes)
|
||||
expected_result[pywt.Modes.modes.index('periodization')] = 4
|
||||
|
||||
assert_allclose(ln_modes, expected_result)
|
||||
ln_modes = [pywt.dwt_coeff_len(len(x), w, mode) for mode in
|
||||
pywt.Modes.modes]
|
||||
assert_allclose(ln_modes, expected_result)
|
||||
|
||||
|
||||
def test_idwt_none_input():
|
||||
# None input equals arrays of zeros of the right length
|
||||
res1 = pywt.idwt([1, 2, 0, 1], None, 'db2', 'symmetric')
|
||||
res2 = pywt.idwt([1, 2, 0, 1], [0, 0, 0, 0], 'db2', 'symmetric')
|
||||
assert_allclose(res1, res2, rtol=1e-15, atol=1e-15)
|
||||
|
||||
res1 = pywt.idwt(None, [1, 2, 0, 1], 'db2', 'symmetric')
|
||||
res2 = pywt.idwt([0, 0, 0, 0], [1, 2, 0, 1], 'db2', 'symmetric')
|
||||
assert_allclose(res1, res2, rtol=1e-15, atol=1e-15)
|
||||
|
||||
# Only one argument at a time can be None
|
||||
assert_raises(ValueError, pywt.idwt, None, None, 'db2', 'symmetric')
|
||||
|
||||
|
||||
def test_idwt_invalid_input():
|
||||
# Too short, min length is 4 for 'db4':
|
||||
assert_raises(ValueError, pywt.idwt, [1, 2, 4], [4, 1, 3], 'db4', 'symmetric')
|
||||
|
||||
|
||||
def test_dwt_single_axis():
|
||||
x = [[3, 7, 1, 1],
|
||||
[-2, 5, 4, 6]]
|
||||
|
||||
cA, cD = pywt.dwt(x, 'db2', axis=-1)
|
||||
|
||||
cA0, cD0 = pywt.dwt(x[0], 'db2')
|
||||
cA1, cD1 = pywt.dwt(x[1], 'db2')
|
||||
|
||||
assert_allclose(cA[0], cA0)
|
||||
assert_allclose(cA[1], cA1)
|
||||
|
||||
assert_allclose(cD[0], cD0)
|
||||
assert_allclose(cD[1], cD1)
|
||||
|
||||
|
||||
def test_idwt_single_axis():
|
||||
x = [[3, 7, 1, 1],
|
||||
[-2, 5, 4, 6]]
|
||||
|
||||
x = np.asarray(x)
|
||||
x = x + 1j*x # test with complex data
|
||||
cA, cD = pywt.dwt(x, 'db2', axis=-1)
|
||||
|
||||
x0 = pywt.idwt(cA[0], cD[0], 'db2', axis=-1)
|
||||
x1 = pywt.idwt(cA[1], cD[1], 'db2', axis=-1)
|
||||
|
||||
assert_allclose(x[0], x0)
|
||||
assert_allclose(x[1], x1)
|
||||
|
||||
def test_dwt_invalid_input():
|
||||
x = np.arange(1)
|
||||
assert_raises(ValueError, pywt.dwt, x, 'db2', 'reflect')
|
||||
assert_raises(ValueError, pywt.dwt, x, 'haar', 'antireflect')
|
||||
|
||||
|
||||
def test_dwt_axis_arg():
|
||||
x = [[3, 7, 1, 1],
|
||||
[-2, 5, 4, 6]]
|
||||
|
||||
cA_, cD_ = pywt.dwt(x, 'db2', axis=-1)
|
||||
cA, cD = pywt.dwt(x, 'db2', axis=1)
|
||||
|
||||
assert_allclose(cA_, cA)
|
||||
assert_allclose(cD_, cD)
|
||||
|
||||
def test_dwt_axis_invalid_input():
|
||||
x = np.ones((3,1))
|
||||
assert_raises(ValueError, pywt.dwt, x, 'db2', 'reflect')
|
||||
|
||||
def test_idwt_axis_arg():
|
||||
x = [[3, 7, 1, 1],
|
||||
[-2, 5, 4, 6]]
|
||||
|
||||
cA, cD = pywt.dwt(x, 'db2', axis=1)
|
||||
|
||||
x_ = pywt.idwt(cA, cD, 'db2', axis=-1)
|
||||
x = pywt.idwt(cA, cD, 'db2', axis=1)
|
||||
|
||||
assert_allclose(x_, x)
|
||||
|
||||
|
||||
def test_dwt_idwt_axis_excess():
|
||||
x = [[3, 7, 1, 1],
|
||||
[-2, 5, 4, 6]]
|
||||
# can't transform over axes that aren't there
|
||||
assert_raises(ValueError,
|
||||
pywt.dwt, x, 'db2', 'symmetric', axis=2)
|
||||
|
||||
assert_raises(ValueError,
|
||||
pywt.idwt, [1, 2, 4], [4, 1, 3], 'db2', 'symmetric', axis=1)
|
||||
|
||||
|
||||
def test_error_on_continuous_wavelet():
|
||||
# A ValueError is raised if a Continuous wavelet is selected
|
||||
data = np.ones((32, ))
|
||||
for cwave in ['morl', pywt.DiscreteContinuousWavelet('morl')]:
|
||||
assert_raises(ValueError, pywt.dwt, data, cwave)
|
||||
|
||||
cA, cD = pywt.dwt(data, 'db1')
|
||||
assert_raises(ValueError, pywt.idwt, cA, cD, cwave)
|
||||
|
||||
|
||||
def test_dwt_zero_size_axes():
|
||||
# raise on empty input array
|
||||
assert_raises(ValueError, pywt.dwt, [], 'db2')
|
||||
|
||||
# >1D case uses a different code path so check there as well
|
||||
x = np.ones((1, 4))[0:0, :] # 2D with a size zero axis
|
||||
assert_raises(ValueError, pywt.dwt, x, 'db2', axis=0)
|
||||
|
||||
|
||||
def test_pad_1d():
|
||||
x = [1, 2, 3]
|
||||
assert_array_equal(pywt.pad(x, (4, 6), 'periodization'),
|
||||
[1, 2, 3, 3, 1, 2, 3, 3, 1, 2, 3, 3, 1, 2])
|
||||
assert_array_equal(pywt.pad(x, (4, 6), 'periodic'),
|
||||
[3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3])
|
||||
assert_array_equal(pywt.pad(x, (4, 6), 'constant'),
|
||||
[1, 1, 1, 1, 1, 2, 3, 3, 3, 3, 3, 3, 3])
|
||||
assert_array_equal(pywt.pad(x, (4, 6), 'zero'),
|
||||
[0, 0, 0, 0, 1, 2, 3, 0, 0, 0, 0, 0, 0])
|
||||
assert_array_equal(pywt.pad(x, (4, 6), 'smooth'),
|
||||
[-3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
|
||||
assert_array_equal(pywt.pad(x, (4, 6), 'symmetric'),
|
||||
[3, 3, 2, 1, 1, 2, 3, 3, 2, 1, 1, 2, 3])
|
||||
assert_array_equal(pywt.pad(x, (4, 6), 'antisymmetric'),
|
||||
[3, -3, -2, -1, 1, 2, 3, -3, -2, -1, 1, 2, 3])
|
||||
assert_array_equal(pywt.pad(x, (4, 6), 'reflect'),
|
||||
[1, 2, 3, 2, 1, 2, 3, 2, 1, 2, 3, 2, 1])
|
||||
assert_array_equal(pywt.pad(x, (4, 6), 'antireflect'),
|
||||
[-3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
|
||||
|
||||
# equivalence of various pad_width formats
|
||||
assert_array_equal(pywt.pad(x, 4, 'periodic'),
|
||||
pywt.pad(x, (4, 4), 'periodic'))
|
||||
|
||||
assert_array_equal(pywt.pad(x, (4, ), 'periodic'),
|
||||
pywt.pad(x, (4, 4), 'periodic'))
|
||||
|
||||
assert_array_equal(pywt.pad(x, [(4, 4)], 'periodic'),
|
||||
pywt.pad(x, (4, 4), 'periodic'))
|
||||
|
||||
|
||||
def test_pad_errors():
|
||||
# negative pad width
|
||||
x = [1, 2, 3]
|
||||
assert_raises(ValueError, pywt.pad, x, -2, 'periodic')
|
||||
|
||||
# wrong length pad width
|
||||
assert_raises(ValueError, pywt.pad, x, (1, 1, 1), 'periodic')
|
||||
|
||||
# invalid mode name
|
||||
assert_raises(ValueError, pywt.pad, x, 2, 'bad_mode')
|
||||
|
||||
|
||||
def test_pad_nd():
|
||||
for ndim in [2, 3]:
|
||||
x = np.arange(4**ndim).reshape((4, ) * ndim)
|
||||
if ndim == 2:
|
||||
pad_widths = [(2, 1), (2, 3)]
|
||||
else:
|
||||
pad_widths = [(2, 1), ] * ndim
|
||||
for mode in pywt.Modes.modes:
|
||||
xp = pywt.pad(x, pad_widths, mode)
|
||||
|
||||
# expected result is the same as applying along axes separably
|
||||
xp_expected = x.copy()
|
||||
for ax in range(ndim):
|
||||
xp_expected = np.apply_along_axis(pywt.pad,
|
||||
ax,
|
||||
xp_expected,
|
||||
pad_widths=[pad_widths[ax]],
|
||||
mode=mode)
|
||||
assert_array_equal(xp, xp_expected)
|
||||
45
.CondaPkg/env/Lib/site-packages/pywt/tests/test_functions.py
vendored
Normal file
45
.CondaPkg/env/Lib/site-packages/pywt/tests/test_functions.py
vendored
Normal file
@@ -0,0 +1,45 @@
|
||||
#!/usr/bin/env python
|
||||
from __future__ import division, print_function, absolute_import
|
||||
|
||||
from numpy.testing import assert_almost_equal, assert_allclose
|
||||
|
||||
import pywt
|
||||
|
||||
|
||||
def test_centrfreq():
|
||||
# db1 is Haar function, frequency=1
|
||||
w = pywt.Wavelet('db1')
|
||||
expected = 1
|
||||
result = pywt.central_frequency(w, precision=12)
|
||||
assert_almost_equal(result, expected, decimal=3)
|
||||
# db2, frequency=2/3
|
||||
w = pywt.Wavelet('db2')
|
||||
expected = 2/3.
|
||||
result = pywt.central_frequency(w, precision=12)
|
||||
assert_almost_equal(result, expected)
|
||||
|
||||
|
||||
def test_scal2frq_scale():
|
||||
scale = 2
|
||||
w = pywt.Wavelet('db1')
|
||||
expected = 1. / scale
|
||||
result = pywt.scale2frequency(w, scale, precision=12)
|
||||
assert_almost_equal(result, expected, decimal=3)
|
||||
|
||||
def test_frq2scal_freq():
|
||||
freq = 2
|
||||
w = pywt.Wavelet('db1')
|
||||
expected = 1. / freq
|
||||
result = pywt.frequency2scale(w, freq, precision=12)
|
||||
assert_almost_equal(result, expected, decimal=3)
|
||||
|
||||
|
||||
def test_intwave_orthogonal():
|
||||
w = pywt.Wavelet('db1')
|
||||
int_psi, x = pywt.integrate_wavelet(w, precision=12)
|
||||
ix = x < 0.5
|
||||
# For x < 0.5, the integral is equal to x
|
||||
assert_allclose(int_psi[ix], x[ix])
|
||||
# For x > 0.5, the integral is equal to (1 - x)
|
||||
# Ignore last point here, there x > 1 and something goes wrong
|
||||
assert_allclose(int_psi[~ix][:-1], 1 - x[~ix][:-1], atol=1e-10)
|
||||
160
.CondaPkg/env/Lib/site-packages/pywt/tests/test_matlab_compatibility.py
vendored
Normal file
160
.CondaPkg/env/Lib/site-packages/pywt/tests/test_matlab_compatibility.py
vendored
Normal file
@@ -0,0 +1,160 @@
|
||||
"""
|
||||
Test used to verify PyWavelets Discrete Wavelet Transform computation
|
||||
accuracy against MathWorks Wavelet Toolbox.
|
||||
"""
|
||||
|
||||
from __future__ import division, print_function, absolute_import
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from numpy.testing import assert_
|
||||
|
||||
import pywt
|
||||
from pywt._pytest import (uses_pymatbridge, uses_precomputed, size_set)
|
||||
from pywt._pytest import matlab_result_dict_dwt as matlab_result_dict
|
||||
|
||||
# list of mode names in pywt and matlab
|
||||
modes = [('zero', 'zpd'),
|
||||
('constant', 'sp0'),
|
||||
('symmetric', 'sym'),
|
||||
('reflect', 'symw'),
|
||||
('periodic', 'ppd'),
|
||||
('smooth', 'sp1'),
|
||||
('periodization', 'per'),
|
||||
# TODO: Now have implemented asymmetric modes too.
|
||||
# Would be nice to update the Matlab data to test these as well.
|
||||
('antisymmetric', 'asym'),
|
||||
('antireflect', 'asymw'),
|
||||
]
|
||||
|
||||
families = ('db', 'sym', 'coif', 'bior', 'rbio')
|
||||
wavelets = sum([pywt.wavelist(name) for name in families], [])
|
||||
|
||||
|
||||
def _get_data_sizes(w):
|
||||
""" Return the sizes to test for wavelet w. """
|
||||
if size_set == 'full':
|
||||
data_sizes = list(range(w.dec_len, 40)) + \
|
||||
[100, 200, 500, 1000, 50000]
|
||||
else:
|
||||
data_sizes = (w.dec_len, w.dec_len + 1)
|
||||
return data_sizes
|
||||
|
||||
|
||||
@uses_pymatbridge
|
||||
@pytest.mark.slow
|
||||
def test_accuracy_pymatbridge():
|
||||
Matlab = pytest.importorskip("pymatbridge.Matlab")
|
||||
mlab = Matlab()
|
||||
|
||||
rstate = np.random.RandomState(1234)
|
||||
# max RMSE (was 1.0e-10, is reduced to 5.0e-5 due to different coefficients)
|
||||
epsilon = 5.0e-5
|
||||
epsilon_pywt_coeffs = 1.0e-10
|
||||
mlab.start()
|
||||
try:
|
||||
for wavelet in wavelets:
|
||||
w = pywt.Wavelet(wavelet)
|
||||
mlab.set_variable('wavelet', wavelet)
|
||||
for N in _get_data_sizes(w):
|
||||
data = rstate.randn(N)
|
||||
mlab.set_variable('data', data)
|
||||
for pmode, mmode in modes:
|
||||
ma, md = _compute_matlab_result(data, wavelet, mmode, mlab)
|
||||
_check_accuracy(data, w, pmode, ma, md, wavelet, epsilon)
|
||||
ma, md = _load_matlab_result_pywt_coeffs(data, wavelet, mmode)
|
||||
_check_accuracy(data, w, pmode, ma, md, wavelet, epsilon_pywt_coeffs)
|
||||
|
||||
finally:
|
||||
mlab.stop()
|
||||
|
||||
|
||||
@uses_precomputed
|
||||
@pytest.mark.slow
|
||||
def test_accuracy_precomputed():
|
||||
# Keep this specific random seed to match the precomputed Matlab result.
|
||||
rstate = np.random.RandomState(1234)
|
||||
# max RMSE (was 1.0e-10, is reduced to 5.0e-5 due to different coefficients)
|
||||
epsilon = 5.0e-5
|
||||
epsilon_pywt_coeffs = 1.0e-10
|
||||
for wavelet in wavelets:
|
||||
w = pywt.Wavelet(wavelet)
|
||||
for N in _get_data_sizes(w):
|
||||
data = rstate.randn(N)
|
||||
for pmode, mmode in modes:
|
||||
ma, md = _load_matlab_result(data, wavelet, mmode)
|
||||
_check_accuracy(data, w, pmode, ma, md, wavelet, epsilon)
|
||||
ma, md = _load_matlab_result_pywt_coeffs(data, wavelet, mmode)
|
||||
_check_accuracy(data, w, pmode, ma, md, wavelet, epsilon_pywt_coeffs)
|
||||
|
||||
|
||||
def _compute_matlab_result(data, wavelet, mmode, mlab):
|
||||
""" Compute the result using MATLAB.
|
||||
|
||||
This function assumes that the Matlab variables `wavelet` and `data` have
|
||||
already been set externally.
|
||||
"""
|
||||
if np.any((wavelet == np.array(['coif6', 'coif7', 'coif8', 'coif9', 'coif10', 'coif11', 'coif12', 'coif13', 'coif14', 'coif15', 'coif16', 'coif17'])),axis=0):
|
||||
w = pywt.Wavelet(wavelet)
|
||||
mlab.set_variable('Lo_D', w.dec_lo)
|
||||
mlab.set_variable('Hi_D', w.dec_hi)
|
||||
mlab_code = ("[ma, md] = dwt(data, Lo_D, Hi_D, 'mode', '%s');" % mmode)
|
||||
else:
|
||||
mlab_code = "[ma, md] = dwt(data, wavelet, 'mode', '%s');" % mmode
|
||||
res = mlab.run_code(mlab_code)
|
||||
if not res['success']:
|
||||
raise RuntimeError("Matlab failed to execute the provided code. "
|
||||
"Check that the wavelet toolbox is installed.")
|
||||
# need np.asarray because sometimes the output is a single float64
|
||||
ma = np.asarray(mlab.get_variable('ma'))
|
||||
md = np.asarray(mlab.get_variable('md'))
|
||||
return ma, md
|
||||
|
||||
|
||||
def _load_matlab_result(data, wavelet, mmode):
|
||||
""" Load the precomputed result.
|
||||
"""
|
||||
N = len(data)
|
||||
ma_key = '_'.join([mmode, wavelet, str(N), 'ma'])
|
||||
md_key = '_'.join([mmode, wavelet, str(N), 'md'])
|
||||
if (ma_key not in matlab_result_dict) or \
|
||||
(md_key not in matlab_result_dict):
|
||||
raise KeyError(
|
||||
"Precompted Matlab result not found for wavelet: "
|
||||
"{0}, mode: {1}, size: {2}".format(wavelet, mmode, N))
|
||||
ma = matlab_result_dict[ma_key]
|
||||
md = matlab_result_dict[md_key]
|
||||
return ma, md
|
||||
|
||||
|
||||
def _load_matlab_result_pywt_coeffs(data, wavelet, mmode):
|
||||
""" Load the precomputed result.
|
||||
"""
|
||||
N = len(data)
|
||||
ma_key = '_'.join([mmode, wavelet, str(N), 'ma_pywtCoeffs'])
|
||||
md_key = '_'.join([mmode, wavelet, str(N), 'md_pywtCoeffs'])
|
||||
if (ma_key not in matlab_result_dict) or \
|
||||
(md_key not in matlab_result_dict):
|
||||
raise KeyError(
|
||||
"Precompted Matlab result not found for wavelet: "
|
||||
"{0}, mode: {1}, size: {2}".format(wavelet, mmode, N))
|
||||
ma = matlab_result_dict[ma_key]
|
||||
md = matlab_result_dict[md_key]
|
||||
return ma, md
|
||||
|
||||
|
||||
def _check_accuracy(data, w, pmode, ma, md, wavelet, epsilon):
|
||||
# PyWavelets result
|
||||
pa, pd = pywt.dwt(data, w, pmode)
|
||||
|
||||
# calculate error measures
|
||||
rms_a = np.sqrt(np.mean((pa - ma) ** 2))
|
||||
rms_d = np.sqrt(np.mean((pd - md) ** 2))
|
||||
|
||||
msg = ('[RMS_A > EPSILON] for Mode: %s, Wavelet: %s, '
|
||||
'Length: %d, rms=%.3g' % (pmode, wavelet, len(data), rms_a))
|
||||
assert_(rms_a < epsilon, msg=msg)
|
||||
|
||||
msg = ('[RMS_D > EPSILON] for Mode: %s, Wavelet: %s, '
|
||||
'Length: %d, rms=%.3g' % (pmode, wavelet, len(data), rms_d))
|
||||
assert_(rms_d < epsilon, msg=msg)
|
||||
174
.CondaPkg/env/Lib/site-packages/pywt/tests/test_matlab_compatibility_cwt.py
vendored
Normal file
174
.CondaPkg/env/Lib/site-packages/pywt/tests/test_matlab_compatibility_cwt.py
vendored
Normal file
@@ -0,0 +1,174 @@
|
||||
"""
|
||||
Test used to verify PyWavelets Continuous Wavelet Transform computation
|
||||
accuracy against MathWorks Wavelet Toolbox.
|
||||
"""
|
||||
|
||||
from __future__ import division, print_function, absolute_import
|
||||
|
||||
import warnings
|
||||
import numpy as np
|
||||
import pytest
|
||||
from numpy.testing import assert_
|
||||
|
||||
import pywt
|
||||
from pywt._pytest import (uses_pymatbridge, uses_precomputed, size_set,
|
||||
matlab_result_dict_cwt)
|
||||
|
||||
families = ('gaus', 'mexh', 'morl', 'cgau', 'shan', 'fbsp', 'cmor')
|
||||
wavelets = sum([pywt.wavelist(name) for name in families], [])
|
||||
|
||||
|
||||
def _get_data_sizes(w):
|
||||
""" Return the sizes to test for wavelet w. """
|
||||
if size_set == 'full':
|
||||
data_sizes = list(range(100, 101)) + \
|
||||
[100, 200, 500, 1000, 50000]
|
||||
else:
|
||||
data_sizes = (1000, 1000 + 1)
|
||||
return data_sizes
|
||||
|
||||
|
||||
def _get_scales(w):
|
||||
""" Return the scales to test for wavelet w. """
|
||||
if size_set == 'full':
|
||||
scales = (1, np.arange(1, 3), np.arange(1, 4), np.arange(1, 5))
|
||||
else:
|
||||
scales = (1, np.arange(1, 3))
|
||||
return scales
|
||||
|
||||
|
||||
@uses_pymatbridge # skip this case if precomputed results are used instead
|
||||
@pytest.mark.slow
|
||||
def test_accuracy_pymatbridge_cwt():
|
||||
Matlab = pytest.importorskip("pymatbridge.Matlab")
|
||||
mlab = Matlab()
|
||||
rstate = np.random.RandomState(1234)
|
||||
# max RMSE (was 1.0e-10, is reduced to 5.0e-5 due to different coefficients)
|
||||
epsilon = 1e-15
|
||||
epsilon_psi = 1e-15
|
||||
mlab.start()
|
||||
try:
|
||||
for wavelet in wavelets:
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore', FutureWarning)
|
||||
w = pywt.ContinuousWavelet(wavelet)
|
||||
if np.any((wavelet == np.array(['shan', 'cmor'])),axis=0):
|
||||
mlab.set_variable('wavelet', wavelet+str(w.bandwidth_frequency)+'-'+str(w.center_frequency))
|
||||
elif wavelet == 'fbsp':
|
||||
mlab.set_variable('wavelet', wavelet+str(w.fbsp_order)+'-'+str(w.bandwidth_frequency)+'-'+str(w.center_frequency))
|
||||
else:
|
||||
mlab.set_variable('wavelet', wavelet)
|
||||
mlab_code = ("psi = wavefun(wavelet,10)")
|
||||
res = mlab.run_code(mlab_code)
|
||||
psi = np.asarray(mlab.get_variable('psi'))
|
||||
_check_accuracy_psi(w, psi, wavelet, epsilon_psi)
|
||||
for N in _get_data_sizes(w):
|
||||
data = rstate.randn(N)
|
||||
mlab.set_variable('data', data)
|
||||
for scales in _get_scales(w):
|
||||
coefs = _compute_matlab_result(data, wavelet, scales, mlab)
|
||||
_check_accuracy(data, w, scales, coefs, wavelet, epsilon)
|
||||
|
||||
finally:
|
||||
mlab.stop()
|
||||
|
||||
|
||||
@uses_precomputed # skip this case if pymatbridge + Matlab are being used
|
||||
@pytest.mark.slow
|
||||
def test_accuracy_precomputed_cwt():
|
||||
# Keep this specific random seed to match the precomputed Matlab result.
|
||||
rstate = np.random.RandomState(1234)
|
||||
# has to be improved
|
||||
epsilon = 2e-15
|
||||
epsilon32 = 1e-5
|
||||
epsilon_psi = 1e-15
|
||||
for wavelet in wavelets:
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore', FutureWarning)
|
||||
w = pywt.ContinuousWavelet(wavelet)
|
||||
w32 = pywt.ContinuousWavelet(wavelet,dtype=np.float32)
|
||||
psi = _load_matlab_result_psi(wavelet)
|
||||
_check_accuracy_psi(w, psi, wavelet, epsilon_psi)
|
||||
|
||||
for N in _get_data_sizes(w):
|
||||
data = rstate.randn(N)
|
||||
data32 = data.astype(np.float32)
|
||||
scales_count = 0
|
||||
for scales in _get_scales(w):
|
||||
scales_count += 1
|
||||
coefs = _load_matlab_result(data, wavelet, scales_count)
|
||||
_check_accuracy(data, w, scales, coefs, wavelet, epsilon)
|
||||
_check_accuracy(data32, w32, scales, coefs, wavelet, epsilon32)
|
||||
|
||||
|
||||
def _compute_matlab_result(data, wavelet, scales, mlab):
|
||||
""" Compute the result using MATLAB.
|
||||
|
||||
This function assumes that the Matlab variables `wavelet` and `data` have
|
||||
already been set externally.
|
||||
"""
|
||||
mlab.set_variable('scales', scales)
|
||||
mlab_code = ("coefs = cwt(data, scales, wavelet)")
|
||||
res = mlab.run_code(mlab_code)
|
||||
if not res['success']:
|
||||
raise RuntimeError("Matlab failed to execute the provided code. "
|
||||
"Check that the wavelet toolbox is installed.")
|
||||
# need np.asarray because sometimes the output is a single float64
|
||||
coefs = np.asarray(mlab.get_variable('coefs'))
|
||||
return coefs
|
||||
|
||||
|
||||
def _load_matlab_result(data, wavelet, scales):
|
||||
""" Load the precomputed result.
|
||||
"""
|
||||
N = len(data)
|
||||
coefs_key = '_'.join([str(scales), wavelet, str(N), 'coefs'])
|
||||
if (coefs_key not in matlab_result_dict_cwt):
|
||||
raise KeyError(
|
||||
"Precompted Matlab result not found for wavelet: "
|
||||
"{0}, mode: {1}, size: {2}".format(wavelet, scales, N))
|
||||
coefs = matlab_result_dict_cwt[coefs_key]
|
||||
return coefs
|
||||
|
||||
|
||||
def _load_matlab_result_psi(wavelet):
|
||||
""" Load the precomputed result.
|
||||
"""
|
||||
psi_key = '_'.join([wavelet, 'psi'])
|
||||
if (psi_key not in matlab_result_dict_cwt):
|
||||
raise KeyError(
|
||||
"Precompted Matlab psi result not found for wavelet: "
|
||||
"{0}}".format(wavelet))
|
||||
psi = matlab_result_dict_cwt[psi_key]
|
||||
return psi
|
||||
|
||||
|
||||
def _check_accuracy(data, w, scales, coefs, wavelet, epsilon):
|
||||
# PyWavelets result
|
||||
coefs_pywt, freq = pywt.cwt(data, scales, w)
|
||||
|
||||
# coefs from Matlab are from R2012a which is missing the complex conjugate
|
||||
# as shown in Eq. 2 of Torrence and Compo. We take the complex conjugate of
|
||||
# the precomputed Matlab result to account for this.
|
||||
coefs = np.conj(coefs)
|
||||
|
||||
# calculate error measures
|
||||
err = coefs_pywt - coefs
|
||||
rms = np.real(np.sqrt(np.mean(np.conj(err) * err)))
|
||||
|
||||
msg = ('[RMS > EPSILON] for Scale: %s, Wavelet: %s, '
|
||||
'Length: %d, rms=%.3g' % (scales, wavelet, len(data), rms))
|
||||
assert_(rms < epsilon, msg=msg)
|
||||
|
||||
|
||||
def _check_accuracy_psi(w, psi, wavelet, epsilon):
|
||||
# PyWavelets result
|
||||
psi_pywt, x = w.wavefun(length=1024)
|
||||
|
||||
# calculate error measures
|
||||
err = psi_pywt.flatten() - psi.flatten()
|
||||
rms = np.real(np.sqrt(np.mean(np.conj(err) * err)))
|
||||
|
||||
msg = ('[RMS > EPSILON] for Wavelet: %s, '
|
||||
'rms=%.3g' % (wavelet, rms))
|
||||
assert_(rms < epsilon, msg=msg)
|
||||
109
.CondaPkg/env/Lib/site-packages/pywt/tests/test_modes.py
vendored
Normal file
109
.CondaPkg/env/Lib/site-packages/pywt/tests/test_modes.py
vendored
Normal file
@@ -0,0 +1,109 @@
|
||||
#!/usr/bin/env python
|
||||
from __future__ import division, print_function, absolute_import
|
||||
|
||||
import numpy as np
|
||||
from numpy.testing import assert_raises, assert_equal, assert_allclose
|
||||
|
||||
import pywt
|
||||
|
||||
|
||||
def test_available_modes():
|
||||
modes = ['zero', 'constant', 'symmetric', 'periodic', 'smooth',
|
||||
'periodization', 'reflect', 'antisymmetric', 'antireflect']
|
||||
assert_equal(pywt.Modes.modes, modes)
|
||||
assert_equal(pywt.Modes.from_object('constant'), 2)
|
||||
|
||||
|
||||
def test_invalid_modes():
|
||||
x = np.arange(4)
|
||||
assert_raises(ValueError, pywt.dwt, x, 'db2', 'unknown')
|
||||
assert_raises(ValueError, pywt.dwt, x, 'db2', -1)
|
||||
assert_raises(ValueError, pywt.dwt, x, 'db2', 9)
|
||||
assert_raises(TypeError, pywt.dwt, x, 'db2', None)
|
||||
|
||||
assert_raises(ValueError, pywt.Modes.from_object, 'unknown')
|
||||
assert_raises(ValueError, pywt.Modes.from_object, -1)
|
||||
assert_raises(ValueError, pywt.Modes.from_object, 9)
|
||||
assert_raises(TypeError, pywt.Modes.from_object, None)
|
||||
|
||||
|
||||
def test_dwt_idwt_allmodes():
|
||||
# Test that :func:`dwt` and :func:`idwt` can be performed using every mode
|
||||
x = [1, 2, 1, 5, -1, 8, 4, 6]
|
||||
dwt_results = {
|
||||
'zero': ([-0.03467518, 1.73309178, 3.40612438, 6.32928585, 6.95094948],
|
||||
[-0.12940952, -2.15599552, -5.95034847, -1.21545369,
|
||||
-1.8625013]),
|
||||
'constant': ([1.28480404, 1.73309178, 3.40612438, 6.32928585,
|
||||
7.51935555],
|
||||
[-0.48296291, -2.15599552, -5.95034847, -1.21545369,
|
||||
0.25881905]),
|
||||
'symmetric': ([1.76776695, 1.73309178, 3.40612438, 6.32928585,
|
||||
7.77817459],
|
||||
[-0.61237244, -2.15599552, -5.95034847, -1.21545369,
|
||||
1.22474487]),
|
||||
'reflect': ([2.12132034, 1.73309178, 3.40612438, 6.32928585,
|
||||
6.81224877],
|
||||
[-0.70710678, -2.15599552, -5.95034847, -1.21545369,
|
||||
-2.38013939]),
|
||||
'periodic': ([6.9162743, 1.73309178, 3.40612438, 6.32928585,
|
||||
6.9162743],
|
||||
[-1.99191082, -2.15599552, -5.95034847, -1.21545369,
|
||||
-1.99191082]),
|
||||
'smooth': ([-0.51763809, 1.73309178, 3.40612438, 6.32928585,
|
||||
7.45000519],
|
||||
[0, -2.15599552, -5.95034847, -1.21545369, 0]),
|
||||
'periodization': ([4.053172, 3.05257099, 2.85381112, 8.42522221],
|
||||
[0.18946869, 4.18258152, 4.33737503, 2.60428326]),
|
||||
'antisymmetric': ([-1.83711731, 1.73309178, 3.40612438, 6.32928585,
|
||||
6.12372436],
|
||||
[0.353553391, -2.15599552, -5.95034847, -1.21545369,
|
||||
-4.94974747]),
|
||||
'antireflect': ([0.44828774, 1.73309178, 3.40612438, 6.32928585,
|
||||
8.22646233],
|
||||
[-0.25881905, -2.15599552, -5.95034847, -1.21545369,
|
||||
2.89777748])
|
||||
}
|
||||
|
||||
for mode in pywt.Modes.modes:
|
||||
cA, cD = pywt.dwt(x, 'db2', mode)
|
||||
assert_allclose(cA, dwt_results[mode][0], rtol=1e-7, atol=1e-8)
|
||||
assert_allclose(cD, dwt_results[mode][1], rtol=1e-7, atol=1e-8)
|
||||
assert_allclose(pywt.idwt(cA, cD, 'db2', mode), x, rtol=1e-10)
|
||||
|
||||
|
||||
def test_dwt_short_input_allmodes():
|
||||
# some test cases where the input is shorter than the DWT filter
|
||||
x = [1, 3, 2]
|
||||
wavelet = 'db2'
|
||||
# manually pad each end by the filter size (4 for 'db2' used here)
|
||||
padded_x = {'zero': [0, 0, 0, 0, 1, 3, 2, 0, 0, 0, 0],
|
||||
'constant': [1, 1, 1, 1, 1, 3, 2, 2, 2, 2, 2],
|
||||
'symmetric': [2, 2, 3, 1, 1, 3, 2, 2, 3, 1, 1],
|
||||
'reflect': [1, 3, 2, 3, 1, 3, 2, 3, 1, 3, 2],
|
||||
'periodic': [2, 1, 3, 2, 1, 3, 2, 1, 3, 2, 1],
|
||||
'smooth': [-7, -5, -3, -1, 1, 3, 2, 1, 0, -1, -2],
|
||||
'antisymmetric': [2, -2, -3, -1, 1, 3, 2, -2, -3, -1, 1],
|
||||
'antireflect': [1, -1, 0, -1, 1, 3, 2, 1, 3, 5, 4],
|
||||
}
|
||||
for mode, xpad in padded_x.items():
|
||||
# DWT of the manually padded array. will discard edges later so
|
||||
# symmetric mode used here doesn't matter.
|
||||
cApad, cDpad = pywt.dwt(xpad, wavelet, mode='symmetric')
|
||||
|
||||
# central region of the padded output (unaffected by mode )
|
||||
expected_result = (cApad[2:-2], cDpad[2:-2])
|
||||
|
||||
cA, cD = pywt.dwt(x, wavelet, mode)
|
||||
assert_allclose(cA, expected_result[0], rtol=1e-7, atol=1e-8)
|
||||
assert_allclose(cD, expected_result[1], rtol=1e-7, atol=1e-8)
|
||||
|
||||
|
||||
def test_default_mode():
|
||||
# The default mode should be 'symmetric'
|
||||
x = [1, 2, 1, 5, -1, 8, 4, 6]
|
||||
cA, cD = pywt.dwt(x, 'db2')
|
||||
cA2, cD2 = pywt.dwt(x, 'db2', mode='symmetric')
|
||||
assert_allclose(cA, cA2)
|
||||
assert_allclose(cD, cD2)
|
||||
assert_allclose(pywt.idwt(cA, cD, 'db2'), x)
|
||||
255
.CondaPkg/env/Lib/site-packages/pywt/tests/test_mra.py
vendored
Normal file
255
.CondaPkg/env/Lib/site-packages/pywt/tests/test_mra.py
vendored
Normal file
@@ -0,0 +1,255 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from numpy.testing import assert_allclose
|
||||
|
||||
import pywt
|
||||
from pywt import data
|
||||
|
||||
# tolerances used in accuracy comparisons
|
||||
tol_single = 1e-6
|
||||
tol_double = 1e-13
|
||||
atol = 1e-7
|
||||
|
||||
|
||||
####
|
||||
# 1d mra tests
|
||||
####
|
||||
|
||||
@pytest.mark.parametrize('wavelet', ['db2', 'sym4', 'coif5'])
|
||||
@pytest.mark.parametrize('transform', ['dwt', 'swt'])
|
||||
@pytest.mark.parametrize('mode', pywt.Modes.modes)
|
||||
@pytest.mark.parametrize(
|
||||
'dtype', ['float32', 'float64', 'complex64', 'complex128']
|
||||
)
|
||||
def test_mra_roundtrip(wavelet, transform, mode, dtype):
|
||||
x = data.ecg()[:64].astype(dtype)
|
||||
if x.dtype.kind == 'c':
|
||||
# fill some data for the imaginary channel
|
||||
x.imag = x[::-1].real
|
||||
|
||||
if transform == 'swt':
|
||||
# swt mode only supports periodization
|
||||
if mode != 'periodization':
|
||||
with pytest.raises(ValueError):
|
||||
pywt.mra(x, wavelet, transform=transform, mode=mode)
|
||||
return
|
||||
|
||||
coeffs = pywt.mra(x, wavelet, transform=transform, mode=mode)
|
||||
assert isinstance(coeffs, list)
|
||||
assert isinstance(coeffs[0], np.ndarray)
|
||||
# assert all(isinstance(coeffs[i], dict) for i in range(1, len(coeffs)))
|
||||
|
||||
y = pywt.imra(coeffs)
|
||||
rtol = tol_single if x.real.dtype.kind == 'f' else tol_double
|
||||
assert_allclose(x, y, rtol=rtol, atol=rtol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('wavelet', ['rbio1.3', 'bior2.4'])
|
||||
@pytest.mark.parametrize('transform', ['dwt', 'swt'])
|
||||
def test_mra_warns_on_non_orthogonal(wavelet, transform):
|
||||
dtype = np.float64
|
||||
x = data.ecg()[:64].astype(dtype)
|
||||
|
||||
assert not pywt.Wavelet(wavelet).orthogonal
|
||||
|
||||
if transform == 'swt':
|
||||
# bi-orthogonal wavelets raise a warning for SWT case
|
||||
msg = 'norm=True, but the wavelet is not orthogonal'
|
||||
with pytest.warns(UserWarning, match=msg):
|
||||
coeffs = pywt.mra(x, wavelet, transform=transform)
|
||||
else:
|
||||
coeffs = pywt.mra(x, wavelet, transform=transform)
|
||||
|
||||
y = pywt.imra(coeffs)
|
||||
rtol = tol_single if x.real.dtype.kind == 'f' else tol_double
|
||||
assert_allclose(x, y, rtol=rtol, atol=rtol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('axis', [0, -1, 1, 2, -3])
|
||||
@pytest.mark.parametrize('ndim', [1, 2, 3])
|
||||
@pytest.mark.parametrize('transform', ['dwt', 'swt'])
|
||||
@pytest.mark.parametrize('dtype', [np.float64, np.complex128])
|
||||
def test_mra_axis(transform, ndim, axis, dtype):
|
||||
# Test transforms over a specific axis of 1D, 2D or 3D data
|
||||
if ndim == 1:
|
||||
x = data.ecg()[:64]
|
||||
elif ndim == 2:
|
||||
x = data.camera()[:64, :32]
|
||||
elif ndim == 3:
|
||||
x = data.camera()[:48, :8]
|
||||
x = np.stack((x,) * 8, axis=-1)
|
||||
x = x.astype(dtype, copy=False)
|
||||
|
||||
# out of range axis
|
||||
if axis < -x.ndim or axis >= x.ndim:
|
||||
with pytest.raises(np.AxisError):
|
||||
pywt.mra(x, 'db1', transform=transform, axis=axis)
|
||||
return
|
||||
|
||||
coeffs = pywt.mra(x, 'db1', transform=transform, axis=axis)
|
||||
y = pywt.imra(coeffs)
|
||||
rtol = tol_single if x.real.dtype.kind == 'f' else tol_double
|
||||
assert_allclose(x, y, rtol=rtol, atol=rtol)
|
||||
|
||||
|
||||
####
|
||||
# 2d mra tests
|
||||
####
|
||||
|
||||
@pytest.mark.parametrize('wavelet', ['db2', 'sym4', 'coif5'])
|
||||
@pytest.mark.parametrize('transform', ['dwt2', 'swt2'])
|
||||
@pytest.mark.parametrize('mode', pywt.Modes.modes)
|
||||
@pytest.mark.parametrize(
|
||||
'dtype', ['float32', 'float64', 'complex64', 'complex128']
|
||||
)
|
||||
def test_mra2_roundtrip(wavelet, transform, mode, dtype):
|
||||
x = data.camera()[:32, :16].astype(dtype, copy=False)
|
||||
if x.dtype.kind == 'c':
|
||||
# fill some data for the imaginary channel
|
||||
x.imag = x[::-1, :].real
|
||||
|
||||
if transform == 'swt2':
|
||||
# swt mode only supports periodization
|
||||
if mode != 'periodization':
|
||||
with pytest.raises(ValueError):
|
||||
pywt.mra2(x, wavelet, transform=transform, mode=mode)
|
||||
return
|
||||
|
||||
coeffs = pywt.mra2(x, wavelet, transform=transform, mode=mode)
|
||||
assert isinstance(coeffs, list)
|
||||
assert isinstance(coeffs[0], np.ndarray)
|
||||
# assert all(isinstance(coeffs[i], dict) for i in range(1, len(coeffs)))
|
||||
|
||||
y = pywt.imra2(coeffs)
|
||||
rtol = tol_single if x.real.dtype.kind == 'f' else tol_double
|
||||
assert_allclose(x, y, rtol=rtol, atol=rtol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('wavelet', ['rbio1.3', 'bior2.4'])
|
||||
@pytest.mark.parametrize('transform', ['dwt2', 'swt2'])
|
||||
def test_mra2_warns_on_non_orthogonal(wavelet, transform):
|
||||
dtype = np.float64
|
||||
x = data.camera()[:32, :8].astype(dtype, copy=False)
|
||||
|
||||
assert not pywt.Wavelet(wavelet).orthogonal
|
||||
|
||||
if transform == 'swt2':
|
||||
# bi-orthogonal wavelets raise a warning for SWT case
|
||||
msg = 'norm=True, but the wavelets used are not orthogonal'
|
||||
with pytest.warns(UserWarning, match=msg):
|
||||
coeffs = pywt.mra2(x, wavelet, transform=transform)
|
||||
else:
|
||||
coeffs = pywt.mra2(x, wavelet, transform=transform)
|
||||
|
||||
y = pywt.imra2(coeffs)
|
||||
rtol = tol_single if x.real.dtype.kind == 'f' else tol_double
|
||||
assert_allclose(x, y, rtol=rtol, atol=rtol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('transform', ['dwt2', 'swt2'])
|
||||
@pytest.mark.parametrize('ndim', [2, 3])
|
||||
@pytest.mark.parametrize('axes', [(0, 1), (-2, -1), (0, 2), (-3, 1), (0, 4)])
|
||||
@pytest.mark.parametrize('dtype', [np.float64, np.complex128])
|
||||
def test_mra2_axes(transform, axes, ndim, dtype):
|
||||
# Test transforms over various axes of 2D or 3D data.
|
||||
x = data.camera()[:32, :16].astype(dtype, copy=False)
|
||||
if ndim == 3:
|
||||
x = np.stack((x,) * 8, axis=-1)
|
||||
|
||||
# out of range axis
|
||||
if any([axis < -x.ndim or axis >= x.ndim for axis in axes]):
|
||||
with pytest.raises(np.AxisError):
|
||||
pywt.mra2(x, 'db1', transform=transform, axes=axes)
|
||||
return
|
||||
|
||||
coeffs = pywt.mra2(x, 'db1', transform=transform, axes=axes)
|
||||
y = pywt.imra2(coeffs)
|
||||
rtol = tol_single if x.real.dtype.kind == 'f' else tol_double
|
||||
assert_allclose(x, y, rtol=rtol, atol=rtol)
|
||||
|
||||
|
||||
####
|
||||
# nd mra tests
|
||||
####
|
||||
|
||||
@pytest.mark.parametrize('wavelet', ['sym2', ])
|
||||
@pytest.mark.parametrize('transform', ['dwtn', 'swtn'])
|
||||
@pytest.mark.parametrize('mode', pywt.Modes.modes)
|
||||
@pytest.mark.parametrize(
|
||||
'dtype', ['float32', 'float64', 'complex64', 'complex128']
|
||||
)
|
||||
@pytest.mark.parametrize('ndim', [1, 2, 3])
|
||||
def test_mran_roundtrip(wavelet, transform, mode, dtype, ndim):
|
||||
if ndim == 1:
|
||||
x = data.ecg()[:48].astype(dtype, copy=False)
|
||||
elif ndim == 2:
|
||||
x = data.camera()[:16, :8].astype(dtype, copy=False)
|
||||
elif ndim == 3:
|
||||
x = data.camera()[:16, :8].astype(dtype, copy=False)
|
||||
x = np.stack((x,) * 8, axis=-1)
|
||||
|
||||
if x.dtype.kind == 'c':
|
||||
# fill some data for the imaginary channel
|
||||
x.imag = x[::-1, ...].real
|
||||
|
||||
if transform == 'swtn':
|
||||
# swt mode only supports periodization
|
||||
if mode != 'periodization':
|
||||
with pytest.raises(ValueError):
|
||||
pywt.mran(x, wavelet, transform=transform, mode=mode)
|
||||
return
|
||||
|
||||
coeffs = pywt.mran(x, wavelet, transform=transform, mode=mode)
|
||||
assert isinstance(coeffs, list)
|
||||
assert isinstance(coeffs[0], np.ndarray)
|
||||
# assert all(isinstance(coeffs[i], dict) for i in range(1, len(coeffs)))
|
||||
|
||||
y = pywt.imran(coeffs)
|
||||
rtol = tol_single if x.real.dtype.kind == 'f' else tol_double
|
||||
assert_allclose(x, y, rtol=rtol, atol=rtol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('wavelet', ['rbio1.3', 'bior2.4'])
|
||||
@pytest.mark.parametrize('transform', ['dwtn', 'swtn'])
|
||||
def test_mran_warns_on_non_orthogonal(wavelet, transform):
|
||||
dtype = np.float64
|
||||
x = data.camera()[:32, :8].astype(dtype, copy=False)
|
||||
|
||||
assert not pywt.Wavelet(wavelet).orthogonal
|
||||
|
||||
if transform == 'swtn':
|
||||
# bi-orthogonal wavelets raise a warning for SWT case
|
||||
msg = 'norm=True, but the wavelets used are not orthogonal'
|
||||
with pytest.warns(UserWarning, match=msg):
|
||||
coeffs = pywt.mran(x, wavelet, transform=transform)
|
||||
else:
|
||||
coeffs = pywt.mran(x, wavelet, transform=transform)
|
||||
|
||||
y = pywt.imran(coeffs)
|
||||
rtol = tol_single if x.real.dtype.kind == 'f' else tol_double
|
||||
assert_allclose(x, y, rtol=rtol, atol=rtol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'axes', [(0, 1), (-2, -1), (0, 2), (-3, 1), (0, 4), (-3, -2, -1),
|
||||
(0, 2, 1), (0, 5, 1), (0,), (1,), (2,), (-2,), (-3,), (-4,)])
|
||||
@pytest.mark.parametrize('transform', ['dwtn', 'swtn'])
|
||||
def test_mran_axes(axes, transform):
|
||||
# Test with transforms over 1, 2 or 3 axes of 3d data.
|
||||
# Cases with out of range axes are also tested
|
||||
dtype = np.float64
|
||||
x = data.camera()[:32, :16].astype(dtype, copy=False)
|
||||
x3d = np.stack((x,) * 8, axis=-1)
|
||||
|
||||
# out of range axis
|
||||
if any([axis < -x.ndim or axis >= x.ndim for axis in axes]):
|
||||
with pytest.raises(np.AxisError):
|
||||
pywt.mran(x, 'db1', transform='dwtn', axes=axes)
|
||||
return
|
||||
|
||||
coeffs = pywt.mran(x3d, 'db1', transform='dwtn', axes=axes)
|
||||
y = pywt.imran(coeffs)
|
||||
rtol = tol_single if x3d.real.dtype.kind == 'f' else tol_double
|
||||
assert_allclose(x3d, y, rtol=rtol, atol=rtol)
|
||||
443
.CondaPkg/env/Lib/site-packages/pywt/tests/test_multidim.py
vendored
Normal file
443
.CondaPkg/env/Lib/site-packages/pywt/tests/test_multidim.py
vendored
Normal file
@@ -0,0 +1,443 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
from __future__ import division, print_function, absolute_import
|
||||
|
||||
import numpy as np
|
||||
from itertools import combinations
|
||||
from numpy.testing import assert_allclose, assert_, assert_raises, assert_equal
|
||||
|
||||
import pywt
|
||||
# Check that float32, float64, complex64, complex128 are preserved.
|
||||
# Other real types get converted to float64.
|
||||
# complex256 gets converted to complex128
|
||||
dtypes_in = [np.int8, np.float16, np.float32, np.float64, np.complex64,
|
||||
np.complex128]
|
||||
dtypes_out = [np.float64, np.float32, np.float32, np.float64, np.complex64,
|
||||
np.complex128]
|
||||
|
||||
# test complex256 as well if it is available
|
||||
try:
|
||||
dtypes_in += [np.complex256, ]
|
||||
dtypes_out += [np.complex128, ]
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
|
||||
def test_dwtn_input():
|
||||
# Array-like must be accepted
|
||||
pywt.dwtn([1, 2, 3, 4], 'haar')
|
||||
# Others must not
|
||||
data = dict()
|
||||
assert_raises(TypeError, pywt.dwtn, data, 'haar')
|
||||
# Must be at least 1D
|
||||
assert_raises(ValueError, pywt.dwtn, 2, 'haar')
|
||||
|
||||
|
||||
def test_3D_reconstruct():
|
||||
data = np.array([
|
||||
[[0, 4, 1, 5, 1, 4],
|
||||
[0, 5, 26, 3, 2, 1],
|
||||
[5, 8, 2, 33, 4, 9],
|
||||
[2, 5, 19, 4, 19, 1]],
|
||||
[[1, 5, 1, 2, 3, 4],
|
||||
[7, 12, 6, 52, 7, 8],
|
||||
[2, 12, 3, 52, 6, 8],
|
||||
[5, 2, 6, 78, 12, 2]]])
|
||||
|
||||
wavelet = pywt.Wavelet('haar')
|
||||
for mode in pywt.Modes.modes:
|
||||
d = pywt.dwtn(data, wavelet, mode=mode)
|
||||
assert_allclose(data, pywt.idwtn(d, wavelet, mode=mode),
|
||||
rtol=1e-13, atol=1e-13)
|
||||
|
||||
|
||||
def test_dwdtn_idwtn_allwavelets():
|
||||
rstate = np.random.RandomState(1234)
|
||||
r = rstate.randn(16, 16)
|
||||
# test 2D case only for all wavelet types
|
||||
wavelist = pywt.wavelist()
|
||||
if 'dmey' in wavelist:
|
||||
wavelist.remove('dmey')
|
||||
for wavelet in wavelist:
|
||||
if wavelet in ['cmor', 'shan', 'fbsp']:
|
||||
# skip these CWT families to avoid warnings
|
||||
continue
|
||||
if isinstance(pywt.DiscreteContinuousWavelet(wavelet), pywt.Wavelet):
|
||||
for mode in pywt.Modes.modes:
|
||||
coeffs = pywt.dwtn(r, wavelet, mode=mode)
|
||||
assert_allclose(pywt.idwtn(coeffs, wavelet, mode=mode),
|
||||
r, rtol=1e-7, atol=1e-7)
|
||||
|
||||
|
||||
def test_stride():
|
||||
wavelet = pywt.Wavelet('haar')
|
||||
|
||||
for dtype in ('float32', 'float64'):
|
||||
data = np.array([[0, 4, 1, 5, 1, 4],
|
||||
[0, 5, 6, 3, 2, 1],
|
||||
[2, 5, 19, 4, 19, 1]],
|
||||
dtype=dtype)
|
||||
|
||||
for mode in pywt.Modes.modes:
|
||||
expected = pywt.dwtn(data, wavelet)
|
||||
strided = np.ones((3, 12), dtype=data.dtype)
|
||||
strided[::-1, ::2] = data
|
||||
strided_dwtn = pywt.dwtn(strided[::-1, ::2], wavelet)
|
||||
for key in expected.keys():
|
||||
assert_allclose(strided_dwtn[key], expected[key])
|
||||
|
||||
|
||||
def test_byte_offset():
|
||||
wavelet = pywt.Wavelet('haar')
|
||||
for dtype in ('float32', 'float64'):
|
||||
data = np.array([[0, 4, 1, 5, 1, 4],
|
||||
[0, 5, 6, 3, 2, 1],
|
||||
[2, 5, 19, 4, 19, 1]],
|
||||
dtype=dtype)
|
||||
|
||||
for mode in pywt.Modes.modes:
|
||||
expected = pywt.dwtn(data, wavelet)
|
||||
padded = np.ones((3, 6), dtype=np.dtype({'data': (data.dtype, 0),
|
||||
'pad': ('byte', data.dtype.itemsize)},
|
||||
align=True))
|
||||
padded[:] = data
|
||||
padded_dwtn = pywt.dwtn(padded['data'], wavelet)
|
||||
for key in expected.keys():
|
||||
assert_allclose(padded_dwtn[key], expected[key])
|
||||
|
||||
|
||||
def test_3D_reconstruct_complex():
|
||||
# All dimensions even length so `take` does not need to be specified
|
||||
data = np.array([
|
||||
[[0, 4, 1, 5, 1, 4],
|
||||
[0, 5, 26, 3, 2, 1],
|
||||
[5, 8, 2, 33, 4, 9],
|
||||
[2, 5, 19, 4, 19, 1]],
|
||||
[[1, 5, 1, 2, 3, 4],
|
||||
[7, 12, 6, 52, 7, 8],
|
||||
[2, 12, 3, 52, 6, 8],
|
||||
[5, 2, 6, 78, 12, 2]]])
|
||||
data = data + 1j
|
||||
|
||||
wavelet = pywt.Wavelet('haar')
|
||||
d = pywt.dwtn(data, wavelet)
|
||||
# idwtn creates even-length shapes (2x dwtn size)
|
||||
original_shape = tuple([slice(None, s) for s in data.shape])
|
||||
assert_allclose(data, pywt.idwtn(d, wavelet)[original_shape],
|
||||
rtol=1e-13, atol=1e-13)
|
||||
|
||||
|
||||
def test_idwtn_idwt2():
|
||||
data = np.array([
|
||||
[0, 4, 1, 5, 1, 4],
|
||||
[0, 5, 6, 3, 2, 1],
|
||||
[2, 5, 19, 4, 19, 1]])
|
||||
|
||||
wavelet = pywt.Wavelet('haar')
|
||||
|
||||
LL, (HL, LH, HH) = pywt.dwt2(data, wavelet)
|
||||
d = {'aa': LL, 'da': HL, 'ad': LH, 'dd': HH}
|
||||
|
||||
for mode in pywt.Modes.modes:
|
||||
assert_allclose(pywt.idwt2((LL, (HL, LH, HH)), wavelet, mode=mode),
|
||||
pywt.idwtn(d, wavelet, mode=mode),
|
||||
rtol=1e-14, atol=1e-14)
|
||||
|
||||
|
||||
def test_idwtn_idwt2_complex():
|
||||
data = np.array([
|
||||
[0, 4, 1, 5, 1, 4],
|
||||
[0, 5, 6, 3, 2, 1],
|
||||
[2, 5, 19, 4, 19, 1]])
|
||||
data = data + 1j
|
||||
wavelet = pywt.Wavelet('haar')
|
||||
|
||||
LL, (HL, LH, HH) = pywt.dwt2(data, wavelet)
|
||||
d = {'aa': LL, 'da': HL, 'ad': LH, 'dd': HH}
|
||||
|
||||
for mode in pywt.Modes.modes:
|
||||
assert_allclose(pywt.idwt2((LL, (HL, LH, HH)), wavelet, mode=mode),
|
||||
pywt.idwtn(d, wavelet, mode=mode),
|
||||
rtol=1e-14, atol=1e-14)
|
||||
|
||||
|
||||
def test_idwtn_missing():
|
||||
# Test to confirm missing data behave as zeroes
|
||||
data = np.array([
|
||||
[0, 4, 1, 5, 1, 4],
|
||||
[0, 5, 6, 3, 2, 1],
|
||||
[2, 5, 19, 4, 19, 1]])
|
||||
|
||||
wavelet = pywt.Wavelet('haar')
|
||||
|
||||
coefs = pywt.dwtn(data, wavelet)
|
||||
|
||||
# No point removing zero, or all
|
||||
for num_missing in range(1, len(coefs)):
|
||||
for missing in combinations(coefs.keys(), num_missing):
|
||||
missing_coefs = coefs.copy()
|
||||
for key in missing:
|
||||
del missing_coefs[key]
|
||||
LL = missing_coefs.get('aa', None)
|
||||
HL = missing_coefs.get('da', None)
|
||||
LH = missing_coefs.get('ad', None)
|
||||
HH = missing_coefs.get('dd', None)
|
||||
|
||||
assert_allclose(pywt.idwt2((LL, (HL, LH, HH)), wavelet),
|
||||
pywt.idwtn(missing_coefs, 'haar'), atol=1e-15)
|
||||
|
||||
|
||||
def test_idwtn_all_coeffs_None():
|
||||
coefs = dict(aa=None, da=None, ad=None, dd=None)
|
||||
assert_raises(ValueError, pywt.idwtn, coefs, 'haar')
|
||||
|
||||
|
||||
def test_error_on_invalid_keys():
|
||||
data = np.array([
|
||||
[0, 4, 1, 5, 1, 4],
|
||||
[0, 5, 6, 3, 2, 1],
|
||||
[2, 5, 19, 4, 19, 1]])
|
||||
|
||||
wavelet = pywt.Wavelet('haar')
|
||||
|
||||
LL, (HL, LH, HH) = pywt.dwt2(data, wavelet)
|
||||
|
||||
# unexpected key
|
||||
d = {'aa': LL, 'da': HL, 'ad': LH, 'dd': HH, 'ff': LH}
|
||||
assert_raises(ValueError, pywt.idwtn, d, wavelet)
|
||||
|
||||
# mismatched key lengths
|
||||
d = {'a': LL, 'da': HL, 'ad': LH, 'dd': HH}
|
||||
assert_raises(ValueError, pywt.idwtn, d, wavelet)
|
||||
|
||||
|
||||
def test_error_mismatched_size():
|
||||
data = np.array([
|
||||
[0, 4, 1, 5, 1, 4],
|
||||
[0, 5, 6, 3, 2, 1],
|
||||
[2, 5, 19, 4, 19, 1]])
|
||||
|
||||
wavelet = pywt.Wavelet('haar')
|
||||
|
||||
LL, (HL, LH, HH) = pywt.dwt2(data, wavelet)
|
||||
|
||||
# Pass/fail depends on first element being shorter than remaining ones so
|
||||
# set 3/4 to an incorrect size to maximize chances. Order of dict items
|
||||
# is random so may not trigger on every test run. Dict is constructed
|
||||
# inside idwtn function so no use using an OrderedDict here.
|
||||
LL = LL[:, :-1]
|
||||
LH = LH[:, :-1]
|
||||
HH = HH[:, :-1]
|
||||
d = {'aa': LL, 'da': HL, 'ad': LH, 'dd': HH}
|
||||
|
||||
assert_raises(ValueError, pywt.idwtn, d, wavelet)
|
||||
|
||||
|
||||
def test_dwt2_idwt2_dtypes():
|
||||
wavelet = pywt.Wavelet('haar')
|
||||
for dt_in, dt_out in zip(dtypes_in, dtypes_out):
|
||||
x = np.ones((4, 4), dtype=dt_in)
|
||||
errmsg = "wrong dtype returned for {0} input".format(dt_in)
|
||||
|
||||
cA, (cH, cV, cD) = pywt.dwt2(x, wavelet)
|
||||
assert_(cA.dtype == cH.dtype == cV.dtype == cD.dtype,
|
||||
"dwt2: " + errmsg)
|
||||
|
||||
x_roundtrip = pywt.idwt2((cA, (cH, cV, cD)), wavelet)
|
||||
assert_(x_roundtrip.dtype == dt_out, "idwt2: " + errmsg)
|
||||
|
||||
|
||||
def test_dwtn_axes():
|
||||
data = np.array([[0, 1, 2, 3],
|
||||
[1, 1, 1, 1],
|
||||
[1, 4, 2, 8]])
|
||||
data = data + 1j*data # test with complex data
|
||||
coefs = pywt.dwtn(data, 'haar', axes=(1,))
|
||||
expected_a = list(map(lambda x: pywt.dwt(x, 'haar')[0], data))
|
||||
assert_equal(coefs['a'], expected_a)
|
||||
expected_d = list(map(lambda x: pywt.dwt(x, 'haar')[1], data))
|
||||
assert_equal(coefs['d'], expected_d)
|
||||
|
||||
coefs = pywt.dwtn(data, 'haar', axes=(1, 1))
|
||||
expected_aa = list(map(lambda x: pywt.dwt(x, 'haar')[0], expected_a))
|
||||
assert_equal(coefs['aa'], expected_aa)
|
||||
expected_ad = list(map(lambda x: pywt.dwt(x, 'haar')[1], expected_a))
|
||||
assert_equal(coefs['ad'], expected_ad)
|
||||
|
||||
|
||||
def test_idwtn_axes():
|
||||
data = np.array([[0, 1, 2, 3],
|
||||
[1, 1, 1, 1],
|
||||
[1, 4, 2, 8]])
|
||||
data = data + 1j*data # test with complex data
|
||||
coefs = pywt.dwtn(data, 'haar', axes=(1, 1))
|
||||
assert_allclose(pywt.idwtn(coefs, 'haar', axes=(1, 1)), data, atol=1e-14)
|
||||
|
||||
|
||||
def test_idwt2_none_coeffs():
|
||||
data = np.array([[0, 1, 2, 3],
|
||||
[1, 1, 1, 1],
|
||||
[1, 4, 2, 8]])
|
||||
data = data + 1j*data # test with complex data
|
||||
cA, (cH, cV, cD) = pywt.dwt2(data, 'haar', axes=(1, 1))
|
||||
|
||||
# verify setting coefficients to None is the same as zeroing them
|
||||
cD = np.zeros_like(cD)
|
||||
result_zeros = pywt.idwt2((cA, (cH, cV, cD)), 'haar', axes=(1, 1))
|
||||
|
||||
cD = None
|
||||
result_none = pywt.idwt2((cA, (cH, cV, cD)), 'haar', axes=(1, 1))
|
||||
|
||||
assert_equal(result_zeros, result_none)
|
||||
|
||||
|
||||
def test_idwtn_none_coeffs():
|
||||
data = np.array([[0, 1, 2, 3],
|
||||
[1, 1, 1, 1],
|
||||
[1, 4, 2, 8]])
|
||||
data = data + 1j*data # test with complex data
|
||||
coefs = pywt.dwtn(data, 'haar', axes=(1, 1))
|
||||
|
||||
# verify setting coefficients to None is the same as zeroing them
|
||||
coefs['dd'] = np.zeros_like(coefs['dd'])
|
||||
result_zeros = pywt.idwtn(coefs, 'haar', axes=(1, 1))
|
||||
|
||||
coefs['dd'] = None
|
||||
result_none = pywt.idwtn(coefs, 'haar', axes=(1, 1))
|
||||
|
||||
assert_equal(result_zeros, result_none)
|
||||
|
||||
|
||||
def test_idwt2_axes():
|
||||
data = np.array([[0, 1, 2, 3],
|
||||
[1, 1, 1, 1],
|
||||
[1, 4, 2, 8]])
|
||||
coefs = pywt.dwt2(data, 'haar', axes=(1, 1))
|
||||
assert_allclose(pywt.idwt2(coefs, 'haar', axes=(1, 1)), data, atol=1e-14)
|
||||
|
||||
# too many axes
|
||||
assert_raises(ValueError, pywt.idwt2, coefs, 'haar', axes=(0, 1, 1))
|
||||
|
||||
|
||||
def test_idwt2_axes_subsets():
|
||||
data = np.array(np.random.standard_normal((4, 4, 4)))
|
||||
# test all combinations of 2 out of 3 axes transformed
|
||||
for axes in combinations((0, 1, 2), 2):
|
||||
coefs = pywt.dwt2(data, 'haar', axes=axes)
|
||||
assert_allclose(pywt.idwt2(coefs, 'haar', axes=axes), data, atol=1e-14)
|
||||
|
||||
|
||||
def test_idwtn_axes_subsets():
|
||||
data = np.array(np.random.standard_normal((4, 4, 4, 4)))
|
||||
# test all combinations of 3 out of 4 axes transformed
|
||||
for axes in combinations((0, 1, 2, 3), 3):
|
||||
coefs = pywt.dwtn(data, 'haar', axes=axes)
|
||||
assert_allclose(pywt.idwtn(coefs, 'haar', axes=axes), data, atol=1e-14)
|
||||
|
||||
|
||||
def test_negative_axes():
|
||||
data = np.array([[0, 1, 2, 3],
|
||||
[1, 1, 1, 1],
|
||||
[1, 4, 2, 8]])
|
||||
coefs1 = pywt.dwtn(data, 'haar', axes=(1, 1))
|
||||
coefs2 = pywt.dwtn(data, 'haar', axes=(-1, -1))
|
||||
assert_equal(coefs1, coefs2)
|
||||
|
||||
rec1 = pywt.idwtn(coefs1, 'haar', axes=(1, 1))
|
||||
rec2 = pywt.idwtn(coefs1, 'haar', axes=(-1, -1))
|
||||
assert_equal(rec1, rec2)
|
||||
|
||||
|
||||
def test_dwtn_idwtn_dtypes():
|
||||
wavelet = pywt.Wavelet('haar')
|
||||
for dt_in, dt_out in zip(dtypes_in, dtypes_out):
|
||||
x = np.ones((4, 4), dtype=dt_in)
|
||||
errmsg = "wrong dtype returned for {0} input".format(dt_in)
|
||||
|
||||
coeffs = pywt.dwtn(x, wavelet)
|
||||
for k, v in coeffs.items():
|
||||
assert_(v.dtype == dt_out, "dwtn: " + errmsg)
|
||||
|
||||
x_roundtrip = pywt.idwtn(coeffs, wavelet)
|
||||
assert_(x_roundtrip.dtype == dt_out, "idwtn: " + errmsg)
|
||||
|
||||
|
||||
def test_idwtn_mixed_complex_dtype():
|
||||
rstate = np.random.RandomState(0)
|
||||
x = rstate.randn(8, 8, 8)
|
||||
x = x + 1j*x
|
||||
coeffs = pywt.dwtn(x, 'db2')
|
||||
|
||||
x_roundtrip = pywt.idwtn(coeffs, 'db2')
|
||||
assert_allclose(x_roundtrip, x, rtol=1e-10)
|
||||
|
||||
# mismatched dtypes OK
|
||||
coeffs['a' * x.ndim] = coeffs['a' * x.ndim].astype(np.complex64)
|
||||
x_roundtrip2 = pywt.idwtn(coeffs, 'db2')
|
||||
assert_allclose(x_roundtrip2, x, rtol=1e-7, atol=1e-7)
|
||||
assert_(x_roundtrip2.dtype == np.complex128)
|
||||
|
||||
|
||||
def test_idwt2_size_mismatch_error():
|
||||
LL = np.zeros((6, 6))
|
||||
LH = HL = HH = np.zeros((5, 5))
|
||||
|
||||
assert_raises(ValueError, pywt.idwt2, (LL, (LH, HL, HH)), wavelet='haar')
|
||||
|
||||
|
||||
def test_dwt2_dimension_error():
|
||||
data = np.ones(16)
|
||||
wavelet = pywt.Wavelet('haar')
|
||||
|
||||
# wrong number of input dimensions
|
||||
assert_raises(ValueError, pywt.dwt2, data, wavelet)
|
||||
|
||||
# too many axes
|
||||
data2 = np.ones((8, 8))
|
||||
assert_raises(ValueError, pywt.dwt2, data2, wavelet, axes=(0, 1, 1))
|
||||
|
||||
|
||||
def test_per_axis_wavelets_and_modes():
|
||||
# tests separate wavelet and edge mode for each axis.
|
||||
rstate = np.random.RandomState(1234)
|
||||
data = rstate.randn(16, 16, 16)
|
||||
|
||||
# wavelet can be a string or wavelet object
|
||||
wavelets = (pywt.Wavelet('haar'), 'sym2', 'db4')
|
||||
|
||||
# mode can be a string or a Modes enum
|
||||
modes = ('symmetric', 'periodization',
|
||||
pywt._extensions._pywt.Modes.reflect)
|
||||
|
||||
coefs = pywt.dwtn(data, wavelets, modes)
|
||||
assert_allclose(pywt.idwtn(coefs, wavelets, modes), data, atol=1e-14)
|
||||
|
||||
coefs = pywt.dwtn(data, wavelets[:1], modes)
|
||||
assert_allclose(pywt.idwtn(coefs, wavelets[:1], modes), data, atol=1e-14)
|
||||
|
||||
coefs = pywt.dwtn(data, wavelets, modes[:1])
|
||||
assert_allclose(pywt.idwtn(coefs, wavelets, modes[:1]), data, atol=1e-14)
|
||||
|
||||
# length of wavelets or modes doesn't match the length of axes
|
||||
assert_raises(ValueError, pywt.dwtn, data, wavelets[:2])
|
||||
assert_raises(ValueError, pywt.dwtn, data, wavelets, mode=modes[:2])
|
||||
assert_raises(ValueError, pywt.idwtn, coefs, wavelets[:2])
|
||||
assert_raises(ValueError, pywt.idwtn, coefs, wavelets, mode=modes[:2])
|
||||
|
||||
# dwt2/idwt2 also support per-axis wavelets/modes
|
||||
data2 = data[..., 0]
|
||||
coefs2 = pywt.dwt2(data2, wavelets[:2], modes[:2])
|
||||
assert_allclose(pywt.idwt2(coefs2, wavelets[:2], modes[:2]), data2,
|
||||
atol=1e-14)
|
||||
|
||||
|
||||
def test_error_on_continuous_wavelet():
|
||||
# A ValueError is raised if a Continuous wavelet is selected
|
||||
data = np.ones((16, 16))
|
||||
for dec_fun, rec_fun in zip([pywt.dwt2, pywt.dwtn],
|
||||
[pywt.idwt2, pywt.idwtn]):
|
||||
for cwave in ['morl', pywt.DiscreteContinuousWavelet('morl')]:
|
||||
assert_raises(ValueError, dec_fun, data, wavelet=cwave)
|
||||
|
||||
c = dec_fun(data, 'db1')
|
||||
assert_raises(ValueError, rec_fun, c, wavelet=cwave)
|
||||
1033
.CondaPkg/env/Lib/site-packages/pywt/tests/test_multilevel.py
vendored
Normal file
1033
.CondaPkg/env/Lib/site-packages/pywt/tests/test_multilevel.py
vendored
Normal file
File diff suppressed because it is too large
Load Diff
61
.CondaPkg/env/Lib/site-packages/pywt/tests/test_perfect_reconstruction.py
vendored
Normal file
61
.CondaPkg/env/Lib/site-packages/pywt/tests/test_perfect_reconstruction.py
vendored
Normal file
@@ -0,0 +1,61 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
"""
|
||||
Verify DWT perfect reconstruction.
|
||||
"""
|
||||
|
||||
from __future__ import division, print_function, absolute_import
|
||||
|
||||
import numpy as np
|
||||
from numpy.testing import assert_
|
||||
|
||||
import pywt
|
||||
|
||||
|
||||
def test_perfect_reconstruction():
|
||||
families = ('db', 'sym', 'coif', 'bior', 'rbio')
|
||||
wavelets = sum([pywt.wavelist(name) for name in families], [])
|
||||
# list of mode names in pywt and matlab
|
||||
modes = [('zero', 'zpd'),
|
||||
('constant', 'sp0'),
|
||||
('symmetric', 'sym'),
|
||||
('periodic', 'ppd'),
|
||||
('smooth', 'sp1'),
|
||||
('periodization', 'per')]
|
||||
|
||||
dtypes = (np.float32, np.float64)
|
||||
|
||||
for wavelet in wavelets:
|
||||
for pmode, mmode in modes:
|
||||
for dt in dtypes:
|
||||
check_reconstruction(pmode, mmode, wavelet, dt)
|
||||
|
||||
|
||||
def check_reconstruction(pmode, mmode, wavelet, dtype):
|
||||
data_size = list(range(2, 40)) + [100, 200, 500, 1000, 2000, 10000,
|
||||
50000, 100000]
|
||||
np.random.seed(12345)
|
||||
# TODO: smoke testing - more failures for different seeds
|
||||
|
||||
if dtype == np.float32:
|
||||
# was 3e-7 has to be lowered as db21, db29, db33, db35, coif14, coif16 were failing
|
||||
epsilon = 6e-7
|
||||
else:
|
||||
epsilon = 5e-11
|
||||
|
||||
for N in data_size:
|
||||
data = np.asarray(np.random.random(N), dtype)
|
||||
|
||||
# compute dwt coefficients
|
||||
pa, pd = pywt.dwt(data, wavelet, pmode)
|
||||
|
||||
# compute reconstruction
|
||||
rec = pywt.idwt(pa, pd, wavelet, pmode)
|
||||
|
||||
if len(data) % 2:
|
||||
rec = rec[:len(data)]
|
||||
|
||||
rms_rec = np.sqrt(np.mean((data-rec)**2))
|
||||
msg = ('[RMS_REC > EPSILON] for Mode: %s, Wavelet: %s, '
|
||||
'Length: %d, rms=%.3g' % (pmode, wavelet, len(data), rms_rec))
|
||||
assert_(rms_rec < epsilon, msg=msg)
|
||||
632
.CondaPkg/env/Lib/site-packages/pywt/tests/test_swt.py
vendored
Normal file
632
.CondaPkg/env/Lib/site-packages/pywt/tests/test_swt.py
vendored
Normal file
@@ -0,0 +1,632 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
from __future__ import division, print_function, absolute_import
|
||||
|
||||
import warnings
|
||||
from copy import deepcopy
|
||||
from itertools import combinations, permutations
|
||||
import numpy as np
|
||||
import pytest
|
||||
from numpy.testing import (assert_allclose, assert_, assert_equal,
|
||||
assert_raises, assert_array_equal, assert_warns)
|
||||
|
||||
import pywt
|
||||
from pywt._extensions._swt import swt_axis
|
||||
|
||||
# Check that float32 and complex64 are preserved. Other real types get
|
||||
# converted to float64.
|
||||
dtypes_in = [np.int8, np.float16, np.float32, np.float64, np.complex64,
|
||||
np.complex128]
|
||||
dtypes_out = [np.float64, np.float32, np.float32, np.float64, np.complex64,
|
||||
np.complex128]
|
||||
|
||||
# tolerances used in accuracy comparisons
|
||||
tol_single = 1e-6
|
||||
tol_double = 1e-13
|
||||
|
||||
####
|
||||
# 1d multilevel swt tests
|
||||
####
|
||||
|
||||
|
||||
def test_swt_decomposition():
|
||||
x = [3, 7, 1, 3, -2, 6, 4, 6]
|
||||
db1 = pywt.Wavelet('db1')
|
||||
atol = tol_double
|
||||
(cA3, cD3), (cA2, cD2), (cA1, cD1) = pywt.swt(x, db1, level=3)
|
||||
expected_cA1 = [7.07106781, 5.65685425, 2.82842712, 0.70710678,
|
||||
2.82842712, 7.07106781, 7.07106781, 6.36396103]
|
||||
assert_allclose(cA1, expected_cA1, rtol=1e-8, atol=atol)
|
||||
expected_cD1 = [-2.82842712, 4.24264069, -1.41421356, 3.53553391,
|
||||
-5.65685425, 1.41421356, -1.41421356, 2.12132034]
|
||||
assert_allclose(cD1, expected_cD1, rtol=1e-8, atol=atol)
|
||||
expected_cA2 = [7, 4.5, 4, 5.5, 7, 9.5, 10, 8.5]
|
||||
assert_allclose(cA2, expected_cA2, rtol=tol_double, atol=atol)
|
||||
expected_cD2 = [3, 3.5, 0, -4.5, -3, 0.5, 0, 0.5]
|
||||
assert_allclose(cD2, expected_cD2, rtol=tol_double, atol=atol)
|
||||
expected_cA3 = [9.89949494, ] * 8
|
||||
assert_allclose(cA3, expected_cA3, rtol=1e-8, atol=atol)
|
||||
expected_cD3 = [0.00000000, -3.53553391, -4.24264069, -2.12132034,
|
||||
0.00000000, 3.53553391, 4.24264069, 2.12132034]
|
||||
assert_allclose(cD3, expected_cD3, rtol=1e-8, atol=atol)
|
||||
|
||||
# level=1, start_level=1 decomposition should match level=2
|
||||
res = pywt.swt(cA1, db1, level=1, start_level=1)
|
||||
cA2, cD2 = res[0]
|
||||
assert_allclose(cA2, expected_cA2, rtol=tol_double, atol=atol)
|
||||
assert_allclose(cD2, expected_cD2, rtol=tol_double, atol=atol)
|
||||
|
||||
coeffs = pywt.swt(x, db1)
|
||||
assert_(len(coeffs) == 3)
|
||||
assert_(pywt.swt_max_level(len(x)), 3)
|
||||
|
||||
|
||||
def test_swt_max_level():
|
||||
# odd sized signal will warn about no levels of decomposition possible
|
||||
assert_warns(UserWarning, pywt.swt_max_level, 11)
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore', UserWarning)
|
||||
assert_equal(pywt.swt_max_level(11), 0)
|
||||
|
||||
# no warnings when >= 1 level of decomposition possible
|
||||
assert_equal(pywt.swt_max_level(2), 1) # divisible by 2**1
|
||||
assert_equal(pywt.swt_max_level(4*3), 2) # divisible by 2**2
|
||||
assert_equal(pywt.swt_max_level(16), 4) # divisible by 2**4
|
||||
assert_equal(pywt.swt_max_level(16*3), 4) # divisible by 2**4
|
||||
|
||||
|
||||
def test_swt_axis():
|
||||
x = [3, 7, 1, 3, -2, 6, 4, 6]
|
||||
|
||||
db1 = pywt.Wavelet('db1')
|
||||
(cA2, cD2), (cA1, cD1) = pywt.swt(x, db1, level=2)
|
||||
|
||||
# test cases use 2D arrays based on tiling x along an axis and then
|
||||
# calling swt along the other axis.
|
||||
for order in ['C', 'F']:
|
||||
# test SWT of 2D data along default axis (-1)
|
||||
x_2d = np.asarray(x).reshape((1, -1))
|
||||
x_2d = np.concatenate((x_2d, )*5, axis=0)
|
||||
if order == 'C':
|
||||
x_2d = np.ascontiguousarray(x_2d)
|
||||
elif order == 'F':
|
||||
x_2d = np.asfortranarray(x_2d)
|
||||
(cA2_2d, cD2_2d), (cA1_2d, cD1_2d) = pywt.swt(x_2d, db1, level=2)
|
||||
|
||||
for c in [cA2_2d, cD2_2d, cA1_2d, cD1_2d]:
|
||||
assert_(c.shape == x_2d.shape)
|
||||
# each row should match the 1D result
|
||||
for row in cA1_2d:
|
||||
assert_array_equal(row, cA1)
|
||||
for row in cA2_2d:
|
||||
assert_array_equal(row, cA2)
|
||||
for row in cD1_2d:
|
||||
assert_array_equal(row, cD1)
|
||||
for row in cD2_2d:
|
||||
assert_array_equal(row, cD2)
|
||||
|
||||
# test SWT of 2D data along other axis (0)
|
||||
x_2d = np.asarray(x).reshape((-1, 1))
|
||||
x_2d = np.concatenate((x_2d, )*5, axis=1)
|
||||
if order == 'C':
|
||||
x_2d = np.ascontiguousarray(x_2d)
|
||||
elif order == 'F':
|
||||
x_2d = np.asfortranarray(x_2d)
|
||||
(cA2_2d, cD2_2d), (cA1_2d, cD1_2d) = pywt.swt(x_2d, db1, level=2,
|
||||
axis=0)
|
||||
|
||||
for c in [cA2_2d, cD2_2d, cA1_2d, cD1_2d]:
|
||||
assert_(c.shape == x_2d.shape)
|
||||
# each column should match the 1D result
|
||||
for row in cA1_2d.transpose((1, 0)):
|
||||
assert_array_equal(row, cA1)
|
||||
for row in cA2_2d.transpose((1, 0)):
|
||||
assert_array_equal(row, cA2)
|
||||
for row in cD1_2d.transpose((1, 0)):
|
||||
assert_array_equal(row, cD1)
|
||||
for row in cD2_2d.transpose((1, 0)):
|
||||
assert_array_equal(row, cD2)
|
||||
|
||||
# axis too large
|
||||
assert_raises(ValueError, pywt.swt, x, db1, level=2, axis=5)
|
||||
|
||||
|
||||
def test_swt_iswt_integration():
|
||||
# This function performs a round-trip swt/iswt transform test on
|
||||
# all available types of wavelets in PyWavelets - except the
|
||||
# 'dmey' wavelet. The latter has been excluded because it does not
|
||||
# produce very precise results. This is likely due to the fact
|
||||
# that the 'dmey' wavelet is a discrete approximation of a
|
||||
# continuous wavelet. All wavelets are tested up to 3 levels. The
|
||||
# test validates neither swt or iswt as such, but it does ensure
|
||||
# that they are each other's inverse.
|
||||
|
||||
max_level = 3
|
||||
wavelets = pywt.wavelist(kind='discrete')
|
||||
if 'dmey' in wavelets:
|
||||
# The 'dmey' wavelet seems to be a bit special - disregard it for now
|
||||
wavelets.remove('dmey')
|
||||
for current_wavelet_str in wavelets:
|
||||
current_wavelet = pywt.Wavelet(current_wavelet_str)
|
||||
input_length_power = int(np.ceil(np.log2(max(
|
||||
current_wavelet.dec_len,
|
||||
current_wavelet.rec_len))))
|
||||
input_length = 2**(input_length_power + max_level - 1)
|
||||
X = np.arange(input_length)
|
||||
for norm in [True, False]:
|
||||
if norm and not current_wavelet.orthogonal:
|
||||
# non-orthogonal wavelets to avoid warnings when norm=True
|
||||
continue
|
||||
for trim_approx in [True, False]:
|
||||
coeffs = pywt.swt(X, current_wavelet, max_level,
|
||||
trim_approx=trim_approx, norm=norm)
|
||||
Y = pywt.iswt(coeffs, current_wavelet, norm=norm)
|
||||
assert_allclose(Y, X, rtol=1e-5, atol=1e-7)
|
||||
|
||||
|
||||
def test_swt_dtypes():
|
||||
wavelet = pywt.Wavelet('haar')
|
||||
for dt_in, dt_out in zip(dtypes_in, dtypes_out):
|
||||
errmsg = "wrong dtype returned for {0} input".format(dt_in)
|
||||
|
||||
# swt
|
||||
x = np.ones(8, dtype=dt_in)
|
||||
(cA2, cD2), (cA1, cD1) = pywt.swt(x, wavelet, level=2)
|
||||
assert_(cA2.dtype == cD2.dtype == cA1.dtype == cD1.dtype == dt_out,
|
||||
"swt: " + errmsg)
|
||||
|
||||
# swt2
|
||||
x = np.ones((8, 8), dtype=dt_in)
|
||||
cA, (cH, cV, cD) = pywt.swt2(x, wavelet, level=1)[0]
|
||||
assert_(cA.dtype == cH.dtype == cV.dtype == cD.dtype == dt_out,
|
||||
"swt2: " + errmsg)
|
||||
|
||||
|
||||
def test_swt_roundtrip_dtypes():
|
||||
# verify perfect reconstruction for all dtypes
|
||||
rstate = np.random.RandomState(5)
|
||||
wavelet = pywt.Wavelet('haar')
|
||||
for dt_in, dt_out in zip(dtypes_in, dtypes_out):
|
||||
# swt, iswt
|
||||
x = rstate.standard_normal((8, )).astype(dt_in)
|
||||
c = pywt.swt(x, wavelet, level=2)
|
||||
xr = pywt.iswt(c, wavelet)
|
||||
assert_allclose(x, xr, rtol=1e-6, atol=1e-7)
|
||||
|
||||
# swt2, iswt2
|
||||
x = rstate.standard_normal((8, 8)).astype(dt_in)
|
||||
c = pywt.swt2(x, wavelet, level=2)
|
||||
xr = pywt.iswt2(c, wavelet)
|
||||
assert_allclose(x, xr, rtol=1e-6, atol=1e-7)
|
||||
|
||||
|
||||
def test_swt_default_level_by_axis():
|
||||
# make sure default number of levels matches the max level along the axis
|
||||
wav = 'db2'
|
||||
x = np.ones((2**3, 2**4, 2**5))
|
||||
for axis in (0, 1, 2):
|
||||
sdec = pywt.swt(x, wav, level=None, start_level=0, axis=axis)
|
||||
assert_equal(len(sdec), pywt.swt_max_level(x.shape[axis]))
|
||||
|
||||
|
||||
def test_swt2_ndim_error():
|
||||
x = np.ones(8)
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore', FutureWarning)
|
||||
assert_raises(ValueError, pywt.swt2, x, 'haar', level=1)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_swt2_iswt2_integration(wavelets=None):
|
||||
# This function performs a round-trip swt2/iswt2 transform test on
|
||||
# all available types of wavelets in PyWavelets - except the
|
||||
# 'dmey' wavelet. The latter has been excluded because it does not
|
||||
# produce very precise results. This is likely due to the fact
|
||||
# that the 'dmey' wavelet is a discrete approximation of a
|
||||
# continuous wavelet. All wavelets are tested up to 3 levels. The
|
||||
# test validates neither swt2 or iswt2 as such, but it does ensure
|
||||
# that they are each other's inverse.
|
||||
|
||||
max_level = 3
|
||||
if wavelets is None:
|
||||
wavelets = pywt.wavelist(kind='discrete')
|
||||
if 'dmey' in wavelets:
|
||||
# The 'dmey' wavelet is a special case - disregard it for now
|
||||
wavelets.remove('dmey')
|
||||
for current_wavelet_str in wavelets:
|
||||
current_wavelet = pywt.Wavelet(current_wavelet_str)
|
||||
input_length_power = int(np.ceil(np.log2(max(
|
||||
current_wavelet.dec_len,
|
||||
current_wavelet.rec_len))))
|
||||
input_length = 2**(input_length_power + max_level - 1)
|
||||
X = np.arange(input_length**2).reshape(input_length, input_length)
|
||||
|
||||
for norm in [True, False]:
|
||||
if norm and not current_wavelet.orthogonal:
|
||||
# non-orthogonal wavelets to avoid warnings when norm=True
|
||||
continue
|
||||
for trim_approx in [True, False]:
|
||||
coeffs = pywt.swt2(X, current_wavelet, max_level,
|
||||
trim_approx=trim_approx, norm=norm)
|
||||
Y = pywt.iswt2(coeffs, current_wavelet, norm=norm)
|
||||
assert_allclose(Y, X, rtol=1e-5, atol=1e-5)
|
||||
|
||||
|
||||
def test_swt2_iswt2_quick():
|
||||
test_swt2_iswt2_integration(wavelets=['db1', ])
|
||||
|
||||
|
||||
def test_swt2_iswt2_non_square(wavelets=None):
|
||||
for nrows in [8, 16, 48]:
|
||||
X = np.arange(nrows*32).reshape(nrows, 32)
|
||||
current_wavelet = 'db1'
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore', FutureWarning)
|
||||
coeffs = pywt.swt2(X, current_wavelet, level=2)
|
||||
Y = pywt.iswt2(coeffs, current_wavelet)
|
||||
assert_allclose(Y, X, rtol=tol_single, atol=tol_single)
|
||||
|
||||
|
||||
def test_swt2_axes():
|
||||
atol = 1e-14
|
||||
current_wavelet = pywt.Wavelet('db2')
|
||||
input_length_power = int(np.ceil(np.log2(max(
|
||||
current_wavelet.dec_len,
|
||||
current_wavelet.rec_len))))
|
||||
input_length = 2**(input_length_power)
|
||||
X = np.arange(input_length**2).reshape(input_length, input_length)
|
||||
|
||||
(cA1, (cH1, cV1, cD1)) = pywt.swt2(X, current_wavelet, level=1)[0]
|
||||
# opposite order
|
||||
(cA2, (cH2, cV2, cD2)) = pywt.swt2(X, current_wavelet, level=1,
|
||||
axes=(1, 0))[0]
|
||||
assert_allclose(cA1, cA2, atol=atol)
|
||||
assert_allclose(cH1, cV2, atol=atol)
|
||||
assert_allclose(cV1, cH2, atol=atol)
|
||||
assert_allclose(cD1, cD2, atol=atol)
|
||||
|
||||
# reverify iswt2 restores the original data
|
||||
r1 = pywt.iswt2([cA1, (cH1, cV1, cD1)], current_wavelet)
|
||||
assert_allclose(X, r1, atol=atol)
|
||||
r2 = pywt.iswt2([cA2, (cH2, cV2, cD2)], current_wavelet, axes=(1, 0))
|
||||
assert_allclose(X, r2, atol=atol)
|
||||
|
||||
# duplicate axes not allowed
|
||||
assert_raises(ValueError, pywt.swt2, X, current_wavelet, 1,
|
||||
axes=(0, 0))
|
||||
# too few axes
|
||||
assert_raises(ValueError, pywt.swt2, X, current_wavelet, 1, axes=(0, ))
|
||||
|
||||
|
||||
def test_swtn_axes():
|
||||
atol = 1e-14
|
||||
current_wavelet = pywt.Wavelet('db2')
|
||||
input_length_power = int(np.ceil(np.log2(max(
|
||||
current_wavelet.dec_len,
|
||||
current_wavelet.rec_len))))
|
||||
input_length = 2**(input_length_power)
|
||||
X = np.arange(input_length**2).reshape(input_length, input_length)
|
||||
coeffs = pywt.swtn(X, current_wavelet, level=1, axes=None)[0]
|
||||
# opposite order
|
||||
coeffs2 = pywt.swtn(X, current_wavelet, level=1, axes=(1, 0))[0]
|
||||
assert_allclose(coeffs['aa'], coeffs2['aa'], atol=atol)
|
||||
assert_allclose(coeffs['ad'], coeffs2['da'], atol=atol)
|
||||
assert_allclose(coeffs['da'], coeffs2['ad'], atol=atol)
|
||||
assert_allclose(coeffs['dd'], coeffs2['dd'], atol=atol)
|
||||
|
||||
# 0-level transform
|
||||
empty = pywt.swtn(X, current_wavelet, level=0)
|
||||
assert_equal(empty, [])
|
||||
|
||||
# duplicate axes not allowed
|
||||
assert_raises(ValueError, pywt.swtn, X, current_wavelet, 1, axes=(0, 0))
|
||||
|
||||
# data.ndim = 0
|
||||
assert_raises(ValueError, pywt.swtn, np.asarray([]), current_wavelet, 1)
|
||||
|
||||
# start_level too large
|
||||
assert_raises(ValueError, pywt.swtn, X, current_wavelet,
|
||||
level=1, start_level=2)
|
||||
|
||||
# level < 1 in swt_axis call
|
||||
assert_raises(ValueError, swt_axis, X, current_wavelet, level=0,
|
||||
start_level=0)
|
||||
# odd-sized data not allowed
|
||||
assert_raises(ValueError, swt_axis, X[:-1, :], current_wavelet, level=0,
|
||||
start_level=0, axis=0)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_swtn_iswtn_integration(wavelets=None):
|
||||
# This function performs a round-trip swtn/iswtn transform for various
|
||||
# possible combinations of:
|
||||
# 1.) 1 out of 2 axes of a 2D array
|
||||
# 2.) 2 out of 3 axes of a 3D array
|
||||
#
|
||||
# To keep test time down, only wavelets of length <= 8 are run.
|
||||
#
|
||||
# This test does not validate swtn or iswtn individually, but only
|
||||
# confirms that iswtn yields an (almost) perfect reconstruction of swtn.
|
||||
max_level = 3
|
||||
if wavelets is None:
|
||||
wavelets = pywt.wavelist(kind='discrete')
|
||||
if 'dmey' in wavelets:
|
||||
# The 'dmey' wavelet is a special case - disregard it for now
|
||||
wavelets.remove('dmey')
|
||||
for ndim_transform in range(1, 3):
|
||||
ndim = ndim_transform + 1
|
||||
for axes in combinations(range(ndim), ndim_transform):
|
||||
for current_wavelet_str in wavelets:
|
||||
wav = pywt.Wavelet(current_wavelet_str)
|
||||
if wav.dec_len > 8:
|
||||
continue # avoid excessive test duration
|
||||
input_length_power = int(np.ceil(np.log2(max(
|
||||
wav.dec_len,
|
||||
wav.rec_len))))
|
||||
N = 2**(input_length_power + max_level - 1)
|
||||
X = np.arange(N**ndim).reshape((N, )*ndim)
|
||||
|
||||
for norm in [True, False]:
|
||||
if norm and not wav.orthogonal:
|
||||
# non-orthogonal wavelets to avoid warnings
|
||||
continue
|
||||
for trim_approx in [True, False]:
|
||||
coeffs = pywt.swtn(X, wav, max_level, axes=axes,
|
||||
trim_approx=trim_approx, norm=norm)
|
||||
coeffs_copy = deepcopy(coeffs)
|
||||
Y = pywt.iswtn(coeffs, wav, axes=axes, norm=norm)
|
||||
assert_allclose(Y, X, rtol=1e-5, atol=1e-5)
|
||||
|
||||
# verify the inverse transform didn't modify any coeffs
|
||||
for c, c2 in zip(coeffs, coeffs_copy):
|
||||
for k, v in c.items():
|
||||
assert_array_equal(c2[k], v)
|
||||
|
||||
|
||||
def test_swtn_iswtn_quick():
|
||||
test_swtn_iswtn_integration(wavelets=['db1', ])
|
||||
|
||||
|
||||
def test_iswtn_errors():
|
||||
x = np.arange(8**3).reshape(8, 8, 8)
|
||||
max_level = 2
|
||||
axes = (0, 1)
|
||||
w = pywt.Wavelet('db1')
|
||||
coeffs = pywt.swtn(x, w, max_level, axes=axes)
|
||||
|
||||
# more axes than dimensions transformed
|
||||
assert_raises(ValueError, pywt.iswtn, coeffs, w, axes=(0, 1, 2))
|
||||
# duplicate axes not allowed
|
||||
assert_raises(ValueError, pywt.iswtn, coeffs, w, axes=(0, 0))
|
||||
# mismatched coefficient size
|
||||
coeffs[0]['da'] = coeffs[0]['da'][:-1, :]
|
||||
assert_raises(RuntimeError, pywt.iswtn, coeffs, w, axes=axes)
|
||||
|
||||
|
||||
def test_swtn_iswtn_unique_shape_per_axis():
|
||||
# test case for gh-460
|
||||
_shape = (1, 48, 32) # unique shape per axis
|
||||
wav = 'sym2'
|
||||
max_level = 3
|
||||
rstate = np.random.RandomState(0)
|
||||
for shape in permutations(_shape):
|
||||
# transform only along the non-singleton axes
|
||||
axes = [ax for ax, s in enumerate(shape) if s != 1]
|
||||
x = rstate.standard_normal(shape)
|
||||
c = pywt.swtn(x, wav, max_level, axes=axes)
|
||||
r = pywt.iswtn(c, wav, axes=axes)
|
||||
assert_allclose(x, r, rtol=1e-10, atol=1e-10)
|
||||
|
||||
|
||||
def test_per_axis_wavelets():
|
||||
# tests separate wavelet for each axis.
|
||||
rstate = np.random.RandomState(1234)
|
||||
data = rstate.randn(16, 16, 16)
|
||||
level = 3
|
||||
|
||||
# wavelet can be a string or wavelet object
|
||||
wavelets = (pywt.Wavelet('haar'), 'sym2', 'db4')
|
||||
|
||||
coefs = pywt.swtn(data, wavelets, level=level)
|
||||
assert_allclose(pywt.iswtn(coefs, wavelets), data, atol=1e-14)
|
||||
|
||||
# 1-tuple also okay
|
||||
coefs = pywt.swtn(data, wavelets[:1], level=level)
|
||||
assert_allclose(pywt.iswtn(coefs, wavelets[:1]), data, atol=1e-14)
|
||||
|
||||
# length of wavelets doesn't match the length of axes
|
||||
assert_raises(ValueError, pywt.swtn, data, wavelets[:2], level)
|
||||
assert_raises(ValueError, pywt.iswtn, coefs, wavelets[:2])
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore', FutureWarning)
|
||||
# swt2/iswt2 also support per-axis wavelets/modes
|
||||
data2 = data[..., 0]
|
||||
coefs2 = pywt.swt2(data2, wavelets[:2], level)
|
||||
assert_allclose(pywt.iswt2(coefs2, wavelets[:2]), data2, atol=1e-14)
|
||||
|
||||
|
||||
def test_error_on_continuous_wavelet():
|
||||
# A ValueError is raised if a Continuous wavelet is selected
|
||||
data = np.ones((16, 16))
|
||||
for dec_func, rec_func in zip([pywt.swt, pywt.swt2, pywt.swtn],
|
||||
[pywt.iswt, pywt.iswt2, pywt.iswtn]):
|
||||
for cwave in ['morl', pywt.DiscreteContinuousWavelet('morl')]:
|
||||
assert_raises(ValueError, dec_func, data, wavelet=cwave,
|
||||
level=3)
|
||||
|
||||
c = dec_func(data, 'db1', level=3)
|
||||
assert_raises(ValueError, rec_func, c, wavelet=cwave)
|
||||
|
||||
|
||||
def test_iswt_mixed_dtypes():
|
||||
# Mixed precision inputs give double precision output
|
||||
x_real = np.arange(16).astype(np.float64)
|
||||
x_complex = x_real + 1j*x_real
|
||||
wav = 'sym2'
|
||||
for dtype1, dtype2 in [(np.float64, np.float32),
|
||||
(np.float32, np.float64),
|
||||
(np.float16, np.float64),
|
||||
(np.complex128, np.complex64),
|
||||
(np.complex64, np.complex128)]:
|
||||
|
||||
if dtype1 in [np.complex64, np.complex128]:
|
||||
x = x_complex
|
||||
output_dtype = np.complex128
|
||||
else:
|
||||
x = x_real
|
||||
output_dtype = np.float64
|
||||
|
||||
coeffs = pywt.swt(x, wav, 2)
|
||||
# different precision for the approximation coefficients
|
||||
coeffs[0] = [coeffs[0][0].astype(dtype1),
|
||||
coeffs[0][1].astype(dtype2)]
|
||||
y = pywt.iswt(coeffs, wav)
|
||||
assert_equal(output_dtype, y.dtype)
|
||||
assert_allclose(y, x, rtol=1e-3, atol=1e-3)
|
||||
|
||||
|
||||
def test_iswt2_mixed_dtypes():
|
||||
# Mixed precision inputs give double precision output
|
||||
rstate = np.random.RandomState(0)
|
||||
x_real = rstate.randn(8, 8)
|
||||
x_complex = x_real + 1j*x_real
|
||||
wav = 'sym2'
|
||||
for dtype1, dtype2 in [(np.float64, np.float32),
|
||||
(np.float32, np.float64),
|
||||
(np.float16, np.float64),
|
||||
(np.complex128, np.complex64),
|
||||
(np.complex64, np.complex128)]:
|
||||
|
||||
if dtype1 in [np.complex64, np.complex128]:
|
||||
x = x_complex
|
||||
output_dtype = np.complex128
|
||||
else:
|
||||
x = x_real
|
||||
output_dtype = np.float64
|
||||
|
||||
coeffs = pywt.swt2(x, wav, 2)
|
||||
# different precision for the approximation coefficients
|
||||
coeffs[0] = [coeffs[0][0].astype(dtype1),
|
||||
tuple([c.astype(dtype2) for c in coeffs[0][1]])]
|
||||
y = pywt.iswt2(coeffs, wav)
|
||||
assert_equal(output_dtype, y.dtype)
|
||||
assert_allclose(y, x, rtol=1e-3, atol=1e-3)
|
||||
|
||||
|
||||
def test_iswtn_mixed_dtypes():
|
||||
# Mixed precision inputs give double precision output
|
||||
rstate = np.random.RandomState(0)
|
||||
x_real = rstate.randn(8, 8, 8)
|
||||
x_complex = x_real + 1j*x_real
|
||||
wav = 'sym2'
|
||||
for dtype1, dtype2 in [(np.float64, np.float32),
|
||||
(np.float32, np.float64),
|
||||
(np.float16, np.float64),
|
||||
(np.complex128, np.complex64),
|
||||
(np.complex64, np.complex128)]:
|
||||
|
||||
if dtype1 in [np.complex64, np.complex128]:
|
||||
x = x_complex
|
||||
output_dtype = np.complex128
|
||||
else:
|
||||
x = x_real
|
||||
output_dtype = np.float64
|
||||
|
||||
coeffs = pywt.swtn(x, wav, 2)
|
||||
# different precision for the approximation coefficients
|
||||
a = coeffs[0].pop('a' * x.ndim)
|
||||
a = a.astype(dtype1)
|
||||
coeffs[0] = {k: c.astype(dtype2) for k, c in coeffs[0].items()}
|
||||
coeffs[0]['a' * x.ndim] = a
|
||||
y = pywt.iswtn(coeffs, wav)
|
||||
assert_equal(output_dtype, y.dtype)
|
||||
assert_allclose(y, x, rtol=1e-3, atol=1e-3)
|
||||
|
||||
|
||||
def test_swt_zero_size_axes():
|
||||
# raise on empty input array
|
||||
assert_raises(ValueError, pywt.swt, [], 'db2')
|
||||
|
||||
# >1D case uses a different code path so check there as well
|
||||
x = np.ones((1, 4))[0:0, :] # 2D with a size zero axis
|
||||
assert_raises(ValueError, pywt.swtn, x, 'db2', level=1, axes=(0,))
|
||||
|
||||
|
||||
def test_swt_variance_and_energy_preservation():
|
||||
"""Verify that the 1D SWT partitions variance among the coefficients."""
|
||||
# When norm is True and the wavelet is orthogonal, the sum of the
|
||||
# variances of the coefficients should equal the variance of the signal.
|
||||
wav = 'db2'
|
||||
rstate = np.random.RandomState(5)
|
||||
x = rstate.randn(256)
|
||||
coeffs = pywt.swt(x, wav, trim_approx=True, norm=True)
|
||||
variances = [np.var(c) for c in coeffs]
|
||||
assert_allclose(np.sum(variances), np.var(x))
|
||||
|
||||
# also verify L2-norm energy preservation property
|
||||
assert_allclose(np.linalg.norm(x),
|
||||
np.linalg.norm(np.concatenate(coeffs)))
|
||||
|
||||
# non-orthogonal wavelet with norm=True raises a warning
|
||||
assert_warns(UserWarning, pywt.swt, x, 'bior2.2', norm=True)
|
||||
|
||||
|
||||
def test_swt2_variance_and_energy_preservation():
|
||||
"""Verify that the 2D SWT partitions variance among the coefficients."""
|
||||
# When norm is True and the wavelet is orthogonal, the sum of the
|
||||
# variances of the coefficients should equal the variance of the signal.
|
||||
wav = 'db2'
|
||||
rstate = np.random.RandomState(5)
|
||||
x = rstate.randn(64, 64)
|
||||
coeffs = pywt.swt2(x, wav, level=4, trim_approx=True, norm=True)
|
||||
coeff_list = [coeffs[0].ravel()]
|
||||
for d in coeffs[1:]:
|
||||
for v in d:
|
||||
coeff_list.append(v.ravel())
|
||||
variances = [np.var(v) for v in coeff_list]
|
||||
assert_allclose(np.sum(variances), np.var(x))
|
||||
|
||||
# also verify L2-norm energy preservation property
|
||||
assert_allclose(np.linalg.norm(x),
|
||||
np.linalg.norm(np.concatenate(coeff_list)))
|
||||
|
||||
# non-orthogonal wavelet with norm=True raises a warning
|
||||
assert_warns(UserWarning, pywt.swt2, x, 'bior2.2', level=4, norm=True)
|
||||
|
||||
|
||||
def test_swtn_variance_and_energy_preservation():
|
||||
"""Verify that the nD SWT partitions variance among the coefficients."""
|
||||
# When norm is True and the wavelet is orthogonal, the sum of the
|
||||
# variances of the coefficients should equal the variance of the signal.
|
||||
wav = 'db2'
|
||||
rstate = np.random.RandomState(5)
|
||||
x = rstate.randn(64, 64)
|
||||
coeffs = pywt.swtn(x, wav, level=4, trim_approx=True, norm=True)
|
||||
coeff_list = [coeffs[0].ravel()]
|
||||
for d in coeffs[1:]:
|
||||
for k, v in d.items():
|
||||
coeff_list.append(v.ravel())
|
||||
variances = [np.var(v) for v in coeff_list]
|
||||
assert_allclose(np.sum(variances), np.var(x))
|
||||
|
||||
# also verify L2-norm energy preservation property
|
||||
assert_allclose(np.linalg.norm(x),
|
||||
np.linalg.norm(np.concatenate(coeff_list)))
|
||||
|
||||
# non-orthogonal wavelet with norm=True raises a warning
|
||||
assert_warns(UserWarning, pywt.swtn, x, 'bior2.2', level=4, norm=True)
|
||||
|
||||
|
||||
def test_swt_ravel_and_unravel():
|
||||
# When trim_approx=True, all swt functions can user pywt.ravel_coeffs
|
||||
for ndim, _swt, _iswt, ravel_type in [
|
||||
(1, pywt.swt, pywt.iswt, 'swt'),
|
||||
(2, pywt.swt2, pywt.iswt2, 'swt2'),
|
||||
(3, pywt.swtn, pywt.iswtn, 'swtn')]:
|
||||
x = np.ones((16, ) * ndim)
|
||||
c = _swt(x, 'sym2', level=3, trim_approx=True)
|
||||
arr, slices, shapes = pywt.ravel_coeffs(c)
|
||||
c = pywt.unravel_coeffs(arr, slices, shapes, output_format=ravel_type)
|
||||
r = _iswt(c, 'sym2')
|
||||
assert_allclose(x, r)
|
||||
169
.CondaPkg/env/Lib/site-packages/pywt/tests/test_thresholding.py
vendored
Normal file
169
.CondaPkg/env/Lib/site-packages/pywt/tests/test_thresholding.py
vendored
Normal file
@@ -0,0 +1,169 @@
|
||||
from __future__ import division, print_function, absolute_import
|
||||
import numpy as np
|
||||
from numpy.testing import assert_allclose, assert_raises, assert_, assert_equal
|
||||
|
||||
import pywt
|
||||
|
||||
|
||||
float_dtypes = [np.float32, np.float64, np.complex64, np.complex128]
|
||||
real_dtypes = [np.float32, np.float64]
|
||||
|
||||
|
||||
def _sign(x):
|
||||
# Matlab-like sign function (numpy uses a different convention).
|
||||
return x / np.abs(x)
|
||||
|
||||
|
||||
def _soft(x, thresh):
|
||||
"""soft thresholding supporting complex values.
|
||||
|
||||
Notes
|
||||
-----
|
||||
This version is not robust to zeros in x.
|
||||
"""
|
||||
return _sign(x) * np.maximum(np.abs(x) - thresh, 0)
|
||||
|
||||
|
||||
def test_threshold():
|
||||
data = np.linspace(1, 4, 7)
|
||||
|
||||
# soft
|
||||
soft_result = [0., 0., 0., 0.5, 1., 1.5, 2.]
|
||||
assert_allclose(pywt.threshold(data, 2, 'soft'),
|
||||
np.array(soft_result), rtol=1e-12)
|
||||
assert_allclose(pywt.threshold(-data, 2, 'soft'),
|
||||
-np.array(soft_result), rtol=1e-12)
|
||||
assert_allclose(pywt.threshold([[1, 2]] * 2, 1, 'soft'),
|
||||
[[0, 1]] * 2, rtol=1e-12)
|
||||
assert_allclose(pywt.threshold([[1, 2]] * 2, 2, 'soft'),
|
||||
[[0, 0]] * 2, rtol=1e-12)
|
||||
|
||||
# soft thresholding complex values
|
||||
assert_allclose(pywt.threshold([[1j, 2j]] * 2, 1, 'soft'),
|
||||
[[0j, 1j]] * 2, rtol=1e-12)
|
||||
assert_allclose(pywt.threshold([[1+1j, 2+2j]] * 2, 6, 'soft'),
|
||||
[[0, 0]] * 2, rtol=1e-12)
|
||||
complex_data = [[1+2j, 2+2j]]*2
|
||||
for thresh in [1, 2]:
|
||||
assert_allclose(pywt.threshold(complex_data, thresh, 'soft'),
|
||||
_soft(complex_data, thresh), rtol=1e-12)
|
||||
|
||||
# test soft thresholding with non-default substitute argument
|
||||
s = 5
|
||||
assert_allclose(pywt.threshold([[1j, 2]] * 2, 1.5, 'soft', substitute=s),
|
||||
[[s, 0.5]] * 2, rtol=1e-12)
|
||||
|
||||
# soft: no divide by zero warnings when input contains zeros
|
||||
assert_allclose(pywt.threshold(np.zeros(16), 2, 'soft'),
|
||||
np.zeros(16), rtol=1e-12)
|
||||
|
||||
# hard
|
||||
hard_result = [0., 0., 2., 2.5, 3., 3.5, 4.]
|
||||
assert_allclose(pywt.threshold(data, 2, 'hard'),
|
||||
np.array(hard_result), rtol=1e-12)
|
||||
assert_allclose(pywt.threshold(-data, 2, 'hard'),
|
||||
-np.array(hard_result), rtol=1e-12)
|
||||
assert_allclose(pywt.threshold([[1, 2]] * 2, 1, 'hard'),
|
||||
[[1, 2]] * 2, rtol=1e-12)
|
||||
assert_allclose(pywt.threshold([[1, 2]] * 2, 2, 'hard'),
|
||||
[[0, 2]] * 2, rtol=1e-12)
|
||||
assert_allclose(pywt.threshold([[1, 2]] * 2, 2, 'hard', substitute=s),
|
||||
[[s, 2]] * 2, rtol=1e-12)
|
||||
assert_allclose(pywt.threshold([[1+1j, 2+2j]] * 2, 2, 'hard'),
|
||||
[[0, 2+2j]] * 2, rtol=1e-12)
|
||||
|
||||
# greater
|
||||
greater_result = [0., 0., 2., 2.5, 3., 3.5, 4.]
|
||||
assert_allclose(pywt.threshold(data, 2, 'greater'),
|
||||
np.array(greater_result), rtol=1e-12)
|
||||
assert_allclose(pywt.threshold([[1, 2]] * 2, 1, 'greater'),
|
||||
[[1, 2]] * 2, rtol=1e-12)
|
||||
assert_allclose(pywt.threshold([[1, 2]] * 2, 2, 'greater'),
|
||||
[[0, 2]] * 2, rtol=1e-12)
|
||||
assert_allclose(pywt.threshold([[1, 2]] * 2, 2, 'greater', substitute=s),
|
||||
[[s, 2]] * 2, rtol=1e-12)
|
||||
# greater doesn't allow complex-valued inputs
|
||||
assert_raises(ValueError, pywt.threshold, [1j, 2j], 2, 'greater')
|
||||
|
||||
# less
|
||||
assert_allclose(pywt.threshold(data, 2, 'less'),
|
||||
np.array([1., 1.5, 2., 0., 0., 0., 0.]), rtol=1e-12)
|
||||
assert_allclose(pywt.threshold([[1, 2]] * 2, 1, 'less'),
|
||||
[[1, 0]] * 2, rtol=1e-12)
|
||||
assert_allclose(pywt.threshold([[1, 2]] * 2, 1, 'less', substitute=s),
|
||||
[[1, s]] * 2, rtol=1e-12)
|
||||
assert_allclose(pywt.threshold([[1, 2]] * 2, 2, 'less'),
|
||||
[[1, 2]] * 2, rtol=1e-12)
|
||||
|
||||
# less doesn't allow complex-valued inputs
|
||||
assert_raises(ValueError, pywt.threshold, [1j, 2j], 2, 'less')
|
||||
|
||||
# invalid
|
||||
assert_raises(ValueError, pywt.threshold, data, 2, 'foo')
|
||||
|
||||
|
||||
def test_nonnegative_garotte():
|
||||
thresh = 0.3
|
||||
data_real = np.linspace(-1, 1, 100)
|
||||
for dtype in float_dtypes:
|
||||
if dtype in real_dtypes:
|
||||
data = np.asarray(data_real, dtype=dtype)
|
||||
else:
|
||||
data = np.asarray(data_real + 0.1j, dtype=dtype)
|
||||
d_hard = pywt.threshold(data, thresh, 'hard')
|
||||
d_soft = pywt.threshold(data, thresh, 'soft')
|
||||
d_garotte = pywt.threshold(data, thresh, 'garotte')
|
||||
|
||||
# check dtypes
|
||||
assert_equal(d_hard.dtype, data.dtype)
|
||||
assert_equal(d_soft.dtype, data.dtype)
|
||||
assert_equal(d_garotte.dtype, data.dtype)
|
||||
|
||||
# values < threshold are zero
|
||||
lt = np.where(np.abs(data) < thresh)
|
||||
assert_(np.all(d_garotte[lt] == 0))
|
||||
|
||||
# values > than the threshold are intermediate between soft and hard
|
||||
gt = np.where(np.abs(data) > thresh)
|
||||
gt_abs_garotte = np.abs(d_garotte[gt])
|
||||
assert_(np.all(gt_abs_garotte < np.abs(d_hard[gt])))
|
||||
assert_(np.all(gt_abs_garotte > np.abs(d_soft[gt])))
|
||||
|
||||
|
||||
def test_threshold_firm():
|
||||
thresh = 0.2
|
||||
thresh2 = 3 * thresh
|
||||
data_real = np.linspace(-1, 1, 100)
|
||||
for dtype in float_dtypes:
|
||||
if dtype in real_dtypes:
|
||||
data = np.asarray(data_real, dtype=dtype)
|
||||
else:
|
||||
data = np.asarray(data_real + 0.1j, dtype=dtype)
|
||||
if data.real.dtype == np.float32:
|
||||
rtol = atol = 1e-6
|
||||
else:
|
||||
rtol = atol = 1e-14
|
||||
d_hard = pywt.threshold(data, thresh, 'hard')
|
||||
d_soft = pywt.threshold(data, thresh, 'soft')
|
||||
d_firm = pywt.threshold_firm(data, thresh, thresh2)
|
||||
|
||||
# check dtypes
|
||||
assert_equal(d_hard.dtype, data.dtype)
|
||||
assert_equal(d_soft.dtype, data.dtype)
|
||||
assert_equal(d_firm.dtype, data.dtype)
|
||||
|
||||
# values < threshold are zero
|
||||
lt = np.where(np.abs(data) < thresh)
|
||||
assert_(np.all(d_firm[lt] == 0))
|
||||
|
||||
# values > than the threshold are equal to hard-thresholding
|
||||
gt = np.where(np.abs(data) >= thresh2)
|
||||
assert_allclose(np.abs(d_hard[gt]), np.abs(d_firm[gt]),
|
||||
rtol=rtol, atol=atol)
|
||||
|
||||
# other values are intermediate between soft and hard thresholding
|
||||
mt = np.where(np.logical_and(np.abs(data) > thresh,
|
||||
np.abs(data) < thresh2))
|
||||
mt_abs_firm = np.abs(d_firm[mt])
|
||||
assert_(np.all(mt_abs_firm < np.abs(d_hard[mt])))
|
||||
assert_(np.all(mt_abs_firm > np.abs(d_soft[mt])))
|
||||
277
.CondaPkg/env/Lib/site-packages/pywt/tests/test_wavelet.py
vendored
Normal file
277
.CondaPkg/env/Lib/site-packages/pywt/tests/test_wavelet.py
vendored
Normal file
@@ -0,0 +1,277 @@
|
||||
#!/usr/bin/env python
|
||||
import os
|
||||
import pickle
|
||||
import numpy as np
|
||||
from numpy.testing import assert_allclose, assert_
|
||||
|
||||
import pywt
|
||||
|
||||
|
||||
def test_wavelet_properties():
|
||||
w = pywt.Wavelet('db3')
|
||||
|
||||
# Name
|
||||
assert_(w.name == 'db3')
|
||||
assert_(w.short_family_name == 'db')
|
||||
assert_(w.family_name, 'Daubechies')
|
||||
|
||||
# String representation
|
||||
fields = ('Family name', 'Short name', 'Filters length', 'Orthogonal',
|
||||
'Biorthogonal', 'Symmetry')
|
||||
for field in fields:
|
||||
assert_(field in str(w))
|
||||
|
||||
# Filter coefficients
|
||||
dec_lo = [0.03522629188210, -0.08544127388224, -0.13501102001039,
|
||||
0.45987750211933, 0.80689150931334, 0.33267055295096]
|
||||
dec_hi = [-0.33267055295096, 0.80689150931334, -0.45987750211933,
|
||||
-0.13501102001039, 0.08544127388224, 0.03522629188210]
|
||||
rec_lo = [0.33267055295096, 0.80689150931334, 0.45987750211933,
|
||||
-0.13501102001039, -0.08544127388224, 0.03522629188210]
|
||||
rec_hi = [0.03522629188210, 0.08544127388224, -0.13501102001039,
|
||||
-0.45987750211933, 0.80689150931334, -0.33267055295096]
|
||||
assert_allclose(w.dec_lo, dec_lo)
|
||||
assert_allclose(w.dec_hi, dec_hi)
|
||||
assert_allclose(w.rec_lo, rec_lo)
|
||||
assert_allclose(w.rec_hi, rec_hi)
|
||||
|
||||
assert_(len(w.filter_bank) == 4)
|
||||
|
||||
# Orthogonality
|
||||
assert_(w.orthogonal)
|
||||
assert_(w.biorthogonal)
|
||||
|
||||
# Symmetry
|
||||
assert_(w.symmetry)
|
||||
|
||||
# Vanishing moments
|
||||
assert_(w.vanishing_moments_phi == 0)
|
||||
assert_(w.vanishing_moments_psi == 3)
|
||||
|
||||
|
||||
def test_wavelet_coefficients():
|
||||
families = ('db', 'sym', 'coif', 'bior', 'rbio')
|
||||
wavelets = sum([pywt.wavelist(name) for name in families], [])
|
||||
for wavelet in wavelets:
|
||||
if (pywt.Wavelet(wavelet).orthogonal):
|
||||
check_coefficients_orthogonal(wavelet)
|
||||
elif(pywt.Wavelet(wavelet).biorthogonal):
|
||||
check_coefficients_biorthogonal(wavelet)
|
||||
else:
|
||||
check_coefficients(wavelet)
|
||||
|
||||
|
||||
def check_coefficients_orthogonal(wavelet):
|
||||
|
||||
epsilon = 5e-11
|
||||
level = 5
|
||||
w = pywt.Wavelet(wavelet)
|
||||
phi, psi, x = w.wavefun(level=level)
|
||||
|
||||
# Lowpass filter coefficients sum to sqrt2
|
||||
res = np.sum(w.dec_lo)-np.sqrt(2)
|
||||
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
||||
assert_(res < epsilon, msg=msg)
|
||||
# sum even coef = sum odd coef = 1 / sqrt(2)
|
||||
res = np.sum(w.dec_lo[::2])-1./np.sqrt(2)
|
||||
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
||||
assert_(res < epsilon, msg=msg)
|
||||
|
||||
res = np.sum(w.dec_lo[1::2])-1./np.sqrt(2)
|
||||
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
||||
assert_(res < epsilon, msg=msg)
|
||||
# Highpass filter coefficients sum to zero
|
||||
res = np.sum(w.dec_hi)
|
||||
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
||||
assert_(res < epsilon, msg=msg)
|
||||
# Scaling function integrates to unity
|
||||
|
||||
res = np.sum(phi) - 2**level
|
||||
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
||||
assert_(res < epsilon, msg=msg)
|
||||
# Wavelet function is orthogonal to the scaling function at the same scale
|
||||
res = np.sum(phi*psi)
|
||||
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
||||
assert_(res < epsilon, msg=msg)
|
||||
# The lowpass and highpass filter coefficients are orthogonal
|
||||
res = np.sum(np.array(w.dec_lo)*np.array(w.dec_hi))
|
||||
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
||||
assert_(res < epsilon, msg=msg)
|
||||
|
||||
|
||||
def check_coefficients_biorthogonal(wavelet):
|
||||
|
||||
epsilon = 5e-11
|
||||
level = 5
|
||||
w = pywt.Wavelet(wavelet)
|
||||
phi_d, psi_d, phi_r, psi_r, x = w.wavefun(level=level)
|
||||
|
||||
# Lowpass filter coefficients sum to sqrt2
|
||||
res = np.sum(w.dec_lo)-np.sqrt(2)
|
||||
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
||||
assert_(res < epsilon, msg=msg)
|
||||
# sum even coef = sum odd coef = 1 / sqrt(2)
|
||||
res = np.sum(w.dec_lo[::2])-1./np.sqrt(2)
|
||||
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
||||
assert_(res < epsilon, msg=msg)
|
||||
res = np.sum(w.dec_lo[1::2])-1./np.sqrt(2)
|
||||
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
||||
assert_(res < epsilon, msg=msg)
|
||||
# Highpass filter coefficients sum to zero
|
||||
res = np.sum(w.dec_hi)
|
||||
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
||||
assert_(res < epsilon, msg=msg)
|
||||
# Scaling function integrates to unity
|
||||
res = np.sum(phi_d) - 2**level
|
||||
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
||||
assert_(res < epsilon, msg=msg)
|
||||
res = np.sum(phi_r) - 2**level
|
||||
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
||||
assert_(res < epsilon, msg=msg)
|
||||
|
||||
|
||||
def check_coefficients(wavelet):
|
||||
epsilon = 5e-11
|
||||
level = 10
|
||||
w = pywt.Wavelet(wavelet)
|
||||
# Lowpass filter coefficients sum to sqrt2
|
||||
res = np.sum(w.dec_lo)-np.sqrt(2)
|
||||
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
||||
assert_(res < epsilon, msg=msg)
|
||||
# sum even coef = sum odd coef = 1 / sqrt(2)
|
||||
res = np.sum(w.dec_lo[::2])-1./np.sqrt(2)
|
||||
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
||||
assert_(res < epsilon, msg=msg)
|
||||
|
||||
res = np.sum(w.dec_lo[1::2])-1./np.sqrt(2)
|
||||
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
||||
assert_(res < epsilon, msg=msg)
|
||||
# Highpass filter coefficients sum to zero
|
||||
res = np.sum(w.dec_hi)
|
||||
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
||||
assert_(res < epsilon, msg=msg)
|
||||
|
||||
|
||||
class _CustomHaarFilterBank(object):
|
||||
@property
|
||||
def filter_bank(self):
|
||||
val = np.sqrt(2) / 2
|
||||
return ([val]*2, [-val, val], [val]*2, [val, -val])
|
||||
|
||||
|
||||
def test_custom_wavelet():
|
||||
haar_custom1 = pywt.Wavelet('Custom Haar Wavelet',
|
||||
filter_bank=_CustomHaarFilterBank())
|
||||
haar_custom1.orthogonal = True
|
||||
haar_custom1.biorthogonal = True
|
||||
|
||||
val = np.sqrt(2) / 2
|
||||
filter_bank = ([val]*2, [-val, val], [val]*2, [val, -val])
|
||||
haar_custom2 = pywt.Wavelet('Custom Haar Wavelet',
|
||||
filter_bank=filter_bank)
|
||||
|
||||
# check expected default wavelet properties
|
||||
assert_(~haar_custom2.orthogonal)
|
||||
assert_(~haar_custom2.biorthogonal)
|
||||
assert_(haar_custom2.symmetry == 'unknown')
|
||||
assert_(haar_custom2.family_name == '')
|
||||
assert_(haar_custom2.short_family_name == '')
|
||||
assert_(haar_custom2.vanishing_moments_phi == 0)
|
||||
assert_(haar_custom2.vanishing_moments_psi == 0)
|
||||
|
||||
# Some properties can be set by the user
|
||||
haar_custom2.orthogonal = True
|
||||
haar_custom2.biorthogonal = True
|
||||
|
||||
|
||||
def test_wavefun_sym3():
|
||||
w = pywt.Wavelet('sym3')
|
||||
# sym3 is an orthogonal wavelet, so 3 outputs from wavefun
|
||||
phi, psi, x = w.wavefun(level=3)
|
||||
assert_(phi.size == 41)
|
||||
assert_(psi.size == 41)
|
||||
assert_(x.size == 41)
|
||||
|
||||
assert_allclose(x, np.linspace(0, 5, num=x.size))
|
||||
phi_expect = np.array([0.00000000e+00, 1.04132926e-01, 2.52574126e-01,
|
||||
3.96525521e-01, 5.70356539e-01, 7.18934305e-01,
|
||||
8.70293448e-01, 1.05363620e+00, 1.24921722e+00,
|
||||
1.15296888e+00, 9.41669683e-01, 7.55875887e-01,
|
||||
4.96118565e-01, 3.28293151e-01, 1.67624969e-01,
|
||||
-7.33690312e-02, -3.35452855e-01, -3.31221131e-01,
|
||||
-2.32061503e-01, -1.66854239e-01, -4.34091324e-02,
|
||||
-2.86152390e-02, -3.63563035e-02, 2.06034491e-02,
|
||||
8.30280254e-02, 7.17779073e-02, 3.85914311e-02,
|
||||
1.47527100e-02, -2.31896077e-02, -1.86122172e-02,
|
||||
-1.56211329e-03, -8.70615088e-04, 3.20760857e-03,
|
||||
2.34142153e-03, -7.73737194e-04, -2.99879354e-04,
|
||||
1.23636238e-04, 0.00000000e+00, 0.00000000e+00,
|
||||
0.00000000e+00, 0.00000000e+00])
|
||||
|
||||
psi_expect = np.array([0.00000000e+00, 1.10265752e-02, 2.67449277e-02,
|
||||
4.19878574e-02, 6.03947231e-02, 7.61275365e-02,
|
||||
9.21548684e-02, 1.11568926e-01, 1.32278887e-01,
|
||||
6.45829680e-02, -3.97635130e-02, -1.38929884e-01,
|
||||
-2.62428322e-01, -3.62246804e-01, -4.62843343e-01,
|
||||
-5.89607507e-01, -7.25363076e-01, -3.36865858e-01,
|
||||
2.67715108e-01, 8.40176767e-01, 1.55574430e+00,
|
||||
1.18688954e+00, 4.20276324e-01, -1.51697311e-01,
|
||||
-9.42076108e-01, -7.93172332e-01, -3.26343710e-01,
|
||||
-1.24552779e-01, 2.12909254e-01, 1.75770320e-01,
|
||||
1.47523075e-02, 8.22192707e-03, -3.02920592e-02,
|
||||
-2.21119497e-02, 7.30703025e-03, 2.83200488e-03,
|
||||
-1.16759765e-03, 0.00000000e+00, 0.00000000e+00,
|
||||
0.00000000e+00, 0.00000000e+00])
|
||||
|
||||
assert_allclose(phi, phi_expect)
|
||||
assert_allclose(psi, psi_expect)
|
||||
|
||||
|
||||
def test_wavefun_bior13():
|
||||
w = pywt.Wavelet('bior1.3')
|
||||
# bior1.3 is not an orthogonal wavelet, so 5 outputs from wavefun
|
||||
phi_d, psi_d, phi_r, psi_r, x = w.wavefun(level=3)
|
||||
for arr in [phi_d, psi_d, phi_r, psi_r]:
|
||||
assert_(arr.size == 40)
|
||||
|
||||
phi_d_expect = np.array([0., -0.00195313, 0.00195313, 0.01757813,
|
||||
0.01367188, 0.00390625, -0.03515625, -0.12890625,
|
||||
-0.15234375, -0.125, -0.09375, -0.0625, 0.03125,
|
||||
0.15234375, 0.37890625, 0.78515625, 0.99609375,
|
||||
1.08203125, 1.13671875, 1.13671875, 1.08203125,
|
||||
0.99609375, 0.78515625, 0.37890625, 0.15234375,
|
||||
0.03125, -0.0625, -0.09375, -0.125, -0.15234375,
|
||||
-0.12890625, -0.03515625, 0.00390625, 0.01367188,
|
||||
0.01757813, 0.00195313, -0.00195313, 0., 0., 0.])
|
||||
phi_r_expect = np.zeros(x.size, dtype=np.float64)
|
||||
phi_r_expect[15:23] = 1
|
||||
|
||||
psi_d_expect = np.array([0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0.015625, -0.015625, -0.140625, -0.109375,
|
||||
-0.03125, 0.28125, 1.03125, 1.21875, 1.125, 0.625,
|
||||
-0.625, -1.125, -1.21875, -1.03125, -0.28125,
|
||||
0.03125, 0.109375, 0.140625, 0.015625, -0.015625,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
|
||||
|
||||
psi_r_expect = np.zeros(x.size, dtype=np.float64)
|
||||
psi_r_expect[7:15] = -0.125
|
||||
psi_r_expect[15:19] = 1
|
||||
psi_r_expect[19:23] = -1
|
||||
psi_r_expect[23:31] = 0.125
|
||||
|
||||
assert_allclose(x, np.linspace(0, 5, x.size, endpoint=False))
|
||||
assert_allclose(phi_d, phi_d_expect, rtol=1e-5, atol=1e-9)
|
||||
assert_allclose(phi_r, phi_r_expect, rtol=1e-10, atol=1e-12)
|
||||
assert_allclose(psi_d, psi_d_expect, rtol=1e-10, atol=1e-12)
|
||||
assert_allclose(psi_r, psi_r_expect, rtol=1e-10, atol=1e-12)
|
||||
|
||||
|
||||
def test_wavelet_pickle(tmpdir):
|
||||
wavelet = pywt.Wavelet('sym4')
|
||||
filename = os.path.join(tmpdir, 'wav.pickle')
|
||||
with open(filename, 'wb') as f:
|
||||
pickle.dump(wavelet, f)
|
||||
with open(filename, 'rb') as f:
|
||||
wavelet2 = pickle.load(f)
|
||||
assert isinstance(wavelet2, pywt.Wavelet)
|
||||
assert wavelet2.name == wavelet.name
|
||||
244
.CondaPkg/env/Lib/site-packages/pywt/tests/test_wp.py
vendored
Normal file
244
.CondaPkg/env/Lib/site-packages/pywt/tests/test_wp.py
vendored
Normal file
@@ -0,0 +1,244 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import os
|
||||
import pickle
|
||||
|
||||
import numpy as np
|
||||
from numpy.testing import (assert_allclose, assert_, assert_raises,
|
||||
assert_equal)
|
||||
|
||||
import pywt
|
||||
|
||||
|
||||
def test_wavelet_packet_structure():
|
||||
x = [1, 2, 3, 4, 5, 6, 7, 8]
|
||||
wp = pywt.WaveletPacket(data=x, wavelet='db1', mode='symmetric')
|
||||
|
||||
assert_(wp.data == [1, 2, 3, 4, 5, 6, 7, 8])
|
||||
assert_(wp.path == '')
|
||||
assert_(wp.level == 0)
|
||||
assert_(wp['ad'].maxlevel == 3)
|
||||
|
||||
|
||||
def test_traversing_wp_tree():
|
||||
x = [1, 2, 3, 4, 5, 6, 7, 8]
|
||||
wp = pywt.WaveletPacket(data=x, wavelet='db1', mode='symmetric')
|
||||
|
||||
assert_(wp.maxlevel == 3)
|
||||
|
||||
# First level
|
||||
assert_allclose(wp['a'].data, np.array([2.12132034356, 4.949747468306,
|
||||
7.778174593052, 10.606601717798]),
|
||||
rtol=1e-12)
|
||||
|
||||
# Second level
|
||||
assert_allclose(wp['aa'].data, np.array([5., 13.]), rtol=1e-12)
|
||||
|
||||
# Third level
|
||||
assert_allclose(wp['aaa'].data, np.array([12.727922061358]), rtol=1e-12)
|
||||
|
||||
|
||||
def test_acess_path():
|
||||
x = [1, 2, 3, 4, 5, 6, 7, 8]
|
||||
wp = pywt.WaveletPacket(data=x, wavelet='db1', mode='symmetric')
|
||||
|
||||
assert_(wp['a'].path == 'a')
|
||||
assert_(wp['aa'].path == 'aa')
|
||||
assert_(wp['aaa'].path == 'aaa')
|
||||
|
||||
# Maximum level reached:
|
||||
assert_raises(IndexError, lambda: wp['aaaa'].path)
|
||||
|
||||
# Wrong path
|
||||
assert_raises(ValueError, lambda: wp['ac'].path)
|
||||
|
||||
|
||||
def test_access_node_attributes():
|
||||
x = [1, 2, 3, 4, 5, 6, 7, 8]
|
||||
wp = pywt.WaveletPacket(data=x, wavelet='db1', mode='symmetric')
|
||||
|
||||
assert_allclose(wp['ad'].data, np.array([-2., -2.]), rtol=1e-12)
|
||||
assert_(wp['ad'].path == 'ad')
|
||||
assert_(wp['ad'].node_name == 'd')
|
||||
assert_(wp['ad'].parent.path == 'a')
|
||||
assert_(wp['ad'].level == 2)
|
||||
assert_(wp['ad'].maxlevel == 3)
|
||||
assert_(wp['ad'].mode == 'symmetric')
|
||||
|
||||
# tuple-based access is also supported
|
||||
node = wp[('a', 'd')]
|
||||
# can access a node's path as either a single string or in tuple form
|
||||
assert_(node.path == 'ad')
|
||||
assert_(node.path_tuple == ('a', 'd'))
|
||||
|
||||
|
||||
def test_collecting_nodes():
|
||||
x = [1, 2, 3, 4, 5, 6, 7, 8]
|
||||
wp = pywt.WaveletPacket(data=x, wavelet='db1', mode='symmetric')
|
||||
|
||||
# All nodes in natural order
|
||||
assert_([node.path for node in wp.get_level(3, 'natural')] ==
|
||||
['aaa', 'aad', 'ada', 'add', 'daa', 'dad', 'dda', 'ddd'])
|
||||
|
||||
# and in frequency order.
|
||||
assert_([node.path for node in wp.get_level(3, 'freq')] ==
|
||||
['aaa', 'aad', 'add', 'ada', 'dda', 'ddd', 'dad', 'daa'])
|
||||
|
||||
assert_raises(ValueError, wp.get_level, 3, 'invalid_order')
|
||||
|
||||
|
||||
def test_reconstructing_data():
|
||||
x = [1, 2, 3, 4, 5, 6, 7, 8]
|
||||
wp = pywt.WaveletPacket(data=x, wavelet='db1', mode='symmetric')
|
||||
|
||||
# Create another Wavelet Packet and feed it with some data.
|
||||
new_wp = pywt.WaveletPacket(data=None, wavelet='db1', mode='symmetric')
|
||||
new_wp['aa'] = wp['aa'].data
|
||||
new_wp['ad'] = [-2., -2.]
|
||||
|
||||
# For convenience, :attr:`Node.data` gets automatically extracted
|
||||
# from the :class:`Node` object:
|
||||
new_wp['d'] = wp['d']
|
||||
|
||||
# Reconstruct data from aa, ad, and d packets.
|
||||
assert_allclose(new_wp.reconstruct(update=False), x, rtol=1e-12)
|
||||
|
||||
# The node's :attr:`~Node.data` will not be updated
|
||||
assert_(new_wp.data is None)
|
||||
|
||||
# When `update` is True:
|
||||
assert_allclose(new_wp.reconstruct(update=True), x, rtol=1e-12)
|
||||
assert_allclose(new_wp.data, np.arange(1, 9), rtol=1e-12)
|
||||
|
||||
assert_([n.path for n in new_wp.get_leaf_nodes(False)] ==
|
||||
['aa', 'ad', 'd'])
|
||||
assert_([n.path for n in new_wp.get_leaf_nodes(True)] ==
|
||||
['aaa', 'aad', 'ada', 'add', 'daa', 'dad', 'dda', 'ddd'])
|
||||
|
||||
|
||||
def test_removing_nodes():
|
||||
x = [1, 2, 3, 4, 5, 6, 7, 8]
|
||||
wp = pywt.WaveletPacket(data=x, wavelet='db1', mode='symmetric')
|
||||
wp.get_level(2)
|
||||
|
||||
dataleafs = [n.data for n in wp.get_leaf_nodes(False)]
|
||||
expected = np.array([[5., 13.], [-2, -2], [-1, -1], [0, 0]])
|
||||
|
||||
for i in range(4):
|
||||
assert_allclose(dataleafs[i], expected[i, :], atol=1e-12)
|
||||
|
||||
node = wp['ad']
|
||||
del(wp['ad'])
|
||||
dataleafs = [n.data for n in wp.get_leaf_nodes(False)]
|
||||
expected = np.array([[5., 13.], [-1, -1], [0, 0]])
|
||||
|
||||
for i in range(3):
|
||||
assert_allclose(dataleafs[i], expected[i, :], atol=1e-12)
|
||||
|
||||
wp.reconstruct()
|
||||
# The reconstruction is:
|
||||
assert_allclose(wp.reconstruct(),
|
||||
np.array([2., 3., 2., 3., 6., 7., 6., 7.]), rtol=1e-12)
|
||||
|
||||
# Restore the data
|
||||
wp['ad'].data = node.data
|
||||
|
||||
dataleafs = [n.data for n in wp.get_leaf_nodes(False)]
|
||||
expected = np.array([[5., 13.], [-2, -2], [-1, -1], [0, 0]])
|
||||
|
||||
for i in range(4):
|
||||
assert_allclose(dataleafs[i], expected[i, :], atol=1e-12)
|
||||
|
||||
assert_allclose(wp.reconstruct(), np.arange(1, 9), rtol=1e-12)
|
||||
|
||||
|
||||
def test_wavelet_packet_dtypes():
|
||||
rstate = np.random.RandomState(0)
|
||||
N = 32
|
||||
for dtype in [np.float32, np.float64, np.complex64, np.complex128]:
|
||||
x = rstate.randn(N).astype(dtype)
|
||||
if np.iscomplexobj(x):
|
||||
x = x + 1j*np.random.randn(N).astype(x.real.dtype)
|
||||
wp = pywt.WaveletPacket(data=x, wavelet='db1', mode='symmetric')
|
||||
# no unnecessary copy made
|
||||
assert_(wp.data is x)
|
||||
|
||||
# assigning to a node should not change supported dtypes
|
||||
wp['d'] = wp['d'].data
|
||||
assert_equal(wp['d'].data.dtype, x.dtype)
|
||||
|
||||
# full decomposition
|
||||
wp.get_level(wp.maxlevel)
|
||||
|
||||
# reconstruction from coefficients should preserve dtype
|
||||
r = wp.reconstruct(False)
|
||||
assert_equal(r.dtype, x.dtype)
|
||||
assert_allclose(r, x, atol=1e-5, rtol=1e-5)
|
||||
|
||||
# first element of the tuple is the input dtype
|
||||
# second element of the tuple is the transform dtype
|
||||
dtype_pairs = [(np.uint8, np.float64),
|
||||
(np.intp, np.float64), ]
|
||||
if hasattr(np, "complex256"):
|
||||
dtype_pairs += [(np.complex256, np.complex128), ]
|
||||
if hasattr(np, "half"):
|
||||
dtype_pairs += [(np.half, np.float32), ]
|
||||
for (dtype, transform_dtype) in dtype_pairs:
|
||||
x = np.arange(N, dtype=dtype)
|
||||
wp = pywt.WaveletPacket(x, wavelet='db1', mode='symmetric')
|
||||
|
||||
# no unnecessary copy made of top-level data
|
||||
assert_(wp.data is x)
|
||||
|
||||
# full decomposition
|
||||
wp.get_level(wp.maxlevel)
|
||||
|
||||
# reconstructed data will have modified dtype
|
||||
r = wp.reconstruct(False)
|
||||
assert_equal(r.dtype, transform_dtype)
|
||||
assert_allclose(r, x.astype(transform_dtype), atol=1e-5, rtol=1e-5)
|
||||
|
||||
|
||||
def test_db3_roundtrip():
|
||||
original = np.arange(512)
|
||||
wp = pywt.WaveletPacket(data=original, wavelet='db3', mode='smooth',
|
||||
maxlevel=3)
|
||||
r = wp.reconstruct()
|
||||
assert_allclose(original, r, atol=1e-12, rtol=1e-12)
|
||||
|
||||
|
||||
def test_wavelet_packet_axis():
|
||||
rstate = np.random.RandomState(0)
|
||||
shape = (32, 16)
|
||||
x = rstate.standard_normal(shape)
|
||||
for axis in [0, 1, -1]:
|
||||
wp = pywt.WaveletPacket(data=x, wavelet='db1', mode='symmetric',
|
||||
axis=axis)
|
||||
|
||||
# partial decomposition
|
||||
nodes = wp.get_level(2)
|
||||
# size along the transformed axis has changed
|
||||
for ax2 in range(x.ndim):
|
||||
if ax2 == (axis % x.ndim):
|
||||
nodes[0].data.shape[ax2] < x.shape[ax2]
|
||||
else:
|
||||
nodes[0].data.shape[ax2] == x.shape[ax2]
|
||||
|
||||
# recontsruction from coefficients should preserve dtype
|
||||
r = wp.reconstruct(False)
|
||||
assert_equal(r.dtype, x.dtype)
|
||||
assert_allclose(r, x, atol=1e-12, rtol=1e-12)
|
||||
|
||||
# ValueError if axis is out of range
|
||||
assert_raises(ValueError, pywt.WaveletPacket, data=x, wavelet='db1',
|
||||
axis=x.ndim)
|
||||
|
||||
|
||||
def test_wavelet_packet_pickle(tmpdir):
|
||||
packet = pywt.WaveletPacket(np.arange(16), 'sym4')
|
||||
filename = os.path.join(tmpdir, 'wp.pickle')
|
||||
with open(filename, 'wb') as f:
|
||||
pickle.dump(packet, f)
|
||||
with open(filename, 'rb') as f:
|
||||
packet2 = pickle.load(f)
|
||||
assert isinstance(packet2, pywt.WaveletPacket)
|
||||
245
.CondaPkg/env/Lib/site-packages/pywt/tests/test_wp2d.py
vendored
Normal file
245
.CondaPkg/env/Lib/site-packages/pywt/tests/test_wp2d.py
vendored
Normal file
@@ -0,0 +1,245 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import os
|
||||
import pickle
|
||||
|
||||
import numpy as np
|
||||
from numpy.testing import (assert_allclose, assert_, assert_raises,
|
||||
assert_equal)
|
||||
|
||||
import pywt
|
||||
|
||||
|
||||
def test_traversing_tree_2d():
|
||||
x = np.array([[1, 2, 3, 4, 5, 6, 7, 8]] * 8, dtype=np.float64)
|
||||
wp = pywt.WaveletPacket2D(data=x, wavelet='db1', mode='symmetric')
|
||||
|
||||
assert_(np.all(wp.data == x))
|
||||
assert_(wp.path == '')
|
||||
assert_(wp.level == 0)
|
||||
assert_(wp.maxlevel == 3)
|
||||
|
||||
assert_allclose(wp['a'].data, np.array([[3., 7., 11., 15.]] * 4),
|
||||
rtol=1e-12)
|
||||
assert_allclose(wp['h'].data, np.zeros((4, 4)), rtol=1e-12, atol=1e-14)
|
||||
assert_allclose(wp['v'].data, -np.ones((4, 4)), rtol=1e-12, atol=1e-14)
|
||||
assert_allclose(wp['d'].data, np.zeros((4, 4)), rtol=1e-12, atol=1e-14)
|
||||
|
||||
assert_allclose(wp['aa'].data, np.array([[10., 26.]] * 2), rtol=1e-12)
|
||||
|
||||
assert_(wp['a']['a'].data is wp['aa'].data)
|
||||
assert_allclose(wp['aaa'].data, np.array([[36.]]), rtol=1e-12)
|
||||
|
||||
assert_raises(IndexError, lambda: wp['aaaa'])
|
||||
assert_raises(ValueError, lambda: wp['f'])
|
||||
|
||||
|
||||
def test_accessing_node_attributes_2d():
|
||||
x = np.array([[1, 2, 3, 4, 5, 6, 7, 8]] * 8, dtype=np.float64)
|
||||
wp = pywt.WaveletPacket2D(data=x, wavelet='db1', mode='symmetric')
|
||||
|
||||
assert_allclose(wp['av'].data, np.zeros((2, 2)) - 4, rtol=1e-12)
|
||||
assert_(wp['av'].path == 'av')
|
||||
assert_(wp['av'].node_name == 'v')
|
||||
assert_(wp['av'].parent.path == 'a')
|
||||
|
||||
assert_allclose(wp['av'].parent.data, np.array([[3., 7., 11., 15.]] * 4),
|
||||
rtol=1e-12)
|
||||
# can also index via a tuple instead of concatenated strings
|
||||
assert_(wp['av'].level == 2)
|
||||
assert_(wp['av'].maxlevel == 3)
|
||||
assert_(wp['av'].mode == 'symmetric')
|
||||
|
||||
# tuple-based access is also supported
|
||||
node = wp[('a', 'v')]
|
||||
# can access a node's path as either a single string or in tuple form
|
||||
assert_(node.path == 'av')
|
||||
assert_(node.path_tuple == ('a', 'v'))
|
||||
|
||||
|
||||
def test_collecting_nodes_2d():
|
||||
x = np.array([[1, 2, 3, 4, 5, 6, 7, 8]] * 8, dtype=np.float64)
|
||||
wp = pywt.WaveletPacket2D(data=x, wavelet='db1', mode='symmetric')
|
||||
|
||||
assert_(len(wp.get_level(0)) == 1)
|
||||
assert_(wp.get_level(0)[0].path == '')
|
||||
|
||||
# First level
|
||||
assert_(len(wp.get_level(1)) == 4)
|
||||
assert_([node.path for node in wp.get_level(1)] == ['a', 'h', 'v', 'd'])
|
||||
|
||||
# Second level
|
||||
assert_(len(wp.get_level(2)) == 16)
|
||||
paths = [node.path for node in wp.get_level(2)]
|
||||
expected_paths = ['aa', 'ah', 'av', 'ad', 'ha', 'hh', 'hv', 'hd', 'va',
|
||||
'vh', 'vv', 'vd', 'da', 'dh', 'dv', 'dd']
|
||||
assert_(paths == expected_paths)
|
||||
|
||||
# Third level.
|
||||
assert_(len(wp.get_level(3)) == 64)
|
||||
paths = [node.path for node in wp.get_level(3)]
|
||||
expected_paths = ['aaa', 'aah', 'aav', 'aad', 'aha', 'ahh', 'ahv', 'ahd',
|
||||
'ava', 'avh', 'avv', 'avd', 'ada', 'adh', 'adv', 'add',
|
||||
'haa', 'hah', 'hav', 'had', 'hha', 'hhh', 'hhv', 'hhd',
|
||||
'hva', 'hvh', 'hvv', 'hvd', 'hda', 'hdh', 'hdv', 'hdd',
|
||||
'vaa', 'vah', 'vav', 'vad', 'vha', 'vhh', 'vhv', 'vhd',
|
||||
'vva', 'vvh', 'vvv', 'vvd', 'vda', 'vdh', 'vdv', 'vdd',
|
||||
'daa', 'dah', 'dav', 'dad', 'dha', 'dhh', 'dhv', 'dhd',
|
||||
'dva', 'dvh', 'dvv', 'dvd', 'dda', 'ddh', 'ddv', 'ddd']
|
||||
|
||||
assert_(paths == expected_paths)
|
||||
|
||||
# test 2D frequency ordering at the first level
|
||||
fnodes = wp.get_level(1, order='freq')
|
||||
assert_(fnodes[0][0].path == 'a')
|
||||
assert_(fnodes[0][1].path == 'v')
|
||||
assert_(fnodes[1][0].path == 'h')
|
||||
assert_(fnodes[1][1].path == 'd')
|
||||
|
||||
# test 2D frequency ordering at the second level
|
||||
fnodes = wp.get_level(2, order='freq')
|
||||
assert_([n.path for n in fnodes[0]] == ['aa', 'av', 'vv', 'va'])
|
||||
assert_([n.path for n in fnodes[1]] == ['ah', 'ad', 'vd', 'vh'])
|
||||
assert_([n.path for n in fnodes[2]] == ['hh', 'hd', 'dd', 'dh'])
|
||||
assert_([n.path for n in fnodes[3]] == ['ha', 'hv', 'dv', 'da'])
|
||||
|
||||
# invalid node collection order
|
||||
assert_raises(ValueError, wp.get_level, 2, 'invalid_order')
|
||||
|
||||
|
||||
def test_data_reconstruction_2d():
|
||||
x = np.array([[1, 2, 3, 4, 5, 6, 7, 8]] * 8, dtype=np.float64)
|
||||
wp = pywt.WaveletPacket2D(data=x, wavelet='db1', mode='symmetric')
|
||||
|
||||
new_wp = pywt.WaveletPacket2D(data=None, wavelet='db1', mode='symmetric')
|
||||
new_wp['vh'] = wp['vh'].data
|
||||
new_wp['vv'] = wp['vh'].data
|
||||
new_wp['vd'] = np.zeros((2, 2), dtype=np.float64)
|
||||
new_wp['a'] = [[3.0, 7.0, 11.0, 15.0]] * 4
|
||||
new_wp['d'] = np.zeros((4, 4), dtype=np.float64)
|
||||
new_wp['h'] = wp['h'] # all zeros
|
||||
|
||||
assert_allclose(new_wp.reconstruct(update=False),
|
||||
np.array([[1.5, 1.5, 3.5, 3.5, 5.5, 5.5, 7.5, 7.5]] * 8),
|
||||
rtol=1e-12)
|
||||
assert_allclose(wp['va'].data, np.zeros((2, 2)) - 2, rtol=1e-12)
|
||||
|
||||
new_wp['va'] = wp['va'].data
|
||||
assert_allclose(new_wp.reconstruct(update=False), x, rtol=1e-12)
|
||||
|
||||
|
||||
def test_data_reconstruction_delete_nodes_2d():
|
||||
x = np.array([[1, 2, 3, 4, 5, 6, 7, 8]] * 8, dtype=np.float64)
|
||||
wp = pywt.WaveletPacket2D(data=x, wavelet='db1', mode='symmetric')
|
||||
|
||||
new_wp = pywt.WaveletPacket2D(data=None, wavelet='db1', mode='symmetric')
|
||||
new_wp['vh'] = wp['vh'].data
|
||||
new_wp['vv'] = wp['vh'].data
|
||||
new_wp['vd'] = np.zeros((2, 2), dtype=np.float64)
|
||||
new_wp['a'] = [[3.0, 7.0, 11.0, 15.0]] * 4
|
||||
new_wp['d'] = np.zeros((4, 4), dtype=np.float64)
|
||||
new_wp['h'] = wp['h'] # all zeros
|
||||
|
||||
assert_allclose(new_wp.reconstruct(update=False),
|
||||
np.array([[1.5, 1.5, 3.5, 3.5, 5.5, 5.5, 7.5, 7.5]] * 8),
|
||||
rtol=1e-12)
|
||||
|
||||
new_wp['va'] = wp['va'].data
|
||||
assert_allclose(new_wp.reconstruct(update=False), x, rtol=1e-12)
|
||||
|
||||
del(new_wp['va'])
|
||||
# TypeError on accessing deleted node
|
||||
assert_raises(TypeError, lambda: new_wp['va'])
|
||||
new_wp['va'] = wp['va'].data
|
||||
assert_(new_wp.data is None)
|
||||
|
||||
assert_allclose(new_wp.reconstruct(update=True), x, rtol=1e-12)
|
||||
assert_allclose(new_wp.data, x, rtol=1e-12)
|
||||
|
||||
# TODO: decompose=True
|
||||
|
||||
|
||||
def test_lazy_evaluation_2D():
|
||||
# Note: internal implementation detail not to be relied on. Testing for
|
||||
# now for backwards compatibility, but this test may be broken in needed.
|
||||
x = np.array([[1, 2, 3, 4, 5, 6, 7, 8]] * 8)
|
||||
wp = pywt.WaveletPacket2D(data=x, wavelet='db1', mode='symmetric')
|
||||
|
||||
assert_(wp.a is None)
|
||||
assert_allclose(wp['a'].data, np.array([[3., 7., 11., 15.]] * 4),
|
||||
rtol=1e-12)
|
||||
assert_allclose(wp.a.data, np.array([[3., 7., 11., 15.]] * 4), rtol=1e-12)
|
||||
assert_allclose(wp.d.data, np.zeros((4, 4)), rtol=1e-12, atol=1e-12)
|
||||
|
||||
|
||||
def test_wavelet_packet_dtypes():
|
||||
shape = (16, 16)
|
||||
for dtype in [np.float32, np.float64, np.complex64, np.complex128]:
|
||||
x = np.random.randn(*shape).astype(dtype)
|
||||
if np.iscomplexobj(x):
|
||||
x = x + 1j*np.random.randn(*shape).astype(x.real.dtype)
|
||||
wp = pywt.WaveletPacket2D(data=x, wavelet='db1', mode='symmetric')
|
||||
# no unnecessary copy made
|
||||
assert_(wp.data is x)
|
||||
|
||||
# assigning to a node should not change supported dtypes
|
||||
wp['d'] = wp['d'].data
|
||||
assert_equal(wp['d'].data.dtype, x.dtype)
|
||||
|
||||
# full decomposition
|
||||
wp.get_level(wp.maxlevel)
|
||||
|
||||
# reconstruction from coefficients should preserve dtype
|
||||
r = wp.reconstruct(False)
|
||||
assert_equal(r.dtype, x.dtype)
|
||||
assert_allclose(r, x, atol=1e-5, rtol=1e-5)
|
||||
|
||||
|
||||
def test_2d_roundtrip():
|
||||
# test case corresponding to PyWavelets issue 447
|
||||
original = pywt.data.camera()
|
||||
wp = pywt.WaveletPacket2D(data=original, wavelet='db3', mode='smooth',
|
||||
maxlevel=3)
|
||||
r = wp.reconstruct()
|
||||
assert_allclose(original, r, atol=1e-12, rtol=1e-12)
|
||||
|
||||
|
||||
def test_wavelet_packet_axes():
|
||||
rstate = np.random.RandomState(0)
|
||||
shape = (32, 16)
|
||||
x = rstate.standard_normal(shape)
|
||||
for axes in [(0, 1), (1, 0), (-2, 1)]:
|
||||
wp = pywt.WaveletPacket2D(data=x, wavelet='db1', mode='symmetric',
|
||||
axes=axes)
|
||||
|
||||
# partial decomposition
|
||||
nodes = wp.get_level(2)
|
||||
# size along the transformed axes has changed
|
||||
for ax2 in range(x.ndim):
|
||||
if ax2 in tuple(np.asarray(axes) % x.ndim):
|
||||
nodes[0].data.shape[ax2] < x.shape[ax2]
|
||||
else:
|
||||
nodes[0].data.shape[ax2] == x.shape[ax2]
|
||||
|
||||
# recontsruction from coefficients should preserve dtype
|
||||
r = wp.reconstruct(False)
|
||||
assert_equal(r.dtype, x.dtype)
|
||||
assert_allclose(r, x, atol=1e-12, rtol=1e-12)
|
||||
|
||||
# must have two non-duplicate axes
|
||||
assert_raises(ValueError, pywt.WaveletPacket2D, data=x, wavelet='db1',
|
||||
axes=(0, 0))
|
||||
assert_raises(ValueError, pywt.WaveletPacket2D, data=x, wavelet='db1',
|
||||
axes=(0, ))
|
||||
assert_raises(ValueError, pywt.WaveletPacket2D, data=x, wavelet='db1',
|
||||
axes=(0, 1, 2))
|
||||
|
||||
|
||||
def test_wavelet_packet2d_pickle(tmpdir):
|
||||
packet = pywt.WaveletPacket2D(np.arange(256).reshape(16, 16), 'sym4')
|
||||
filename = os.path.join(tmpdir, 'wp2d.pickle')
|
||||
with open(filename, 'wb') as f:
|
||||
pickle.dump(packet, f)
|
||||
with open(filename, 'rb') as f:
|
||||
packet2 = pickle.load(f)
|
||||
assert isinstance(packet2, pywt.WaveletPacket2D)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user