Skip to content

Commit 1188a6d

Browse files
Apply review comments
1 parent 9529628 commit 1188a6d

File tree

2 files changed

+24
-22
lines changed

2 files changed

+24
-22
lines changed

dpnp/dpnp_iface_statistics.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ def _get_padding(a_size, v_size, mode):
447447
r_pad = v_size - l_pad - 1
448448
elif mode == "full":
449449
l_pad, r_pad = v_size - 1, v_size - 1
450-
else:
450+
else: # pragma: no cover
451451
raise ValueError(
452452
f"Unknown mode: {mode}. Only 'valid', 'same', 'full' are supported."
453453
)
@@ -458,9 +458,11 @@ def _get_padding(a_size, v_size, mode):
458458
def _choose_conv_method(a, v, rdtype):
459459
assert a.size >= v.size
460460
if rdtype == dpnp.bool:
461+
# to avoid accuracy issues
461462
return "direct"
462463

463464
if v.size < 10**4 or a.size < 10**4:
465+
# direct method is faster for small arrays
464466
return "direct"
465467

466468
if dpnp.issubdtype(rdtype, dpnp.integer):
@@ -470,12 +472,13 @@ def _choose_conv_method(a, v, rdtype):
470472

471473
default_float = dpnp.default_float_type(a.sycl_device)
472474
if max_value > 2 ** numpy.finfo(default_float).nmant - 1:
475+
# can't represent the result in the default float type
473476
return "direct"
474477

475478
if dpnp.issubdtype(rdtype, dpnp.number):
476479
return "fft"
477480

478-
raise ValueError(f"Unsupported dtype: {rdtype}")
481+
raise ValueError(f"Unsupported dtype: {rdtype}") # pragma: no cover
479482

480483

481484
def _run_native_sliding_dot_product1d(a, v, l_pad, r_pad, rdtype):
@@ -485,7 +488,7 @@ def _run_native_sliding_dot_product1d(a, v, l_pad, r_pad, rdtype):
485488
supported_types = statistics_ext.sliding_dot_product1d_dtypes()
486489
supported_dtype = to_supported_dtypes(rdtype, supported_types, device)
487490

488-
if supported_dtype is None:
491+
if supported_dtype is None: # pragma: no cover
489492
raise ValueError(
490493
f"function does not support input types "
491494
f"({a.dtype.name}, {v.dtype.name}), "
@@ -676,7 +679,7 @@ def correlate(a, v, mode="valid", method="auto"):
676679
r = _run_native_sliding_dot_product1d(a, v, l_pad, r_pad, rdtype)
677680
elif method == "fft":
678681
r = _convolve_fft(a, v[::-1], l_pad, r_pad, rdtype)
679-
else:
682+
else: # pragma: no cover
680683
raise ValueError(f"Unknown method: {method}")
681684

682685
if revert:

dpnp/tests/test_statistics.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -583,6 +583,17 @@ def test_corrcoef_scalar(self):
583583

584584

585585
class TestCorrelate:
586+
@staticmethod
587+
def _get_kwargs(mode=None, method=None):
588+
dpnp_kwargs = {}
589+
numpy_kwargs = {}
590+
if mode is not None:
591+
dpnp_kwargs["mode"] = mode
592+
numpy_kwargs["mode"] = mode
593+
if method is not None:
594+
dpnp_kwargs["method"] = method
595+
return dpnp_kwargs, numpy_kwargs
596+
586597
def setup_method(self):
587598
numpy.random.seed(0)
588599

@@ -598,13 +609,7 @@ def test_correlate(self, a, v, mode, dtype, method):
598609
ad = dpnp.array(an)
599610
vd = dpnp.array(vn)
600611

601-
dpnp_kwargs = {}
602-
numpy_kwargs = {}
603-
if mode is not None:
604-
dpnp_kwargs["mode"] = mode
605-
numpy_kwargs["mode"] = mode
606-
if method is not None:
607-
dpnp_kwargs["method"] = method
612+
dpnp_kwargs, numpy_kwargs = self._get_kwargs(mode, method)
608613

609614
expected = numpy.correlate(an, vn, **numpy_kwargs)
610615
result = dpnp.correlate(ad, vd, **dpnp_kwargs)
@@ -621,23 +626,17 @@ def test_correlate_random(self, a_size, v_size, mode, dtype, method):
621626
an = numpy.random.rand(a_size) > 0.9
622627
vn = numpy.random.rand(v_size) > 0.9
623628
else:
624-
an = (100 * numpy.random.rand(a_size)).astype(dtype)
625-
vn = (100 * numpy.random.rand(v_size)).astype(dtype)
629+
an = 100 * numpy.random.rand(a_size).astype(dtype)
630+
vn = 100 * numpy.random.rand(v_size).astype(dtype)
626631

627632
if dpnp.issubdtype(dtype, dpnp.complexfloating):
628-
an = an + 1j * (100 * numpy.random.rand(a_size)).astype(dtype)
629-
vn = vn + 1j * (100 * numpy.random.rand(v_size)).astype(dtype)
633+
an = an + 100j * numpy.random.rand(a_size).astype(dtype)
634+
vn = vn + 100j * numpy.random.rand(v_size).astype(dtype)
630635

631636
ad = dpnp.array(an)
632637
vd = dpnp.array(vn)
633638

634-
dpnp_kwargs = {}
635-
numpy_kwargs = {}
636-
if mode is not None:
637-
dpnp_kwargs["mode"] = mode
638-
numpy_kwargs["mode"] = mode
639-
if method is not None:
640-
dpnp_kwargs["method"] = method
639+
dpnp_kwargs, numpy_kwargs = self._get_kwargs(mode, method)
641640

642641
result = dpnp.correlate(ad, vd, **dpnp_kwargs)
643642
expected = numpy.correlate(an, vn, **numpy_kwargs)

0 commit comments

Comments
 (0)