using for loop to install conda package
This commit is contained in:
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test__pywt.cpython-311.pyc
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test__pywt.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_concurrent.cpython-311.pyc
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_concurrent.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_cwt_wavelets.cpython-311.pyc
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_cwt_wavelets.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_data.cpython-311.pyc
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_data.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_deprecations.cpython-311.pyc
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_deprecations.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_doc.cpython-311.pyc
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_doc.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_dwt_idwt.cpython-311.pyc
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_dwt_idwt.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_functions.cpython-311.pyc
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_functions.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_matlab_compatibility.cpython-311.pyc
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_matlab_compatibility.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_matlab_compatibility_cwt.cpython-311.pyc
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_matlab_compatibility_cwt.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_modes.cpython-311.pyc
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_modes.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_mra.cpython-311.pyc
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_mra.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_multidim.cpython-311.pyc
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_multidim.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_multilevel.cpython-311.pyc
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_multilevel.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_perfect_reconstruction.cpython-311.pyc
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_perfect_reconstruction.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_swt.cpython-311.pyc
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_swt.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_thresholding.cpython-311.pyc
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_thresholding.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_wavelet.cpython-311.pyc
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_wavelet.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_wp.cpython-311.pyc
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_wp.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_wp2d.cpython-311.pyc
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_wp2d.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_wpnd.cpython-311.pyc
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/__pycache__/test_wpnd.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/data/__pycache__/generate_matlab_data.cpython-311.pyc
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/data/__pycache__/generate_matlab_data.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/data/__pycache__/generate_matlab_data_cwt.cpython-311.pyc
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/data/__pycache__/generate_matlab_data_cwt.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/data/cwt_matlabR2015b_result.npz
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/data/cwt_matlabR2015b_result.npz
vendored
Normal file
Binary file not shown.
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/data/dwt_matlabR2012a_result.npz
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/data/dwt_matlabR2012a_result.npz
vendored
Normal file
Binary file not shown.
96
.CondaPkg/env/Lib/site-packages/pywt/tests/data/generate_matlab_data.py
vendored
Normal file
96
.CondaPkg/env/Lib/site-packages/pywt/tests/data/generate_matlab_data.py
vendored
Normal file
@@ -0,0 +1,96 @@
|
||||
""" This script was used to generate dwt_matlabR2012a_result.npz by storing
|
||||
the outputs from Matlab R2012a. """
|
||||
|
||||
from __future__ import division, print_function, absolute_import
|
||||
|
||||
import numpy as np
|
||||
import pywt
|
||||
|
||||
try:
|
||||
from pymatbridge import Matlab
|
||||
mlab = Matlab()
|
||||
_matlab_missing = False
|
||||
except ImportError:
|
||||
print("To run Matlab compatibility tests you need to have MathWorks "
|
||||
"MATLAB, MathWorks Wavelet Toolbox and the pymatbridge Python "
|
||||
"package installed.")
|
||||
_matlab_missing = True
|
||||
|
||||
if _matlab_missing:
|
||||
raise EnvironmentError("Can't generate matlab data files without MATLAB")
|
||||
|
||||
size_set = 'reduced'
|
||||
|
||||
# list of mode names in pywt and matlab
|
||||
modes = [('zero', 'zpd'),
|
||||
('constant', 'sp0'),
|
||||
('symmetric', 'sym'),
|
||||
('reflect', 'symw'),
|
||||
('periodic', 'ppd'),
|
||||
('smooth', 'sp1'),
|
||||
('periodization', 'per'),
|
||||
('antisymmetric', 'asym'),
|
||||
('antireflect', 'asymw')]
|
||||
|
||||
families = ('db', 'sym', 'coif', 'bior', 'rbio')
|
||||
wavelets = sum([pywt.wavelist(name) for name in families], [])
|
||||
|
||||
rstate = np.random.RandomState(1234)
|
||||
mlab.start()
|
||||
try:
|
||||
all_matlab_results = {}
|
||||
for wavelet in wavelets:
|
||||
w = pywt.Wavelet(wavelet)
|
||||
mlab.set_variable('wavelet', wavelet)
|
||||
if size_set == 'full':
|
||||
data_sizes = list(range(w.dec_len, 40)) + \
|
||||
[100, 200, 500, 1000, 50000]
|
||||
else:
|
||||
data_sizes = (w.dec_len, w.dec_len + 1)
|
||||
for N in data_sizes:
|
||||
data = rstate.randn(N)
|
||||
mlab.set_variable('data', data)
|
||||
for pmode, mmode in modes:
|
||||
# Matlab result
|
||||
if np.any((wavelet == np.array(['coif6', 'coif7', 'coif8', 'coif9', 'coif10', 'coif11', 'coif12', 'coif13', 'coif14', 'coif15', 'coif16', 'coif17'])),axis=0):
|
||||
mlab.set_variable('Lo_D', w.dec_lo)
|
||||
mlab.set_variable('Hi_D', w.dec_hi)
|
||||
mlab_code = ("[ma, md] = dwt(data, Lo_D, Hi_D, "
|
||||
"'mode', '%s');" % mmode)
|
||||
else:
|
||||
mlab_code = ("[ma, md] = dwt(data, wavelet, "
|
||||
"'mode', '%s');" % mmode)
|
||||
res = mlab.run_code(mlab_code)
|
||||
if not res['success']:
|
||||
raise RuntimeError(
|
||||
"Matlab failed to execute the provided code. "
|
||||
"Check that the wavelet toolbox is installed.")
|
||||
# need np.asarray because sometimes the output is type float
|
||||
ma = np.asarray(mlab.get_variable('ma'))
|
||||
md = np.asarray(mlab.get_variable('md'))
|
||||
ma_key = '_'.join([mmode, wavelet, str(N), 'ma'])
|
||||
md_key = '_'.join([mmode, wavelet, str(N), 'md'])
|
||||
all_matlab_results[ma_key] = ma
|
||||
all_matlab_results[md_key] = md
|
||||
|
||||
# Matlab result
|
||||
mlab.set_variable('Lo_D', w.dec_lo)
|
||||
mlab.set_variable('Hi_D', w.dec_hi)
|
||||
mlab_code = ("[ma, md] = dwt(data, Lo_D, Hi_D, "
|
||||
"'mode', '%s');" % mmode)
|
||||
res = mlab.run_code(mlab_code)
|
||||
if not res['success']:
|
||||
raise RuntimeError(
|
||||
"Matlab failed to execute the provided code. "
|
||||
"Check that the wavelet toolbox is installed.")
|
||||
# need np.asarray because sometimes the output is type float
|
||||
ma = np.asarray(mlab.get_variable('ma'))
|
||||
md = np.asarray(mlab.get_variable('md'))
|
||||
ma_key = '_'.join([mmode, wavelet, str(N), 'ma_pywtCoeffs'])
|
||||
md_key = '_'.join([mmode, wavelet, str(N), 'md_pywtCoeffs'])
|
||||
all_matlab_results[ma_key] = ma
|
||||
all_matlab_results[md_key] = md
|
||||
finally:
|
||||
mlab.stop()
|
||||
|
||||
np.savez('dwt_matlabR2012a_result.npz', **all_matlab_results)
|
||||
86
.CondaPkg/env/Lib/site-packages/pywt/tests/data/generate_matlab_data_cwt.py
vendored
Normal file
86
.CondaPkg/env/Lib/site-packages/pywt/tests/data/generate_matlab_data_cwt.py
vendored
Normal file
@@ -0,0 +1,86 @@
|
||||
""" This script was used to generate dwt_matlabR2012a_result.npz by storing
|
||||
the outputs from Matlab R2012a. """
|
||||
|
||||
from __future__ import division, print_function, absolute_import
|
||||
|
||||
import numpy as np
|
||||
import pywt
|
||||
|
||||
try:
|
||||
from pymatbridge import Matlab
|
||||
mlab = Matlab()
|
||||
_matlab_missing = False
|
||||
except ImportError:
|
||||
print("To run Matlab compatibility tests you need to have MathWorks "
|
||||
"MATLAB, MathWorks Wavelet Toolbox and the pymatbridge Python "
|
||||
"package installed.")
|
||||
_matlab_missing = True
|
||||
|
||||
if _matlab_missing:
|
||||
raise EnvironmentError("Can't generate matlab data files without MATLAB")
|
||||
|
||||
size_set = 'reduced'
|
||||
|
||||
# list of mode names in pywt and matlab
|
||||
modes = [('zero', 'zpd'),
|
||||
('constant', 'sp0'),
|
||||
('symmetric', 'sym'),
|
||||
('periodic', 'ppd'),
|
||||
('smooth', 'sp1'),
|
||||
('periodization', 'per')]
|
||||
|
||||
families = ('gaus', 'mexh', 'morl', 'cgau', 'shan', 'fbsp', 'cmor')
|
||||
wavelets = sum([pywt.wavelist(name) for name in families], [])
|
||||
|
||||
rstate = np.random.RandomState(1234)
|
||||
mlab.start()
|
||||
try:
|
||||
all_matlab_results = {}
|
||||
for wavelet in wavelets:
|
||||
w = pywt.ContinuousWavelet(wavelet)
|
||||
if np.any((wavelet == np.array(['shan', 'cmor'])),axis=0):
|
||||
mlab.set_variable('wavelet', wavelet+str(w.bandwidth_frequency)+'-'+str(w.center_frequency))
|
||||
elif wavelet == 'fbsp':
|
||||
mlab.set_variable('wavelet', wavelet+str(w.fbsp_order)+'-'+str(w.bandwidth_frequency)+'-'+str(w.center_frequency))
|
||||
else:
|
||||
mlab.set_variable('wavelet', wavelet)
|
||||
if size_set == 'full':
|
||||
data_sizes = list(range(100, 101)) + \
|
||||
[100, 200, 500, 1000, 50000]
|
||||
Scales = (1,np.arange(1,3),np.arange(1,4),np.arange(1,5))
|
||||
else:
|
||||
data_sizes = (1000, 1000 + 1)
|
||||
Scales = (1,np.arange(1,3))
|
||||
mlab_code = ("psi = wavefun(wavelet,10)")
|
||||
res = mlab.run_code(mlab_code)
|
||||
if not res['success']:
|
||||
raise RuntimeError(
|
||||
"Matlab failed to execute the provided code. "
|
||||
"Check that the wavelet toolbox is installed.")
|
||||
psi = np.asarray(mlab.get_variable('psi'))
|
||||
psi_key = '_'.join([wavelet, 'psi'])
|
||||
all_matlab_results[psi_key] = psi
|
||||
for N in data_sizes:
|
||||
data = rstate.randn(N)
|
||||
mlab.set_variable('data', data)
|
||||
|
||||
# Matlab result
|
||||
scale_count = 0
|
||||
for scales in Scales:
|
||||
scale_count += 1
|
||||
mlab.set_variable('scales', scales)
|
||||
mlab_code = ("coefs = cwt(data, scales, wavelet)")
|
||||
res = mlab.run_code(mlab_code)
|
||||
if not res['success']:
|
||||
raise RuntimeError(
|
||||
"Matlab failed to execute the provided code. "
|
||||
"Check that the wavelet toolbox is installed.")
|
||||
# need np.asarray because sometimes the output is type float
|
||||
coefs = np.asarray(mlab.get_variable('coefs'))
|
||||
coefs_key = '_'.join([str(scale_count), wavelet, str(N), 'coefs'])
|
||||
all_matlab_results[coefs_key] = coefs
|
||||
|
||||
finally:
|
||||
mlab.stop()
|
||||
|
||||
np.savez('cwt_matlabR2015b_result.npz', **all_matlab_results)
|
||||
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/data/wavelab_test_signals.npz
vendored
Normal file
BIN
.CondaPkg/env/Lib/site-packages/pywt/tests/data/wavelab_test_signals.npz
vendored
Normal file
Binary file not shown.
170
.CondaPkg/env/Lib/site-packages/pywt/tests/test__pywt.py
vendored
Normal file
170
.CondaPkg/env/Lib/site-packages/pywt/tests/test__pywt.py
vendored
Normal file
@@ -0,0 +1,170 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
from __future__ import division, print_function, absolute_import
|
||||
|
||||
import numpy as np
|
||||
from numpy.testing import assert_allclose, assert_, assert_raises
|
||||
|
||||
import pywt
|
||||
|
||||
|
||||
def test_upcoef_reconstruct():
|
||||
data = np.arange(3)
|
||||
a = pywt.downcoef('a', data, 'haar')
|
||||
d = pywt.downcoef('d', data, 'haar')
|
||||
|
||||
rec = (pywt.upcoef('a', a, 'haar', take=3) +
|
||||
pywt.upcoef('d', d, 'haar', take=3))
|
||||
assert_allclose(rec, data)
|
||||
|
||||
|
||||
def test_downcoef_multilevel():
|
||||
rstate = np.random.RandomState(1234)
|
||||
r = rstate.randn(16)
|
||||
nlevels = 3
|
||||
# calling with level=1 nlevels times
|
||||
a1 = r.copy()
|
||||
for i in range(nlevels):
|
||||
a1 = pywt.downcoef('a', a1, 'haar', level=1)
|
||||
# call with level=nlevels once
|
||||
a3 = pywt.downcoef('a', r, 'haar', level=nlevels)
|
||||
assert_allclose(a1, a3)
|
||||
|
||||
|
||||
def test_downcoef_complex():
|
||||
rstate = np.random.RandomState(1234)
|
||||
r = rstate.randn(16) + 1j * rstate.randn(16)
|
||||
nlevels = 3
|
||||
a = pywt.downcoef('a', r, 'haar', level=nlevels)
|
||||
a_ref = pywt.downcoef('a', r.real, 'haar', level=nlevels)
|
||||
a_ref = a_ref + 1j * pywt.downcoef('a', r.imag, 'haar', level=nlevels)
|
||||
assert_allclose(a, a_ref)
|
||||
|
||||
|
||||
def test_downcoef_errs():
|
||||
# invalid part string (not 'a' or 'd')
|
||||
assert_raises(ValueError, pywt.downcoef, 'f', np.ones(16), 'haar')
|
||||
|
||||
|
||||
def test_compare_downcoef_coeffs():
|
||||
rstate = np.random.RandomState(1234)
|
||||
r = rstate.randn(16)
|
||||
# compare downcoef against wavedec outputs
|
||||
for nlevels in [1, 2, 3]:
|
||||
for wavelet in pywt.wavelist():
|
||||
if wavelet in ['cmor', 'shan', 'fbsp']:
|
||||
# skip these CWT families to avoid warnings
|
||||
continue
|
||||
wavelet = pywt.DiscreteContinuousWavelet(wavelet)
|
||||
if isinstance(wavelet, pywt.Wavelet):
|
||||
max_level = pywt.dwt_max_level(r.size, wavelet.dec_len)
|
||||
if nlevels <= max_level:
|
||||
a = pywt.downcoef('a', r, wavelet, level=nlevels)
|
||||
d = pywt.downcoef('d', r, wavelet, level=nlevels)
|
||||
coeffs = pywt.wavedec(r, wavelet, level=nlevels)
|
||||
assert_allclose(a, coeffs[0])
|
||||
assert_allclose(d, coeffs[1])
|
||||
|
||||
|
||||
def test_upcoef_multilevel():
|
||||
rstate = np.random.RandomState(1234)
|
||||
r = rstate.randn(4)
|
||||
nlevels = 3
|
||||
# calling with level=1 nlevels times
|
||||
a1 = r.copy()
|
||||
for i in range(nlevels):
|
||||
a1 = pywt.upcoef('a', a1, 'haar', level=1)
|
||||
# call with level=nlevels once
|
||||
a3 = pywt.upcoef('a', r, 'haar', level=nlevels)
|
||||
assert_allclose(a1, a3)
|
||||
|
||||
|
||||
def test_upcoef_complex():
|
||||
rstate = np.random.RandomState(1234)
|
||||
r = rstate.randn(4) + 1j*rstate.randn(4)
|
||||
nlevels = 3
|
||||
a = pywt.upcoef('a', r, 'haar', level=nlevels)
|
||||
a_ref = pywt.upcoef('a', r.real, 'haar', level=nlevels)
|
||||
a_ref = a_ref + 1j*pywt.upcoef('a', r.imag, 'haar', level=nlevels)
|
||||
assert_allclose(a, a_ref)
|
||||
|
||||
|
||||
def test_upcoef_errs():
|
||||
# invalid part string (not 'a' or 'd')
|
||||
assert_raises(ValueError, pywt.upcoef, 'f', np.ones(4), 'haar')
|
||||
|
||||
|
||||
def test_upcoef_and_downcoef_1d_only():
|
||||
# upcoef and downcoef raise a ValueError if data.ndim > 1d
|
||||
for ndim in [2, 3]:
|
||||
data = np.ones((8, )*ndim)
|
||||
assert_raises(ValueError, pywt.downcoef, 'a', data, 'haar')
|
||||
assert_raises(ValueError, pywt.upcoef, 'a', data, 'haar')
|
||||
|
||||
|
||||
def test_wavelet_repr():
|
||||
from pywt._extensions import _pywt
|
||||
wavelet = _pywt.Wavelet('sym8')
|
||||
|
||||
repr_wavelet = eval(wavelet.__repr__())
|
||||
|
||||
assert_(wavelet.__repr__() == repr_wavelet.__repr__())
|
||||
|
||||
|
||||
def test_dwt_max_level():
|
||||
assert_(pywt.dwt_max_level(16, 2) == 4)
|
||||
assert_(pywt.dwt_max_level(16, 8) == 1)
|
||||
assert_(pywt.dwt_max_level(16, 9) == 1)
|
||||
assert_(pywt.dwt_max_level(16, 10) == 0)
|
||||
assert_(pywt.dwt_max_level(16, np.int8(10)) == 0)
|
||||
assert_(pywt.dwt_max_level(16, 10.) == 0)
|
||||
assert_(pywt.dwt_max_level(16, 18) == 0)
|
||||
|
||||
# accepts discrete Wavelet object or string as well
|
||||
assert_(pywt.dwt_max_level(32, pywt.Wavelet('sym5')) == 1)
|
||||
assert_(pywt.dwt_max_level(32, 'sym5') == 1)
|
||||
|
||||
# string input that is not a discrete wavelet
|
||||
assert_raises(ValueError, pywt.dwt_max_level, 16, 'mexh')
|
||||
|
||||
# filter_len must be an integer >= 2
|
||||
assert_raises(ValueError, pywt.dwt_max_level, 16, 1)
|
||||
assert_raises(ValueError, pywt.dwt_max_level, 16, -1)
|
||||
assert_raises(ValueError, pywt.dwt_max_level, 16, 3.3)
|
||||
|
||||
|
||||
def test_ContinuousWavelet_errs():
|
||||
assert_raises(ValueError, pywt.ContinuousWavelet, 'qwertz')
|
||||
|
||||
|
||||
def test_ContinuousWavelet_repr():
|
||||
from pywt._extensions import _pywt
|
||||
wavelet = _pywt.ContinuousWavelet('gaus2')
|
||||
|
||||
repr_wavelet = eval(wavelet.__repr__())
|
||||
|
||||
assert_(wavelet.__repr__() == repr_wavelet.__repr__())
|
||||
|
||||
|
||||
def test_wavelist():
|
||||
for name in pywt.wavelist(family='coif'):
|
||||
assert_(name.startswith('coif'))
|
||||
|
||||
assert_('cgau7' in pywt.wavelist(kind='continuous'))
|
||||
assert_('sym20' in pywt.wavelist(kind='discrete'))
|
||||
assert_(len(pywt.wavelist(kind='continuous')) +
|
||||
len(pywt.wavelist(kind='discrete')) ==
|
||||
len(pywt.wavelist(kind='all')))
|
||||
|
||||
assert_raises(ValueError, pywt.wavelist, kind='foobar')
|
||||
|
||||
|
||||
def test_wavelet_errormsgs():
|
||||
try:
|
||||
pywt.Wavelet('gaus1')
|
||||
except ValueError as e:
|
||||
assert_(e.args[0].startswith('The `Wavelet` class'))
|
||||
try:
|
||||
pywt.Wavelet('cmord')
|
||||
except ValueError as e:
|
||||
assert_(e.args[0] == "Invalid wavelet name 'cmord'.")
|
||||
105
.CondaPkg/env/Lib/site-packages/pywt/tests/test_concurrent.py
vendored
Normal file
105
.CondaPkg/env/Lib/site-packages/pywt/tests/test_concurrent.py
vendored
Normal file
@@ -0,0 +1,105 @@
|
||||
"""
|
||||
Tests used to verify running PyWavelets transforms in parallel via
|
||||
concurrent.futures.ThreadPoolExecutor does not raise errors.
|
||||
"""
|
||||
|
||||
from __future__ import division, print_function, absolute_import
|
||||
|
||||
import warnings
|
||||
import numpy as np
|
||||
from functools import partial
|
||||
from numpy.testing import assert_array_equal, assert_allclose
|
||||
from pywt._pytest import uses_futures, futures, max_workers
|
||||
|
||||
import pywt
|
||||
|
||||
|
||||
def _assert_all_coeffs_equal(coefs1, coefs2):
|
||||
# return True only if all coefficients of SWT or DWT match over all levels
|
||||
if len(coefs1) != len(coefs2):
|
||||
return False
|
||||
for (c1, c2) in zip(coefs1, coefs2):
|
||||
if isinstance(c1, tuple):
|
||||
# for swt, swt2, dwt, dwt2, wavedec, wavedec2
|
||||
for a1, a2 in zip(c1, c2):
|
||||
assert_array_equal(a1, a2)
|
||||
elif isinstance(c1, dict):
|
||||
# for swtn, dwtn, wavedecn
|
||||
for k, v in c1.items():
|
||||
assert_array_equal(v, c2[k])
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@uses_futures
|
||||
def test_concurrent_swt():
|
||||
# tests error-free concurrent operation (see gh-288)
|
||||
# swt on 1D data calls the Cython swt
|
||||
# other cases call swt_axes
|
||||
with warnings.catch_warnings():
|
||||
# can remove catch_warnings once the swt2 FutureWarning is removed
|
||||
warnings.simplefilter('ignore', FutureWarning)
|
||||
for swt_func, x in zip([pywt.swt, pywt.swt2, pywt.swtn],
|
||||
[np.ones(8), np.eye(16), np.eye(16)]):
|
||||
transform = partial(swt_func, wavelet='haar', level=3)
|
||||
for _ in range(10):
|
||||
arrs = [x.copy() for _ in range(100)]
|
||||
with futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
|
||||
results = list(ex.map(transform, arrs))
|
||||
|
||||
# validate result from one of the concurrent runs
|
||||
expected_result = transform(x)
|
||||
_assert_all_coeffs_equal(expected_result, results[-1])
|
||||
|
||||
|
||||
@uses_futures
|
||||
def test_concurrent_wavedec():
|
||||
# wavedec on 1D data calls the Cython dwt_single
|
||||
# other cases call dwt_axis
|
||||
for wavedec_func, x in zip([pywt.wavedec, pywt.wavedec2, pywt.wavedecn],
|
||||
[np.ones(8), np.eye(16), np.eye(16)]):
|
||||
transform = partial(wavedec_func, wavelet='haar', level=1)
|
||||
for _ in range(10):
|
||||
arrs = [x.copy() for _ in range(100)]
|
||||
with futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
|
||||
results = list(ex.map(transform, arrs))
|
||||
|
||||
# validate result from one of the concurrent runs
|
||||
expected_result = transform(x)
|
||||
_assert_all_coeffs_equal(expected_result, results[-1])
|
||||
|
||||
|
||||
@uses_futures
|
||||
def test_concurrent_dwt():
|
||||
# dwt on 1D data calls the Cython dwt_single
|
||||
# other cases call dwt_axis
|
||||
for dwt_func, x in zip([pywt.dwt, pywt.dwt2, pywt.dwtn],
|
||||
[np.ones(8), np.eye(16), np.eye(16)]):
|
||||
transform = partial(dwt_func, wavelet='haar')
|
||||
for _ in range(10):
|
||||
arrs = [x.copy() for _ in range(100)]
|
||||
with futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
|
||||
results = list(ex.map(transform, arrs))
|
||||
|
||||
# validate result from one of the concurrent runs
|
||||
expected_result = transform(x)
|
||||
_assert_all_coeffs_equal([expected_result, ], [results[-1], ])
|
||||
|
||||
|
||||
@uses_futures
|
||||
def test_concurrent_cwt():
|
||||
atol = rtol = 1e-14
|
||||
time, sst = pywt.data.nino()
|
||||
dt = time[1]-time[0]
|
||||
transform = partial(pywt.cwt, scales=np.arange(1, 4), wavelet='cmor1.5-1',
|
||||
sampling_period=dt)
|
||||
for _ in range(10):
|
||||
arrs = [sst.copy() for _ in range(50)]
|
||||
with futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
|
||||
results = list(ex.map(transform, arrs))
|
||||
|
||||
# validate result from one of the concurrent runs
|
||||
expected_result = transform(sst)
|
||||
for a1, a2 in zip(expected_result, results[-1]):
|
||||
assert_allclose(a1, a2, atol=atol, rtol=rtol)
|
||||
462
.CondaPkg/env/Lib/site-packages/pywt/tests/test_cwt_wavelets.py
vendored
Normal file
462
.CondaPkg/env/Lib/site-packages/pywt/tests/test_cwt_wavelets.py
vendored
Normal file
@@ -0,0 +1,462 @@
|
||||
#!/usr/bin/env python
|
||||
import os
|
||||
from itertools import product
|
||||
import pickle
|
||||
|
||||
from numpy.testing import (assert_allclose, assert_warns, assert_almost_equal,
|
||||
assert_raises, assert_equal)
|
||||
import pytest
|
||||
import numpy as np
|
||||
import pywt
|
||||
|
||||
|
||||
def ref_gaus(LB, UB, N, num):
|
||||
X = np.linspace(LB, UB, N)
|
||||
F0 = (2./np.pi)**(1./4.)*np.exp(-(X**2))
|
||||
if (num == 1):
|
||||
psi = -2.*X*F0
|
||||
elif (num == 2):
|
||||
psi = -2/(3**(1/2))*(-1 + 2*X**2)*F0
|
||||
elif (num == 3):
|
||||
psi = -4/(15**(1/2))*X*(3 - 2*X**2)*F0
|
||||
elif (num == 4):
|
||||
psi = 4/(105**(1/2))*(3 - 12*X**2 + 4*X**4)*F0
|
||||
elif (num == 5):
|
||||
psi = 8/(3*(105**(1/2)))*X*(-15 + 20*X**2 - 4*X**4)*F0
|
||||
elif (num == 6):
|
||||
psi = -8/(3*(1155**(1/2)))*(-15 + 90*X**2 - 60*X**4 + 8*X**6)*F0
|
||||
elif (num == 7):
|
||||
psi = -16/(3*(15015**(1/2)))*X*(105 - 210*X**2 + 84*X**4 - 8*X**6)*F0
|
||||
elif (num == 8):
|
||||
psi = 16/(45*(1001**(1/2)))*(105 - 840*X**2 + 840*X**4 -
|
||||
224*X**6 + 16*X**8)*F0
|
||||
return (psi, X)
|
||||
|
||||
|
||||
def ref_cgau(LB, UB, N, num):
|
||||
X = np.linspace(LB, UB, N)
|
||||
F0 = np.exp(-X**2)
|
||||
F1 = np.exp(-1j*X)
|
||||
F2 = (F1*F0)/(np.exp(-1/2)*2**(1/2)*np.pi**(1/2))**(1/2)
|
||||
if (num == 1):
|
||||
psi = F2*(-1j - 2*X)*2**(1/2)
|
||||
elif (num == 2):
|
||||
psi = 1/3*F2*(-3 + 4j*X + 4*X**2)*6**(1/2)
|
||||
elif (num == 3):
|
||||
psi = 1/15*F2*(7j + 18*X - 12j*X**2 - 8*X**3)*30**(1/2)
|
||||
elif (num == 4):
|
||||
psi = 1/105*F2*(25 - 56j*X - 72*X**2 + 32j*X**3 + 16*X**4)*210**(1/2)
|
||||
elif (num == 5):
|
||||
psi = 1/315*F2*(-81j - 250*X + 280j*X**2 + 240*X**3 -
|
||||
80j*X**4 - 32*X**5)*210**(1/2)
|
||||
elif (num == 6):
|
||||
psi = 1/3465*F2*(-331 + 972j*X + 1500*X**2 - 1120j*X**3 - 720*X**4 +
|
||||
192j*X**5 + 64*X**6)*2310**(1/2)
|
||||
elif (num == 7):
|
||||
psi = 1/45045*F2*(
|
||||
1303j + 4634*X - 6804j*X**2 - 7000*X**3 + 3920j*X**4 + 2016*X**5 -
|
||||
448j*X**6 - 128*X**7)*30030**(1/2)
|
||||
elif (num == 8):
|
||||
psi = 1/45045*F2*(
|
||||
5937 - 20848j*X - 37072*X**2 + 36288j*X**3 + 28000*X**4 -
|
||||
12544j*X**5 - 5376*X**6 + 1024j*X**7 + 256*X**8)*2002**(1/2)
|
||||
|
||||
psi = psi/np.real(np.sqrt(np.real(np.sum(psi*np.conj(psi)))*(X[1] - X[0])))
|
||||
return (psi, X)
|
||||
|
||||
|
||||
def sinc2(x):
|
||||
y = np.ones_like(x)
|
||||
k = np.where(x)[0]
|
||||
y[k] = np.sin(np.pi*x[k])/(np.pi*x[k])
|
||||
return y
|
||||
|
||||
|
||||
def ref_shan(LB, UB, N, Fb, Fc):
|
||||
x = np.linspace(LB, UB, N)
|
||||
psi = np.sqrt(Fb)*(sinc2(Fb*x)*np.exp(2j*np.pi*Fc*x))
|
||||
return (psi, x)
|
||||
|
||||
|
||||
def ref_fbsp(LB, UB, N, m, Fb, Fc):
|
||||
x = np.linspace(LB, UB, N)
|
||||
psi = np.sqrt(Fb)*((sinc2(Fb*x/m)**m)*np.exp(2j*np.pi*Fc*x))
|
||||
return (psi, x)
|
||||
|
||||
|
||||
def ref_cmor(LB, UB, N, Fb, Fc):
|
||||
x = np.linspace(LB, UB, N)
|
||||
psi = ((np.pi*Fb)**(-0.5))*np.exp(2j*np.pi*Fc*x)*np.exp(-(x**2)/Fb)
|
||||
return (psi, x)
|
||||
|
||||
|
||||
def ref_morl(LB, UB, N):
|
||||
x = np.linspace(LB, UB, N)
|
||||
psi = np.exp(-(x**2)/2)*np.cos(5*x)
|
||||
return (psi, x)
|
||||
|
||||
|
||||
def ref_mexh(LB, UB, N):
|
||||
x = np.linspace(LB, UB, N)
|
||||
psi = (2/(np.sqrt(3)*np.pi**0.25))*np.exp(-(x**2)/2)*(1 - (x**2))
|
||||
return (psi, x)
|
||||
|
||||
|
||||
def test_gaus():
|
||||
LB = -5
|
||||
UB = 5
|
||||
N = 1000
|
||||
for num in np.arange(1, 9):
|
||||
[psi, x] = ref_gaus(LB, UB, N, num)
|
||||
w = pywt.ContinuousWavelet("gaus" + str(num))
|
||||
PSI, X = w.wavefun(length=N)
|
||||
|
||||
assert_allclose(np.real(PSI), np.real(psi))
|
||||
assert_allclose(np.imag(PSI), np.imag(psi))
|
||||
assert_allclose(X, x)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
|
||||
def test_continuous_wavelet_dtype(dtype):
|
||||
wavelet = pywt.ContinuousWavelet('cmor1.5-1.0', dtype)
|
||||
int_psi, x = pywt.integrate_wavelet(wavelet)
|
||||
assert int_psi.real.dtype == dtype
|
||||
assert x.dtype == dtype
|
||||
|
||||
|
||||
def test_continuous_wavelet_invalid_dtype():
|
||||
with pytest.raises(ValueError):
|
||||
pywt.ContinuousWavelet('gaus5', np.complex64)
|
||||
with pytest.raises(ValueError):
|
||||
pywt.ContinuousWavelet('gaus5', np.int_)
|
||||
|
||||
|
||||
def test_cgau():
|
||||
LB = -5
|
||||
UB = 5
|
||||
N = 1000
|
||||
for num in np.arange(1, 9):
|
||||
[psi, x] = ref_cgau(LB, UB, N, num)
|
||||
w = pywt.ContinuousWavelet("cgau" + str(num))
|
||||
PSI, X = w.wavefun(length=N)
|
||||
|
||||
assert_allclose(np.real(PSI), np.real(psi))
|
||||
assert_allclose(np.imag(PSI), np.imag(psi))
|
||||
assert_allclose(X, x)
|
||||
|
||||
|
||||
def test_shan():
|
||||
LB = -20
|
||||
UB = 20
|
||||
N = 1000
|
||||
Fb = 1
|
||||
Fc = 1.5
|
||||
|
||||
[psi, x] = ref_shan(LB, UB, N, Fb, Fc)
|
||||
w = pywt.ContinuousWavelet("shan{}-{}".format(Fb, Fc))
|
||||
assert_almost_equal(w.center_frequency, Fc)
|
||||
assert_almost_equal(w.bandwidth_frequency, Fb)
|
||||
w.upper_bound = UB
|
||||
w.lower_bound = LB
|
||||
PSI, X = w.wavefun(length=N)
|
||||
|
||||
assert_allclose(np.real(PSI), np.real(psi), atol=1e-15)
|
||||
assert_allclose(np.imag(PSI), np.imag(psi), atol=1e-15)
|
||||
assert_allclose(X, x, atol=1e-15)
|
||||
|
||||
LB = -20
|
||||
UB = 20
|
||||
N = 1000
|
||||
Fb = 1.5
|
||||
Fc = 1
|
||||
|
||||
[psi, x] = ref_shan(LB, UB, N, Fb, Fc)
|
||||
w = pywt.ContinuousWavelet("shan{}-{}".format(Fb, Fc))
|
||||
assert_almost_equal(w.center_frequency, Fc)
|
||||
assert_almost_equal(w.bandwidth_frequency, Fb)
|
||||
w.upper_bound = UB
|
||||
w.lower_bound = LB
|
||||
PSI, X = w.wavefun(length=N)
|
||||
|
||||
assert_allclose(np.real(PSI), np.real(psi), atol=1e-15)
|
||||
assert_allclose(np.imag(PSI), np.imag(psi), atol=1e-15)
|
||||
assert_allclose(X, x, atol=1e-15)
|
||||
|
||||
|
||||
def test_cmor():
|
||||
LB = -20
|
||||
UB = 20
|
||||
N = 1000
|
||||
Fb = 1
|
||||
Fc = 1.5
|
||||
|
||||
[psi, x] = ref_cmor(LB, UB, N, Fb, Fc)
|
||||
w = pywt.ContinuousWavelet("cmor{}-{}".format(Fb, Fc))
|
||||
assert_almost_equal(w.center_frequency, Fc)
|
||||
assert_almost_equal(w.bandwidth_frequency, Fb)
|
||||
w.upper_bound = UB
|
||||
w.lower_bound = LB
|
||||
PSI, X = w.wavefun(length=N)
|
||||
|
||||
assert_allclose(np.real(PSI), np.real(psi), atol=1e-15)
|
||||
assert_allclose(np.imag(PSI), np.imag(psi), atol=1e-15)
|
||||
assert_allclose(X, x, atol=1e-15)
|
||||
|
||||
LB = -20
|
||||
UB = 20
|
||||
N = 1000
|
||||
Fb = 1.5
|
||||
Fc = 1
|
||||
|
||||
[psi, x] = ref_cmor(LB, UB, N, Fb, Fc)
|
||||
w = pywt.ContinuousWavelet("cmor{}-{}".format(Fb, Fc))
|
||||
assert_almost_equal(w.center_frequency, Fc)
|
||||
assert_almost_equal(w.bandwidth_frequency, Fb)
|
||||
w.upper_bound = UB
|
||||
w.lower_bound = LB
|
||||
PSI, X = w.wavefun(length=N)
|
||||
|
||||
assert_allclose(np.real(PSI), np.real(psi), atol=1e-15)
|
||||
assert_allclose(np.imag(PSI), np.imag(psi), atol=1e-15)
|
||||
assert_allclose(X, x, atol=1e-15)
|
||||
|
||||
|
||||
def test_fbsp():
|
||||
LB = -20
|
||||
UB = 20
|
||||
N = 1000
|
||||
M = 2
|
||||
Fb = 1
|
||||
Fc = 1.5
|
||||
|
||||
[psi, x] = ref_fbsp(LB, UB, N, M, Fb, Fc)
|
||||
|
||||
w = pywt.ContinuousWavelet("fbsp{}-{}-{}".format(M, Fb, Fc))
|
||||
assert_almost_equal(w.center_frequency, Fc)
|
||||
assert_almost_equal(w.bandwidth_frequency, Fb)
|
||||
w.fbsp_order = M
|
||||
w.upper_bound = UB
|
||||
w.lower_bound = LB
|
||||
PSI, X = w.wavefun(length=N)
|
||||
|
||||
assert_allclose(np.real(PSI), np.real(psi), atol=1e-15)
|
||||
assert_allclose(np.imag(PSI), np.imag(psi), atol=1e-15)
|
||||
assert_allclose(X, x, atol=1e-15)
|
||||
|
||||
LB = -20
|
||||
UB = 20
|
||||
N = 1000
|
||||
M = 2
|
||||
Fb = 1.5
|
||||
Fc = 1
|
||||
|
||||
[psi, x] = ref_fbsp(LB, UB, N, M, Fb, Fc)
|
||||
w = pywt.ContinuousWavelet("fbsp{}-{}-{}".format(M, Fb, Fc))
|
||||
assert_almost_equal(w.center_frequency, Fc)
|
||||
assert_almost_equal(w.bandwidth_frequency, Fb)
|
||||
w.fbsp_order = M
|
||||
w.upper_bound = UB
|
||||
w.lower_bound = LB
|
||||
PSI, X = w.wavefun(length=N)
|
||||
|
||||
assert_allclose(np.real(PSI), np.real(psi), atol=1e-15)
|
||||
assert_allclose(np.imag(PSI), np.imag(psi), atol=1e-15)
|
||||
assert_allclose(X, x, atol=1e-15)
|
||||
|
||||
LB = -20
|
||||
UB = 20
|
||||
N = 1000
|
||||
M = 3
|
||||
Fb = 1.5
|
||||
Fc = 1.2
|
||||
|
||||
[psi, x] = ref_fbsp(LB, UB, N, M, Fb, Fc)
|
||||
w = pywt.ContinuousWavelet("fbsp{}-{}-{}".format(M, Fb, Fc))
|
||||
assert_almost_equal(w.center_frequency, Fc)
|
||||
assert_almost_equal(w.bandwidth_frequency, Fb)
|
||||
w.fbsp_order = M
|
||||
w.upper_bound = UB
|
||||
w.lower_bound = LB
|
||||
PSI, X = w.wavefun(length=N)
|
||||
# TODO: investigate why atol = 1e-5 is necessary
|
||||
assert_allclose(np.real(PSI), np.real(psi), atol=1e-5)
|
||||
assert_allclose(np.imag(PSI), np.imag(psi), atol=1e-5)
|
||||
assert_allclose(X, x, atol=1e-15)
|
||||
|
||||
|
||||
def test_morl():
|
||||
LB = -5
|
||||
UB = 5
|
||||
N = 1000
|
||||
|
||||
[psi, x] = ref_morl(LB, UB, N)
|
||||
w = pywt.ContinuousWavelet("morl")
|
||||
w.upper_bound = UB
|
||||
w.lower_bound = LB
|
||||
PSI, X = w.wavefun(length=N)
|
||||
|
||||
assert_allclose(np.real(PSI), np.real(psi))
|
||||
assert_allclose(np.imag(PSI), np.imag(psi))
|
||||
assert_allclose(X, x)
|
||||
|
||||
|
||||
def test_mexh():
|
||||
LB = -5
|
||||
UB = 5
|
||||
N = 1000
|
||||
|
||||
[psi, x] = ref_mexh(LB, UB, N)
|
||||
w = pywt.ContinuousWavelet("mexh")
|
||||
w.upper_bound = UB
|
||||
w.lower_bound = LB
|
||||
PSI, X = w.wavefun(length=N)
|
||||
|
||||
assert_allclose(np.real(PSI), np.real(psi))
|
||||
assert_allclose(np.imag(PSI), np.imag(psi))
|
||||
assert_allclose(X, x)
|
||||
|
||||
LB = -5
|
||||
UB = 5
|
||||
N = 1001
|
||||
|
||||
[psi, x] = ref_mexh(LB, UB, N)
|
||||
w = pywt.ContinuousWavelet("mexh")
|
||||
w.upper_bound = UB
|
||||
w.lower_bound = LB
|
||||
PSI, X = w.wavefun(length=N)
|
||||
|
||||
assert_allclose(np.real(PSI), np.real(psi))
|
||||
assert_allclose(np.imag(PSI), np.imag(psi))
|
||||
assert_allclose(X, x)
|
||||
|
||||
|
||||
def test_cwt_parameters_in_names():
|
||||
|
||||
for func in [pywt.ContinuousWavelet, pywt.DiscreteContinuousWavelet]:
|
||||
for name in ['fbsp', 'cmor', 'shan']:
|
||||
# additional parameters should be specified within the name
|
||||
assert_warns(FutureWarning, func, name)
|
||||
|
||||
for name in ['cmor', 'shan']:
|
||||
# valid names
|
||||
func(name + '1.5-1.0')
|
||||
func(name + '1-4')
|
||||
|
||||
# invalid names
|
||||
assert_raises(ValueError, func, name + '1.0')
|
||||
assert_raises(ValueError, func, name + 'B-C')
|
||||
assert_raises(ValueError, func, name + '1.0-1.0-1.0')
|
||||
|
||||
# valid names
|
||||
func('fbsp1-1.5-1.0')
|
||||
func('fbsp1.0-1.5-1')
|
||||
func('fbsp2-5-1')
|
||||
|
||||
# invalid name (non-integer order)
|
||||
assert_raises(ValueError, func, 'fbsp1.5-1-1')
|
||||
assert_raises(ValueError, func, 'fbspM-B-C')
|
||||
|
||||
# invalid name (too few or too many params)
|
||||
assert_raises(ValueError, func, 'fbsp1.0')
|
||||
assert_raises(ValueError, func, 'fbsp1.0-0.4')
|
||||
assert_raises(ValueError, func, 'fbsp1-1-1-1')
|
||||
|
||||
|
||||
@pytest.mark.parametrize('dtype, tol, method',
|
||||
[(np.float32, 1e-5, 'conv'),
|
||||
(np.float32, 1e-5, 'fft'),
|
||||
(np.float64, 1e-13, 'conv'),
|
||||
(np.float64, 1e-13, 'fft')])
|
||||
def test_cwt_complex(dtype, tol, method):
|
||||
time, sst = pywt.data.nino()
|
||||
sst = np.asarray(sst, dtype=dtype)
|
||||
dt = time[1] - time[0]
|
||||
wavelet = 'cmor1.5-1.0'
|
||||
scales = np.arange(1, 32)
|
||||
|
||||
# real-valued tranfsorm as a reference
|
||||
[cfs, f] = pywt.cwt(sst, scales, wavelet, dt, method=method)
|
||||
|
||||
# verify same precision
|
||||
assert_equal(cfs.real.dtype, sst.dtype)
|
||||
|
||||
# complex-valued transform equals sum of the transforms of the real
|
||||
# and imaginary components
|
||||
sst_complex = sst + 1j*sst
|
||||
[cfs_complex, f] = pywt.cwt(sst_complex, scales, wavelet, dt,
|
||||
method=method)
|
||||
assert_allclose(cfs + 1j*cfs, cfs_complex, atol=tol, rtol=tol)
|
||||
# verify dtype is preserved
|
||||
assert_equal(cfs_complex.dtype, sst_complex.dtype)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('axis, method', product([0, 1], ['conv', 'fft']))
|
||||
def test_cwt_batch(axis, method):
|
||||
dtype = np.float64
|
||||
time, sst = pywt.data.nino()
|
||||
n_batch = 8
|
||||
batch_axis = 1 - axis
|
||||
sst1 = np.asarray(sst, dtype=dtype)
|
||||
sst = np.stack((sst1, ) * n_batch, axis=batch_axis)
|
||||
dt = time[1] - time[0]
|
||||
wavelet = 'cmor1.5-1.0'
|
||||
scales = np.arange(1, 32)
|
||||
|
||||
# non-batch transform as reference
|
||||
[cfs1, f] = pywt.cwt(sst1, scales, wavelet, dt, method=method, axis=axis)
|
||||
|
||||
shape_in = sst.shape
|
||||
[cfs, f] = pywt.cwt(sst, scales, wavelet, dt, method=method, axis=axis)
|
||||
|
||||
# shape of input is not modified
|
||||
assert_equal(shape_in, sst.shape)
|
||||
|
||||
# verify same precision
|
||||
assert_equal(cfs.real.dtype, sst.dtype)
|
||||
|
||||
# verify expected shape
|
||||
assert_equal(cfs.shape[0], len(scales))
|
||||
assert_equal(cfs.shape[1 + batch_axis], n_batch)
|
||||
assert_equal(cfs.shape[1 + axis], sst.shape[axis])
|
||||
|
||||
# batch result on stacked input is the same as stacked 1d result
|
||||
assert_almost_equal(cfs, np.stack((cfs1,) * n_batch, axis=batch_axis + 1),
|
||||
decimal=12)
|
||||
|
||||
|
||||
def test_cwt_small_scales():
|
||||
data = np.zeros(32)
|
||||
|
||||
# A scale of 0.1 was chosen specifically to give a filter of length 2 for
|
||||
# mexh. This corner case should not raise an error.
|
||||
cfs, f = pywt.cwt(data, scales=0.1, wavelet='mexh')
|
||||
assert_allclose(cfs, np.zeros_like(cfs))
|
||||
|
||||
# extremely short scale factors raise a ValueError
|
||||
assert_raises(ValueError, pywt.cwt, data, scales=0.01, wavelet='mexh')
|
||||
|
||||
|
||||
def test_cwt_method_fft():
|
||||
rstate = np.random.RandomState(1)
|
||||
data = rstate.randn(50)
|
||||
data[15] = 1.
|
||||
scales = np.arange(1, 64)
|
||||
wavelet = 'cmor1.5-1.0'
|
||||
|
||||
# build a reference cwt with the legacy np.conv() method
|
||||
cfs_conv, _ = pywt.cwt(data, scales, wavelet, method='conv')
|
||||
|
||||
# compare with the fft based convolution
|
||||
cfs_fft, _ = pywt.cwt(data, scales, wavelet, method='fft')
|
||||
assert_allclose(cfs_conv, cfs_fft, rtol=0, atol=1e-13)
|
||||
|
||||
|
||||
def test_continuous_wavelet_pickle(tmpdir):
|
||||
wavelet = pywt.ContinuousWavelet('cmor1.5-1.0')
|
||||
filename = os.path.join(tmpdir, 'cwav.pickle')
|
||||
with open(filename, 'wb') as f:
|
||||
pickle.dump(wavelet, f)
|
||||
with open(filename, 'rb') as f:
|
||||
wavelet2 = pickle.load(f)
|
||||
assert isinstance(wavelet2, pywt.ContinuousWavelet)
|
||||
assert wavelet2.name == wavelet.name
|
||||
77
.CondaPkg/env/Lib/site-packages/pywt/tests/test_data.py
vendored
Normal file
77
.CondaPkg/env/Lib/site-packages/pywt/tests/test_data.py
vendored
Normal file
@@ -0,0 +1,77 @@
|
||||
import os
|
||||
import numpy as np
|
||||
from numpy.testing import assert_allclose, assert_raises, assert_
|
||||
|
||||
import pywt.data
|
||||
|
||||
data_dir = os.path.join(os.path.dirname(__file__), 'data')
|
||||
wavelab_data_file = os.path.join(data_dir, 'wavelab_test_signals.npz')
|
||||
wavelab_result_dict = np.load(wavelab_data_file)
|
||||
|
||||
|
||||
def test_data_aero():
|
||||
aero = pywt.data.aero()
|
||||
|
||||
ref = np.array([[178, 178, 179],
|
||||
[170, 173, 171],
|
||||
[185, 174, 171]])
|
||||
|
||||
assert_allclose(aero[:3, :3], ref)
|
||||
|
||||
|
||||
def test_data_ascent():
|
||||
ascent = pywt.data.ascent()
|
||||
|
||||
ref = np.array([[83, 83, 83],
|
||||
[82, 82, 83],
|
||||
[80, 81, 83]])
|
||||
|
||||
assert_allclose(ascent[:3, :3], ref)
|
||||
|
||||
|
||||
def test_data_camera():
|
||||
camera = pywt.data.camera()
|
||||
|
||||
ref = np.array([[200, 200, 200],
|
||||
[200, 199, 199],
|
||||
[199, 199, 199]])
|
||||
|
||||
assert_allclose(camera[:3, :3], ref)
|
||||
|
||||
|
||||
def test_data_ecg():
|
||||
ecg = pywt.data.ecg()
|
||||
|
||||
ref = np.array([-86, -87, -87])
|
||||
|
||||
assert_allclose(ecg[:3], ref)
|
||||
|
||||
|
||||
def test_wavelab_signals():
|
||||
"""Comparison with results generated using WaveLab"""
|
||||
rtol = atol = 1e-12
|
||||
|
||||
# get a list of the available signals
|
||||
available_signals = pywt.data.demo_signal('list')
|
||||
assert_('Doppler' in available_signals)
|
||||
|
||||
for signal in available_signals:
|
||||
# reference dictionary has lowercase names for the keys
|
||||
key = signal.replace('-', '_').lower()
|
||||
val = wavelab_result_dict[key]
|
||||
if key in ['gabor', 'sineoneoverx']:
|
||||
# these functions do not allow a size to be provided
|
||||
assert_allclose(val, pywt.data.demo_signal(signal),
|
||||
rtol=rtol, atol=atol)
|
||||
assert_raises(ValueError, pywt.data.demo_signal, key, val.size)
|
||||
else:
|
||||
assert_allclose(val, pywt.data.demo_signal(signal, val.size),
|
||||
rtol=rtol, atol=atol)
|
||||
# these functions require a size to be provided
|
||||
assert_raises(ValueError, pywt.data.demo_signal, key)
|
||||
|
||||
# ValueError on unrecognized signal type
|
||||
assert_raises(ValueError, pywt.data.demo_signal, 'unknown_signal', 512)
|
||||
|
||||
# ValueError on invalid length
|
||||
assert_raises(ValueError, pywt.data.demo_signal, 'Doppler', 0)
|
||||
89
.CondaPkg/env/Lib/site-packages/pywt/tests/test_deprecations.py
vendored
Normal file
89
.CondaPkg/env/Lib/site-packages/pywt/tests/test_deprecations.py
vendored
Normal file
@@ -0,0 +1,89 @@
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
from numpy.testing import assert_warns, assert_array_equal
|
||||
|
||||
import pywt
|
||||
|
||||
|
||||
def test_intwave_deprecation():
|
||||
wavelet = pywt.Wavelet('db3')
|
||||
assert_warns(DeprecationWarning, pywt.intwave, wavelet)
|
||||
|
||||
|
||||
def test_centrfrq_deprecation():
|
||||
wavelet = pywt.Wavelet('db3')
|
||||
assert_warns(DeprecationWarning, pywt.centrfrq, wavelet)
|
||||
|
||||
|
||||
def test_scal2frq_deprecation():
|
||||
wavelet = pywt.Wavelet('db3')
|
||||
assert_warns(DeprecationWarning, pywt.scal2frq, wavelet, 1)
|
||||
|
||||
|
||||
def test_orthfilt_deprecation():
|
||||
assert_warns(DeprecationWarning, pywt.orthfilt, range(6))
|
||||
|
||||
|
||||
def test_integrate_wave_tuple():
|
||||
sig = [0, 1, 2, 3]
|
||||
xgrid = [0, 1, 2, 3]
|
||||
assert_warns(DeprecationWarning, pywt.integrate_wavelet, (sig, xgrid))
|
||||
|
||||
|
||||
old_modes = ['zpd',
|
||||
'cpd',
|
||||
'sym',
|
||||
'ppd',
|
||||
'sp1',
|
||||
'per',
|
||||
]
|
||||
|
||||
|
||||
def test_MODES_from_object_deprecation():
|
||||
for mode in old_modes:
|
||||
assert_warns(DeprecationWarning, pywt.Modes.from_object, mode)
|
||||
|
||||
|
||||
def test_MODES_attributes_deprecation():
|
||||
def get_mode(Modes, name):
|
||||
return getattr(Modes, name)
|
||||
|
||||
for mode in old_modes:
|
||||
assert_warns(DeprecationWarning, get_mode, pywt.Modes, mode)
|
||||
|
||||
|
||||
def test_MODES_deprecation_new():
|
||||
def use_MODES_new():
|
||||
return pywt.MODES.symmetric
|
||||
|
||||
assert_warns(DeprecationWarning, use_MODES_new)
|
||||
|
||||
|
||||
def test_MODES_deprecation_old():
|
||||
def use_MODES_old():
|
||||
return pywt.MODES.sym
|
||||
|
||||
assert_warns(DeprecationWarning, use_MODES_old)
|
||||
|
||||
|
||||
def test_MODES_deprecation_getattr():
|
||||
def use_MODES_new():
|
||||
return getattr(pywt.MODES, 'symmetric')
|
||||
|
||||
assert_warns(DeprecationWarning, use_MODES_new)
|
||||
|
||||
|
||||
def test_mode_equivalence():
|
||||
old_new = [('zpd', 'zero'),
|
||||
('cpd', 'constant'),
|
||||
('sym', 'symmetric'),
|
||||
('ppd', 'periodic'),
|
||||
('sp1', 'smooth'),
|
||||
('per', 'periodization')]
|
||||
x = np.arange(8.)
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore', DeprecationWarning)
|
||||
for old, new in old_new:
|
||||
assert_array_equal(pywt.dwt(x, 'db2', mode=old),
|
||||
pywt.dwt(x, 'db2', mode=new))
|
||||
25
.CondaPkg/env/Lib/site-packages/pywt/tests/test_doc.py
vendored
Normal file
25
.CondaPkg/env/Lib/site-packages/pywt/tests/test_doc.py
vendored
Normal file
@@ -0,0 +1,25 @@
|
||||
from __future__ import division, print_function, absolute_import
|
||||
|
||||
import doctest
|
||||
import glob
|
||||
import os
|
||||
import unittest
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
np.set_printoptions(legacy='1.13')
|
||||
except TypeError:
|
||||
pass
|
||||
|
||||
pdir = os.path.pardir
|
||||
docs_base = os.path.abspath(os.path.join(os.path.dirname(__file__),
|
||||
pdir, pdir, "doc", "source"))
|
||||
|
||||
files = glob.glob(os.path.join(docs_base, "*.rst")) + \
|
||||
glob.glob(os.path.join(docs_base, "*", "*.rst"))
|
||||
|
||||
suite = doctest.DocFileSuite(*files, module_relative=False, encoding="utf-8")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.TextTestRunner().run(suite)
|
||||
307
.CondaPkg/env/Lib/site-packages/pywt/tests/test_dwt_idwt.py
vendored
Normal file
307
.CondaPkg/env/Lib/site-packages/pywt/tests/test_dwt_idwt.py
vendored
Normal file
@@ -0,0 +1,307 @@
|
||||
#!/usr/bin/env python
|
||||
from __future__ import division, print_function, absolute_import
|
||||
|
||||
import numpy as np
|
||||
from numpy.testing import (assert_allclose, assert_, assert_raises,
|
||||
assert_array_equal)
|
||||
import pywt
|
||||
|
||||
# Check that float32, float64, complex64, complex128 are preserved.
|
||||
# Other real types get converted to float64.
|
||||
# complex256 gets converted to complex128
|
||||
dtypes_in = [np.int8, np.float16, np.float32, np.float64, np.complex64,
|
||||
np.complex128]
|
||||
dtypes_out = [np.float64, np.float32, np.float32, np.float64, np.complex64,
|
||||
np.complex128]
|
||||
|
||||
# test complex256 as well if it is available
|
||||
try:
|
||||
dtypes_in += [np.complex256, ]
|
||||
dtypes_out += [np.complex128, ]
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
|
||||
def test_dwt_idwt_basic():
|
||||
x = [3, 7, 1, 1, -2, 5, 4, 6]
|
||||
cA, cD = pywt.dwt(x, 'db2')
|
||||
cA_expect = [5.65685425, 7.39923721, 0.22414387, 3.33677403, 7.77817459]
|
||||
cD_expect = [-2.44948974, -1.60368225, -4.44140056, -0.41361256,
|
||||
1.22474487]
|
||||
assert_allclose(cA, cA_expect)
|
||||
assert_allclose(cD, cD_expect)
|
||||
|
||||
x_roundtrip = pywt.idwt(cA, cD, 'db2')
|
||||
assert_allclose(x_roundtrip, x, rtol=1e-10)
|
||||
|
||||
# mismatched dtypes OK
|
||||
x_roundtrip2 = pywt.idwt(cA.astype(np.float64), cD.astype(np.float32),
|
||||
'db2')
|
||||
assert_allclose(x_roundtrip2, x, rtol=1e-7, atol=1e-7)
|
||||
assert_(x_roundtrip2.dtype == np.float64)
|
||||
|
||||
|
||||
def test_idwt_mixed_complex_dtype():
|
||||
x = np.arange(8).astype(float)
|
||||
x = x + 1j*x[::-1]
|
||||
cA, cD = pywt.dwt(x, 'db2')
|
||||
|
||||
x_roundtrip = pywt.idwt(cA, cD, 'db2')
|
||||
assert_allclose(x_roundtrip, x, rtol=1e-10)
|
||||
|
||||
# mismatched dtypes OK
|
||||
x_roundtrip2 = pywt.idwt(cA.astype(np.complex128), cD.astype(np.complex64),
|
||||
'db2')
|
||||
assert_allclose(x_roundtrip2, x, rtol=1e-7, atol=1e-7)
|
||||
assert_(x_roundtrip2.dtype == np.complex128)
|
||||
|
||||
|
||||
def test_dwt_idwt_dtypes():
|
||||
wavelet = pywt.Wavelet('haar')
|
||||
for dt_in, dt_out in zip(dtypes_in, dtypes_out):
|
||||
x = np.ones(4, dtype=dt_in)
|
||||
errmsg = "wrong dtype returned for {0} input".format(dt_in)
|
||||
|
||||
cA, cD = pywt.dwt(x, wavelet)
|
||||
assert_(cA.dtype == cD.dtype == dt_out, "dwt: " + errmsg)
|
||||
|
||||
x_roundtrip = pywt.idwt(cA, cD, wavelet)
|
||||
assert_(x_roundtrip.dtype == dt_out, "idwt: " + errmsg)
|
||||
|
||||
|
||||
def test_dwt_idwt_basic_complex():
|
||||
x = np.asarray([3, 7, 1, 1, -2, 5, 4, 6])
|
||||
x = x + 0.5j*x
|
||||
cA, cD = pywt.dwt(x, 'db2')
|
||||
cA_expect = np.asarray([5.65685425, 7.39923721, 0.22414387, 3.33677403,
|
||||
7.77817459])
|
||||
cA_expect = cA_expect + 0.5j*cA_expect
|
||||
cD_expect = np.asarray([-2.44948974, -1.60368225, -4.44140056, -0.41361256,
|
||||
1.22474487])
|
||||
cD_expect = cD_expect + 0.5j*cD_expect
|
||||
assert_allclose(cA, cA_expect)
|
||||
assert_allclose(cD, cD_expect)
|
||||
|
||||
x_roundtrip = pywt.idwt(cA, cD, 'db2')
|
||||
assert_allclose(x_roundtrip, x, rtol=1e-10)
|
||||
|
||||
|
||||
def test_dwt_idwt_partial_complex():
|
||||
x = np.asarray([3, 7, 1, 1, -2, 5, 4, 6])
|
||||
x = x + 0.5j*x
|
||||
|
||||
cA, cD = pywt.dwt(x, 'haar')
|
||||
cA_rec_expect = np.array([5.0+2.5j, 5.0+2.5j, 1.0+0.5j, 1.0+0.5j,
|
||||
1.5+0.75j, 1.5+0.75j, 5.0+2.5j, 5.0+2.5j])
|
||||
cA_rec = pywt.idwt(cA, None, 'haar')
|
||||
assert_allclose(cA_rec, cA_rec_expect)
|
||||
|
||||
cD_rec_expect = np.array([-2.0-1.0j, 2.0+1.0j, 0.0+0.0j, 0.0+0.0j,
|
||||
-3.5-1.75j, 3.5+1.75j, -1.0-0.5j, 1.0+0.5j])
|
||||
cD_rec = pywt.idwt(None, cD, 'haar')
|
||||
assert_allclose(cD_rec, cD_rec_expect)
|
||||
|
||||
assert_allclose(cA_rec + cD_rec, x)
|
||||
|
||||
|
||||
def test_dwt_wavelet_kwd():
|
||||
x = np.array([3, 7, 1, 1, -2, 5, 4, 6])
|
||||
w = pywt.Wavelet('sym3')
|
||||
cA, cD = pywt.dwt(x, wavelet=w, mode='constant')
|
||||
cA_expect = [4.38354585, 3.80302657, 7.31813271, -0.58565539, 4.09727044,
|
||||
7.81994027]
|
||||
cD_expect = [-1.33068221, -2.78795192, -3.16825651, -0.67715519,
|
||||
-0.09722957, -0.07045258]
|
||||
assert_allclose(cA, cA_expect)
|
||||
assert_allclose(cD, cD_expect)
|
||||
|
||||
|
||||
def test_dwt_coeff_len():
|
||||
x = np.array([3, 7, 1, 1, -2, 5, 4, 6])
|
||||
w = pywt.Wavelet('sym3')
|
||||
ln_modes = [pywt.dwt_coeff_len(len(x), w.dec_len, mode) for mode in
|
||||
pywt.Modes.modes]
|
||||
|
||||
expected_result = [6, ] * len(pywt.Modes.modes)
|
||||
expected_result[pywt.Modes.modes.index('periodization')] = 4
|
||||
|
||||
assert_allclose(ln_modes, expected_result)
|
||||
ln_modes = [pywt.dwt_coeff_len(len(x), w, mode) for mode in
|
||||
pywt.Modes.modes]
|
||||
assert_allclose(ln_modes, expected_result)
|
||||
|
||||
|
||||
def test_idwt_none_input():
|
||||
# None input equals arrays of zeros of the right length
|
||||
res1 = pywt.idwt([1, 2, 0, 1], None, 'db2', 'symmetric')
|
||||
res2 = pywt.idwt([1, 2, 0, 1], [0, 0, 0, 0], 'db2', 'symmetric')
|
||||
assert_allclose(res1, res2, rtol=1e-15, atol=1e-15)
|
||||
|
||||
res1 = pywt.idwt(None, [1, 2, 0, 1], 'db2', 'symmetric')
|
||||
res2 = pywt.idwt([0, 0, 0, 0], [1, 2, 0, 1], 'db2', 'symmetric')
|
||||
assert_allclose(res1, res2, rtol=1e-15, atol=1e-15)
|
||||
|
||||
# Only one argument at a time can be None
|
||||
assert_raises(ValueError, pywt.idwt, None, None, 'db2', 'symmetric')
|
||||
|
||||
|
||||
def test_idwt_invalid_input():
|
||||
# Too short, min length is 4 for 'db4':
|
||||
assert_raises(ValueError, pywt.idwt, [1, 2, 4], [4, 1, 3], 'db4', 'symmetric')
|
||||
|
||||
|
||||
def test_dwt_single_axis():
|
||||
x = [[3, 7, 1, 1],
|
||||
[-2, 5, 4, 6]]
|
||||
|
||||
cA, cD = pywt.dwt(x, 'db2', axis=-1)
|
||||
|
||||
cA0, cD0 = pywt.dwt(x[0], 'db2')
|
||||
cA1, cD1 = pywt.dwt(x[1], 'db2')
|
||||
|
||||
assert_allclose(cA[0], cA0)
|
||||
assert_allclose(cA[1], cA1)
|
||||
|
||||
assert_allclose(cD[0], cD0)
|
||||
assert_allclose(cD[1], cD1)
|
||||
|
||||
|
||||
def test_idwt_single_axis():
|
||||
x = [[3, 7, 1, 1],
|
||||
[-2, 5, 4, 6]]
|
||||
|
||||
x = np.asarray(x)
|
||||
x = x + 1j*x # test with complex data
|
||||
cA, cD = pywt.dwt(x, 'db2', axis=-1)
|
||||
|
||||
x0 = pywt.idwt(cA[0], cD[0], 'db2', axis=-1)
|
||||
x1 = pywt.idwt(cA[1], cD[1], 'db2', axis=-1)
|
||||
|
||||
assert_allclose(x[0], x0)
|
||||
assert_allclose(x[1], x1)
|
||||
|
||||
def test_dwt_invalid_input():
|
||||
x = np.arange(1)
|
||||
assert_raises(ValueError, pywt.dwt, x, 'db2', 'reflect')
|
||||
assert_raises(ValueError, pywt.dwt, x, 'haar', 'antireflect')
|
||||
|
||||
|
||||
def test_dwt_axis_arg():
|
||||
x = [[3, 7, 1, 1],
|
||||
[-2, 5, 4, 6]]
|
||||
|
||||
cA_, cD_ = pywt.dwt(x, 'db2', axis=-1)
|
||||
cA, cD = pywt.dwt(x, 'db2', axis=1)
|
||||
|
||||
assert_allclose(cA_, cA)
|
||||
assert_allclose(cD_, cD)
|
||||
|
||||
def test_dwt_axis_invalid_input():
|
||||
x = np.ones((3,1))
|
||||
assert_raises(ValueError, pywt.dwt, x, 'db2', 'reflect')
|
||||
|
||||
def test_idwt_axis_arg():
|
||||
x = [[3, 7, 1, 1],
|
||||
[-2, 5, 4, 6]]
|
||||
|
||||
cA, cD = pywt.dwt(x, 'db2', axis=1)
|
||||
|
||||
x_ = pywt.idwt(cA, cD, 'db2', axis=-1)
|
||||
x = pywt.idwt(cA, cD, 'db2', axis=1)
|
||||
|
||||
assert_allclose(x_, x)
|
||||
|
||||
|
||||
def test_dwt_idwt_axis_excess():
|
||||
x = [[3, 7, 1, 1],
|
||||
[-2, 5, 4, 6]]
|
||||
# can't transform over axes that aren't there
|
||||
assert_raises(ValueError,
|
||||
pywt.dwt, x, 'db2', 'symmetric', axis=2)
|
||||
|
||||
assert_raises(ValueError,
|
||||
pywt.idwt, [1, 2, 4], [4, 1, 3], 'db2', 'symmetric', axis=1)
|
||||
|
||||
|
||||
def test_error_on_continuous_wavelet():
|
||||
# A ValueError is raised if a Continuous wavelet is selected
|
||||
data = np.ones((32, ))
|
||||
for cwave in ['morl', pywt.DiscreteContinuousWavelet('morl')]:
|
||||
assert_raises(ValueError, pywt.dwt, data, cwave)
|
||||
|
||||
cA, cD = pywt.dwt(data, 'db1')
|
||||
assert_raises(ValueError, pywt.idwt, cA, cD, cwave)
|
||||
|
||||
|
||||
def test_dwt_zero_size_axes():
|
||||
# raise on empty input array
|
||||
assert_raises(ValueError, pywt.dwt, [], 'db2')
|
||||
|
||||
# >1D case uses a different code path so check there as well
|
||||
x = np.ones((1, 4))[0:0, :] # 2D with a size zero axis
|
||||
assert_raises(ValueError, pywt.dwt, x, 'db2', axis=0)
|
||||
|
||||
|
||||
def test_pad_1d():
|
||||
x = [1, 2, 3]
|
||||
assert_array_equal(pywt.pad(x, (4, 6), 'periodization'),
|
||||
[1, 2, 3, 3, 1, 2, 3, 3, 1, 2, 3, 3, 1, 2])
|
||||
assert_array_equal(pywt.pad(x, (4, 6), 'periodic'),
|
||||
[3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3])
|
||||
assert_array_equal(pywt.pad(x, (4, 6), 'constant'),
|
||||
[1, 1, 1, 1, 1, 2, 3, 3, 3, 3, 3, 3, 3])
|
||||
assert_array_equal(pywt.pad(x, (4, 6), 'zero'),
|
||||
[0, 0, 0, 0, 1, 2, 3, 0, 0, 0, 0, 0, 0])
|
||||
assert_array_equal(pywt.pad(x, (4, 6), 'smooth'),
|
||||
[-3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
|
||||
assert_array_equal(pywt.pad(x, (4, 6), 'symmetric'),
|
||||
[3, 3, 2, 1, 1, 2, 3, 3, 2, 1, 1, 2, 3])
|
||||
assert_array_equal(pywt.pad(x, (4, 6), 'antisymmetric'),
|
||||
[3, -3, -2, -1, 1, 2, 3, -3, -2, -1, 1, 2, 3])
|
||||
assert_array_equal(pywt.pad(x, (4, 6), 'reflect'),
|
||||
[1, 2, 3, 2, 1, 2, 3, 2, 1, 2, 3, 2, 1])
|
||||
assert_array_equal(pywt.pad(x, (4, 6), 'antireflect'),
|
||||
[-3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
|
||||
|
||||
# equivalence of various pad_width formats
|
||||
assert_array_equal(pywt.pad(x, 4, 'periodic'),
|
||||
pywt.pad(x, (4, 4), 'periodic'))
|
||||
|
||||
assert_array_equal(pywt.pad(x, (4, ), 'periodic'),
|
||||
pywt.pad(x, (4, 4), 'periodic'))
|
||||
|
||||
assert_array_equal(pywt.pad(x, [(4, 4)], 'periodic'),
|
||||
pywt.pad(x, (4, 4), 'periodic'))
|
||||
|
||||
|
||||
def test_pad_errors():
|
||||
# negative pad width
|
||||
x = [1, 2, 3]
|
||||
assert_raises(ValueError, pywt.pad, x, -2, 'periodic')
|
||||
|
||||
# wrong length pad width
|
||||
assert_raises(ValueError, pywt.pad, x, (1, 1, 1), 'periodic')
|
||||
|
||||
# invalid mode name
|
||||
assert_raises(ValueError, pywt.pad, x, 2, 'bad_mode')
|
||||
|
||||
|
||||
def test_pad_nd():
|
||||
for ndim in [2, 3]:
|
||||
x = np.arange(4**ndim).reshape((4, ) * ndim)
|
||||
if ndim == 2:
|
||||
pad_widths = [(2, 1), (2, 3)]
|
||||
else:
|
||||
pad_widths = [(2, 1), ] * ndim
|
||||
for mode in pywt.Modes.modes:
|
||||
xp = pywt.pad(x, pad_widths, mode)
|
||||
|
||||
# expected result is the same as applying along axes separably
|
||||
xp_expected = x.copy()
|
||||
for ax in range(ndim):
|
||||
xp_expected = np.apply_along_axis(pywt.pad,
|
||||
ax,
|
||||
xp_expected,
|
||||
pad_widths=[pad_widths[ax]],
|
||||
mode=mode)
|
||||
assert_array_equal(xp, xp_expected)
|
||||
45
.CondaPkg/env/Lib/site-packages/pywt/tests/test_functions.py
vendored
Normal file
45
.CondaPkg/env/Lib/site-packages/pywt/tests/test_functions.py
vendored
Normal file
@@ -0,0 +1,45 @@
|
||||
#!/usr/bin/env python
|
||||
from __future__ import division, print_function, absolute_import
|
||||
|
||||
from numpy.testing import assert_almost_equal, assert_allclose
|
||||
|
||||
import pywt
|
||||
|
||||
|
||||
def test_centrfreq():
|
||||
# db1 is Haar function, frequency=1
|
||||
w = pywt.Wavelet('db1')
|
||||
expected = 1
|
||||
result = pywt.central_frequency(w, precision=12)
|
||||
assert_almost_equal(result, expected, decimal=3)
|
||||
# db2, frequency=2/3
|
||||
w = pywt.Wavelet('db2')
|
||||
expected = 2/3.
|
||||
result = pywt.central_frequency(w, precision=12)
|
||||
assert_almost_equal(result, expected)
|
||||
|
||||
|
||||
def test_scal2frq_scale():
|
||||
scale = 2
|
||||
w = pywt.Wavelet('db1')
|
||||
expected = 1. / scale
|
||||
result = pywt.scale2frequency(w, scale, precision=12)
|
||||
assert_almost_equal(result, expected, decimal=3)
|
||||
|
||||
def test_frq2scal_freq():
|
||||
freq = 2
|
||||
w = pywt.Wavelet('db1')
|
||||
expected = 1. / freq
|
||||
result = pywt.frequency2scale(w, freq, precision=12)
|
||||
assert_almost_equal(result, expected, decimal=3)
|
||||
|
||||
|
||||
def test_intwave_orthogonal():
|
||||
w = pywt.Wavelet('db1')
|
||||
int_psi, x = pywt.integrate_wavelet(w, precision=12)
|
||||
ix = x < 0.5
|
||||
# For x < 0.5, the integral is equal to x
|
||||
assert_allclose(int_psi[ix], x[ix])
|
||||
# For x > 0.5, the integral is equal to (1 - x)
|
||||
# Ignore last point here, there x > 1 and something goes wrong
|
||||
assert_allclose(int_psi[~ix][:-1], 1 - x[~ix][:-1], atol=1e-10)
|
||||
160
.CondaPkg/env/Lib/site-packages/pywt/tests/test_matlab_compatibility.py
vendored
Normal file
160
.CondaPkg/env/Lib/site-packages/pywt/tests/test_matlab_compatibility.py
vendored
Normal file
@@ -0,0 +1,160 @@
|
||||
"""
|
||||
Test used to verify PyWavelets Discrete Wavelet Transform computation
|
||||
accuracy against MathWorks Wavelet Toolbox.
|
||||
"""
|
||||
|
||||
from __future__ import division, print_function, absolute_import
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from numpy.testing import assert_
|
||||
|
||||
import pywt
|
||||
from pywt._pytest import (uses_pymatbridge, uses_precomputed, size_set)
|
||||
from pywt._pytest import matlab_result_dict_dwt as matlab_result_dict
|
||||
|
||||
# list of mode names in pywt and matlab
|
||||
modes = [('zero', 'zpd'),
|
||||
('constant', 'sp0'),
|
||||
('symmetric', 'sym'),
|
||||
('reflect', 'symw'),
|
||||
('periodic', 'ppd'),
|
||||
('smooth', 'sp1'),
|
||||
('periodization', 'per'),
|
||||
# TODO: Now have implemented asymmetric modes too.
|
||||
# Would be nice to update the Matlab data to test these as well.
|
||||
('antisymmetric', 'asym'),
|
||||
('antireflect', 'asymw'),
|
||||
]
|
||||
|
||||
families = ('db', 'sym', 'coif', 'bior', 'rbio')
|
||||
wavelets = sum([pywt.wavelist(name) for name in families], [])
|
||||
|
||||
|
||||
def _get_data_sizes(w):
|
||||
""" Return the sizes to test for wavelet w. """
|
||||
if size_set == 'full':
|
||||
data_sizes = list(range(w.dec_len, 40)) + \
|
||||
[100, 200, 500, 1000, 50000]
|
||||
else:
|
||||
data_sizes = (w.dec_len, w.dec_len + 1)
|
||||
return data_sizes
|
||||
|
||||
|
||||
@uses_pymatbridge
|
||||
@pytest.mark.slow
|
||||
def test_accuracy_pymatbridge():
|
||||
Matlab = pytest.importorskip("pymatbridge.Matlab")
|
||||
mlab = Matlab()
|
||||
|
||||
rstate = np.random.RandomState(1234)
|
||||
# max RMSE (was 1.0e-10, is reduced to 5.0e-5 due to different coefficients)
|
||||
epsilon = 5.0e-5
|
||||
epsilon_pywt_coeffs = 1.0e-10
|
||||
mlab.start()
|
||||
try:
|
||||
for wavelet in wavelets:
|
||||
w = pywt.Wavelet(wavelet)
|
||||
mlab.set_variable('wavelet', wavelet)
|
||||
for N in _get_data_sizes(w):
|
||||
data = rstate.randn(N)
|
||||
mlab.set_variable('data', data)
|
||||
for pmode, mmode in modes:
|
||||
ma, md = _compute_matlab_result(data, wavelet, mmode, mlab)
|
||||
_check_accuracy(data, w, pmode, ma, md, wavelet, epsilon)
|
||||
ma, md = _load_matlab_result_pywt_coeffs(data, wavelet, mmode)
|
||||
_check_accuracy(data, w, pmode, ma, md, wavelet, epsilon_pywt_coeffs)
|
||||
|
||||
finally:
|
||||
mlab.stop()
|
||||
|
||||
|
||||
@uses_precomputed
|
||||
@pytest.mark.slow
|
||||
def test_accuracy_precomputed():
|
||||
# Keep this specific random seed to match the precomputed Matlab result.
|
||||
rstate = np.random.RandomState(1234)
|
||||
# max RMSE (was 1.0e-10, is reduced to 5.0e-5 due to different coefficients)
|
||||
epsilon = 5.0e-5
|
||||
epsilon_pywt_coeffs = 1.0e-10
|
||||
for wavelet in wavelets:
|
||||
w = pywt.Wavelet(wavelet)
|
||||
for N in _get_data_sizes(w):
|
||||
data = rstate.randn(N)
|
||||
for pmode, mmode in modes:
|
||||
ma, md = _load_matlab_result(data, wavelet, mmode)
|
||||
_check_accuracy(data, w, pmode, ma, md, wavelet, epsilon)
|
||||
ma, md = _load_matlab_result_pywt_coeffs(data, wavelet, mmode)
|
||||
_check_accuracy(data, w, pmode, ma, md, wavelet, epsilon_pywt_coeffs)
|
||||
|
||||
|
||||
def _compute_matlab_result(data, wavelet, mmode, mlab):
|
||||
""" Compute the result using MATLAB.
|
||||
|
||||
This function assumes that the Matlab variables `wavelet` and `data` have
|
||||
already been set externally.
|
||||
"""
|
||||
if np.any((wavelet == np.array(['coif6', 'coif7', 'coif8', 'coif9', 'coif10', 'coif11', 'coif12', 'coif13', 'coif14', 'coif15', 'coif16', 'coif17'])),axis=0):
|
||||
w = pywt.Wavelet(wavelet)
|
||||
mlab.set_variable('Lo_D', w.dec_lo)
|
||||
mlab.set_variable('Hi_D', w.dec_hi)
|
||||
mlab_code = ("[ma, md] = dwt(data, Lo_D, Hi_D, 'mode', '%s');" % mmode)
|
||||
else:
|
||||
mlab_code = "[ma, md] = dwt(data, wavelet, 'mode', '%s');" % mmode
|
||||
res = mlab.run_code(mlab_code)
|
||||
if not res['success']:
|
||||
raise RuntimeError("Matlab failed to execute the provided code. "
|
||||
"Check that the wavelet toolbox is installed.")
|
||||
# need np.asarray because sometimes the output is a single float64
|
||||
ma = np.asarray(mlab.get_variable('ma'))
|
||||
md = np.asarray(mlab.get_variable('md'))
|
||||
return ma, md
|
||||
|
||||
|
||||
def _load_matlab_result(data, wavelet, mmode):
|
||||
""" Load the precomputed result.
|
||||
"""
|
||||
N = len(data)
|
||||
ma_key = '_'.join([mmode, wavelet, str(N), 'ma'])
|
||||
md_key = '_'.join([mmode, wavelet, str(N), 'md'])
|
||||
if (ma_key not in matlab_result_dict) or \
|
||||
(md_key not in matlab_result_dict):
|
||||
raise KeyError(
|
||||
"Precompted Matlab result not found for wavelet: "
|
||||
"{0}, mode: {1}, size: {2}".format(wavelet, mmode, N))
|
||||
ma = matlab_result_dict[ma_key]
|
||||
md = matlab_result_dict[md_key]
|
||||
return ma, md
|
||||
|
||||
|
||||
def _load_matlab_result_pywt_coeffs(data, wavelet, mmode):
|
||||
""" Load the precomputed result.
|
||||
"""
|
||||
N = len(data)
|
||||
ma_key = '_'.join([mmode, wavelet, str(N), 'ma_pywtCoeffs'])
|
||||
md_key = '_'.join([mmode, wavelet, str(N), 'md_pywtCoeffs'])
|
||||
if (ma_key not in matlab_result_dict) or \
|
||||
(md_key not in matlab_result_dict):
|
||||
raise KeyError(
|
||||
"Precompted Matlab result not found for wavelet: "
|
||||
"{0}, mode: {1}, size: {2}".format(wavelet, mmode, N))
|
||||
ma = matlab_result_dict[ma_key]
|
||||
md = matlab_result_dict[md_key]
|
||||
return ma, md
|
||||
|
||||
|
||||
def _check_accuracy(data, w, pmode, ma, md, wavelet, epsilon):
|
||||
# PyWavelets result
|
||||
pa, pd = pywt.dwt(data, w, pmode)
|
||||
|
||||
# calculate error measures
|
||||
rms_a = np.sqrt(np.mean((pa - ma) ** 2))
|
||||
rms_d = np.sqrt(np.mean((pd - md) ** 2))
|
||||
|
||||
msg = ('[RMS_A > EPSILON] for Mode: %s, Wavelet: %s, '
|
||||
'Length: %d, rms=%.3g' % (pmode, wavelet, len(data), rms_a))
|
||||
assert_(rms_a < epsilon, msg=msg)
|
||||
|
||||
msg = ('[RMS_D > EPSILON] for Mode: %s, Wavelet: %s, '
|
||||
'Length: %d, rms=%.3g' % (pmode, wavelet, len(data), rms_d))
|
||||
assert_(rms_d < epsilon, msg=msg)
|
||||
174
.CondaPkg/env/Lib/site-packages/pywt/tests/test_matlab_compatibility_cwt.py
vendored
Normal file
174
.CondaPkg/env/Lib/site-packages/pywt/tests/test_matlab_compatibility_cwt.py
vendored
Normal file
@@ -0,0 +1,174 @@
|
||||
"""
|
||||
Test used to verify PyWavelets Continuous Wavelet Transform computation
|
||||
accuracy against MathWorks Wavelet Toolbox.
|
||||
"""
|
||||
|
||||
from __future__ import division, print_function, absolute_import
|
||||
|
||||
import warnings
|
||||
import numpy as np
|
||||
import pytest
|
||||
from numpy.testing import assert_
|
||||
|
||||
import pywt
|
||||
from pywt._pytest import (uses_pymatbridge, uses_precomputed, size_set,
|
||||
matlab_result_dict_cwt)
|
||||
|
||||
families = ('gaus', 'mexh', 'morl', 'cgau', 'shan', 'fbsp', 'cmor')
|
||||
wavelets = sum([pywt.wavelist(name) for name in families], [])
|
||||
|
||||
|
||||
def _get_data_sizes(w):
|
||||
""" Return the sizes to test for wavelet w. """
|
||||
if size_set == 'full':
|
||||
data_sizes = list(range(100, 101)) + \
|
||||
[100, 200, 500, 1000, 50000]
|
||||
else:
|
||||
data_sizes = (1000, 1000 + 1)
|
||||
return data_sizes
|
||||
|
||||
|
||||
def _get_scales(w):
|
||||
""" Return the scales to test for wavelet w. """
|
||||
if size_set == 'full':
|
||||
scales = (1, np.arange(1, 3), np.arange(1, 4), np.arange(1, 5))
|
||||
else:
|
||||
scales = (1, np.arange(1, 3))
|
||||
return scales
|
||||
|
||||
|
||||
@uses_pymatbridge # skip this case if precomputed results are used instead
|
||||
@pytest.mark.slow
|
||||
def test_accuracy_pymatbridge_cwt():
|
||||
Matlab = pytest.importorskip("pymatbridge.Matlab")
|
||||
mlab = Matlab()
|
||||
rstate = np.random.RandomState(1234)
|
||||
# max RMSE (was 1.0e-10, is reduced to 5.0e-5 due to different coefficients)
|
||||
epsilon = 1e-15
|
||||
epsilon_psi = 1e-15
|
||||
mlab.start()
|
||||
try:
|
||||
for wavelet in wavelets:
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore', FutureWarning)
|
||||
w = pywt.ContinuousWavelet(wavelet)
|
||||
if np.any((wavelet == np.array(['shan', 'cmor'])),axis=0):
|
||||
mlab.set_variable('wavelet', wavelet+str(w.bandwidth_frequency)+'-'+str(w.center_frequency))
|
||||
elif wavelet == 'fbsp':
|
||||
mlab.set_variable('wavelet', wavelet+str(w.fbsp_order)+'-'+str(w.bandwidth_frequency)+'-'+str(w.center_frequency))
|
||||
else:
|
||||
mlab.set_variable('wavelet', wavelet)
|
||||
mlab_code = ("psi = wavefun(wavelet,10)")
|
||||
res = mlab.run_code(mlab_code)
|
||||
psi = np.asarray(mlab.get_variable('psi'))
|
||||
_check_accuracy_psi(w, psi, wavelet, epsilon_psi)
|
||||
for N in _get_data_sizes(w):
|
||||
data = rstate.randn(N)
|
||||
mlab.set_variable('data', data)
|
||||
for scales in _get_scales(w):
|
||||
coefs = _compute_matlab_result(data, wavelet, scales, mlab)
|
||||
_check_accuracy(data, w, scales, coefs, wavelet, epsilon)
|
||||
|
||||
finally:
|
||||
mlab.stop()
|
||||
|
||||
|
||||
@uses_precomputed # skip this case if pymatbridge + Matlab are being used
|
||||
@pytest.mark.slow
|
||||
def test_accuracy_precomputed_cwt():
|
||||
# Keep this specific random seed to match the precomputed Matlab result.
|
||||
rstate = np.random.RandomState(1234)
|
||||
# has to be improved
|
||||
epsilon = 2e-15
|
||||
epsilon32 = 1e-5
|
||||
epsilon_psi = 1e-15
|
||||
for wavelet in wavelets:
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore', FutureWarning)
|
||||
w = pywt.ContinuousWavelet(wavelet)
|
||||
w32 = pywt.ContinuousWavelet(wavelet,dtype=np.float32)
|
||||
psi = _load_matlab_result_psi(wavelet)
|
||||
_check_accuracy_psi(w, psi, wavelet, epsilon_psi)
|
||||
|
||||
for N in _get_data_sizes(w):
|
||||
data = rstate.randn(N)
|
||||
data32 = data.astype(np.float32)
|
||||
scales_count = 0
|
||||
for scales in _get_scales(w):
|
||||
scales_count += 1
|
||||
coefs = _load_matlab_result(data, wavelet, scales_count)
|
||||
_check_accuracy(data, w, scales, coefs, wavelet, epsilon)
|
||||
_check_accuracy(data32, w32, scales, coefs, wavelet, epsilon32)
|
||||
|
||||
|
||||
def _compute_matlab_result(data, wavelet, scales, mlab):
|
||||
""" Compute the result using MATLAB.
|
||||
|
||||
This function assumes that the Matlab variables `wavelet` and `data` have
|
||||
already been set externally.
|
||||
"""
|
||||
mlab.set_variable('scales', scales)
|
||||
mlab_code = ("coefs = cwt(data, scales, wavelet)")
|
||||
res = mlab.run_code(mlab_code)
|
||||
if not res['success']:
|
||||
raise RuntimeError("Matlab failed to execute the provided code. "
|
||||
"Check that the wavelet toolbox is installed.")
|
||||
# need np.asarray because sometimes the output is a single float64
|
||||
coefs = np.asarray(mlab.get_variable('coefs'))
|
||||
return coefs
|
||||
|
||||
|
||||
def _load_matlab_result(data, wavelet, scales):
|
||||
""" Load the precomputed result.
|
||||
"""
|
||||
N = len(data)
|
||||
coefs_key = '_'.join([str(scales), wavelet, str(N), 'coefs'])
|
||||
if (coefs_key not in matlab_result_dict_cwt):
|
||||
raise KeyError(
|
||||
"Precompted Matlab result not found for wavelet: "
|
||||
"{0}, mode: {1}, size: {2}".format(wavelet, scales, N))
|
||||
coefs = matlab_result_dict_cwt[coefs_key]
|
||||
return coefs
|
||||
|
||||
|
||||
def _load_matlab_result_psi(wavelet):
|
||||
""" Load the precomputed result.
|
||||
"""
|
||||
psi_key = '_'.join([wavelet, 'psi'])
|
||||
if (psi_key not in matlab_result_dict_cwt):
|
||||
raise KeyError(
|
||||
"Precompted Matlab psi result not found for wavelet: "
|
||||
"{0}}".format(wavelet))
|
||||
psi = matlab_result_dict_cwt[psi_key]
|
||||
return psi
|
||||
|
||||
|
||||
def _check_accuracy(data, w, scales, coefs, wavelet, epsilon):
|
||||
# PyWavelets result
|
||||
coefs_pywt, freq = pywt.cwt(data, scales, w)
|
||||
|
||||
# coefs from Matlab are from R2012a which is missing the complex conjugate
|
||||
# as shown in Eq. 2 of Torrence and Compo. We take the complex conjugate of
|
||||
# the precomputed Matlab result to account for this.
|
||||
coefs = np.conj(coefs)
|
||||
|
||||
# calculate error measures
|
||||
err = coefs_pywt - coefs
|
||||
rms = np.real(np.sqrt(np.mean(np.conj(err) * err)))
|
||||
|
||||
msg = ('[RMS > EPSILON] for Scale: %s, Wavelet: %s, '
|
||||
'Length: %d, rms=%.3g' % (scales, wavelet, len(data), rms))
|
||||
assert_(rms < epsilon, msg=msg)
|
||||
|
||||
|
||||
def _check_accuracy_psi(w, psi, wavelet, epsilon):
|
||||
# PyWavelets result
|
||||
psi_pywt, x = w.wavefun(length=1024)
|
||||
|
||||
# calculate error measures
|
||||
err = psi_pywt.flatten() - psi.flatten()
|
||||
rms = np.real(np.sqrt(np.mean(np.conj(err) * err)))
|
||||
|
||||
msg = ('[RMS > EPSILON] for Wavelet: %s, '
|
||||
'rms=%.3g' % (wavelet, rms))
|
||||
assert_(rms < epsilon, msg=msg)
|
||||
109
.CondaPkg/env/Lib/site-packages/pywt/tests/test_modes.py
vendored
Normal file
109
.CondaPkg/env/Lib/site-packages/pywt/tests/test_modes.py
vendored
Normal file
@@ -0,0 +1,109 @@
|
||||
#!/usr/bin/env python
|
||||
from __future__ import division, print_function, absolute_import
|
||||
|
||||
import numpy as np
|
||||
from numpy.testing import assert_raises, assert_equal, assert_allclose
|
||||
|
||||
import pywt
|
||||
|
||||
|
||||
def test_available_modes():
|
||||
modes = ['zero', 'constant', 'symmetric', 'periodic', 'smooth',
|
||||
'periodization', 'reflect', 'antisymmetric', 'antireflect']
|
||||
assert_equal(pywt.Modes.modes, modes)
|
||||
assert_equal(pywt.Modes.from_object('constant'), 2)
|
||||
|
||||
|
||||
def test_invalid_modes():
|
||||
x = np.arange(4)
|
||||
assert_raises(ValueError, pywt.dwt, x, 'db2', 'unknown')
|
||||
assert_raises(ValueError, pywt.dwt, x, 'db2', -1)
|
||||
assert_raises(ValueError, pywt.dwt, x, 'db2', 9)
|
||||
assert_raises(TypeError, pywt.dwt, x, 'db2', None)
|
||||
|
||||
assert_raises(ValueError, pywt.Modes.from_object, 'unknown')
|
||||
assert_raises(ValueError, pywt.Modes.from_object, -1)
|
||||
assert_raises(ValueError, pywt.Modes.from_object, 9)
|
||||
assert_raises(TypeError, pywt.Modes.from_object, None)
|
||||
|
||||
|
||||
def test_dwt_idwt_allmodes():
|
||||
# Test that :func:`dwt` and :func:`idwt` can be performed using every mode
|
||||
x = [1, 2, 1, 5, -1, 8, 4, 6]
|
||||
dwt_results = {
|
||||
'zero': ([-0.03467518, 1.73309178, 3.40612438, 6.32928585, 6.95094948],
|
||||
[-0.12940952, -2.15599552, -5.95034847, -1.21545369,
|
||||
-1.8625013]),
|
||||
'constant': ([1.28480404, 1.73309178, 3.40612438, 6.32928585,
|
||||
7.51935555],
|
||||
[-0.48296291, -2.15599552, -5.95034847, -1.21545369,
|
||||
0.25881905]),
|
||||
'symmetric': ([1.76776695, 1.73309178, 3.40612438, 6.32928585,
|
||||
7.77817459],
|
||||
[-0.61237244, -2.15599552, -5.95034847, -1.21545369,
|
||||
1.22474487]),
|
||||
'reflect': ([2.12132034, 1.73309178, 3.40612438, 6.32928585,
|
||||
6.81224877],
|
||||
[-0.70710678, -2.15599552, -5.95034847, -1.21545369,
|
||||
-2.38013939]),
|
||||
'periodic': ([6.9162743, 1.73309178, 3.40612438, 6.32928585,
|
||||
6.9162743],
|
||||
[-1.99191082, -2.15599552, -5.95034847, -1.21545369,
|
||||
-1.99191082]),
|
||||
'smooth': ([-0.51763809, 1.73309178, 3.40612438, 6.32928585,
|
||||
7.45000519],
|
||||
[0, -2.15599552, -5.95034847, -1.21545369, 0]),
|
||||
'periodization': ([4.053172, 3.05257099, 2.85381112, 8.42522221],
|
||||
[0.18946869, 4.18258152, 4.33737503, 2.60428326]),
|
||||
'antisymmetric': ([-1.83711731, 1.73309178, 3.40612438, 6.32928585,
|
||||
6.12372436],
|
||||
[0.353553391, -2.15599552, -5.95034847, -1.21545369,
|
||||
-4.94974747]),
|
||||
'antireflect': ([0.44828774, 1.73309178, 3.40612438, 6.32928585,
|
||||
8.22646233],
|
||||
[-0.25881905, -2.15599552, -5.95034847, -1.21545369,
|
||||
2.89777748])
|
||||
}
|
||||
|
||||
for mode in pywt.Modes.modes:
|
||||
cA, cD = pywt.dwt(x, 'db2', mode)
|
||||
assert_allclose(cA, dwt_results[mode][0], rtol=1e-7, atol=1e-8)
|
||||
assert_allclose(cD, dwt_results[mode][1], rtol=1e-7, atol=1e-8)
|
||||
assert_allclose(pywt.idwt(cA, cD, 'db2', mode), x, rtol=1e-10)
|
||||
|
||||
|
||||
def test_dwt_short_input_allmodes():
|
||||
# some test cases where the input is shorter than the DWT filter
|
||||
x = [1, 3, 2]
|
||||
wavelet = 'db2'
|
||||
# manually pad each end by the filter size (4 for 'db2' used here)
|
||||
padded_x = {'zero': [0, 0, 0, 0, 1, 3, 2, 0, 0, 0, 0],
|
||||
'constant': [1, 1, 1, 1, 1, 3, 2, 2, 2, 2, 2],
|
||||
'symmetric': [2, 2, 3, 1, 1, 3, 2, 2, 3, 1, 1],
|
||||
'reflect': [1, 3, 2, 3, 1, 3, 2, 3, 1, 3, 2],
|
||||
'periodic': [2, 1, 3, 2, 1, 3, 2, 1, 3, 2, 1],
|
||||
'smooth': [-7, -5, -3, -1, 1, 3, 2, 1, 0, -1, -2],
|
||||
'antisymmetric': [2, -2, -3, -1, 1, 3, 2, -2, -3, -1, 1],
|
||||
'antireflect': [1, -1, 0, -1, 1, 3, 2, 1, 3, 5, 4],
|
||||
}
|
||||
for mode, xpad in padded_x.items():
|
||||
# DWT of the manually padded array. will discard edges later so
|
||||
# symmetric mode used here doesn't matter.
|
||||
cApad, cDpad = pywt.dwt(xpad, wavelet, mode='symmetric')
|
||||
|
||||
# central region of the padded output (unaffected by mode )
|
||||
expected_result = (cApad[2:-2], cDpad[2:-2])
|
||||
|
||||
cA, cD = pywt.dwt(x, wavelet, mode)
|
||||
assert_allclose(cA, expected_result[0], rtol=1e-7, atol=1e-8)
|
||||
assert_allclose(cD, expected_result[1], rtol=1e-7, atol=1e-8)
|
||||
|
||||
|
||||
def test_default_mode():
|
||||
# The default mode should be 'symmetric'
|
||||
x = [1, 2, 1, 5, -1, 8, 4, 6]
|
||||
cA, cD = pywt.dwt(x, 'db2')
|
||||
cA2, cD2 = pywt.dwt(x, 'db2', mode='symmetric')
|
||||
assert_allclose(cA, cA2)
|
||||
assert_allclose(cD, cD2)
|
||||
assert_allclose(pywt.idwt(cA, cD, 'db2'), x)
|
||||
255
.CondaPkg/env/Lib/site-packages/pywt/tests/test_mra.py
vendored
Normal file
255
.CondaPkg/env/Lib/site-packages/pywt/tests/test_mra.py
vendored
Normal file
@@ -0,0 +1,255 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from numpy.testing import assert_allclose
|
||||
|
||||
import pywt
|
||||
from pywt import data
|
||||
|
||||
# tolerances used in accuracy comparisons
|
||||
tol_single = 1e-6
|
||||
tol_double = 1e-13
|
||||
atol = 1e-7
|
||||
|
||||
|
||||
####
|
||||
# 1d mra tests
|
||||
####
|
||||
|
||||
@pytest.mark.parametrize('wavelet', ['db2', 'sym4', 'coif5'])
|
||||
@pytest.mark.parametrize('transform', ['dwt', 'swt'])
|
||||
@pytest.mark.parametrize('mode', pywt.Modes.modes)
|
||||
@pytest.mark.parametrize(
|
||||
'dtype', ['float32', 'float64', 'complex64', 'complex128']
|
||||
)
|
||||
def test_mra_roundtrip(wavelet, transform, mode, dtype):
|
||||
x = data.ecg()[:64].astype(dtype)
|
||||
if x.dtype.kind == 'c':
|
||||
# fill some data for the imaginary channel
|
||||
x.imag = x[::-1].real
|
||||
|
||||
if transform == 'swt':
|
||||
# swt mode only supports periodization
|
||||
if mode != 'periodization':
|
||||
with pytest.raises(ValueError):
|
||||
pywt.mra(x, wavelet, transform=transform, mode=mode)
|
||||
return
|
||||
|
||||
coeffs = pywt.mra(x, wavelet, transform=transform, mode=mode)
|
||||
assert isinstance(coeffs, list)
|
||||
assert isinstance(coeffs[0], np.ndarray)
|
||||
# assert all(isinstance(coeffs[i], dict) for i in range(1, len(coeffs)))
|
||||
|
||||
y = pywt.imra(coeffs)
|
||||
rtol = tol_single if x.real.dtype.kind == 'f' else tol_double
|
||||
assert_allclose(x, y, rtol=rtol, atol=rtol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('wavelet', ['rbio1.3', 'bior2.4'])
|
||||
@pytest.mark.parametrize('transform', ['dwt', 'swt'])
|
||||
def test_mra_warns_on_non_orthogonal(wavelet, transform):
|
||||
dtype = np.float64
|
||||
x = data.ecg()[:64].astype(dtype)
|
||||
|
||||
assert not pywt.Wavelet(wavelet).orthogonal
|
||||
|
||||
if transform == 'swt':
|
||||
# bi-orthogonal wavelets raise a warning for SWT case
|
||||
msg = 'norm=True, but the wavelet is not orthogonal'
|
||||
with pytest.warns(UserWarning, match=msg):
|
||||
coeffs = pywt.mra(x, wavelet, transform=transform)
|
||||
else:
|
||||
coeffs = pywt.mra(x, wavelet, transform=transform)
|
||||
|
||||
y = pywt.imra(coeffs)
|
||||
rtol = tol_single if x.real.dtype.kind == 'f' else tol_double
|
||||
assert_allclose(x, y, rtol=rtol, atol=rtol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('axis', [0, -1, 1, 2, -3])
|
||||
@pytest.mark.parametrize('ndim', [1, 2, 3])
|
||||
@pytest.mark.parametrize('transform', ['dwt', 'swt'])
|
||||
@pytest.mark.parametrize('dtype', [np.float64, np.complex128])
|
||||
def test_mra_axis(transform, ndim, axis, dtype):
|
||||
# Test transforms over a specific axis of 1D, 2D or 3D data
|
||||
if ndim == 1:
|
||||
x = data.ecg()[:64]
|
||||
elif ndim == 2:
|
||||
x = data.camera()[:64, :32]
|
||||
elif ndim == 3:
|
||||
x = data.camera()[:48, :8]
|
||||
x = np.stack((x,) * 8, axis=-1)
|
||||
x = x.astype(dtype, copy=False)
|
||||
|
||||
# out of range axis
|
||||
if axis < -x.ndim or axis >= x.ndim:
|
||||
with pytest.raises(np.AxisError):
|
||||
pywt.mra(x, 'db1', transform=transform, axis=axis)
|
||||
return
|
||||
|
||||
coeffs = pywt.mra(x, 'db1', transform=transform, axis=axis)
|
||||
y = pywt.imra(coeffs)
|
||||
rtol = tol_single if x.real.dtype.kind == 'f' else tol_double
|
||||
assert_allclose(x, y, rtol=rtol, atol=rtol)
|
||||
|
||||
|
||||
####
|
||||
# 2d mra tests
|
||||
####
|
||||
|
||||
@pytest.mark.parametrize('wavelet', ['db2', 'sym4', 'coif5'])
|
||||
@pytest.mark.parametrize('transform', ['dwt2', 'swt2'])
|
||||
@pytest.mark.parametrize('mode', pywt.Modes.modes)
|
||||
@pytest.mark.parametrize(
|
||||
'dtype', ['float32', 'float64', 'complex64', 'complex128']
|
||||
)
|
||||
def test_mra2_roundtrip(wavelet, transform, mode, dtype):
|
||||
x = data.camera()[:32, :16].astype(dtype, copy=False)
|
||||
if x.dtype.kind == 'c':
|
||||
# fill some data for the imaginary channel
|
||||
x.imag = x[::-1, :].real
|
||||
|
||||
if transform == 'swt2':
|
||||
# swt mode only supports periodization
|
||||
if mode != 'periodization':
|
||||
with pytest.raises(ValueError):
|
||||
pywt.mra2(x, wavelet, transform=transform, mode=mode)
|
||||
return
|
||||
|
||||
coeffs = pywt.mra2(x, wavelet, transform=transform, mode=mode)
|
||||
assert isinstance(coeffs, list)
|
||||
assert isinstance(coeffs[0], np.ndarray)
|
||||
# assert all(isinstance(coeffs[i], dict) for i in range(1, len(coeffs)))
|
||||
|
||||
y = pywt.imra2(coeffs)
|
||||
rtol = tol_single if x.real.dtype.kind == 'f' else tol_double
|
||||
assert_allclose(x, y, rtol=rtol, atol=rtol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('wavelet', ['rbio1.3', 'bior2.4'])
|
||||
@pytest.mark.parametrize('transform', ['dwt2', 'swt2'])
|
||||
def test_mra2_warns_on_non_orthogonal(wavelet, transform):
|
||||
dtype = np.float64
|
||||
x = data.camera()[:32, :8].astype(dtype, copy=False)
|
||||
|
||||
assert not pywt.Wavelet(wavelet).orthogonal
|
||||
|
||||
if transform == 'swt2':
|
||||
# bi-orthogonal wavelets raise a warning for SWT case
|
||||
msg = 'norm=True, but the wavelets used are not orthogonal'
|
||||
with pytest.warns(UserWarning, match=msg):
|
||||
coeffs = pywt.mra2(x, wavelet, transform=transform)
|
||||
else:
|
||||
coeffs = pywt.mra2(x, wavelet, transform=transform)
|
||||
|
||||
y = pywt.imra2(coeffs)
|
||||
rtol = tol_single if x.real.dtype.kind == 'f' else tol_double
|
||||
assert_allclose(x, y, rtol=rtol, atol=rtol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('transform', ['dwt2', 'swt2'])
|
||||
@pytest.mark.parametrize('ndim', [2, 3])
|
||||
@pytest.mark.parametrize('axes', [(0, 1), (-2, -1), (0, 2), (-3, 1), (0, 4)])
|
||||
@pytest.mark.parametrize('dtype', [np.float64, np.complex128])
|
||||
def test_mra2_axes(transform, axes, ndim, dtype):
|
||||
# Test transforms over various axes of 2D or 3D data.
|
||||
x = data.camera()[:32, :16].astype(dtype, copy=False)
|
||||
if ndim == 3:
|
||||
x = np.stack((x,) * 8, axis=-1)
|
||||
|
||||
# out of range axis
|
||||
if any([axis < -x.ndim or axis >= x.ndim for axis in axes]):
|
||||
with pytest.raises(np.AxisError):
|
||||
pywt.mra2(x, 'db1', transform=transform, axes=axes)
|
||||
return
|
||||
|
||||
coeffs = pywt.mra2(x, 'db1', transform=transform, axes=axes)
|
||||
y = pywt.imra2(coeffs)
|
||||
rtol = tol_single if x.real.dtype.kind == 'f' else tol_double
|
||||
assert_allclose(x, y, rtol=rtol, atol=rtol)
|
||||
|
||||
|
||||
####
|
||||
# nd mra tests
|
||||
####
|
||||
|
||||
@pytest.mark.parametrize('wavelet', ['sym2', ])
|
||||
@pytest.mark.parametrize('transform', ['dwtn', 'swtn'])
|
||||
@pytest.mark.parametrize('mode', pywt.Modes.modes)
|
||||
@pytest.mark.parametrize(
|
||||
'dtype', ['float32', 'float64', 'complex64', 'complex128']
|
||||
)
|
||||
@pytest.mark.parametrize('ndim', [1, 2, 3])
|
||||
def test_mran_roundtrip(wavelet, transform, mode, dtype, ndim):
|
||||
if ndim == 1:
|
||||
x = data.ecg()[:48].astype(dtype, copy=False)
|
||||
elif ndim == 2:
|
||||
x = data.camera()[:16, :8].astype(dtype, copy=False)
|
||||
elif ndim == 3:
|
||||
x = data.camera()[:16, :8].astype(dtype, copy=False)
|
||||
x = np.stack((x,) * 8, axis=-1)
|
||||
|
||||
if x.dtype.kind == 'c':
|
||||
# fill some data for the imaginary channel
|
||||
x.imag = x[::-1, ...].real
|
||||
|
||||
if transform == 'swtn':
|
||||
# swt mode only supports periodization
|
||||
if mode != 'periodization':
|
||||
with pytest.raises(ValueError):
|
||||
pywt.mran(x, wavelet, transform=transform, mode=mode)
|
||||
return
|
||||
|
||||
coeffs = pywt.mran(x, wavelet, transform=transform, mode=mode)
|
||||
assert isinstance(coeffs, list)
|
||||
assert isinstance(coeffs[0], np.ndarray)
|
||||
# assert all(isinstance(coeffs[i], dict) for i in range(1, len(coeffs)))
|
||||
|
||||
y = pywt.imran(coeffs)
|
||||
rtol = tol_single if x.real.dtype.kind == 'f' else tol_double
|
||||
assert_allclose(x, y, rtol=rtol, atol=rtol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('wavelet', ['rbio1.3', 'bior2.4'])
|
||||
@pytest.mark.parametrize('transform', ['dwtn', 'swtn'])
|
||||
def test_mran_warns_on_non_orthogonal(wavelet, transform):
|
||||
dtype = np.float64
|
||||
x = data.camera()[:32, :8].astype(dtype, copy=False)
|
||||
|
||||
assert not pywt.Wavelet(wavelet).orthogonal
|
||||
|
||||
if transform == 'swtn':
|
||||
# bi-orthogonal wavelets raise a warning for SWT case
|
||||
msg = 'norm=True, but the wavelets used are not orthogonal'
|
||||
with pytest.warns(UserWarning, match=msg):
|
||||
coeffs = pywt.mran(x, wavelet, transform=transform)
|
||||
else:
|
||||
coeffs = pywt.mran(x, wavelet, transform=transform)
|
||||
|
||||
y = pywt.imran(coeffs)
|
||||
rtol = tol_single if x.real.dtype.kind == 'f' else tol_double
|
||||
assert_allclose(x, y, rtol=rtol, atol=rtol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'axes', [(0, 1), (-2, -1), (0, 2), (-3, 1), (0, 4), (-3, -2, -1),
|
||||
(0, 2, 1), (0, 5, 1), (0,), (1,), (2,), (-2,), (-3,), (-4,)])
|
||||
@pytest.mark.parametrize('transform', ['dwtn', 'swtn'])
|
||||
def test_mran_axes(axes, transform):
|
||||
# Test with transforms over 1, 2 or 3 axes of 3d data.
|
||||
# Cases with out of range axes are also tested
|
||||
dtype = np.float64
|
||||
x = data.camera()[:32, :16].astype(dtype, copy=False)
|
||||
x3d = np.stack((x,) * 8, axis=-1)
|
||||
|
||||
# out of range axis
|
||||
if any([axis < -x.ndim or axis >= x.ndim for axis in axes]):
|
||||
with pytest.raises(np.AxisError):
|
||||
pywt.mran(x, 'db1', transform='dwtn', axes=axes)
|
||||
return
|
||||
|
||||
coeffs = pywt.mran(x3d, 'db1', transform='dwtn', axes=axes)
|
||||
y = pywt.imran(coeffs)
|
||||
rtol = tol_single if x3d.real.dtype.kind == 'f' else tol_double
|
||||
assert_allclose(x3d, y, rtol=rtol, atol=rtol)
|
||||
443
.CondaPkg/env/Lib/site-packages/pywt/tests/test_multidim.py
vendored
Normal file
443
.CondaPkg/env/Lib/site-packages/pywt/tests/test_multidim.py
vendored
Normal file
@@ -0,0 +1,443 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
from __future__ import division, print_function, absolute_import
|
||||
|
||||
import numpy as np
|
||||
from itertools import combinations
|
||||
from numpy.testing import assert_allclose, assert_, assert_raises, assert_equal
|
||||
|
||||
import pywt
|
||||
# Check that float32, float64, complex64, complex128 are preserved.
|
||||
# Other real types get converted to float64.
|
||||
# complex256 gets converted to complex128
|
||||
dtypes_in = [np.int8, np.float16, np.float32, np.float64, np.complex64,
|
||||
np.complex128]
|
||||
dtypes_out = [np.float64, np.float32, np.float32, np.float64, np.complex64,
|
||||
np.complex128]
|
||||
|
||||
# test complex256 as well if it is available
|
||||
try:
|
||||
dtypes_in += [np.complex256, ]
|
||||
dtypes_out += [np.complex128, ]
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
|
||||
def test_dwtn_input():
|
||||
# Array-like must be accepted
|
||||
pywt.dwtn([1, 2, 3, 4], 'haar')
|
||||
# Others must not
|
||||
data = dict()
|
||||
assert_raises(TypeError, pywt.dwtn, data, 'haar')
|
||||
# Must be at least 1D
|
||||
assert_raises(ValueError, pywt.dwtn, 2, 'haar')
|
||||
|
||||
|
||||
def test_3D_reconstruct():
|
||||
data = np.array([
|
||||
[[0, 4, 1, 5, 1, 4],
|
||||
[0, 5, 26, 3, 2, 1],
|
||||
[5, 8, 2, 33, 4, 9],
|
||||
[2, 5, 19, 4, 19, 1]],
|
||||
[[1, 5, 1, 2, 3, 4],
|
||||
[7, 12, 6, 52, 7, 8],
|
||||
[2, 12, 3, 52, 6, 8],
|
||||
[5, 2, 6, 78, 12, 2]]])
|
||||
|
||||
wavelet = pywt.Wavelet('haar')
|
||||
for mode in pywt.Modes.modes:
|
||||
d = pywt.dwtn(data, wavelet, mode=mode)
|
||||
assert_allclose(data, pywt.idwtn(d, wavelet, mode=mode),
|
||||
rtol=1e-13, atol=1e-13)
|
||||
|
||||
|
||||
def test_dwdtn_idwtn_allwavelets():
|
||||
rstate = np.random.RandomState(1234)
|
||||
r = rstate.randn(16, 16)
|
||||
# test 2D case only for all wavelet types
|
||||
wavelist = pywt.wavelist()
|
||||
if 'dmey' in wavelist:
|
||||
wavelist.remove('dmey')
|
||||
for wavelet in wavelist:
|
||||
if wavelet in ['cmor', 'shan', 'fbsp']:
|
||||
# skip these CWT families to avoid warnings
|
||||
continue
|
||||
if isinstance(pywt.DiscreteContinuousWavelet(wavelet), pywt.Wavelet):
|
||||
for mode in pywt.Modes.modes:
|
||||
coeffs = pywt.dwtn(r, wavelet, mode=mode)
|
||||
assert_allclose(pywt.idwtn(coeffs, wavelet, mode=mode),
|
||||
r, rtol=1e-7, atol=1e-7)
|
||||
|
||||
|
||||
def test_stride():
|
||||
wavelet = pywt.Wavelet('haar')
|
||||
|
||||
for dtype in ('float32', 'float64'):
|
||||
data = np.array([[0, 4, 1, 5, 1, 4],
|
||||
[0, 5, 6, 3, 2, 1],
|
||||
[2, 5, 19, 4, 19, 1]],
|
||||
dtype=dtype)
|
||||
|
||||
for mode in pywt.Modes.modes:
|
||||
expected = pywt.dwtn(data, wavelet)
|
||||
strided = np.ones((3, 12), dtype=data.dtype)
|
||||
strided[::-1, ::2] = data
|
||||
strided_dwtn = pywt.dwtn(strided[::-1, ::2], wavelet)
|
||||
for key in expected.keys():
|
||||
assert_allclose(strided_dwtn[key], expected[key])
|
||||
|
||||
|
||||
def test_byte_offset():
|
||||
wavelet = pywt.Wavelet('haar')
|
||||
for dtype in ('float32', 'float64'):
|
||||
data = np.array([[0, 4, 1, 5, 1, 4],
|
||||
[0, 5, 6, 3, 2, 1],
|
||||
[2, 5, 19, 4, 19, 1]],
|
||||
dtype=dtype)
|
||||
|
||||
for mode in pywt.Modes.modes:
|
||||
expected = pywt.dwtn(data, wavelet)
|
||||
padded = np.ones((3, 6), dtype=np.dtype({'data': (data.dtype, 0),
|
||||
'pad': ('byte', data.dtype.itemsize)},
|
||||
align=True))
|
||||
padded[:] = data
|
||||
padded_dwtn = pywt.dwtn(padded['data'], wavelet)
|
||||
for key in expected.keys():
|
||||
assert_allclose(padded_dwtn[key], expected[key])
|
||||
|
||||
|
||||
def test_3D_reconstruct_complex():
|
||||
# All dimensions even length so `take` does not need to be specified
|
||||
data = np.array([
|
||||
[[0, 4, 1, 5, 1, 4],
|
||||
[0, 5, 26, 3, 2, 1],
|
||||
[5, 8, 2, 33, 4, 9],
|
||||
[2, 5, 19, 4, 19, 1]],
|
||||
[[1, 5, 1, 2, 3, 4],
|
||||
[7, 12, 6, 52, 7, 8],
|
||||
[2, 12, 3, 52, 6, 8],
|
||||
[5, 2, 6, 78, 12, 2]]])
|
||||
data = data + 1j
|
||||
|
||||
wavelet = pywt.Wavelet('haar')
|
||||
d = pywt.dwtn(data, wavelet)
|
||||
# idwtn creates even-length shapes (2x dwtn size)
|
||||
original_shape = tuple([slice(None, s) for s in data.shape])
|
||||
assert_allclose(data, pywt.idwtn(d, wavelet)[original_shape],
|
||||
rtol=1e-13, atol=1e-13)
|
||||
|
||||
|
||||
def test_idwtn_idwt2():
|
||||
data = np.array([
|
||||
[0, 4, 1, 5, 1, 4],
|
||||
[0, 5, 6, 3, 2, 1],
|
||||
[2, 5, 19, 4, 19, 1]])
|
||||
|
||||
wavelet = pywt.Wavelet('haar')
|
||||
|
||||
LL, (HL, LH, HH) = pywt.dwt2(data, wavelet)
|
||||
d = {'aa': LL, 'da': HL, 'ad': LH, 'dd': HH}
|
||||
|
||||
for mode in pywt.Modes.modes:
|
||||
assert_allclose(pywt.idwt2((LL, (HL, LH, HH)), wavelet, mode=mode),
|
||||
pywt.idwtn(d, wavelet, mode=mode),
|
||||
rtol=1e-14, atol=1e-14)
|
||||
|
||||
|
||||
def test_idwtn_idwt2_complex():
|
||||
data = np.array([
|
||||
[0, 4, 1, 5, 1, 4],
|
||||
[0, 5, 6, 3, 2, 1],
|
||||
[2, 5, 19, 4, 19, 1]])
|
||||
data = data + 1j
|
||||
wavelet = pywt.Wavelet('haar')
|
||||
|
||||
LL, (HL, LH, HH) = pywt.dwt2(data, wavelet)
|
||||
d = {'aa': LL, 'da': HL, 'ad': LH, 'dd': HH}
|
||||
|
||||
for mode in pywt.Modes.modes:
|
||||
assert_allclose(pywt.idwt2((LL, (HL, LH, HH)), wavelet, mode=mode),
|
||||
pywt.idwtn(d, wavelet, mode=mode),
|
||||
rtol=1e-14, atol=1e-14)
|
||||
|
||||
|
||||
def test_idwtn_missing():
|
||||
# Test to confirm missing data behave as zeroes
|
||||
data = np.array([
|
||||
[0, 4, 1, 5, 1, 4],
|
||||
[0, 5, 6, 3, 2, 1],
|
||||
[2, 5, 19, 4, 19, 1]])
|
||||
|
||||
wavelet = pywt.Wavelet('haar')
|
||||
|
||||
coefs = pywt.dwtn(data, wavelet)
|
||||
|
||||
# No point removing zero, or all
|
||||
for num_missing in range(1, len(coefs)):
|
||||
for missing in combinations(coefs.keys(), num_missing):
|
||||
missing_coefs = coefs.copy()
|
||||
for key in missing:
|
||||
del missing_coefs[key]
|
||||
LL = missing_coefs.get('aa', None)
|
||||
HL = missing_coefs.get('da', None)
|
||||
LH = missing_coefs.get('ad', None)
|
||||
HH = missing_coefs.get('dd', None)
|
||||
|
||||
assert_allclose(pywt.idwt2((LL, (HL, LH, HH)), wavelet),
|
||||
pywt.idwtn(missing_coefs, 'haar'), atol=1e-15)
|
||||
|
||||
|
||||
def test_idwtn_all_coeffs_None():
|
||||
coefs = dict(aa=None, da=None, ad=None, dd=None)
|
||||
assert_raises(ValueError, pywt.idwtn, coefs, 'haar')
|
||||
|
||||
|
||||
def test_error_on_invalid_keys():
|
||||
data = np.array([
|
||||
[0, 4, 1, 5, 1, 4],
|
||||
[0, 5, 6, 3, 2, 1],
|
||||
[2, 5, 19, 4, 19, 1]])
|
||||
|
||||
wavelet = pywt.Wavelet('haar')
|
||||
|
||||
LL, (HL, LH, HH) = pywt.dwt2(data, wavelet)
|
||||
|
||||
# unexpected key
|
||||
d = {'aa': LL, 'da': HL, 'ad': LH, 'dd': HH, 'ff': LH}
|
||||
assert_raises(ValueError, pywt.idwtn, d, wavelet)
|
||||
|
||||
# mismatched key lengths
|
||||
d = {'a': LL, 'da': HL, 'ad': LH, 'dd': HH}
|
||||
assert_raises(ValueError, pywt.idwtn, d, wavelet)
|
||||
|
||||
|
||||
def test_error_mismatched_size():
|
||||
data = np.array([
|
||||
[0, 4, 1, 5, 1, 4],
|
||||
[0, 5, 6, 3, 2, 1],
|
||||
[2, 5, 19, 4, 19, 1]])
|
||||
|
||||
wavelet = pywt.Wavelet('haar')
|
||||
|
||||
LL, (HL, LH, HH) = pywt.dwt2(data, wavelet)
|
||||
|
||||
# Pass/fail depends on first element being shorter than remaining ones so
|
||||
# set 3/4 to an incorrect size to maximize chances. Order of dict items
|
||||
# is random so may not trigger on every test run. Dict is constructed
|
||||
# inside idwtn function so no use using an OrderedDict here.
|
||||
LL = LL[:, :-1]
|
||||
LH = LH[:, :-1]
|
||||
HH = HH[:, :-1]
|
||||
d = {'aa': LL, 'da': HL, 'ad': LH, 'dd': HH}
|
||||
|
||||
assert_raises(ValueError, pywt.idwtn, d, wavelet)
|
||||
|
||||
|
||||
def test_dwt2_idwt2_dtypes():
|
||||
wavelet = pywt.Wavelet('haar')
|
||||
for dt_in, dt_out in zip(dtypes_in, dtypes_out):
|
||||
x = np.ones((4, 4), dtype=dt_in)
|
||||
errmsg = "wrong dtype returned for {0} input".format(dt_in)
|
||||
|
||||
cA, (cH, cV, cD) = pywt.dwt2(x, wavelet)
|
||||
assert_(cA.dtype == cH.dtype == cV.dtype == cD.dtype,
|
||||
"dwt2: " + errmsg)
|
||||
|
||||
x_roundtrip = pywt.idwt2((cA, (cH, cV, cD)), wavelet)
|
||||
assert_(x_roundtrip.dtype == dt_out, "idwt2: " + errmsg)
|
||||
|
||||
|
||||
def test_dwtn_axes():
|
||||
data = np.array([[0, 1, 2, 3],
|
||||
[1, 1, 1, 1],
|
||||
[1, 4, 2, 8]])
|
||||
data = data + 1j*data # test with complex data
|
||||
coefs = pywt.dwtn(data, 'haar', axes=(1,))
|
||||
expected_a = list(map(lambda x: pywt.dwt(x, 'haar')[0], data))
|
||||
assert_equal(coefs['a'], expected_a)
|
||||
expected_d = list(map(lambda x: pywt.dwt(x, 'haar')[1], data))
|
||||
assert_equal(coefs['d'], expected_d)
|
||||
|
||||
coefs = pywt.dwtn(data, 'haar', axes=(1, 1))
|
||||
expected_aa = list(map(lambda x: pywt.dwt(x, 'haar')[0], expected_a))
|
||||
assert_equal(coefs['aa'], expected_aa)
|
||||
expected_ad = list(map(lambda x: pywt.dwt(x, 'haar')[1], expected_a))
|
||||
assert_equal(coefs['ad'], expected_ad)
|
||||
|
||||
|
||||
def test_idwtn_axes():
|
||||
data = np.array([[0, 1, 2, 3],
|
||||
[1, 1, 1, 1],
|
||||
[1, 4, 2, 8]])
|
||||
data = data + 1j*data # test with complex data
|
||||
coefs = pywt.dwtn(data, 'haar', axes=(1, 1))
|
||||
assert_allclose(pywt.idwtn(coefs, 'haar', axes=(1, 1)), data, atol=1e-14)
|
||||
|
||||
|
||||
def test_idwt2_none_coeffs():
|
||||
data = np.array([[0, 1, 2, 3],
|
||||
[1, 1, 1, 1],
|
||||
[1, 4, 2, 8]])
|
||||
data = data + 1j*data # test with complex data
|
||||
cA, (cH, cV, cD) = pywt.dwt2(data, 'haar', axes=(1, 1))
|
||||
|
||||
# verify setting coefficients to None is the same as zeroing them
|
||||
cD = np.zeros_like(cD)
|
||||
result_zeros = pywt.idwt2((cA, (cH, cV, cD)), 'haar', axes=(1, 1))
|
||||
|
||||
cD = None
|
||||
result_none = pywt.idwt2((cA, (cH, cV, cD)), 'haar', axes=(1, 1))
|
||||
|
||||
assert_equal(result_zeros, result_none)
|
||||
|
||||
|
||||
def test_idwtn_none_coeffs():
|
||||
data = np.array([[0, 1, 2, 3],
|
||||
[1, 1, 1, 1],
|
||||
[1, 4, 2, 8]])
|
||||
data = data + 1j*data # test with complex data
|
||||
coefs = pywt.dwtn(data, 'haar', axes=(1, 1))
|
||||
|
||||
# verify setting coefficients to None is the same as zeroing them
|
||||
coefs['dd'] = np.zeros_like(coefs['dd'])
|
||||
result_zeros = pywt.idwtn(coefs, 'haar', axes=(1, 1))
|
||||
|
||||
coefs['dd'] = None
|
||||
result_none = pywt.idwtn(coefs, 'haar', axes=(1, 1))
|
||||
|
||||
assert_equal(result_zeros, result_none)
|
||||
|
||||
|
||||
def test_idwt2_axes():
|
||||
data = np.array([[0, 1, 2, 3],
|
||||
[1, 1, 1, 1],
|
||||
[1, 4, 2, 8]])
|
||||
coefs = pywt.dwt2(data, 'haar', axes=(1, 1))
|
||||
assert_allclose(pywt.idwt2(coefs, 'haar', axes=(1, 1)), data, atol=1e-14)
|
||||
|
||||
# too many axes
|
||||
assert_raises(ValueError, pywt.idwt2, coefs, 'haar', axes=(0, 1, 1))
|
||||
|
||||
|
||||
def test_idwt2_axes_subsets():
|
||||
data = np.array(np.random.standard_normal((4, 4, 4)))
|
||||
# test all combinations of 2 out of 3 axes transformed
|
||||
for axes in combinations((0, 1, 2), 2):
|
||||
coefs = pywt.dwt2(data, 'haar', axes=axes)
|
||||
assert_allclose(pywt.idwt2(coefs, 'haar', axes=axes), data, atol=1e-14)
|
||||
|
||||
|
||||
def test_idwtn_axes_subsets():
|
||||
data = np.array(np.random.standard_normal((4, 4, 4, 4)))
|
||||
# test all combinations of 3 out of 4 axes transformed
|
||||
for axes in combinations((0, 1, 2, 3), 3):
|
||||
coefs = pywt.dwtn(data, 'haar', axes=axes)
|
||||
assert_allclose(pywt.idwtn(coefs, 'haar', axes=axes), data, atol=1e-14)
|
||||
|
||||
|
||||
def test_negative_axes():
|
||||
data = np.array([[0, 1, 2, 3],
|
||||
[1, 1, 1, 1],
|
||||
[1, 4, 2, 8]])
|
||||
coefs1 = pywt.dwtn(data, 'haar', axes=(1, 1))
|
||||
coefs2 = pywt.dwtn(data, 'haar', axes=(-1, -1))
|
||||
assert_equal(coefs1, coefs2)
|
||||
|
||||
rec1 = pywt.idwtn(coefs1, 'haar', axes=(1, 1))
|
||||
rec2 = pywt.idwtn(coefs1, 'haar', axes=(-1, -1))
|
||||
assert_equal(rec1, rec2)
|
||||
|
||||
|
||||
def test_dwtn_idwtn_dtypes():
|
||||
wavelet = pywt.Wavelet('haar')
|
||||
for dt_in, dt_out in zip(dtypes_in, dtypes_out):
|
||||
x = np.ones((4, 4), dtype=dt_in)
|
||||
errmsg = "wrong dtype returned for {0} input".format(dt_in)
|
||||
|
||||
coeffs = pywt.dwtn(x, wavelet)
|
||||
for k, v in coeffs.items():
|
||||
assert_(v.dtype == dt_out, "dwtn: " + errmsg)
|
||||
|
||||
x_roundtrip = pywt.idwtn(coeffs, wavelet)
|
||||
assert_(x_roundtrip.dtype == dt_out, "idwtn: " + errmsg)
|
||||
|
||||
|
||||
def test_idwtn_mixed_complex_dtype():
|
||||
rstate = np.random.RandomState(0)
|
||||
x = rstate.randn(8, 8, 8)
|
||||
x = x + 1j*x
|
||||
coeffs = pywt.dwtn(x, 'db2')
|
||||
|
||||
x_roundtrip = pywt.idwtn(coeffs, 'db2')
|
||||
assert_allclose(x_roundtrip, x, rtol=1e-10)
|
||||
|
||||
# mismatched dtypes OK
|
||||
coeffs['a' * x.ndim] = coeffs['a' * x.ndim].astype(np.complex64)
|
||||
x_roundtrip2 = pywt.idwtn(coeffs, 'db2')
|
||||
assert_allclose(x_roundtrip2, x, rtol=1e-7, atol=1e-7)
|
||||
assert_(x_roundtrip2.dtype == np.complex128)
|
||||
|
||||
|
||||
def test_idwt2_size_mismatch_error():
|
||||
LL = np.zeros((6, 6))
|
||||
LH = HL = HH = np.zeros((5, 5))
|
||||
|
||||
assert_raises(ValueError, pywt.idwt2, (LL, (LH, HL, HH)), wavelet='haar')
|
||||
|
||||
|
||||
def test_dwt2_dimension_error():
|
||||
data = np.ones(16)
|
||||
wavelet = pywt.Wavelet('haar')
|
||||
|
||||
# wrong number of input dimensions
|
||||
assert_raises(ValueError, pywt.dwt2, data, wavelet)
|
||||
|
||||
# too many axes
|
||||
data2 = np.ones((8, 8))
|
||||
assert_raises(ValueError, pywt.dwt2, data2, wavelet, axes=(0, 1, 1))
|
||||
|
||||
|
||||
def test_per_axis_wavelets_and_modes():
|
||||
# tests separate wavelet and edge mode for each axis.
|
||||
rstate = np.random.RandomState(1234)
|
||||
data = rstate.randn(16, 16, 16)
|
||||
|
||||
# wavelet can be a string or wavelet object
|
||||
wavelets = (pywt.Wavelet('haar'), 'sym2', 'db4')
|
||||
|
||||
# mode can be a string or a Modes enum
|
||||
modes = ('symmetric', 'periodization',
|
||||
pywt._extensions._pywt.Modes.reflect)
|
||||
|
||||
coefs = pywt.dwtn(data, wavelets, modes)
|
||||
assert_allclose(pywt.idwtn(coefs, wavelets, modes), data, atol=1e-14)
|
||||
|
||||
coefs = pywt.dwtn(data, wavelets[:1], modes)
|
||||
assert_allclose(pywt.idwtn(coefs, wavelets[:1], modes), data, atol=1e-14)
|
||||
|
||||
coefs = pywt.dwtn(data, wavelets, modes[:1])
|
||||
assert_allclose(pywt.idwtn(coefs, wavelets, modes[:1]), data, atol=1e-14)
|
||||
|
||||
# length of wavelets or modes doesn't match the length of axes
|
||||
assert_raises(ValueError, pywt.dwtn, data, wavelets[:2])
|
||||
assert_raises(ValueError, pywt.dwtn, data, wavelets, mode=modes[:2])
|
||||
assert_raises(ValueError, pywt.idwtn, coefs, wavelets[:2])
|
||||
assert_raises(ValueError, pywt.idwtn, coefs, wavelets, mode=modes[:2])
|
||||
|
||||
# dwt2/idwt2 also support per-axis wavelets/modes
|
||||
data2 = data[..., 0]
|
||||
coefs2 = pywt.dwt2(data2, wavelets[:2], modes[:2])
|
||||
assert_allclose(pywt.idwt2(coefs2, wavelets[:2], modes[:2]), data2,
|
||||
atol=1e-14)
|
||||
|
||||
|
||||
def test_error_on_continuous_wavelet():
|
||||
# A ValueError is raised if a Continuous wavelet is selected
|
||||
data = np.ones((16, 16))
|
||||
for dec_fun, rec_fun in zip([pywt.dwt2, pywt.dwtn],
|
||||
[pywt.idwt2, pywt.idwtn]):
|
||||
for cwave in ['morl', pywt.DiscreteContinuousWavelet('morl')]:
|
||||
assert_raises(ValueError, dec_fun, data, wavelet=cwave)
|
||||
|
||||
c = dec_fun(data, 'db1')
|
||||
assert_raises(ValueError, rec_fun, c, wavelet=cwave)
|
||||
1033
.CondaPkg/env/Lib/site-packages/pywt/tests/test_multilevel.py
vendored
Normal file
1033
.CondaPkg/env/Lib/site-packages/pywt/tests/test_multilevel.py
vendored
Normal file
File diff suppressed because it is too large
Load Diff
61
.CondaPkg/env/Lib/site-packages/pywt/tests/test_perfect_reconstruction.py
vendored
Normal file
61
.CondaPkg/env/Lib/site-packages/pywt/tests/test_perfect_reconstruction.py
vendored
Normal file
@@ -0,0 +1,61 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
"""
|
||||
Verify DWT perfect reconstruction.
|
||||
"""
|
||||
|
||||
from __future__ import division, print_function, absolute_import
|
||||
|
||||
import numpy as np
|
||||
from numpy.testing import assert_
|
||||
|
||||
import pywt
|
||||
|
||||
|
||||
def test_perfect_reconstruction():
|
||||
families = ('db', 'sym', 'coif', 'bior', 'rbio')
|
||||
wavelets = sum([pywt.wavelist(name) for name in families], [])
|
||||
# list of mode names in pywt and matlab
|
||||
modes = [('zero', 'zpd'),
|
||||
('constant', 'sp0'),
|
||||
('symmetric', 'sym'),
|
||||
('periodic', 'ppd'),
|
||||
('smooth', 'sp1'),
|
||||
('periodization', 'per')]
|
||||
|
||||
dtypes = (np.float32, np.float64)
|
||||
|
||||
for wavelet in wavelets:
|
||||
for pmode, mmode in modes:
|
||||
for dt in dtypes:
|
||||
check_reconstruction(pmode, mmode, wavelet, dt)
|
||||
|
||||
|
||||
def check_reconstruction(pmode, mmode, wavelet, dtype):
|
||||
data_size = list(range(2, 40)) + [100, 200, 500, 1000, 2000, 10000,
|
||||
50000, 100000]
|
||||
np.random.seed(12345)
|
||||
# TODO: smoke testing - more failures for different seeds
|
||||
|
||||
if dtype == np.float32:
|
||||
# was 3e-7 has to be lowered as db21, db29, db33, db35, coif14, coif16 were failing
|
||||
epsilon = 6e-7
|
||||
else:
|
||||
epsilon = 5e-11
|
||||
|
||||
for N in data_size:
|
||||
data = np.asarray(np.random.random(N), dtype)
|
||||
|
||||
# compute dwt coefficients
|
||||
pa, pd = pywt.dwt(data, wavelet, pmode)
|
||||
|
||||
# compute reconstruction
|
||||
rec = pywt.idwt(pa, pd, wavelet, pmode)
|
||||
|
||||
if len(data) % 2:
|
||||
rec = rec[:len(data)]
|
||||
|
||||
rms_rec = np.sqrt(np.mean((data-rec)**2))
|
||||
msg = ('[RMS_REC > EPSILON] for Mode: %s, Wavelet: %s, '
|
||||
'Length: %d, rms=%.3g' % (pmode, wavelet, len(data), rms_rec))
|
||||
assert_(rms_rec < epsilon, msg=msg)
|
||||
632
.CondaPkg/env/Lib/site-packages/pywt/tests/test_swt.py
vendored
Normal file
632
.CondaPkg/env/Lib/site-packages/pywt/tests/test_swt.py
vendored
Normal file
@@ -0,0 +1,632 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
from __future__ import division, print_function, absolute_import
|
||||
|
||||
import warnings
|
||||
from copy import deepcopy
|
||||
from itertools import combinations, permutations
|
||||
import numpy as np
|
||||
import pytest
|
||||
from numpy.testing import (assert_allclose, assert_, assert_equal,
|
||||
assert_raises, assert_array_equal, assert_warns)
|
||||
|
||||
import pywt
|
||||
from pywt._extensions._swt import swt_axis
|
||||
|
||||
# Check that float32 and complex64 are preserved. Other real types get
|
||||
# converted to float64.
|
||||
dtypes_in = [np.int8, np.float16, np.float32, np.float64, np.complex64,
|
||||
np.complex128]
|
||||
dtypes_out = [np.float64, np.float32, np.float32, np.float64, np.complex64,
|
||||
np.complex128]
|
||||
|
||||
# tolerances used in accuracy comparisons
|
||||
tol_single = 1e-6
|
||||
tol_double = 1e-13
|
||||
|
||||
####
|
||||
# 1d multilevel swt tests
|
||||
####
|
||||
|
||||
|
||||
def test_swt_decomposition():
|
||||
x = [3, 7, 1, 3, -2, 6, 4, 6]
|
||||
db1 = pywt.Wavelet('db1')
|
||||
atol = tol_double
|
||||
(cA3, cD3), (cA2, cD2), (cA1, cD1) = pywt.swt(x, db1, level=3)
|
||||
expected_cA1 = [7.07106781, 5.65685425, 2.82842712, 0.70710678,
|
||||
2.82842712, 7.07106781, 7.07106781, 6.36396103]
|
||||
assert_allclose(cA1, expected_cA1, rtol=1e-8, atol=atol)
|
||||
expected_cD1 = [-2.82842712, 4.24264069, -1.41421356, 3.53553391,
|
||||
-5.65685425, 1.41421356, -1.41421356, 2.12132034]
|
||||
assert_allclose(cD1, expected_cD1, rtol=1e-8, atol=atol)
|
||||
expected_cA2 = [7, 4.5, 4, 5.5, 7, 9.5, 10, 8.5]
|
||||
assert_allclose(cA2, expected_cA2, rtol=tol_double, atol=atol)
|
||||
expected_cD2 = [3, 3.5, 0, -4.5, -3, 0.5, 0, 0.5]
|
||||
assert_allclose(cD2, expected_cD2, rtol=tol_double, atol=atol)
|
||||
expected_cA3 = [9.89949494, ] * 8
|
||||
assert_allclose(cA3, expected_cA3, rtol=1e-8, atol=atol)
|
||||
expected_cD3 = [0.00000000, -3.53553391, -4.24264069, -2.12132034,
|
||||
0.00000000, 3.53553391, 4.24264069, 2.12132034]
|
||||
assert_allclose(cD3, expected_cD3, rtol=1e-8, atol=atol)
|
||||
|
||||
# level=1, start_level=1 decomposition should match level=2
|
||||
res = pywt.swt(cA1, db1, level=1, start_level=1)
|
||||
cA2, cD2 = res[0]
|
||||
assert_allclose(cA2, expected_cA2, rtol=tol_double, atol=atol)
|
||||
assert_allclose(cD2, expected_cD2, rtol=tol_double, atol=atol)
|
||||
|
||||
coeffs = pywt.swt(x, db1)
|
||||
assert_(len(coeffs) == 3)
|
||||
assert_(pywt.swt_max_level(len(x)), 3)
|
||||
|
||||
|
||||
def test_swt_max_level():
|
||||
# odd sized signal will warn about no levels of decomposition possible
|
||||
assert_warns(UserWarning, pywt.swt_max_level, 11)
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore', UserWarning)
|
||||
assert_equal(pywt.swt_max_level(11), 0)
|
||||
|
||||
# no warnings when >= 1 level of decomposition possible
|
||||
assert_equal(pywt.swt_max_level(2), 1) # divisible by 2**1
|
||||
assert_equal(pywt.swt_max_level(4*3), 2) # divisible by 2**2
|
||||
assert_equal(pywt.swt_max_level(16), 4) # divisible by 2**4
|
||||
assert_equal(pywt.swt_max_level(16*3), 4) # divisible by 2**4
|
||||
|
||||
|
||||
def test_swt_axis():
|
||||
x = [3, 7, 1, 3, -2, 6, 4, 6]
|
||||
|
||||
db1 = pywt.Wavelet('db1')
|
||||
(cA2, cD2), (cA1, cD1) = pywt.swt(x, db1, level=2)
|
||||
|
||||
# test cases use 2D arrays based on tiling x along an axis and then
|
||||
# calling swt along the other axis.
|
||||
for order in ['C', 'F']:
|
||||
# test SWT of 2D data along default axis (-1)
|
||||
x_2d = np.asarray(x).reshape((1, -1))
|
||||
x_2d = np.concatenate((x_2d, )*5, axis=0)
|
||||
if order == 'C':
|
||||
x_2d = np.ascontiguousarray(x_2d)
|
||||
elif order == 'F':
|
||||
x_2d = np.asfortranarray(x_2d)
|
||||
(cA2_2d, cD2_2d), (cA1_2d, cD1_2d) = pywt.swt(x_2d, db1, level=2)
|
||||
|
||||
for c in [cA2_2d, cD2_2d, cA1_2d, cD1_2d]:
|
||||
assert_(c.shape == x_2d.shape)
|
||||
# each row should match the 1D result
|
||||
for row in cA1_2d:
|
||||
assert_array_equal(row, cA1)
|
||||
for row in cA2_2d:
|
||||
assert_array_equal(row, cA2)
|
||||
for row in cD1_2d:
|
||||
assert_array_equal(row, cD1)
|
||||
for row in cD2_2d:
|
||||
assert_array_equal(row, cD2)
|
||||
|
||||
# test SWT of 2D data along other axis (0)
|
||||
x_2d = np.asarray(x).reshape((-1, 1))
|
||||
x_2d = np.concatenate((x_2d, )*5, axis=1)
|
||||
if order == 'C':
|
||||
x_2d = np.ascontiguousarray(x_2d)
|
||||
elif order == 'F':
|
||||
x_2d = np.asfortranarray(x_2d)
|
||||
(cA2_2d, cD2_2d), (cA1_2d, cD1_2d) = pywt.swt(x_2d, db1, level=2,
|
||||
axis=0)
|
||||
|
||||
for c in [cA2_2d, cD2_2d, cA1_2d, cD1_2d]:
|
||||
assert_(c.shape == x_2d.shape)
|
||||
# each column should match the 1D result
|
||||
for row in cA1_2d.transpose((1, 0)):
|
||||
assert_array_equal(row, cA1)
|
||||
for row in cA2_2d.transpose((1, 0)):
|
||||
assert_array_equal(row, cA2)
|
||||
for row in cD1_2d.transpose((1, 0)):
|
||||
assert_array_equal(row, cD1)
|
||||
for row in cD2_2d.transpose((1, 0)):
|
||||
assert_array_equal(row, cD2)
|
||||
|
||||
# axis too large
|
||||
assert_raises(ValueError, pywt.swt, x, db1, level=2, axis=5)
|
||||
|
||||
|
||||
def test_swt_iswt_integration():
|
||||
# This function performs a round-trip swt/iswt transform test on
|
||||
# all available types of wavelets in PyWavelets - except the
|
||||
# 'dmey' wavelet. The latter has been excluded because it does not
|
||||
# produce very precise results. This is likely due to the fact
|
||||
# that the 'dmey' wavelet is a discrete approximation of a
|
||||
# continuous wavelet. All wavelets are tested up to 3 levels. The
|
||||
# test validates neither swt or iswt as such, but it does ensure
|
||||
# that they are each other's inverse.
|
||||
|
||||
max_level = 3
|
||||
wavelets = pywt.wavelist(kind='discrete')
|
||||
if 'dmey' in wavelets:
|
||||
# The 'dmey' wavelet seems to be a bit special - disregard it for now
|
||||
wavelets.remove('dmey')
|
||||
for current_wavelet_str in wavelets:
|
||||
current_wavelet = pywt.Wavelet(current_wavelet_str)
|
||||
input_length_power = int(np.ceil(np.log2(max(
|
||||
current_wavelet.dec_len,
|
||||
current_wavelet.rec_len))))
|
||||
input_length = 2**(input_length_power + max_level - 1)
|
||||
X = np.arange(input_length)
|
||||
for norm in [True, False]:
|
||||
if norm and not current_wavelet.orthogonal:
|
||||
# non-orthogonal wavelets to avoid warnings when norm=True
|
||||
continue
|
||||
for trim_approx in [True, False]:
|
||||
coeffs = pywt.swt(X, current_wavelet, max_level,
|
||||
trim_approx=trim_approx, norm=norm)
|
||||
Y = pywt.iswt(coeffs, current_wavelet, norm=norm)
|
||||
assert_allclose(Y, X, rtol=1e-5, atol=1e-7)
|
||||
|
||||
|
||||
def test_swt_dtypes():
|
||||
wavelet = pywt.Wavelet('haar')
|
||||
for dt_in, dt_out in zip(dtypes_in, dtypes_out):
|
||||
errmsg = "wrong dtype returned for {0} input".format(dt_in)
|
||||
|
||||
# swt
|
||||
x = np.ones(8, dtype=dt_in)
|
||||
(cA2, cD2), (cA1, cD1) = pywt.swt(x, wavelet, level=2)
|
||||
assert_(cA2.dtype == cD2.dtype == cA1.dtype == cD1.dtype == dt_out,
|
||||
"swt: " + errmsg)
|
||||
|
||||
# swt2
|
||||
x = np.ones((8, 8), dtype=dt_in)
|
||||
cA, (cH, cV, cD) = pywt.swt2(x, wavelet, level=1)[0]
|
||||
assert_(cA.dtype == cH.dtype == cV.dtype == cD.dtype == dt_out,
|
||||
"swt2: " + errmsg)
|
||||
|
||||
|
||||
def test_swt_roundtrip_dtypes():
|
||||
# verify perfect reconstruction for all dtypes
|
||||
rstate = np.random.RandomState(5)
|
||||
wavelet = pywt.Wavelet('haar')
|
||||
for dt_in, dt_out in zip(dtypes_in, dtypes_out):
|
||||
# swt, iswt
|
||||
x = rstate.standard_normal((8, )).astype(dt_in)
|
||||
c = pywt.swt(x, wavelet, level=2)
|
||||
xr = pywt.iswt(c, wavelet)
|
||||
assert_allclose(x, xr, rtol=1e-6, atol=1e-7)
|
||||
|
||||
# swt2, iswt2
|
||||
x = rstate.standard_normal((8, 8)).astype(dt_in)
|
||||
c = pywt.swt2(x, wavelet, level=2)
|
||||
xr = pywt.iswt2(c, wavelet)
|
||||
assert_allclose(x, xr, rtol=1e-6, atol=1e-7)
|
||||
|
||||
|
||||
def test_swt_default_level_by_axis():
|
||||
# make sure default number of levels matches the max level along the axis
|
||||
wav = 'db2'
|
||||
x = np.ones((2**3, 2**4, 2**5))
|
||||
for axis in (0, 1, 2):
|
||||
sdec = pywt.swt(x, wav, level=None, start_level=0, axis=axis)
|
||||
assert_equal(len(sdec), pywt.swt_max_level(x.shape[axis]))
|
||||
|
||||
|
||||
def test_swt2_ndim_error():
|
||||
x = np.ones(8)
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore', FutureWarning)
|
||||
assert_raises(ValueError, pywt.swt2, x, 'haar', level=1)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_swt2_iswt2_integration(wavelets=None):
|
||||
# This function performs a round-trip swt2/iswt2 transform test on
|
||||
# all available types of wavelets in PyWavelets - except the
|
||||
# 'dmey' wavelet. The latter has been excluded because it does not
|
||||
# produce very precise results. This is likely due to the fact
|
||||
# that the 'dmey' wavelet is a discrete approximation of a
|
||||
# continuous wavelet. All wavelets are tested up to 3 levels. The
|
||||
# test validates neither swt2 or iswt2 as such, but it does ensure
|
||||
# that they are each other's inverse.
|
||||
|
||||
max_level = 3
|
||||
if wavelets is None:
|
||||
wavelets = pywt.wavelist(kind='discrete')
|
||||
if 'dmey' in wavelets:
|
||||
# The 'dmey' wavelet is a special case - disregard it for now
|
||||
wavelets.remove('dmey')
|
||||
for current_wavelet_str in wavelets:
|
||||
current_wavelet = pywt.Wavelet(current_wavelet_str)
|
||||
input_length_power = int(np.ceil(np.log2(max(
|
||||
current_wavelet.dec_len,
|
||||
current_wavelet.rec_len))))
|
||||
input_length = 2**(input_length_power + max_level - 1)
|
||||
X = np.arange(input_length**2).reshape(input_length, input_length)
|
||||
|
||||
for norm in [True, False]:
|
||||
if norm and not current_wavelet.orthogonal:
|
||||
# non-orthogonal wavelets to avoid warnings when norm=True
|
||||
continue
|
||||
for trim_approx in [True, False]:
|
||||
coeffs = pywt.swt2(X, current_wavelet, max_level,
|
||||
trim_approx=trim_approx, norm=norm)
|
||||
Y = pywt.iswt2(coeffs, current_wavelet, norm=norm)
|
||||
assert_allclose(Y, X, rtol=1e-5, atol=1e-5)
|
||||
|
||||
|
||||
def test_swt2_iswt2_quick():
|
||||
test_swt2_iswt2_integration(wavelets=['db1', ])
|
||||
|
||||
|
||||
def test_swt2_iswt2_non_square(wavelets=None):
|
||||
for nrows in [8, 16, 48]:
|
||||
X = np.arange(nrows*32).reshape(nrows, 32)
|
||||
current_wavelet = 'db1'
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore', FutureWarning)
|
||||
coeffs = pywt.swt2(X, current_wavelet, level=2)
|
||||
Y = pywt.iswt2(coeffs, current_wavelet)
|
||||
assert_allclose(Y, X, rtol=tol_single, atol=tol_single)
|
||||
|
||||
|
||||
def test_swt2_axes():
|
||||
atol = 1e-14
|
||||
current_wavelet = pywt.Wavelet('db2')
|
||||
input_length_power = int(np.ceil(np.log2(max(
|
||||
current_wavelet.dec_len,
|
||||
current_wavelet.rec_len))))
|
||||
input_length = 2**(input_length_power)
|
||||
X = np.arange(input_length**2).reshape(input_length, input_length)
|
||||
|
||||
(cA1, (cH1, cV1, cD1)) = pywt.swt2(X, current_wavelet, level=1)[0]
|
||||
# opposite order
|
||||
(cA2, (cH2, cV2, cD2)) = pywt.swt2(X, current_wavelet, level=1,
|
||||
axes=(1, 0))[0]
|
||||
assert_allclose(cA1, cA2, atol=atol)
|
||||
assert_allclose(cH1, cV2, atol=atol)
|
||||
assert_allclose(cV1, cH2, atol=atol)
|
||||
assert_allclose(cD1, cD2, atol=atol)
|
||||
|
||||
# reverify iswt2 restores the original data
|
||||
r1 = pywt.iswt2([cA1, (cH1, cV1, cD1)], current_wavelet)
|
||||
assert_allclose(X, r1, atol=atol)
|
||||
r2 = pywt.iswt2([cA2, (cH2, cV2, cD2)], current_wavelet, axes=(1, 0))
|
||||
assert_allclose(X, r2, atol=atol)
|
||||
|
||||
# duplicate axes not allowed
|
||||
assert_raises(ValueError, pywt.swt2, X, current_wavelet, 1,
|
||||
axes=(0, 0))
|
||||
# too few axes
|
||||
assert_raises(ValueError, pywt.swt2, X, current_wavelet, 1, axes=(0, ))
|
||||
|
||||
|
||||
def test_swtn_axes():
|
||||
atol = 1e-14
|
||||
current_wavelet = pywt.Wavelet('db2')
|
||||
input_length_power = int(np.ceil(np.log2(max(
|
||||
current_wavelet.dec_len,
|
||||
current_wavelet.rec_len))))
|
||||
input_length = 2**(input_length_power)
|
||||
X = np.arange(input_length**2).reshape(input_length, input_length)
|
||||
coeffs = pywt.swtn(X, current_wavelet, level=1, axes=None)[0]
|
||||
# opposite order
|
||||
coeffs2 = pywt.swtn(X, current_wavelet, level=1, axes=(1, 0))[0]
|
||||
assert_allclose(coeffs['aa'], coeffs2['aa'], atol=atol)
|
||||
assert_allclose(coeffs['ad'], coeffs2['da'], atol=atol)
|
||||
assert_allclose(coeffs['da'], coeffs2['ad'], atol=atol)
|
||||
assert_allclose(coeffs['dd'], coeffs2['dd'], atol=atol)
|
||||
|
||||
# 0-level transform
|
||||
empty = pywt.swtn(X, current_wavelet, level=0)
|
||||
assert_equal(empty, [])
|
||||
|
||||
# duplicate axes not allowed
|
||||
assert_raises(ValueError, pywt.swtn, X, current_wavelet, 1, axes=(0, 0))
|
||||
|
||||
# data.ndim = 0
|
||||
assert_raises(ValueError, pywt.swtn, np.asarray([]), current_wavelet, 1)
|
||||
|
||||
# start_level too large
|
||||
assert_raises(ValueError, pywt.swtn, X, current_wavelet,
|
||||
level=1, start_level=2)
|
||||
|
||||
# level < 1 in swt_axis call
|
||||
assert_raises(ValueError, swt_axis, X, current_wavelet, level=0,
|
||||
start_level=0)
|
||||
# odd-sized data not allowed
|
||||
assert_raises(ValueError, swt_axis, X[:-1, :], current_wavelet, level=0,
|
||||
start_level=0, axis=0)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_swtn_iswtn_integration(wavelets=None):
|
||||
# This function performs a round-trip swtn/iswtn transform for various
|
||||
# possible combinations of:
|
||||
# 1.) 1 out of 2 axes of a 2D array
|
||||
# 2.) 2 out of 3 axes of a 3D array
|
||||
#
|
||||
# To keep test time down, only wavelets of length <= 8 are run.
|
||||
#
|
||||
# This test does not validate swtn or iswtn individually, but only
|
||||
# confirms that iswtn yields an (almost) perfect reconstruction of swtn.
|
||||
max_level = 3
|
||||
if wavelets is None:
|
||||
wavelets = pywt.wavelist(kind='discrete')
|
||||
if 'dmey' in wavelets:
|
||||
# The 'dmey' wavelet is a special case - disregard it for now
|
||||
wavelets.remove('dmey')
|
||||
for ndim_transform in range(1, 3):
|
||||
ndim = ndim_transform + 1
|
||||
for axes in combinations(range(ndim), ndim_transform):
|
||||
for current_wavelet_str in wavelets:
|
||||
wav = pywt.Wavelet(current_wavelet_str)
|
||||
if wav.dec_len > 8:
|
||||
continue # avoid excessive test duration
|
||||
input_length_power = int(np.ceil(np.log2(max(
|
||||
wav.dec_len,
|
||||
wav.rec_len))))
|
||||
N = 2**(input_length_power + max_level - 1)
|
||||
X = np.arange(N**ndim).reshape((N, )*ndim)
|
||||
|
||||
for norm in [True, False]:
|
||||
if norm and not wav.orthogonal:
|
||||
# non-orthogonal wavelets to avoid warnings
|
||||
continue
|
||||
for trim_approx in [True, False]:
|
||||
coeffs = pywt.swtn(X, wav, max_level, axes=axes,
|
||||
trim_approx=trim_approx, norm=norm)
|
||||
coeffs_copy = deepcopy(coeffs)
|
||||
Y = pywt.iswtn(coeffs, wav, axes=axes, norm=norm)
|
||||
assert_allclose(Y, X, rtol=1e-5, atol=1e-5)
|
||||
|
||||
# verify the inverse transform didn't modify any coeffs
|
||||
for c, c2 in zip(coeffs, coeffs_copy):
|
||||
for k, v in c.items():
|
||||
assert_array_equal(c2[k], v)
|
||||
|
||||
|
||||
def test_swtn_iswtn_quick():
|
||||
test_swtn_iswtn_integration(wavelets=['db1', ])
|
||||
|
||||
|
||||
def test_iswtn_errors():
|
||||
x = np.arange(8**3).reshape(8, 8, 8)
|
||||
max_level = 2
|
||||
axes = (0, 1)
|
||||
w = pywt.Wavelet('db1')
|
||||
coeffs = pywt.swtn(x, w, max_level, axes=axes)
|
||||
|
||||
# more axes than dimensions transformed
|
||||
assert_raises(ValueError, pywt.iswtn, coeffs, w, axes=(0, 1, 2))
|
||||
# duplicate axes not allowed
|
||||
assert_raises(ValueError, pywt.iswtn, coeffs, w, axes=(0, 0))
|
||||
# mismatched coefficient size
|
||||
coeffs[0]['da'] = coeffs[0]['da'][:-1, :]
|
||||
assert_raises(RuntimeError, pywt.iswtn, coeffs, w, axes=axes)
|
||||
|
||||
|
||||
def test_swtn_iswtn_unique_shape_per_axis():
|
||||
# test case for gh-460
|
||||
_shape = (1, 48, 32) # unique shape per axis
|
||||
wav = 'sym2'
|
||||
max_level = 3
|
||||
rstate = np.random.RandomState(0)
|
||||
for shape in permutations(_shape):
|
||||
# transform only along the non-singleton axes
|
||||
axes = [ax for ax, s in enumerate(shape) if s != 1]
|
||||
x = rstate.standard_normal(shape)
|
||||
c = pywt.swtn(x, wav, max_level, axes=axes)
|
||||
r = pywt.iswtn(c, wav, axes=axes)
|
||||
assert_allclose(x, r, rtol=1e-10, atol=1e-10)
|
||||
|
||||
|
||||
def test_per_axis_wavelets():
|
||||
# tests separate wavelet for each axis.
|
||||
rstate = np.random.RandomState(1234)
|
||||
data = rstate.randn(16, 16, 16)
|
||||
level = 3
|
||||
|
||||
# wavelet can be a string or wavelet object
|
||||
wavelets = (pywt.Wavelet('haar'), 'sym2', 'db4')
|
||||
|
||||
coefs = pywt.swtn(data, wavelets, level=level)
|
||||
assert_allclose(pywt.iswtn(coefs, wavelets), data, atol=1e-14)
|
||||
|
||||
# 1-tuple also okay
|
||||
coefs = pywt.swtn(data, wavelets[:1], level=level)
|
||||
assert_allclose(pywt.iswtn(coefs, wavelets[:1]), data, atol=1e-14)
|
||||
|
||||
# length of wavelets doesn't match the length of axes
|
||||
assert_raises(ValueError, pywt.swtn, data, wavelets[:2], level)
|
||||
assert_raises(ValueError, pywt.iswtn, coefs, wavelets[:2])
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore', FutureWarning)
|
||||
# swt2/iswt2 also support per-axis wavelets/modes
|
||||
data2 = data[..., 0]
|
||||
coefs2 = pywt.swt2(data2, wavelets[:2], level)
|
||||
assert_allclose(pywt.iswt2(coefs2, wavelets[:2]), data2, atol=1e-14)
|
||||
|
||||
|
||||
def test_error_on_continuous_wavelet():
|
||||
# A ValueError is raised if a Continuous wavelet is selected
|
||||
data = np.ones((16, 16))
|
||||
for dec_func, rec_func in zip([pywt.swt, pywt.swt2, pywt.swtn],
|
||||
[pywt.iswt, pywt.iswt2, pywt.iswtn]):
|
||||
for cwave in ['morl', pywt.DiscreteContinuousWavelet('morl')]:
|
||||
assert_raises(ValueError, dec_func, data, wavelet=cwave,
|
||||
level=3)
|
||||
|
||||
c = dec_func(data, 'db1', level=3)
|
||||
assert_raises(ValueError, rec_func, c, wavelet=cwave)
|
||||
|
||||
|
||||
def test_iswt_mixed_dtypes():
|
||||
# Mixed precision inputs give double precision output
|
||||
x_real = np.arange(16).astype(np.float64)
|
||||
x_complex = x_real + 1j*x_real
|
||||
wav = 'sym2'
|
||||
for dtype1, dtype2 in [(np.float64, np.float32),
|
||||
(np.float32, np.float64),
|
||||
(np.float16, np.float64),
|
||||
(np.complex128, np.complex64),
|
||||
(np.complex64, np.complex128)]:
|
||||
|
||||
if dtype1 in [np.complex64, np.complex128]:
|
||||
x = x_complex
|
||||
output_dtype = np.complex128
|
||||
else:
|
||||
x = x_real
|
||||
output_dtype = np.float64
|
||||
|
||||
coeffs = pywt.swt(x, wav, 2)
|
||||
# different precision for the approximation coefficients
|
||||
coeffs[0] = [coeffs[0][0].astype(dtype1),
|
||||
coeffs[0][1].astype(dtype2)]
|
||||
y = pywt.iswt(coeffs, wav)
|
||||
assert_equal(output_dtype, y.dtype)
|
||||
assert_allclose(y, x, rtol=1e-3, atol=1e-3)
|
||||
|
||||
|
||||
def test_iswt2_mixed_dtypes():
|
||||
# Mixed precision inputs give double precision output
|
||||
rstate = np.random.RandomState(0)
|
||||
x_real = rstate.randn(8, 8)
|
||||
x_complex = x_real + 1j*x_real
|
||||
wav = 'sym2'
|
||||
for dtype1, dtype2 in [(np.float64, np.float32),
|
||||
(np.float32, np.float64),
|
||||
(np.float16, np.float64),
|
||||
(np.complex128, np.complex64),
|
||||
(np.complex64, np.complex128)]:
|
||||
|
||||
if dtype1 in [np.complex64, np.complex128]:
|
||||
x = x_complex
|
||||
output_dtype = np.complex128
|
||||
else:
|
||||
x = x_real
|
||||
output_dtype = np.float64
|
||||
|
||||
coeffs = pywt.swt2(x, wav, 2)
|
||||
# different precision for the approximation coefficients
|
||||
coeffs[0] = [coeffs[0][0].astype(dtype1),
|
||||
tuple([c.astype(dtype2) for c in coeffs[0][1]])]
|
||||
y = pywt.iswt2(coeffs, wav)
|
||||
assert_equal(output_dtype, y.dtype)
|
||||
assert_allclose(y, x, rtol=1e-3, atol=1e-3)
|
||||
|
||||
|
||||
def test_iswtn_mixed_dtypes():
|
||||
# Mixed precision inputs give double precision output
|
||||
rstate = np.random.RandomState(0)
|
||||
x_real = rstate.randn(8, 8, 8)
|
||||
x_complex = x_real + 1j*x_real
|
||||
wav = 'sym2'
|
||||
for dtype1, dtype2 in [(np.float64, np.float32),
|
||||
(np.float32, np.float64),
|
||||
(np.float16, np.float64),
|
||||
(np.complex128, np.complex64),
|
||||
(np.complex64, np.complex128)]:
|
||||
|
||||
if dtype1 in [np.complex64, np.complex128]:
|
||||
x = x_complex
|
||||
output_dtype = np.complex128
|
||||
else:
|
||||
x = x_real
|
||||
output_dtype = np.float64
|
||||
|
||||
coeffs = pywt.swtn(x, wav, 2)
|
||||
# different precision for the approximation coefficients
|
||||
a = coeffs[0].pop('a' * x.ndim)
|
||||
a = a.astype(dtype1)
|
||||
coeffs[0] = {k: c.astype(dtype2) for k, c in coeffs[0].items()}
|
||||
coeffs[0]['a' * x.ndim] = a
|
||||
y = pywt.iswtn(coeffs, wav)
|
||||
assert_equal(output_dtype, y.dtype)
|
||||
assert_allclose(y, x, rtol=1e-3, atol=1e-3)
|
||||
|
||||
|
||||
def test_swt_zero_size_axes():
|
||||
# raise on empty input array
|
||||
assert_raises(ValueError, pywt.swt, [], 'db2')
|
||||
|
||||
# >1D case uses a different code path so check there as well
|
||||
x = np.ones((1, 4))[0:0, :] # 2D with a size zero axis
|
||||
assert_raises(ValueError, pywt.swtn, x, 'db2', level=1, axes=(0,))
|
||||
|
||||
|
||||
def test_swt_variance_and_energy_preservation():
|
||||
"""Verify that the 1D SWT partitions variance among the coefficients."""
|
||||
# When norm is True and the wavelet is orthogonal, the sum of the
|
||||
# variances of the coefficients should equal the variance of the signal.
|
||||
wav = 'db2'
|
||||
rstate = np.random.RandomState(5)
|
||||
x = rstate.randn(256)
|
||||
coeffs = pywt.swt(x, wav, trim_approx=True, norm=True)
|
||||
variances = [np.var(c) for c in coeffs]
|
||||
assert_allclose(np.sum(variances), np.var(x))
|
||||
|
||||
# also verify L2-norm energy preservation property
|
||||
assert_allclose(np.linalg.norm(x),
|
||||
np.linalg.norm(np.concatenate(coeffs)))
|
||||
|
||||
# non-orthogonal wavelet with norm=True raises a warning
|
||||
assert_warns(UserWarning, pywt.swt, x, 'bior2.2', norm=True)
|
||||
|
||||
|
||||
def test_swt2_variance_and_energy_preservation():
|
||||
"""Verify that the 2D SWT partitions variance among the coefficients."""
|
||||
# When norm is True and the wavelet is orthogonal, the sum of the
|
||||
# variances of the coefficients should equal the variance of the signal.
|
||||
wav = 'db2'
|
||||
rstate = np.random.RandomState(5)
|
||||
x = rstate.randn(64, 64)
|
||||
coeffs = pywt.swt2(x, wav, level=4, trim_approx=True, norm=True)
|
||||
coeff_list = [coeffs[0].ravel()]
|
||||
for d in coeffs[1:]:
|
||||
for v in d:
|
||||
coeff_list.append(v.ravel())
|
||||
variances = [np.var(v) for v in coeff_list]
|
||||
assert_allclose(np.sum(variances), np.var(x))
|
||||
|
||||
# also verify L2-norm energy preservation property
|
||||
assert_allclose(np.linalg.norm(x),
|
||||
np.linalg.norm(np.concatenate(coeff_list)))
|
||||
|
||||
# non-orthogonal wavelet with norm=True raises a warning
|
||||
assert_warns(UserWarning, pywt.swt2, x, 'bior2.2', level=4, norm=True)
|
||||
|
||||
|
||||
def test_swtn_variance_and_energy_preservation():
|
||||
"""Verify that the nD SWT partitions variance among the coefficients."""
|
||||
# When norm is True and the wavelet is orthogonal, the sum of the
|
||||
# variances of the coefficients should equal the variance of the signal.
|
||||
wav = 'db2'
|
||||
rstate = np.random.RandomState(5)
|
||||
x = rstate.randn(64, 64)
|
||||
coeffs = pywt.swtn(x, wav, level=4, trim_approx=True, norm=True)
|
||||
coeff_list = [coeffs[0].ravel()]
|
||||
for d in coeffs[1:]:
|
||||
for k, v in d.items():
|
||||
coeff_list.append(v.ravel())
|
||||
variances = [np.var(v) for v in coeff_list]
|
||||
assert_allclose(np.sum(variances), np.var(x))
|
||||
|
||||
# also verify L2-norm energy preservation property
|
||||
assert_allclose(np.linalg.norm(x),
|
||||
np.linalg.norm(np.concatenate(coeff_list)))
|
||||
|
||||
# non-orthogonal wavelet with norm=True raises a warning
|
||||
assert_warns(UserWarning, pywt.swtn, x, 'bior2.2', level=4, norm=True)
|
||||
|
||||
|
||||
def test_swt_ravel_and_unravel():
|
||||
# When trim_approx=True, all swt functions can user pywt.ravel_coeffs
|
||||
for ndim, _swt, _iswt, ravel_type in [
|
||||
(1, pywt.swt, pywt.iswt, 'swt'),
|
||||
(2, pywt.swt2, pywt.iswt2, 'swt2'),
|
||||
(3, pywt.swtn, pywt.iswtn, 'swtn')]:
|
||||
x = np.ones((16, ) * ndim)
|
||||
c = _swt(x, 'sym2', level=3, trim_approx=True)
|
||||
arr, slices, shapes = pywt.ravel_coeffs(c)
|
||||
c = pywt.unravel_coeffs(arr, slices, shapes, output_format=ravel_type)
|
||||
r = _iswt(c, 'sym2')
|
||||
assert_allclose(x, r)
|
||||
169
.CondaPkg/env/Lib/site-packages/pywt/tests/test_thresholding.py
vendored
Normal file
169
.CondaPkg/env/Lib/site-packages/pywt/tests/test_thresholding.py
vendored
Normal file
@@ -0,0 +1,169 @@
|
||||
from __future__ import division, print_function, absolute_import
|
||||
import numpy as np
|
||||
from numpy.testing import assert_allclose, assert_raises, assert_, assert_equal
|
||||
|
||||
import pywt
|
||||
|
||||
|
||||
float_dtypes = [np.float32, np.float64, np.complex64, np.complex128]
|
||||
real_dtypes = [np.float32, np.float64]
|
||||
|
||||
|
||||
def _sign(x):
|
||||
# Matlab-like sign function (numpy uses a different convention).
|
||||
return x / np.abs(x)
|
||||
|
||||
|
||||
def _soft(x, thresh):
|
||||
"""soft thresholding supporting complex values.
|
||||
|
||||
Notes
|
||||
-----
|
||||
This version is not robust to zeros in x.
|
||||
"""
|
||||
return _sign(x) * np.maximum(np.abs(x) - thresh, 0)
|
||||
|
||||
|
||||
def test_threshold():
|
||||
data = np.linspace(1, 4, 7)
|
||||
|
||||
# soft
|
||||
soft_result = [0., 0., 0., 0.5, 1., 1.5, 2.]
|
||||
assert_allclose(pywt.threshold(data, 2, 'soft'),
|
||||
np.array(soft_result), rtol=1e-12)
|
||||
assert_allclose(pywt.threshold(-data, 2, 'soft'),
|
||||
-np.array(soft_result), rtol=1e-12)
|
||||
assert_allclose(pywt.threshold([[1, 2]] * 2, 1, 'soft'),
|
||||
[[0, 1]] * 2, rtol=1e-12)
|
||||
assert_allclose(pywt.threshold([[1, 2]] * 2, 2, 'soft'),
|
||||
[[0, 0]] * 2, rtol=1e-12)
|
||||
|
||||
# soft thresholding complex values
|
||||
assert_allclose(pywt.threshold([[1j, 2j]] * 2, 1, 'soft'),
|
||||
[[0j, 1j]] * 2, rtol=1e-12)
|
||||
assert_allclose(pywt.threshold([[1+1j, 2+2j]] * 2, 6, 'soft'),
|
||||
[[0, 0]] * 2, rtol=1e-12)
|
||||
complex_data = [[1+2j, 2+2j]]*2
|
||||
for thresh in [1, 2]:
|
||||
assert_allclose(pywt.threshold(complex_data, thresh, 'soft'),
|
||||
_soft(complex_data, thresh), rtol=1e-12)
|
||||
|
||||
# test soft thresholding with non-default substitute argument
|
||||
s = 5
|
||||
assert_allclose(pywt.threshold([[1j, 2]] * 2, 1.5, 'soft', substitute=s),
|
||||
[[s, 0.5]] * 2, rtol=1e-12)
|
||||
|
||||
# soft: no divide by zero warnings when input contains zeros
|
||||
assert_allclose(pywt.threshold(np.zeros(16), 2, 'soft'),
|
||||
np.zeros(16), rtol=1e-12)
|
||||
|
||||
# hard
|
||||
hard_result = [0., 0., 2., 2.5, 3., 3.5, 4.]
|
||||
assert_allclose(pywt.threshold(data, 2, 'hard'),
|
||||
np.array(hard_result), rtol=1e-12)
|
||||
assert_allclose(pywt.threshold(-data, 2, 'hard'),
|
||||
-np.array(hard_result), rtol=1e-12)
|
||||
assert_allclose(pywt.threshold([[1, 2]] * 2, 1, 'hard'),
|
||||
[[1, 2]] * 2, rtol=1e-12)
|
||||
assert_allclose(pywt.threshold([[1, 2]] * 2, 2, 'hard'),
|
||||
[[0, 2]] * 2, rtol=1e-12)
|
||||
assert_allclose(pywt.threshold([[1, 2]] * 2, 2, 'hard', substitute=s),
|
||||
[[s, 2]] * 2, rtol=1e-12)
|
||||
assert_allclose(pywt.threshold([[1+1j, 2+2j]] * 2, 2, 'hard'),
|
||||
[[0, 2+2j]] * 2, rtol=1e-12)
|
||||
|
||||
# greater
|
||||
greater_result = [0., 0., 2., 2.5, 3., 3.5, 4.]
|
||||
assert_allclose(pywt.threshold(data, 2, 'greater'),
|
||||
np.array(greater_result), rtol=1e-12)
|
||||
assert_allclose(pywt.threshold([[1, 2]] * 2, 1, 'greater'),
|
||||
[[1, 2]] * 2, rtol=1e-12)
|
||||
assert_allclose(pywt.threshold([[1, 2]] * 2, 2, 'greater'),
|
||||
[[0, 2]] * 2, rtol=1e-12)
|
||||
assert_allclose(pywt.threshold([[1, 2]] * 2, 2, 'greater', substitute=s),
|
||||
[[s, 2]] * 2, rtol=1e-12)
|
||||
# greater doesn't allow complex-valued inputs
|
||||
assert_raises(ValueError, pywt.threshold, [1j, 2j], 2, 'greater')
|
||||
|
||||
# less
|
||||
assert_allclose(pywt.threshold(data, 2, 'less'),
|
||||
np.array([1., 1.5, 2., 0., 0., 0., 0.]), rtol=1e-12)
|
||||
assert_allclose(pywt.threshold([[1, 2]] * 2, 1, 'less'),
|
||||
[[1, 0]] * 2, rtol=1e-12)
|
||||
assert_allclose(pywt.threshold([[1, 2]] * 2, 1, 'less', substitute=s),
|
||||
[[1, s]] * 2, rtol=1e-12)
|
||||
assert_allclose(pywt.threshold([[1, 2]] * 2, 2, 'less'),
|
||||
[[1, 2]] * 2, rtol=1e-12)
|
||||
|
||||
# less doesn't allow complex-valued inputs
|
||||
assert_raises(ValueError, pywt.threshold, [1j, 2j], 2, 'less')
|
||||
|
||||
# invalid
|
||||
assert_raises(ValueError, pywt.threshold, data, 2, 'foo')
|
||||
|
||||
|
||||
def test_nonnegative_garotte():
|
||||
thresh = 0.3
|
||||
data_real = np.linspace(-1, 1, 100)
|
||||
for dtype in float_dtypes:
|
||||
if dtype in real_dtypes:
|
||||
data = np.asarray(data_real, dtype=dtype)
|
||||
else:
|
||||
data = np.asarray(data_real + 0.1j, dtype=dtype)
|
||||
d_hard = pywt.threshold(data, thresh, 'hard')
|
||||
d_soft = pywt.threshold(data, thresh, 'soft')
|
||||
d_garotte = pywt.threshold(data, thresh, 'garotte')
|
||||
|
||||
# check dtypes
|
||||
assert_equal(d_hard.dtype, data.dtype)
|
||||
assert_equal(d_soft.dtype, data.dtype)
|
||||
assert_equal(d_garotte.dtype, data.dtype)
|
||||
|
||||
# values < threshold are zero
|
||||
lt = np.where(np.abs(data) < thresh)
|
||||
assert_(np.all(d_garotte[lt] == 0))
|
||||
|
||||
# values > than the threshold are intermediate between soft and hard
|
||||
gt = np.where(np.abs(data) > thresh)
|
||||
gt_abs_garotte = np.abs(d_garotte[gt])
|
||||
assert_(np.all(gt_abs_garotte < np.abs(d_hard[gt])))
|
||||
assert_(np.all(gt_abs_garotte > np.abs(d_soft[gt])))
|
||||
|
||||
|
||||
def test_threshold_firm():
|
||||
thresh = 0.2
|
||||
thresh2 = 3 * thresh
|
||||
data_real = np.linspace(-1, 1, 100)
|
||||
for dtype in float_dtypes:
|
||||
if dtype in real_dtypes:
|
||||
data = np.asarray(data_real, dtype=dtype)
|
||||
else:
|
||||
data = np.asarray(data_real + 0.1j, dtype=dtype)
|
||||
if data.real.dtype == np.float32:
|
||||
rtol = atol = 1e-6
|
||||
else:
|
||||
rtol = atol = 1e-14
|
||||
d_hard = pywt.threshold(data, thresh, 'hard')
|
||||
d_soft = pywt.threshold(data, thresh, 'soft')
|
||||
d_firm = pywt.threshold_firm(data, thresh, thresh2)
|
||||
|
||||
# check dtypes
|
||||
assert_equal(d_hard.dtype, data.dtype)
|
||||
assert_equal(d_soft.dtype, data.dtype)
|
||||
assert_equal(d_firm.dtype, data.dtype)
|
||||
|
||||
# values < threshold are zero
|
||||
lt = np.where(np.abs(data) < thresh)
|
||||
assert_(np.all(d_firm[lt] == 0))
|
||||
|
||||
# values > than the threshold are equal to hard-thresholding
|
||||
gt = np.where(np.abs(data) >= thresh2)
|
||||
assert_allclose(np.abs(d_hard[gt]), np.abs(d_firm[gt]),
|
||||
rtol=rtol, atol=atol)
|
||||
|
||||
# other values are intermediate between soft and hard thresholding
|
||||
mt = np.where(np.logical_and(np.abs(data) > thresh,
|
||||
np.abs(data) < thresh2))
|
||||
mt_abs_firm = np.abs(d_firm[mt])
|
||||
assert_(np.all(mt_abs_firm < np.abs(d_hard[mt])))
|
||||
assert_(np.all(mt_abs_firm > np.abs(d_soft[mt])))
|
||||
277
.CondaPkg/env/Lib/site-packages/pywt/tests/test_wavelet.py
vendored
Normal file
277
.CondaPkg/env/Lib/site-packages/pywt/tests/test_wavelet.py
vendored
Normal file
@@ -0,0 +1,277 @@
|
||||
#!/usr/bin/env python
|
||||
import os
|
||||
import pickle
|
||||
import numpy as np
|
||||
from numpy.testing import assert_allclose, assert_
|
||||
|
||||
import pywt
|
||||
|
||||
|
||||
def test_wavelet_properties():
|
||||
w = pywt.Wavelet('db3')
|
||||
|
||||
# Name
|
||||
assert_(w.name == 'db3')
|
||||
assert_(w.short_family_name == 'db')
|
||||
assert_(w.family_name, 'Daubechies')
|
||||
|
||||
# String representation
|
||||
fields = ('Family name', 'Short name', 'Filters length', 'Orthogonal',
|
||||
'Biorthogonal', 'Symmetry')
|
||||
for field in fields:
|
||||
assert_(field in str(w))
|
||||
|
||||
# Filter coefficients
|
||||
dec_lo = [0.03522629188210, -0.08544127388224, -0.13501102001039,
|
||||
0.45987750211933, 0.80689150931334, 0.33267055295096]
|
||||
dec_hi = [-0.33267055295096, 0.80689150931334, -0.45987750211933,
|
||||
-0.13501102001039, 0.08544127388224, 0.03522629188210]
|
||||
rec_lo = [0.33267055295096, 0.80689150931334, 0.45987750211933,
|
||||
-0.13501102001039, -0.08544127388224, 0.03522629188210]
|
||||
rec_hi = [0.03522629188210, 0.08544127388224, -0.13501102001039,
|
||||
-0.45987750211933, 0.80689150931334, -0.33267055295096]
|
||||
assert_allclose(w.dec_lo, dec_lo)
|
||||
assert_allclose(w.dec_hi, dec_hi)
|
||||
assert_allclose(w.rec_lo, rec_lo)
|
||||
assert_allclose(w.rec_hi, rec_hi)
|
||||
|
||||
assert_(len(w.filter_bank) == 4)
|
||||
|
||||
# Orthogonality
|
||||
assert_(w.orthogonal)
|
||||
assert_(w.biorthogonal)
|
||||
|
||||
# Symmetry
|
||||
assert_(w.symmetry)
|
||||
|
||||
# Vanishing moments
|
||||
assert_(w.vanishing_moments_phi == 0)
|
||||
assert_(w.vanishing_moments_psi == 3)
|
||||
|
||||
|
||||
def test_wavelet_coefficients():
|
||||
families = ('db', 'sym', 'coif', 'bior', 'rbio')
|
||||
wavelets = sum([pywt.wavelist(name) for name in families], [])
|
||||
for wavelet in wavelets:
|
||||
if (pywt.Wavelet(wavelet).orthogonal):
|
||||
check_coefficients_orthogonal(wavelet)
|
||||
elif(pywt.Wavelet(wavelet).biorthogonal):
|
||||
check_coefficients_biorthogonal(wavelet)
|
||||
else:
|
||||
check_coefficients(wavelet)
|
||||
|
||||
|
||||
def check_coefficients_orthogonal(wavelet):
|
||||
|
||||
epsilon = 5e-11
|
||||
level = 5
|
||||
w = pywt.Wavelet(wavelet)
|
||||
phi, psi, x = w.wavefun(level=level)
|
||||
|
||||
# Lowpass filter coefficients sum to sqrt2
|
||||
res = np.sum(w.dec_lo)-np.sqrt(2)
|
||||
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
||||
assert_(res < epsilon, msg=msg)
|
||||
# sum even coef = sum odd coef = 1 / sqrt(2)
|
||||
res = np.sum(w.dec_lo[::2])-1./np.sqrt(2)
|
||||
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
||||
assert_(res < epsilon, msg=msg)
|
||||
|
||||
res = np.sum(w.dec_lo[1::2])-1./np.sqrt(2)
|
||||
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
||||
assert_(res < epsilon, msg=msg)
|
||||
# Highpass filter coefficients sum to zero
|
||||
res = np.sum(w.dec_hi)
|
||||
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
||||
assert_(res < epsilon, msg=msg)
|
||||
# Scaling function integrates to unity
|
||||
|
||||
res = np.sum(phi) - 2**level
|
||||
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
||||
assert_(res < epsilon, msg=msg)
|
||||
# Wavelet function is orthogonal to the scaling function at the same scale
|
||||
res = np.sum(phi*psi)
|
||||
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
||||
assert_(res < epsilon, msg=msg)
|
||||
# The lowpass and highpass filter coefficients are orthogonal
|
||||
res = np.sum(np.array(w.dec_lo)*np.array(w.dec_hi))
|
||||
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
||||
assert_(res < epsilon, msg=msg)
|
||||
|
||||
|
||||
def check_coefficients_biorthogonal(wavelet):
|
||||
|
||||
epsilon = 5e-11
|
||||
level = 5
|
||||
w = pywt.Wavelet(wavelet)
|
||||
phi_d, psi_d, phi_r, psi_r, x = w.wavefun(level=level)
|
||||
|
||||
# Lowpass filter coefficients sum to sqrt2
|
||||
res = np.sum(w.dec_lo)-np.sqrt(2)
|
||||
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
||||
assert_(res < epsilon, msg=msg)
|
||||
# sum even coef = sum odd coef = 1 / sqrt(2)
|
||||
res = np.sum(w.dec_lo[::2])-1./np.sqrt(2)
|
||||
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
||||
assert_(res < epsilon, msg=msg)
|
||||
res = np.sum(w.dec_lo[1::2])-1./np.sqrt(2)
|
||||
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
||||
assert_(res < epsilon, msg=msg)
|
||||
# Highpass filter coefficients sum to zero
|
||||
res = np.sum(w.dec_hi)
|
||||
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
||||
assert_(res < epsilon, msg=msg)
|
||||
# Scaling function integrates to unity
|
||||
res = np.sum(phi_d) - 2**level
|
||||
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
||||
assert_(res < epsilon, msg=msg)
|
||||
res = np.sum(phi_r) - 2**level
|
||||
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
||||
assert_(res < epsilon, msg=msg)
|
||||
|
||||
|
||||
def check_coefficients(wavelet):
|
||||
epsilon = 5e-11
|
||||
level = 10
|
||||
w = pywt.Wavelet(wavelet)
|
||||
# Lowpass filter coefficients sum to sqrt2
|
||||
res = np.sum(w.dec_lo)-np.sqrt(2)
|
||||
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
||||
assert_(res < epsilon, msg=msg)
|
||||
# sum even coef = sum odd coef = 1 / sqrt(2)
|
||||
res = np.sum(w.dec_lo[::2])-1./np.sqrt(2)
|
||||
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
||||
assert_(res < epsilon, msg=msg)
|
||||
|
||||
res = np.sum(w.dec_lo[1::2])-1./np.sqrt(2)
|
||||
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
||||
assert_(res < epsilon, msg=msg)
|
||||
# Highpass filter coefficients sum to zero
|
||||
res = np.sum(w.dec_hi)
|
||||
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
||||
assert_(res < epsilon, msg=msg)
|
||||
|
||||
|
||||
class _CustomHaarFilterBank(object):
|
||||
@property
|
||||
def filter_bank(self):
|
||||
val = np.sqrt(2) / 2
|
||||
return ([val]*2, [-val, val], [val]*2, [val, -val])
|
||||
|
||||
|
||||
def test_custom_wavelet():
|
||||
haar_custom1 = pywt.Wavelet('Custom Haar Wavelet',
|
||||
filter_bank=_CustomHaarFilterBank())
|
||||
haar_custom1.orthogonal = True
|
||||
haar_custom1.biorthogonal = True
|
||||
|
||||
val = np.sqrt(2) / 2
|
||||
filter_bank = ([val]*2, [-val, val], [val]*2, [val, -val])
|
||||
haar_custom2 = pywt.Wavelet('Custom Haar Wavelet',
|
||||
filter_bank=filter_bank)
|
||||
|
||||
# check expected default wavelet properties
|
||||
assert_(~haar_custom2.orthogonal)
|
||||
assert_(~haar_custom2.biorthogonal)
|
||||
assert_(haar_custom2.symmetry == 'unknown')
|
||||
assert_(haar_custom2.family_name == '')
|
||||
assert_(haar_custom2.short_family_name == '')
|
||||
assert_(haar_custom2.vanishing_moments_phi == 0)
|
||||
assert_(haar_custom2.vanishing_moments_psi == 0)
|
||||
|
||||
# Some properties can be set by the user
|
||||
haar_custom2.orthogonal = True
|
||||
haar_custom2.biorthogonal = True
|
||||
|
||||
|
||||
def test_wavefun_sym3():
|
||||
w = pywt.Wavelet('sym3')
|
||||
# sym3 is an orthogonal wavelet, so 3 outputs from wavefun
|
||||
phi, psi, x = w.wavefun(level=3)
|
||||
assert_(phi.size == 41)
|
||||
assert_(psi.size == 41)
|
||||
assert_(x.size == 41)
|
||||
|
||||
assert_allclose(x, np.linspace(0, 5, num=x.size))
|
||||
phi_expect = np.array([0.00000000e+00, 1.04132926e-01, 2.52574126e-01,
|
||||
3.96525521e-01, 5.70356539e-01, 7.18934305e-01,
|
||||
8.70293448e-01, 1.05363620e+00, 1.24921722e+00,
|
||||
1.15296888e+00, 9.41669683e-01, 7.55875887e-01,
|
||||
4.96118565e-01, 3.28293151e-01, 1.67624969e-01,
|
||||
-7.33690312e-02, -3.35452855e-01, -3.31221131e-01,
|
||||
-2.32061503e-01, -1.66854239e-01, -4.34091324e-02,
|
||||
-2.86152390e-02, -3.63563035e-02, 2.06034491e-02,
|
||||
8.30280254e-02, 7.17779073e-02, 3.85914311e-02,
|
||||
1.47527100e-02, -2.31896077e-02, -1.86122172e-02,
|
||||
-1.56211329e-03, -8.70615088e-04, 3.20760857e-03,
|
||||
2.34142153e-03, -7.73737194e-04, -2.99879354e-04,
|
||||
1.23636238e-04, 0.00000000e+00, 0.00000000e+00,
|
||||
0.00000000e+00, 0.00000000e+00])
|
||||
|
||||
psi_expect = np.array([0.00000000e+00, 1.10265752e-02, 2.67449277e-02,
|
||||
4.19878574e-02, 6.03947231e-02, 7.61275365e-02,
|
||||
9.21548684e-02, 1.11568926e-01, 1.32278887e-01,
|
||||
6.45829680e-02, -3.97635130e-02, -1.38929884e-01,
|
||||
-2.62428322e-01, -3.62246804e-01, -4.62843343e-01,
|
||||
-5.89607507e-01, -7.25363076e-01, -3.36865858e-01,
|
||||
2.67715108e-01, 8.40176767e-01, 1.55574430e+00,
|
||||
1.18688954e+00, 4.20276324e-01, -1.51697311e-01,
|
||||
-9.42076108e-01, -7.93172332e-01, -3.26343710e-01,
|
||||
-1.24552779e-01, 2.12909254e-01, 1.75770320e-01,
|
||||
1.47523075e-02, 8.22192707e-03, -3.02920592e-02,
|
||||
-2.21119497e-02, 7.30703025e-03, 2.83200488e-03,
|
||||
-1.16759765e-03, 0.00000000e+00, 0.00000000e+00,
|
||||
0.00000000e+00, 0.00000000e+00])
|
||||
|
||||
assert_allclose(phi, phi_expect)
|
||||
assert_allclose(psi, psi_expect)
|
||||
|
||||
|
||||
def test_wavefun_bior13():
|
||||
w = pywt.Wavelet('bior1.3')
|
||||
# bior1.3 is not an orthogonal wavelet, so 5 outputs from wavefun
|
||||
phi_d, psi_d, phi_r, psi_r, x = w.wavefun(level=3)
|
||||
for arr in [phi_d, psi_d, phi_r, psi_r]:
|
||||
assert_(arr.size == 40)
|
||||
|
||||
phi_d_expect = np.array([0., -0.00195313, 0.00195313, 0.01757813,
|
||||
0.01367188, 0.00390625, -0.03515625, -0.12890625,
|
||||
-0.15234375, -0.125, -0.09375, -0.0625, 0.03125,
|
||||
0.15234375, 0.37890625, 0.78515625, 0.99609375,
|
||||
1.08203125, 1.13671875, 1.13671875, 1.08203125,
|
||||
0.99609375, 0.78515625, 0.37890625, 0.15234375,
|
||||
0.03125, -0.0625, -0.09375, -0.125, -0.15234375,
|
||||
-0.12890625, -0.03515625, 0.00390625, 0.01367188,
|
||||
0.01757813, 0.00195313, -0.00195313, 0., 0., 0.])
|
||||
phi_r_expect = np.zeros(x.size, dtype=np.float64)
|
||||
phi_r_expect[15:23] = 1
|
||||
|
||||
psi_d_expect = np.array([0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0.015625, -0.015625, -0.140625, -0.109375,
|
||||
-0.03125, 0.28125, 1.03125, 1.21875, 1.125, 0.625,
|
||||
-0.625, -1.125, -1.21875, -1.03125, -0.28125,
|
||||
0.03125, 0.109375, 0.140625, 0.015625, -0.015625,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
|
||||
|
||||
psi_r_expect = np.zeros(x.size, dtype=np.float64)
|
||||
psi_r_expect[7:15] = -0.125
|
||||
psi_r_expect[15:19] = 1
|
||||
psi_r_expect[19:23] = -1
|
||||
psi_r_expect[23:31] = 0.125
|
||||
|
||||
assert_allclose(x, np.linspace(0, 5, x.size, endpoint=False))
|
||||
assert_allclose(phi_d, phi_d_expect, rtol=1e-5, atol=1e-9)
|
||||
assert_allclose(phi_r, phi_r_expect, rtol=1e-10, atol=1e-12)
|
||||
assert_allclose(psi_d, psi_d_expect, rtol=1e-10, atol=1e-12)
|
||||
assert_allclose(psi_r, psi_r_expect, rtol=1e-10, atol=1e-12)
|
||||
|
||||
|
||||
def test_wavelet_pickle(tmpdir):
|
||||
wavelet = pywt.Wavelet('sym4')
|
||||
filename = os.path.join(tmpdir, 'wav.pickle')
|
||||
with open(filename, 'wb') as f:
|
||||
pickle.dump(wavelet, f)
|
||||
with open(filename, 'rb') as f:
|
||||
wavelet2 = pickle.load(f)
|
||||
assert isinstance(wavelet2, pywt.Wavelet)
|
||||
assert wavelet2.name == wavelet.name
|
||||
244
.CondaPkg/env/Lib/site-packages/pywt/tests/test_wp.py
vendored
Normal file
244
.CondaPkg/env/Lib/site-packages/pywt/tests/test_wp.py
vendored
Normal file
@@ -0,0 +1,244 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import os
|
||||
import pickle
|
||||
|
||||
import numpy as np
|
||||
from numpy.testing import (assert_allclose, assert_, assert_raises,
|
||||
assert_equal)
|
||||
|
||||
import pywt
|
||||
|
||||
|
||||
def test_wavelet_packet_structure():
|
||||
x = [1, 2, 3, 4, 5, 6, 7, 8]
|
||||
wp = pywt.WaveletPacket(data=x, wavelet='db1', mode='symmetric')
|
||||
|
||||
assert_(wp.data == [1, 2, 3, 4, 5, 6, 7, 8])
|
||||
assert_(wp.path == '')
|
||||
assert_(wp.level == 0)
|
||||
assert_(wp['ad'].maxlevel == 3)
|
||||
|
||||
|
||||
def test_traversing_wp_tree():
|
||||
x = [1, 2, 3, 4, 5, 6, 7, 8]
|
||||
wp = pywt.WaveletPacket(data=x, wavelet='db1', mode='symmetric')
|
||||
|
||||
assert_(wp.maxlevel == 3)
|
||||
|
||||
# First level
|
||||
assert_allclose(wp['a'].data, np.array([2.12132034356, 4.949747468306,
|
||||
7.778174593052, 10.606601717798]),
|
||||
rtol=1e-12)
|
||||
|
||||
# Second level
|
||||
assert_allclose(wp['aa'].data, np.array([5., 13.]), rtol=1e-12)
|
||||
|
||||
# Third level
|
||||
assert_allclose(wp['aaa'].data, np.array([12.727922061358]), rtol=1e-12)
|
||||
|
||||
|
||||
def test_acess_path():
|
||||
x = [1, 2, 3, 4, 5, 6, 7, 8]
|
||||
wp = pywt.WaveletPacket(data=x, wavelet='db1', mode='symmetric')
|
||||
|
||||
assert_(wp['a'].path == 'a')
|
||||
assert_(wp['aa'].path == 'aa')
|
||||
assert_(wp['aaa'].path == 'aaa')
|
||||
|
||||
# Maximum level reached:
|
||||
assert_raises(IndexError, lambda: wp['aaaa'].path)
|
||||
|
||||
# Wrong path
|
||||
assert_raises(ValueError, lambda: wp['ac'].path)
|
||||
|
||||
|
||||
def test_access_node_attributes():
|
||||
x = [1, 2, 3, 4, 5, 6, 7, 8]
|
||||
wp = pywt.WaveletPacket(data=x, wavelet='db1', mode='symmetric')
|
||||
|
||||
assert_allclose(wp['ad'].data, np.array([-2., -2.]), rtol=1e-12)
|
||||
assert_(wp['ad'].path == 'ad')
|
||||
assert_(wp['ad'].node_name == 'd')
|
||||
assert_(wp['ad'].parent.path == 'a')
|
||||
assert_(wp['ad'].level == 2)
|
||||
assert_(wp['ad'].maxlevel == 3)
|
||||
assert_(wp['ad'].mode == 'symmetric')
|
||||
|
||||
# tuple-based access is also supported
|
||||
node = wp[('a', 'd')]
|
||||
# can access a node's path as either a single string or in tuple form
|
||||
assert_(node.path == 'ad')
|
||||
assert_(node.path_tuple == ('a', 'd'))
|
||||
|
||||
|
||||
def test_collecting_nodes():
|
||||
x = [1, 2, 3, 4, 5, 6, 7, 8]
|
||||
wp = pywt.WaveletPacket(data=x, wavelet='db1', mode='symmetric')
|
||||
|
||||
# All nodes in natural order
|
||||
assert_([node.path for node in wp.get_level(3, 'natural')] ==
|
||||
['aaa', 'aad', 'ada', 'add', 'daa', 'dad', 'dda', 'ddd'])
|
||||
|
||||
# and in frequency order.
|
||||
assert_([node.path for node in wp.get_level(3, 'freq')] ==
|
||||
['aaa', 'aad', 'add', 'ada', 'dda', 'ddd', 'dad', 'daa'])
|
||||
|
||||
assert_raises(ValueError, wp.get_level, 3, 'invalid_order')
|
||||
|
||||
|
||||
def test_reconstructing_data():
|
||||
x = [1, 2, 3, 4, 5, 6, 7, 8]
|
||||
wp = pywt.WaveletPacket(data=x, wavelet='db1', mode='symmetric')
|
||||
|
||||
# Create another Wavelet Packet and feed it with some data.
|
||||
new_wp = pywt.WaveletPacket(data=None, wavelet='db1', mode='symmetric')
|
||||
new_wp['aa'] = wp['aa'].data
|
||||
new_wp['ad'] = [-2., -2.]
|
||||
|
||||
# For convenience, :attr:`Node.data` gets automatically extracted
|
||||
# from the :class:`Node` object:
|
||||
new_wp['d'] = wp['d']
|
||||
|
||||
# Reconstruct data from aa, ad, and d packets.
|
||||
assert_allclose(new_wp.reconstruct(update=False), x, rtol=1e-12)
|
||||
|
||||
# The node's :attr:`~Node.data` will not be updated
|
||||
assert_(new_wp.data is None)
|
||||
|
||||
# When `update` is True:
|
||||
assert_allclose(new_wp.reconstruct(update=True), x, rtol=1e-12)
|
||||
assert_allclose(new_wp.data, np.arange(1, 9), rtol=1e-12)
|
||||
|
||||
assert_([n.path for n in new_wp.get_leaf_nodes(False)] ==
|
||||
['aa', 'ad', 'd'])
|
||||
assert_([n.path for n in new_wp.get_leaf_nodes(True)] ==
|
||||
['aaa', 'aad', 'ada', 'add', 'daa', 'dad', 'dda', 'ddd'])
|
||||
|
||||
|
||||
def test_removing_nodes():
|
||||
x = [1, 2, 3, 4, 5, 6, 7, 8]
|
||||
wp = pywt.WaveletPacket(data=x, wavelet='db1', mode='symmetric')
|
||||
wp.get_level(2)
|
||||
|
||||
dataleafs = [n.data for n in wp.get_leaf_nodes(False)]
|
||||
expected = np.array([[5., 13.], [-2, -2], [-1, -1], [0, 0]])
|
||||
|
||||
for i in range(4):
|
||||
assert_allclose(dataleafs[i], expected[i, :], atol=1e-12)
|
||||
|
||||
node = wp['ad']
|
||||
del(wp['ad'])
|
||||
dataleafs = [n.data for n in wp.get_leaf_nodes(False)]
|
||||
expected = np.array([[5., 13.], [-1, -1], [0, 0]])
|
||||
|
||||
for i in range(3):
|
||||
assert_allclose(dataleafs[i], expected[i, :], atol=1e-12)
|
||||
|
||||
wp.reconstruct()
|
||||
# The reconstruction is:
|
||||
assert_allclose(wp.reconstruct(),
|
||||
np.array([2., 3., 2., 3., 6., 7., 6., 7.]), rtol=1e-12)
|
||||
|
||||
# Restore the data
|
||||
wp['ad'].data = node.data
|
||||
|
||||
dataleafs = [n.data for n in wp.get_leaf_nodes(False)]
|
||||
expected = np.array([[5., 13.], [-2, -2], [-1, -1], [0, 0]])
|
||||
|
||||
for i in range(4):
|
||||
assert_allclose(dataleafs[i], expected[i, :], atol=1e-12)
|
||||
|
||||
assert_allclose(wp.reconstruct(), np.arange(1, 9), rtol=1e-12)
|
||||
|
||||
|
||||
def test_wavelet_packet_dtypes():
|
||||
rstate = np.random.RandomState(0)
|
||||
N = 32
|
||||
for dtype in [np.float32, np.float64, np.complex64, np.complex128]:
|
||||
x = rstate.randn(N).astype(dtype)
|
||||
if np.iscomplexobj(x):
|
||||
x = x + 1j*np.random.randn(N).astype(x.real.dtype)
|
||||
wp = pywt.WaveletPacket(data=x, wavelet='db1', mode='symmetric')
|
||||
# no unnecessary copy made
|
||||
assert_(wp.data is x)
|
||||
|
||||
# assigning to a node should not change supported dtypes
|
||||
wp['d'] = wp['d'].data
|
||||
assert_equal(wp['d'].data.dtype, x.dtype)
|
||||
|
||||
# full decomposition
|
||||
wp.get_level(wp.maxlevel)
|
||||
|
||||
# reconstruction from coefficients should preserve dtype
|
||||
r = wp.reconstruct(False)
|
||||
assert_equal(r.dtype, x.dtype)
|
||||
assert_allclose(r, x, atol=1e-5, rtol=1e-5)
|
||||
|
||||
# first element of the tuple is the input dtype
|
||||
# second element of the tuple is the transform dtype
|
||||
dtype_pairs = [(np.uint8, np.float64),
|
||||
(np.intp, np.float64), ]
|
||||
if hasattr(np, "complex256"):
|
||||
dtype_pairs += [(np.complex256, np.complex128), ]
|
||||
if hasattr(np, "half"):
|
||||
dtype_pairs += [(np.half, np.float32), ]
|
||||
for (dtype, transform_dtype) in dtype_pairs:
|
||||
x = np.arange(N, dtype=dtype)
|
||||
wp = pywt.WaveletPacket(x, wavelet='db1', mode='symmetric')
|
||||
|
||||
# no unnecessary copy made of top-level data
|
||||
assert_(wp.data is x)
|
||||
|
||||
# full decomposition
|
||||
wp.get_level(wp.maxlevel)
|
||||
|
||||
# reconstructed data will have modified dtype
|
||||
r = wp.reconstruct(False)
|
||||
assert_equal(r.dtype, transform_dtype)
|
||||
assert_allclose(r, x.astype(transform_dtype), atol=1e-5, rtol=1e-5)
|
||||
|
||||
|
||||
def test_db3_roundtrip():
|
||||
original = np.arange(512)
|
||||
wp = pywt.WaveletPacket(data=original, wavelet='db3', mode='smooth',
|
||||
maxlevel=3)
|
||||
r = wp.reconstruct()
|
||||
assert_allclose(original, r, atol=1e-12, rtol=1e-12)
|
||||
|
||||
|
||||
def test_wavelet_packet_axis():
|
||||
rstate = np.random.RandomState(0)
|
||||
shape = (32, 16)
|
||||
x = rstate.standard_normal(shape)
|
||||
for axis in [0, 1, -1]:
|
||||
wp = pywt.WaveletPacket(data=x, wavelet='db1', mode='symmetric',
|
||||
axis=axis)
|
||||
|
||||
# partial decomposition
|
||||
nodes = wp.get_level(2)
|
||||
# size along the transformed axis has changed
|
||||
for ax2 in range(x.ndim):
|
||||
if ax2 == (axis % x.ndim):
|
||||
nodes[0].data.shape[ax2] < x.shape[ax2]
|
||||
else:
|
||||
nodes[0].data.shape[ax2] == x.shape[ax2]
|
||||
|
||||
# recontsruction from coefficients should preserve dtype
|
||||
r = wp.reconstruct(False)
|
||||
assert_equal(r.dtype, x.dtype)
|
||||
assert_allclose(r, x, atol=1e-12, rtol=1e-12)
|
||||
|
||||
# ValueError if axis is out of range
|
||||
assert_raises(ValueError, pywt.WaveletPacket, data=x, wavelet='db1',
|
||||
axis=x.ndim)
|
||||
|
||||
|
||||
def test_wavelet_packet_pickle(tmpdir):
|
||||
packet = pywt.WaveletPacket(np.arange(16), 'sym4')
|
||||
filename = os.path.join(tmpdir, 'wp.pickle')
|
||||
with open(filename, 'wb') as f:
|
||||
pickle.dump(packet, f)
|
||||
with open(filename, 'rb') as f:
|
||||
packet2 = pickle.load(f)
|
||||
assert isinstance(packet2, pywt.WaveletPacket)
|
||||
245
.CondaPkg/env/Lib/site-packages/pywt/tests/test_wp2d.py
vendored
Normal file
245
.CondaPkg/env/Lib/site-packages/pywt/tests/test_wp2d.py
vendored
Normal file
@@ -0,0 +1,245 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import os
|
||||
import pickle
|
||||
|
||||
import numpy as np
|
||||
from numpy.testing import (assert_allclose, assert_, assert_raises,
|
||||
assert_equal)
|
||||
|
||||
import pywt
|
||||
|
||||
|
||||
def test_traversing_tree_2d():
|
||||
x = np.array([[1, 2, 3, 4, 5, 6, 7, 8]] * 8, dtype=np.float64)
|
||||
wp = pywt.WaveletPacket2D(data=x, wavelet='db1', mode='symmetric')
|
||||
|
||||
assert_(np.all(wp.data == x))
|
||||
assert_(wp.path == '')
|
||||
assert_(wp.level == 0)
|
||||
assert_(wp.maxlevel == 3)
|
||||
|
||||
assert_allclose(wp['a'].data, np.array([[3., 7., 11., 15.]] * 4),
|
||||
rtol=1e-12)
|
||||
assert_allclose(wp['h'].data, np.zeros((4, 4)), rtol=1e-12, atol=1e-14)
|
||||
assert_allclose(wp['v'].data, -np.ones((4, 4)), rtol=1e-12, atol=1e-14)
|
||||
assert_allclose(wp['d'].data, np.zeros((4, 4)), rtol=1e-12, atol=1e-14)
|
||||
|
||||
assert_allclose(wp['aa'].data, np.array([[10., 26.]] * 2), rtol=1e-12)
|
||||
|
||||
assert_(wp['a']['a'].data is wp['aa'].data)
|
||||
assert_allclose(wp['aaa'].data, np.array([[36.]]), rtol=1e-12)
|
||||
|
||||
assert_raises(IndexError, lambda: wp['aaaa'])
|
||||
assert_raises(ValueError, lambda: wp['f'])
|
||||
|
||||
|
||||
def test_accessing_node_attributes_2d():
|
||||
x = np.array([[1, 2, 3, 4, 5, 6, 7, 8]] * 8, dtype=np.float64)
|
||||
wp = pywt.WaveletPacket2D(data=x, wavelet='db1', mode='symmetric')
|
||||
|
||||
assert_allclose(wp['av'].data, np.zeros((2, 2)) - 4, rtol=1e-12)
|
||||
assert_(wp['av'].path == 'av')
|
||||
assert_(wp['av'].node_name == 'v')
|
||||
assert_(wp['av'].parent.path == 'a')
|
||||
|
||||
assert_allclose(wp['av'].parent.data, np.array([[3., 7., 11., 15.]] * 4),
|
||||
rtol=1e-12)
|
||||
# can also index via a tuple instead of concatenated strings
|
||||
assert_(wp['av'].level == 2)
|
||||
assert_(wp['av'].maxlevel == 3)
|
||||
assert_(wp['av'].mode == 'symmetric')
|
||||
|
||||
# tuple-based access is also supported
|
||||
node = wp[('a', 'v')]
|
||||
# can access a node's path as either a single string or in tuple form
|
||||
assert_(node.path == 'av')
|
||||
assert_(node.path_tuple == ('a', 'v'))
|
||||
|
||||
|
||||
def test_collecting_nodes_2d():
|
||||
x = np.array([[1, 2, 3, 4, 5, 6, 7, 8]] * 8, dtype=np.float64)
|
||||
wp = pywt.WaveletPacket2D(data=x, wavelet='db1', mode='symmetric')
|
||||
|
||||
assert_(len(wp.get_level(0)) == 1)
|
||||
assert_(wp.get_level(0)[0].path == '')
|
||||
|
||||
# First level
|
||||
assert_(len(wp.get_level(1)) == 4)
|
||||
assert_([node.path for node in wp.get_level(1)] == ['a', 'h', 'v', 'd'])
|
||||
|
||||
# Second level
|
||||
assert_(len(wp.get_level(2)) == 16)
|
||||
paths = [node.path for node in wp.get_level(2)]
|
||||
expected_paths = ['aa', 'ah', 'av', 'ad', 'ha', 'hh', 'hv', 'hd', 'va',
|
||||
'vh', 'vv', 'vd', 'da', 'dh', 'dv', 'dd']
|
||||
assert_(paths == expected_paths)
|
||||
|
||||
# Third level.
|
||||
assert_(len(wp.get_level(3)) == 64)
|
||||
paths = [node.path for node in wp.get_level(3)]
|
||||
expected_paths = ['aaa', 'aah', 'aav', 'aad', 'aha', 'ahh', 'ahv', 'ahd',
|
||||
'ava', 'avh', 'avv', 'avd', 'ada', 'adh', 'adv', 'add',
|
||||
'haa', 'hah', 'hav', 'had', 'hha', 'hhh', 'hhv', 'hhd',
|
||||
'hva', 'hvh', 'hvv', 'hvd', 'hda', 'hdh', 'hdv', 'hdd',
|
||||
'vaa', 'vah', 'vav', 'vad', 'vha', 'vhh', 'vhv', 'vhd',
|
||||
'vva', 'vvh', 'vvv', 'vvd', 'vda', 'vdh', 'vdv', 'vdd',
|
||||
'daa', 'dah', 'dav', 'dad', 'dha', 'dhh', 'dhv', 'dhd',
|
||||
'dva', 'dvh', 'dvv', 'dvd', 'dda', 'ddh', 'ddv', 'ddd']
|
||||
|
||||
assert_(paths == expected_paths)
|
||||
|
||||
# test 2D frequency ordering at the first level
|
||||
fnodes = wp.get_level(1, order='freq')
|
||||
assert_(fnodes[0][0].path == 'a')
|
||||
assert_(fnodes[0][1].path == 'v')
|
||||
assert_(fnodes[1][0].path == 'h')
|
||||
assert_(fnodes[1][1].path == 'd')
|
||||
|
||||
# test 2D frequency ordering at the second level
|
||||
fnodes = wp.get_level(2, order='freq')
|
||||
assert_([n.path for n in fnodes[0]] == ['aa', 'av', 'vv', 'va'])
|
||||
assert_([n.path for n in fnodes[1]] == ['ah', 'ad', 'vd', 'vh'])
|
||||
assert_([n.path for n in fnodes[2]] == ['hh', 'hd', 'dd', 'dh'])
|
||||
assert_([n.path for n in fnodes[3]] == ['ha', 'hv', 'dv', 'da'])
|
||||
|
||||
# invalid node collection order
|
||||
assert_raises(ValueError, wp.get_level, 2, 'invalid_order')
|
||||
|
||||
|
||||
def test_data_reconstruction_2d():
|
||||
x = np.array([[1, 2, 3, 4, 5, 6, 7, 8]] * 8, dtype=np.float64)
|
||||
wp = pywt.WaveletPacket2D(data=x, wavelet='db1', mode='symmetric')
|
||||
|
||||
new_wp = pywt.WaveletPacket2D(data=None, wavelet='db1', mode='symmetric')
|
||||
new_wp['vh'] = wp['vh'].data
|
||||
new_wp['vv'] = wp['vh'].data
|
||||
new_wp['vd'] = np.zeros((2, 2), dtype=np.float64)
|
||||
new_wp['a'] = [[3.0, 7.0, 11.0, 15.0]] * 4
|
||||
new_wp['d'] = np.zeros((4, 4), dtype=np.float64)
|
||||
new_wp['h'] = wp['h'] # all zeros
|
||||
|
||||
assert_allclose(new_wp.reconstruct(update=False),
|
||||
np.array([[1.5, 1.5, 3.5, 3.5, 5.5, 5.5, 7.5, 7.5]] * 8),
|
||||
rtol=1e-12)
|
||||
assert_allclose(wp['va'].data, np.zeros((2, 2)) - 2, rtol=1e-12)
|
||||
|
||||
new_wp['va'] = wp['va'].data
|
||||
assert_allclose(new_wp.reconstruct(update=False), x, rtol=1e-12)
|
||||
|
||||
|
||||
def test_data_reconstruction_delete_nodes_2d():
|
||||
x = np.array([[1, 2, 3, 4, 5, 6, 7, 8]] * 8, dtype=np.float64)
|
||||
wp = pywt.WaveletPacket2D(data=x, wavelet='db1', mode='symmetric')
|
||||
|
||||
new_wp = pywt.WaveletPacket2D(data=None, wavelet='db1', mode='symmetric')
|
||||
new_wp['vh'] = wp['vh'].data
|
||||
new_wp['vv'] = wp['vh'].data
|
||||
new_wp['vd'] = np.zeros((2, 2), dtype=np.float64)
|
||||
new_wp['a'] = [[3.0, 7.0, 11.0, 15.0]] * 4
|
||||
new_wp['d'] = np.zeros((4, 4), dtype=np.float64)
|
||||
new_wp['h'] = wp['h'] # all zeros
|
||||
|
||||
assert_allclose(new_wp.reconstruct(update=False),
|
||||
np.array([[1.5, 1.5, 3.5, 3.5, 5.5, 5.5, 7.5, 7.5]] * 8),
|
||||
rtol=1e-12)
|
||||
|
||||
new_wp['va'] = wp['va'].data
|
||||
assert_allclose(new_wp.reconstruct(update=False), x, rtol=1e-12)
|
||||
|
||||
del(new_wp['va'])
|
||||
# TypeError on accessing deleted node
|
||||
assert_raises(TypeError, lambda: new_wp['va'])
|
||||
new_wp['va'] = wp['va'].data
|
||||
assert_(new_wp.data is None)
|
||||
|
||||
assert_allclose(new_wp.reconstruct(update=True), x, rtol=1e-12)
|
||||
assert_allclose(new_wp.data, x, rtol=1e-12)
|
||||
|
||||
# TODO: decompose=True
|
||||
|
||||
|
||||
def test_lazy_evaluation_2D():
|
||||
# Note: internal implementation detail not to be relied on. Testing for
|
||||
# now for backwards compatibility, but this test may be broken in needed.
|
||||
x = np.array([[1, 2, 3, 4, 5, 6, 7, 8]] * 8)
|
||||
wp = pywt.WaveletPacket2D(data=x, wavelet='db1', mode='symmetric')
|
||||
|
||||
assert_(wp.a is None)
|
||||
assert_allclose(wp['a'].data, np.array([[3., 7., 11., 15.]] * 4),
|
||||
rtol=1e-12)
|
||||
assert_allclose(wp.a.data, np.array([[3., 7., 11., 15.]] * 4), rtol=1e-12)
|
||||
assert_allclose(wp.d.data, np.zeros((4, 4)), rtol=1e-12, atol=1e-12)
|
||||
|
||||
|
||||
def test_wavelet_packet_dtypes():
|
||||
shape = (16, 16)
|
||||
for dtype in [np.float32, np.float64, np.complex64, np.complex128]:
|
||||
x = np.random.randn(*shape).astype(dtype)
|
||||
if np.iscomplexobj(x):
|
||||
x = x + 1j*np.random.randn(*shape).astype(x.real.dtype)
|
||||
wp = pywt.WaveletPacket2D(data=x, wavelet='db1', mode='symmetric')
|
||||
# no unnecessary copy made
|
||||
assert_(wp.data is x)
|
||||
|
||||
# assigning to a node should not change supported dtypes
|
||||
wp['d'] = wp['d'].data
|
||||
assert_equal(wp['d'].data.dtype, x.dtype)
|
||||
|
||||
# full decomposition
|
||||
wp.get_level(wp.maxlevel)
|
||||
|
||||
# reconstruction from coefficients should preserve dtype
|
||||
r = wp.reconstruct(False)
|
||||
assert_equal(r.dtype, x.dtype)
|
||||
assert_allclose(r, x, atol=1e-5, rtol=1e-5)
|
||||
|
||||
|
||||
def test_2d_roundtrip():
|
||||
# test case corresponding to PyWavelets issue 447
|
||||
original = pywt.data.camera()
|
||||
wp = pywt.WaveletPacket2D(data=original, wavelet='db3', mode='smooth',
|
||||
maxlevel=3)
|
||||
r = wp.reconstruct()
|
||||
assert_allclose(original, r, atol=1e-12, rtol=1e-12)
|
||||
|
||||
|
||||
def test_wavelet_packet_axes():
|
||||
rstate = np.random.RandomState(0)
|
||||
shape = (32, 16)
|
||||
x = rstate.standard_normal(shape)
|
||||
for axes in [(0, 1), (1, 0), (-2, 1)]:
|
||||
wp = pywt.WaveletPacket2D(data=x, wavelet='db1', mode='symmetric',
|
||||
axes=axes)
|
||||
|
||||
# partial decomposition
|
||||
nodes = wp.get_level(2)
|
||||
# size along the transformed axes has changed
|
||||
for ax2 in range(x.ndim):
|
||||
if ax2 in tuple(np.asarray(axes) % x.ndim):
|
||||
nodes[0].data.shape[ax2] < x.shape[ax2]
|
||||
else:
|
||||
nodes[0].data.shape[ax2] == x.shape[ax2]
|
||||
|
||||
# recontsruction from coefficients should preserve dtype
|
||||
r = wp.reconstruct(False)
|
||||
assert_equal(r.dtype, x.dtype)
|
||||
assert_allclose(r, x, atol=1e-12, rtol=1e-12)
|
||||
|
||||
# must have two non-duplicate axes
|
||||
assert_raises(ValueError, pywt.WaveletPacket2D, data=x, wavelet='db1',
|
||||
axes=(0, 0))
|
||||
assert_raises(ValueError, pywt.WaveletPacket2D, data=x, wavelet='db1',
|
||||
axes=(0, ))
|
||||
assert_raises(ValueError, pywt.WaveletPacket2D, data=x, wavelet='db1',
|
||||
axes=(0, 1, 2))
|
||||
|
||||
|
||||
def test_wavelet_packet2d_pickle(tmpdir):
|
||||
packet = pywt.WaveletPacket2D(np.arange(256).reshape(16, 16), 'sym4')
|
||||
filename = os.path.join(tmpdir, 'wp2d.pickle')
|
||||
with open(filename, 'wb') as f:
|
||||
pickle.dump(packet, f)
|
||||
with open(filename, 'rb') as f:
|
||||
packet2 = pickle.load(f)
|
||||
assert isinstance(packet2, pywt.WaveletPacket2D)
|
||||
171
.CondaPkg/env/Lib/site-packages/pywt/tests/test_wpnd.py
vendored
Normal file
171
.CondaPkg/env/Lib/site-packages/pywt/tests/test_wpnd.py
vendored
Normal file
@@ -0,0 +1,171 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
from __future__ import division, print_function, absolute_import
|
||||
|
||||
from itertools import product
|
||||
from functools import reduce
|
||||
import operator
|
||||
import numpy as np
|
||||
from numpy.testing import (assert_allclose, assert_, assert_raises,
|
||||
assert_equal)
|
||||
|
||||
import pywt
|
||||
|
||||
|
||||
def test_traversing_tree_nd():
|
||||
x = np.array([[1, 2, 3, 4, 5, 6, 7, 8]] * 8, dtype=np.float64)
|
||||
wp = pywt.WaveletPacketND(data=x, wavelet='db1', mode='symmetric')
|
||||
|
||||
assert_(np.all(wp.data == x))
|
||||
assert_(wp.path == '')
|
||||
assert_(wp.level == 0)
|
||||
assert_(wp.maxlevel == 3)
|
||||
|
||||
assert_allclose(wp['aa'].data, np.array([[3., 7., 11., 15.]] * 4),
|
||||
rtol=1e-12)
|
||||
assert_allclose(wp['da'].data, np.zeros((4, 4)), rtol=1e-12, atol=1e-14)
|
||||
assert_allclose(wp['ad'].data, -np.ones((4, 4)), rtol=1e-12, atol=1e-14)
|
||||
assert_allclose(wp['dd'].data, np.zeros((4, 4)), rtol=1e-12, atol=1e-14)
|
||||
|
||||
assert_allclose(wp['aa'*2].data, np.array([[10., 26.]] * 2), rtol=1e-12)
|
||||
# __getitem__ using a tuple access instead
|
||||
assert_allclose(wp[('aa', 'aa')].data, np.array([[10., 26.]] * 2),
|
||||
rtol=1e-12)
|
||||
|
||||
assert_(wp['aa']['aa'].data is wp['aa'*2].data)
|
||||
assert_allclose(wp['aa'*3].data, np.array([[36.]]), rtol=1e-12)
|
||||
|
||||
assert_raises(IndexError, lambda: wp['aa'*(wp.maxlevel+1)])
|
||||
assert_raises(ValueError, lambda: wp['f'])
|
||||
|
||||
# getitem input must be a string or tuple of strings
|
||||
assert_raises(TypeError, wp.__getitem__, (5, 3))
|
||||
assert_raises(TypeError, wp.__getitem__, 5)
|
||||
|
||||
|
||||
def test_accessing_node_attributes_nd():
|
||||
x = np.array([[1, 2, 3, 4, 5, 6, 7, 8]] * 8, dtype=np.float64)
|
||||
wp = pywt.WaveletPacketND(data=x, wavelet='db1', mode='symmetric')
|
||||
|
||||
assert_allclose(wp['aa'+'ad'].data, np.zeros((2, 2)) - 4, rtol=1e-12)
|
||||
assert_(wp['aa'+'ad'].path == 'aa'+'ad')
|
||||
assert_(wp['aa'+'ad'].node_name == 'ad')
|
||||
assert_(wp['aa'+'ad'].parent.path == 'aa')
|
||||
|
||||
assert_allclose(wp['aa'+'ad'].parent.data,
|
||||
np.array([[3., 7., 11., 15.]] * 4), rtol=1e-12)
|
||||
# can also index via a tuple instead of concatenated strings
|
||||
assert_(wp[('aa', 'ad')].level == 2)
|
||||
assert_(wp[('aa', 'ad')].maxlevel == 3)
|
||||
assert_(wp[('aa', 'ad')].mode == 'symmetric')
|
||||
|
||||
# can access a node's path as either a single string or in tuple form
|
||||
node = wp[('ad', 'dd')]
|
||||
assert_(node.path == 'addd')
|
||||
assert_(node.path_tuple == ('ad', 'dd'))
|
||||
|
||||
|
||||
def test_collecting_nodes_nd():
|
||||
x = np.array([[1, 2, 3, 4, 5, 6, 7, 8]] * 8, dtype=np.float64)
|
||||
wp = pywt.WaveletPacketND(data=x, wavelet='db1', mode='symmetric')
|
||||
|
||||
assert_(len(wp.get_level(0)) == 1)
|
||||
assert_(wp.get_level(0)[0].path == '')
|
||||
|
||||
# First level
|
||||
assert_(len(wp.get_level(1)) == 4)
|
||||
assert_(
|
||||
[node.path for node in wp.get_level(1)] == ['aa', 'ad', 'da', 'dd'])
|
||||
|
||||
# Second and third levels
|
||||
for lev in [2, 3]:
|
||||
assert_(len(wp.get_level(lev)) == (2**x.ndim)**lev)
|
||||
paths = [node.path for node in wp.get_level(lev)]
|
||||
expected_paths = [
|
||||
reduce(operator.add, p) for
|
||||
p in sorted(product(['aa', 'ad', 'da', 'dd'], repeat=lev))]
|
||||
assert_(paths == expected_paths)
|
||||
|
||||
|
||||
def test_data_reconstruction_delete_nodes_nd():
|
||||
x = np.array([[1, 2, 3, 4, 5, 6, 7, 8]] * 8, dtype=np.float64)
|
||||
wp = pywt.WaveletPacketND(data=x, wavelet='db1', mode='symmetric')
|
||||
|
||||
# The user must supply either data or axes
|
||||
assert_raises(ValueError, pywt.WaveletPacketND, data=None, wavelet='db1',
|
||||
axes=None)
|
||||
|
||||
new_wp = pywt.WaveletPacketND(data=None, wavelet='db1', mode='symmetric',
|
||||
axes=range(x.ndim))
|
||||
|
||||
new_wp['ad'+'da'] = wp['ad'+'da'].data
|
||||
new_wp['ad'*2] = wp['ad'+'da'].data
|
||||
new_wp['ad'+'dd'] = np.zeros((2, 2), dtype=np.float64)
|
||||
new_wp['aa'] = [[3.0, 7.0, 11.0, 15.0]] * 4
|
||||
new_wp['dd'] = np.zeros((4, 4), dtype=np.float64)
|
||||
new_wp['da'] = wp['da'] # all zeros
|
||||
|
||||
assert_allclose(new_wp.reconstruct(update=False),
|
||||
np.array([[1.5, 1.5, 3.5, 3.5, 5.5, 5.5, 7.5, 7.5]] * 8),
|
||||
rtol=1e-12)
|
||||
|
||||
new_wp['ad'+'aa'] = wp['ad'+'aa'].data
|
||||
assert_allclose(new_wp.reconstruct(update=False), x, rtol=1e-12)
|
||||
|
||||
del(new_wp['ad'+'aa'])
|
||||
# TypeError on accessing deleted node
|
||||
assert_raises(TypeError, lambda: new_wp['ad'+'aa'])
|
||||
|
||||
new_wp['ad'+'aa'] = wp['ad'+'aa'].data
|
||||
assert_(new_wp.data is None)
|
||||
|
||||
assert_allclose(new_wp.reconstruct(update=True), x, rtol=1e-12)
|
||||
assert_allclose(new_wp.data, x, rtol=1e-12)
|
||||
|
||||
# TODO: decompose=True
|
||||
|
||||
|
||||
def test_wavelet_packet_dtypes():
|
||||
shape = (16, 8, 8)
|
||||
for dtype in [np.float32, np.float64, np.complex64, np.complex128]:
|
||||
x = np.random.randn(*shape).astype(dtype)
|
||||
if np.iscomplexobj(x):
|
||||
x = x + 1j*np.random.randn(*shape).astype(x.real.dtype)
|
||||
wp = pywt.WaveletPacketND(data=x, wavelet='db1', mode='symmetric')
|
||||
# no unnecessary copy made
|
||||
assert_(wp.data is x)
|
||||
|
||||
# full decomposition
|
||||
wp.get_level(wp.maxlevel)
|
||||
|
||||
# reconstruction from coefficients should preserve dtype
|
||||
r = wp.reconstruct(False)
|
||||
assert_equal(r.dtype, x.dtype)
|
||||
assert_allclose(r, x, atol=1e-6, rtol=1e-6)
|
||||
|
||||
|
||||
def test_wavelet_packet_axes():
|
||||
rstate = np.random.RandomState(0)
|
||||
shape = (32, 16, 8)
|
||||
x = rstate.standard_normal(shape)
|
||||
for axes in [(0, 1), 1, (-3, -2, -1), (0, 2), (1, )]:
|
||||
wp = pywt.WaveletPacketND(data=x, wavelet='db1', mode='symmetric',
|
||||
axes=axes)
|
||||
|
||||
# partial decomposition
|
||||
nodes = wp.get_level(1)
|
||||
# size along the transformed axes has changed
|
||||
for ax2 in range(x.ndim):
|
||||
if ax2 in tuple(np.atleast_1d(axes) % x.ndim):
|
||||
nodes[0].data.shape[ax2] < x.shape[ax2]
|
||||
else:
|
||||
nodes[0].data.shape[ax2] == x.shape[ax2]
|
||||
|
||||
# recontsruction from coefficients should preserve dtype
|
||||
r = wp.reconstruct(False)
|
||||
assert_equal(r.dtype, x.dtype)
|
||||
assert_allclose(r, x, atol=1e-12, rtol=1e-12)
|
||||
|
||||
# must have non-duplicate axes
|
||||
assert_raises(ValueError, pywt.WaveletPacketND, data=x, wavelet='db1',
|
||||
axes=(0, 0))
|
||||
Reference in New Issue
Block a user