Skip to content

Commit b3ddb2c

Browse files
endolithlarsoner
authored andcommitted
TST: Improve coverage of convolve, etc. (scipy#8413)
* DOC: Harmonize 'valid' docs / PEP8 changes * BUG: Don't pass invalid modes to numpy.convolve Change _np_conv_ok it so it only passes to numpy if `full` or `valid`. If mode is invalid it will be rejected elsewhere. * TST: Validate method arg of convolve/correlate scipy#7211 (review) "Something I noticed that could be a candidate for another PR is that convolve doesn't appear to check the validity of the mode and method arguments." Invalid modes were already covered, invalid methods were not. * TST: correlate's numpy fastpath * TST: Mismatched dimensions conv/correlate funcs convolve/correlate/fftconvolve/convolve2d * TST: Invalid fftconvolve mode * TST: Basic deconvolve test * TST: Make class for tests that don't care about dt Also PEP8 changes * BUG: Catch all mismatched dimensions in convolve
1 parent 386542c commit b3ddb2c

File tree

2 files changed

+147
-54
lines changed

2 files changed

+147
-54
lines changed

scipy/signal/signaltools.py

Lines changed: 65 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,8 @@ def correlate(in1, in2, mode='full', method='auto'):
165165
166166
z[...,k,...] = sum[..., i_l, ...] x[..., i_l,...] * conj(y[..., i_l - k,...])
167167
168-
This way, if x and y are 1-D arrays and ``z = correlate(x, y, 'full')`` then
168+
This way, if x and y are 1-D arrays and ``z = correlate(x, y, 'full')``
169+
then
169170
170171
.. math::
171172
@@ -227,46 +228,51 @@ def correlate(in1, in2, mode='full', method='auto'):
227228
if method in ('fft', 'auto'):
228229
return convolve(in1, _reverse_and_conj(in2), mode, method)
229230

230-
# fastpath to faster numpy.correlate for 1d inputs when possible
231-
if _np_conv_ok(in1, in2, mode):
232-
return np.correlate(in1, in2, mode)
231+
elif method == 'direct':
232+
# fastpath to faster numpy.correlate for 1d inputs when possible
233+
if _np_conv_ok(in1, in2, mode):
234+
return np.correlate(in1, in2, mode)
233235

234-
# _correlateND is far slower when in2.size > in1.size, so swap them
235-
# and then undo the effect afterward if mode == 'full'. Also, it fails
236-
# with 'valid' mode if in2 is larger than in1, so swap those, too.
237-
# Don't swap inputs for 'same' mode, since shape of in1 matters.
238-
swapped_inputs = ((mode == 'full') and (in2.size > in1.size) or
239-
_inputs_swap_needed(mode, in1.shape, in2.shape))
236+
# _correlateND is far slower when in2.size > in1.size, so swap them
237+
# and then undo the effect afterward if mode == 'full'. Also, it fails
238+
# with 'valid' mode if in2 is larger than in1, so swap those, too.
239+
# Don't swap inputs for 'same' mode, since shape of in1 matters.
240+
swapped_inputs = ((mode == 'full') and (in2.size > in1.size) or
241+
_inputs_swap_needed(mode, in1.shape, in2.shape))
240242

241-
if swapped_inputs:
242-
in1, in2 = in2, in1
243+
if swapped_inputs:
244+
in1, in2 = in2, in1
243245

244-
if mode == 'valid':
245-
ps = [i - j + 1 for i, j in zip(in1.shape, in2.shape)]
246-
out = np.empty(ps, in1.dtype)
246+
if mode == 'valid':
247+
ps = [i - j + 1 for i, j in zip(in1.shape, in2.shape)]
248+
out = np.empty(ps, in1.dtype)
247249

248-
z = sigtools._correlateND(in1, in2, out, val)
250+
z = sigtools._correlateND(in1, in2, out, val)
249251

250-
else:
251-
ps = [i + j - 1 for i, j in zip(in1.shape, in2.shape)]
252+
else:
253+
ps = [i + j - 1 for i, j in zip(in1.shape, in2.shape)]
252254

253-
# zero pad input
254-
in1zpadded = np.zeros(ps, in1.dtype)
255-
sc = [slice(0, i) for i in in1.shape]
256-
in1zpadded[sc] = in1.copy()
255+
# zero pad input
256+
in1zpadded = np.zeros(ps, in1.dtype)
257+
sc = [slice(0, i) for i in in1.shape]
258+
in1zpadded[sc] = in1.copy()
257259

258-
if mode == 'full':
259-
out = np.empty(ps, in1.dtype)
260-
elif mode == 'same':
261-
out = np.empty(in1.shape, in1.dtype)
260+
if mode == 'full':
261+
out = np.empty(ps, in1.dtype)
262+
elif mode == 'same':
263+
out = np.empty(in1.shape, in1.dtype)
262264

263-
z = sigtools._correlateND(in1zpadded, in2, out, val)
265+
z = sigtools._correlateND(in1zpadded, in2, out, val)
264266

265-
if swapped_inputs:
266-
# Reverse and conjugate to undo the effect of swapping inputs
267-
z = _reverse_and_conj(z)
267+
if swapped_inputs:
268+
# Reverse and conjugate to undo the effect of swapping inputs
269+
z = _reverse_and_conj(z)
270+
271+
return z
268272

269-
return z
273+
else:
274+
raise ValueError("Acceptable method flags are 'auto',"
275+
" 'direct', or 'fft'.")
270276

271277

272278
def _centered(arr, newshape):
@@ -298,8 +304,6 @@ def fftconvolve(in1, in2, mode="full"):
298304
First input.
299305
in2 : array_like
300306
Second input. Should have the same number of dimensions as `in1`.
301-
If operating in 'valid' mode, either `in1` or `in2` must be
302-
at least as large as the other in every dimension.
303307
mode : str {'full', 'valid', 'same'}, optional
304308
A string indicating the size of the output:
305309
@@ -308,7 +312,8 @@ def fftconvolve(in1, in2, mode="full"):
308312
of the inputs. (Default)
309313
``valid``
310314
The output consists only of those elements that do not
311-
rely on the zero-padding.
315+
rely on the zero-padding. In 'valid' mode, either `in1` or `in2`
316+
must be at least as large as the other in every dimension.
312317
``same``
313318
The output is the same size as `in1`, centered
314319
with respect to the 'full' output.
@@ -474,7 +479,8 @@ def _fftconv_faster(x, h, mode):
474479
out_shape = [n - k + 1 for n, k in zip(x.shape, h.shape)]
475480
big_O_constant = 41954.28006344 if x.ndim == 1 else 66453.24316434
476481
else:
477-
raise ValueError('mode is invalid')
482+
raise ValueError("Acceptable mode flags are 'valid',"
483+
" 'same', or 'full'.")
478484

479485
# see whether the Fourier transform convolution method or the direct
480486
# convolution method is faster (discussed in scikit-image PR #1792)
@@ -497,9 +503,16 @@ def _np_conv_ok(volume, kernel, mode):
497503
See if numpy supports convolution of `volume` and `kernel` (i.e. both are
498504
1D ndarrays and of the appropriate shape). Numpy's 'same' mode uses the
499505
size of the larger input, while Scipy's uses the size of the first input.
506+
507+
Invalid mode strings will return False and be caught by the calling func.
500508
"""
501-
np_conv_ok = volume.ndim == kernel.ndim == 1
502-
return np_conv_ok and (volume.size >= kernel.size or mode != 'same')
509+
if volume.ndim == kernel.ndim == 1:
510+
if mode in ('full', 'valid'):
511+
return True
512+
elif mode == 'same':
513+
return volume.size >= kernel.size
514+
else:
515+
return False
503516

504517

505518
def _timeit_fast(stmt="pass", setup="pass", repeat=3):
@@ -755,6 +768,9 @@ def convolve(in1, in2, mode='full', method='auto'):
755768

756769
if volume.ndim == kernel.ndim == 0:
757770
return volume * kernel
771+
elif volume.ndim != kernel.ndim:
772+
raise ValueError("volume and kernel should have the same "
773+
"dimensionality")
758774

759775
if _inputs_swap_needed(mode, volume.shape, kernel.shape):
760776
# Convolution is commutative; order doesn't have any effect on output
@@ -769,12 +785,15 @@ def convolve(in1, in2, mode='full', method='auto'):
769785
if result_type.kind in {'u', 'i'}:
770786
out = np.around(out)
771787
return out.astype(result_type)
788+
elif method == 'direct':
789+
# fastpath to faster numpy.convolve for 1d inputs when possible
790+
if _np_conv_ok(volume, kernel, mode):
791+
return np.convolve(volume, kernel, mode)
772792

773-
# fastpath to faster numpy.convolve for 1d inputs when possible
774-
if _np_conv_ok(volume, kernel, mode):
775-
return np.convolve(volume, kernel, mode)
776-
777-
return correlate(volume, _reverse_and_conj(kernel), mode, 'direct')
793+
return correlate(volume, _reverse_and_conj(kernel), mode, 'direct')
794+
else:
795+
raise ValueError("Acceptable method flags are 'auto',"
796+
" 'direct', or 'fft'.")
778797

779798

780799
def order_filter(a, domain, rank):
@@ -945,8 +964,6 @@ def convolve2d(in1, in2, mode='full', boundary='fill', fillvalue=0):
945964
First input.
946965
in2 : array_like
947966
Second input. Should have the same number of dimensions as `in1`.
948-
If operating in 'valid' mode, either `in1` or `in2` must be
949-
at least as large as the other in every dimension.
950967
mode : str {'full', 'valid', 'same'}, optional
951968
A string indicating the size of the output:
952969
@@ -955,11 +972,11 @@ def convolve2d(in1, in2, mode='full', boundary='fill', fillvalue=0):
955972
of the inputs. (Default)
956973
``valid``
957974
The output consists only of those elements that do not
958-
rely on the zero-padding.
975+
rely on the zero-padding. In 'valid' mode, either `in1` or `in2`
976+
must be at least as large as the other in every dimension.
959977
``same``
960978
The output is the same size as `in1`, centered
961979
with respect to the 'full' output.
962-
963980
boundary : str {'fill', 'wrap', 'symm'}, optional
964981
A flag indicating how to handle boundaries:
965982
@@ -1036,8 +1053,6 @@ def correlate2d(in1, in2, mode='full', boundary='fill', fillvalue=0):
10361053
First input.
10371054
in2 : array_like
10381055
Second input. Should have the same number of dimensions as `in1`.
1039-
If operating in 'valid' mode, either `in1` or `in2` must be
1040-
at least as large as the other in every dimension.
10411056
mode : str {'full', 'valid', 'same'}, optional
10421057
A string indicating the size of the output:
10431058
@@ -1046,11 +1061,11 @@ def correlate2d(in1, in2, mode='full', boundary='fill', fillvalue=0):
10461061
of the inputs. (Default)
10471062
``valid``
10481063
The output consists only of those elements that do not
1049-
rely on the zero-padding.
1064+
rely on the zero-padding. In 'valid' mode, either `in1` or `in2`
1065+
must be at least as large as the other in every dimension.
10501066
``same``
10511067
The output is the same size as `in1`, centered
10521068
with respect to the 'full' output.
1053-
10541069
boundary : str {'fill', 'wrap', 'symm'}, optional
10551070
A flag indicating how to handle boundaries:
10561071

scipy/signal/tests/test_signaltools.py

Lines changed: 82 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,16 @@ def test_input_swapping(self):
119119
assert_array_equal(convolve(big, small, 'valid'),
120120
out_array[1:3, 1:3, 1:3])
121121

122+
def test_invalid_params(self):
123+
a = [3, 4, 5]
124+
b = [1, 2, 3]
125+
assert_raises(ValueError, convolve, a, b, mode='spam')
126+
assert_raises(ValueError, convolve, a, b, mode='eggs', method='fft')
127+
assert_raises(ValueError, convolve, a, b, mode='ham', method='direct')
128+
assert_raises(ValueError, convolve, a, b, mode='full', method='bacon')
129+
assert_raises(ValueError, convolve, a, b, mode='same', method='bacon')
130+
131+
122132
class TestConvolve(_TestConvolve):
123133

124134
def test_valid_mode2(self):
@@ -225,6 +235,15 @@ def test_convolve_method_large_input(self):
225235
assert_equal(fft, 2**(2*n))
226236
assert_equal(direct, 2**(2*n))
227237

238+
def test_mismatched_dims(self):
239+
# Input arrays should have the same number of dimensions
240+
assert_raises(ValueError, convolve, [1], 2, method='direct')
241+
assert_raises(ValueError, convolve, 1, [2], method='direct')
242+
assert_raises(ValueError, convolve, [1], 2, method='fft')
243+
assert_raises(ValueError, convolve, 1, [2], method='fft')
244+
assert_raises(ValueError, convolve, [1], [[2]])
245+
assert_raises(ValueError, convolve, [3], 2)
246+
228247

229248
class _TestConvolve2d(object):
230249

@@ -376,6 +395,11 @@ def test_consistency_convolve_funcs(self):
376395
signal.convolve2d([a], [b], mode=mode)),
377396
signal.convolve(a, b, mode=mode))
378397

398+
def test_invalid_dims(self):
399+
assert_raises(ValueError, convolve2d, 3, 4)
400+
assert_raises(ValueError, convolve2d, [3], [4])
401+
assert_raises(ValueError, convolve2d, [[[3]]], [[[4]]])
402+
379403

380404
class TestFFTConvolve(object):
381405

@@ -509,6 +533,15 @@ def test_invalid_shapes(self):
509533
assert_raises(ValueError, fftconvolve, *(a, b), **{'mode': 'valid'})
510534
assert_raises(ValueError, fftconvolve, *(b, a), **{'mode': 'valid'})
511535

536+
def test_mismatched_dims(self):
537+
assert_raises(ValueError, fftconvolve, [1], 2)
538+
assert_raises(ValueError, fftconvolve, 1, [2])
539+
assert_raises(ValueError, fftconvolve, [1], [[2]])
540+
assert_raises(ValueError, fftconvolve, [3], 2)
541+
542+
def test_invalid_flags(self):
543+
assert_raises(ValueError, fftconvolve, [1], [2], mode='chips')
544+
512545

513546
class TestMedFilt(object):
514547

@@ -1204,9 +1237,10 @@ def test_lfilter_notimplemented_input():
12041237
assert_raises(NotImplementedError, lfilter, [2,3], [4,5], [1,2,3,4,5])
12051238

12061239

1207-
@pytest.mark.parametrize('dt', [np.ubyte, np.byte, np.ushort, np.short, np.uint, int,
1208-
np.ulonglong, np.ulonglong, np.float32, np.float64,
1209-
np.longdouble, Decimal])
1240+
@pytest.mark.parametrize('dt', [np.ubyte, np.byte, np.ushort, np.short,
1241+
np.uint, int, np.ulonglong, np.ulonglong,
1242+
np.float32, np.float64, np.longdouble,
1243+
Decimal])
12101244
class TestCorrelateReal(object):
12111245
def _setup_rank1(self, dt):
12121246
a = np.linspace(0, 3, 4).astype(dt)
@@ -1316,7 +1350,11 @@ def test_rank3_all(self, dt):
13161350
assert_array_almost_equal(y, y_r)
13171351
assert_equal(y.dtype, dt)
13181352

1319-
def test_invalid_shapes(self, dt):
1353+
1354+
class TestCorrelate(object):
1355+
# Tests that don't depend on dtype
1356+
1357+
def test_invalid_shapes(self):
13201358
# By "invalid," we mean that no one
13211359
# array has dimensions that are all at
13221360
# least as large as the corresponding
@@ -1328,6 +1366,35 @@ def test_invalid_shapes(self, dt):
13281366
assert_raises(ValueError, correlate, *(a, b), **{'mode': 'valid'})
13291367
assert_raises(ValueError, correlate, *(b, a), **{'mode': 'valid'})
13301368

1369+
def test_invalid_params(self):
1370+
a = [3, 4, 5]
1371+
b = [1, 2, 3]
1372+
assert_raises(ValueError, correlate, a, b, mode='spam')
1373+
assert_raises(ValueError, correlate, a, b, mode='eggs', method='fft')
1374+
assert_raises(ValueError, correlate, a, b, mode='ham', method='direct')
1375+
assert_raises(ValueError, correlate, a, b, mode='full', method='bacon')
1376+
assert_raises(ValueError, correlate, a, b, mode='same', method='bacon')
1377+
1378+
def test_mismatched_dims(self):
1379+
# Input arrays should have the same number of dimensions
1380+
assert_raises(ValueError, correlate, [1], 2, method='direct')
1381+
assert_raises(ValueError, correlate, 1, [2], method='direct')
1382+
assert_raises(ValueError, correlate, [1], 2, method='fft')
1383+
assert_raises(ValueError, correlate, 1, [2], method='fft')
1384+
assert_raises(ValueError, correlate, [1], [[2]])
1385+
assert_raises(ValueError, correlate, [3], 2)
1386+
1387+
def test_numpy_fastpath(self):
1388+
a = [1, 2, 3]
1389+
b = [4, 5]
1390+
assert_allclose(correlate(a, b, mode='same'), [5, 14, 23])
1391+
1392+
a = [1, 2, 3]
1393+
b = [4, 5, 6]
1394+
assert_allclose(correlate(a, b, mode='same'), [17, 32, 23])
1395+
assert_allclose(correlate(a, b, mode='full'), [6, 17, 32, 23, 12])
1396+
assert_allclose(correlate(a, b, mode='valid'), [32])
1397+
13311398

13321399
@pytest.mark.parametrize('dt', [np.csingle, np.cdouble, np.clongdouble])
13331400
class TestCorrelateComplex(object):
@@ -2304,3 +2371,14 @@ def test_sosfilt_zi(self):
23042371
# Expected steady state value of the step response of this filter:
23052372
ss = np.prod(sos[:, :3].sum(axis=-1) / sos[:, 3:].sum(axis=-1))
23062373
assert_allclose(y, ss, rtol=1e-13)
2374+
2375+
2376+
class TestDeconvolve(object):
2377+
2378+
def test_basic(self):
2379+
# From docstring example
2380+
original = [0, 1, 0, 0, 1, 1, 0, 0]
2381+
impulse_response = [2, 1]
2382+
recorded = [0, 2, 1, 0, 2, 3, 1, 0, 0]
2383+
recovered, remainder = signal.deconvolve(recorded, impulse_response)
2384+
assert_allclose(recovered, original)

0 commit comments

Comments
 (0)