Skip to content

Commit fafbb6c

Browse files
Merge 4a715c7 into 44bb068
2 parents 44bb068 + 4a715c7 commit fafbb6c

File tree

2 files changed

+194
-30
lines changed

2 files changed

+194
-30
lines changed

dpnp/dpnp_iface_statistics.py

Lines changed: 101 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
3838
"""
3939

40+
import math
41+
4042
import dpctl.tensor as dpt
4143
import dpctl.utils as dpu
4244
import numpy
@@ -55,6 +57,8 @@
5557
from .dpnp_utils.dpnp_utils_reduction import dpnp_wrap_reduction_call
5658
from .dpnp_utils.dpnp_utils_statistics import dpnp_cov, dpnp_median
5759

60+
min_ = min # pylint: disable=used-before-assignment
61+
5862
__all__ = [
5963
"amax",
6064
"amin",
@@ -451,16 +455,55 @@ def _get_padding(a_size, v_size, mode):
451455
return l_pad, r_pad
452456

453457

454-
def _run_native_sliding_dot_product1d(a, v, l_pad, r_pad):
458+
def _choose_conv_method(a, v, rdtype):
459+
assert a.size >= v.size
460+
if rdtype == dpnp.bool:
461+
return "direct"
462+
463+
if v.size < 10**4 or a.size < 10**4:
464+
return "direct"
465+
466+
if dpnp.issubdtype(rdtype, dpnp.integer):
467+
max_a = int(dpnp.max(dpnp.abs(a)))
468+
sum_v = int(dpnp.sum(dpnp.abs(v)))
469+
max_value = int(max_a * sum_v)
470+
471+
default_float = dpnp.default_float_type(a.sycl_device)
472+
if max_value > 2 ** numpy.finfo(default_float).nmant - 1:
473+
return "direct"
474+
475+
if dpnp.issubdtype(rdtype, dpnp.number):
476+
return "fft"
477+
478+
raise ValueError(f"Unsupported dtype: {rdtype}")
479+
480+
481+
def _run_native_sliding_dot_product1d(a, v, l_pad, r_pad, rdtype):
455482
queue = a.sycl_queue
483+
device = a.sycl_device
456484

457-
usm_type = dpu.get_coerced_usm_type([a.usm_type, v.usm_type])
458-
out_size = l_pad + r_pad + a.size - v.size + 1
485+
supported_types = statistics_ext.sliding_dot_product1d_dtypes()
486+
supported_dtype = to_supported_dtypes(rdtype, supported_types, device)
487+
488+
if supported_dtype is None:
489+
raise ValueError(
490+
f"function does not support input types "
491+
f"({a.dtype.name}, {v.dtype.name}), "
492+
"and the inputs could not be coerced to any "
493+
f"supported types. List of supported types: "
494+
f"{[st.name for st in supported_types]}"
495+
)
496+
497+
a_casted = dpnp.asarray(a, dtype=supported_dtype, order="C")
498+
v_casted = dpnp.asarray(v, dtype=supported_dtype, order="C")
499+
500+
usm_type = dpu.get_coerced_usm_type([a_casted.usm_type, v_casted.usm_type])
501+
out_size = l_pad + r_pad + a_casted.size - v_casted.size + 1
459502
# out type is the same as input type
460-
out = dpnp.empty_like(a, shape=out_size, usm_type=usm_type)
503+
out = dpnp.empty_like(a_casted, shape=out_size, usm_type=usm_type)
461504

462-
a_usm = dpnp.get_usm_ndarray(a)
463-
v_usm = dpnp.get_usm_ndarray(v)
505+
a_usm = dpnp.get_usm_ndarray(a_casted)
506+
v_usm = dpnp.get_usm_ndarray(v_casted)
464507
out_usm = dpnp.get_usm_ndarray(out)
465508

466509
_manager = dpu.SequentialOrderManager[queue]
@@ -478,7 +521,30 @@ def _run_native_sliding_dot_product1d(a, v, l_pad, r_pad):
478521
return out
479522

480523

481-
def correlate(a, v, mode="valid"):
524+
def _convolve_fft(a, v, l_pad, r_pad, rtype):
525+
assert a.size >= v.size
526+
assert l_pad < v.size
527+
528+
# +1 is needed to avoid circular convolution
529+
padded_size = a.size + r_pad + 1
530+
fft_size = 2 ** int(math.ceil(math.log2(padded_size)))
531+
532+
af = dpnp.fft.fft(a, fft_size) # pylint: disable=no-member
533+
vf = dpnp.fft.fft(v, fft_size) # pylint: disable=no-member
534+
535+
r = dpnp.fft.ifft(af * vf) # pylint: disable=no-member
536+
if dpnp.issubdtype(rtype, dpnp.floating):
537+
r = r.real
538+
elif dpnp.issubdtype(rtype, dpnp.integer) or rtype == dpnp.bool:
539+
r = r.real.round()
540+
541+
start = v.size - 1 - l_pad
542+
end = padded_size - 1
543+
544+
return r[start:end]
545+
546+
547+
def correlate(a, v, mode="valid", method="auto"):
482548
r"""
483549
Cross-correlation of two 1-dimensional sequences.
484550
@@ -503,10 +569,24 @@ def correlate(a, v, mode="valid"):
503569
is ``"valid"``, unlike :obj:`dpnp.convolve`, which uses ``"full"``.
504570
505571
Default: ``"valid"``.
572+
method : {"auto", "direct", "fft"}, optional
573+
`"direct"`: The correlation is determined directly from sums.
574+
575+
`"fft"`: The Fourier Transform is used to perform the calculations.
576+
This method is faster for long sequences but can have accuracy issues.
577+
578+
`"auto"`: Automatically chooses direct or Fourier method based on
579+
an estimate of which is faster.
580+
581+
Note: Use of the FFT convolution on input containing NAN or INF
582+
will lead to the entire output being NAN or INF.
583+
Use method='direct' when your input contains NAN or INF values.
584+
585+
Default: ``"auto"``.
506586
507587
Returns
508588
-------
509-
out : dpnp.ndarray
589+
out : {dpnp.ndarray}
510590
Discrete cross-correlation of `a` and `v`.
511591
512592
Notes
@@ -570,20 +650,14 @@ def correlate(a, v, mode="valid"):
570650
f"Received shapes: a.shape={a.shape}, v.shape={v.shape}"
571651
)
572652

573-
supported_types = statistics_ext.sliding_dot_product1d_dtypes()
653+
supported_methods = ["auto", "direct", "fft"]
654+
if method not in supported_methods:
655+
raise ValueError(
656+
f"Unknown method: {method}. Supported methods: {supported_methods}"
657+
)
574658

575659
device = a.sycl_device
576660
rdtype = result_type_for_device([a.dtype, v.dtype], device)
577-
supported_dtype = to_supported_dtypes(rdtype, supported_types, device)
578-
579-
if supported_dtype is None:
580-
raise ValueError(
581-
f"function does not support input types "
582-
f"({a.dtype.name}, {v.dtype.name}), "
583-
"and the inputs could not be coerced to any "
584-
f"supported types. List of supported types: "
585-
f"{[st.name for st in supported_types]}"
586-
)
587661

588662
if dpnp.issubdtype(v.dtype, dpnp.complexfloating):
589663
v = dpnp.conj(v)
@@ -595,10 +669,15 @@ def correlate(a, v, mode="valid"):
595669

596670
l_pad, r_pad = _get_padding(a.size, v.size, mode)
597671

598-
a_casted = dpnp.asarray(a, dtype=supported_dtype, order="C")
599-
v_casted = dpnp.asarray(v, dtype=supported_dtype, order="C")
672+
if method == "auto":
673+
method = _choose_conv_method(a, v, rdtype)
600674

601-
r = _run_native_sliding_dot_product1d(a_casted, v_casted, l_pad, r_pad)
675+
if method == "direct":
676+
r = _run_native_sliding_dot_product1d(a, v, l_pad, r_pad, rdtype)
677+
elif method == "fft":
678+
r = _convolve_fft(a, v[::-1], l_pad, r_pad, rdtype)
679+
else:
680+
raise ValueError(f"Unknown method: {method}")
602681

603682
if revert:
604683
r = r[::-1]

dpnp/tests/test_statistics.py

Lines changed: 93 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -583,26 +583,104 @@ def test_corrcoef_scalar(self):
583583

584584

585585
class TestCorrelate:
586+
def setup_method(self):
587+
numpy.random.seed(0)
588+
586589
@pytest.mark.parametrize(
587590
"a, v", [([1], [1, 2, 3]), ([1, 2, 3], [1]), ([1, 2, 3], [1, 2])]
588591
)
589592
@pytest.mark.parametrize("mode", [None, "full", "valid", "same"])
590593
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))
591-
def test_correlate(self, a, v, mode, dtype):
594+
@pytest.mark.parametrize("method", [None, "auto", "direct", "fft"])
595+
def test_correlate(self, a, v, mode, dtype, method):
592596
an = numpy.array(a, dtype=dtype)
593597
vn = numpy.array(v, dtype=dtype)
594598
ad = dpnp.array(an)
595599
vd = dpnp.array(vn)
596600

597-
if mode is None:
598-
expected = numpy.correlate(an, vn)
599-
result = dpnp.correlate(ad, vd)
600-
else:
601-
expected = numpy.correlate(an, vn, mode=mode)
602-
result = dpnp.correlate(ad, vd, mode=mode)
601+
dpnp_kwargs = {}
602+
numpy_kwargs = {}
603+
if mode is not None:
604+
dpnp_kwargs["mode"] = mode
605+
numpy_kwargs["mode"] = mode
606+
if method is not None:
607+
dpnp_kwargs["method"] = method
608+
609+
expected = numpy.correlate(an, vn, **numpy_kwargs)
610+
result = dpnp.correlate(ad, vd, **dpnp_kwargs)
603611

604612
assert_dtype_allclose(result, expected)
605613

614+
@pytest.mark.parametrize("a_size", [1, 100, 10000])
615+
@pytest.mark.parametrize("v_size", [1, 100, 10000])
616+
@pytest.mark.parametrize("mode", ["full", "valid", "same"])
617+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True))
618+
@pytest.mark.parametrize("method", ["auto", "direct", "fft"])
619+
def test_correlate_random(self, a_size, v_size, mode, dtype, method):
620+
if dtype == dpnp.bool:
621+
an = numpy.random.rand(a_size) > 0.9
622+
vn = numpy.random.rand(v_size) > 0.9
623+
else:
624+
an = (100 * numpy.random.rand(a_size)).astype(dtype)
625+
vn = (100 * numpy.random.rand(v_size)).astype(dtype)
626+
627+
if dpnp.issubdtype(dtype, dpnp.complexfloating):
628+
an = an + 1j * (100 * numpy.random.rand(a_size)).astype(dtype)
629+
vn = vn + 1j * (100 * numpy.random.rand(v_size)).astype(dtype)
630+
631+
ad = dpnp.array(an)
632+
vd = dpnp.array(vn)
633+
634+
dpnp_kwargs = {}
635+
numpy_kwargs = {}
636+
if mode is not None:
637+
dpnp_kwargs["mode"] = mode
638+
numpy_kwargs["mode"] = mode
639+
if method is not None:
640+
dpnp_kwargs["method"] = method
641+
642+
result = dpnp.correlate(ad, vd, **dpnp_kwargs)
643+
expected = numpy.correlate(an, vn, **numpy_kwargs)
644+
645+
rdtype = result.dtype
646+
if dpnp.issubdtype(rdtype, dpnp.integer):
647+
rdtype = dpnp.default_float_type(ad.device)
648+
649+
if method != "fft" and (
650+
dpnp.issubdtype(dtype, dpnp.integer) or dtype == dpnp.bool
651+
):
652+
# For 'direct' and 'auto' methods, we expect exact results for integer types
653+
assert_array_equal(result, expected)
654+
else:
655+
result = result.astype(rdtype)
656+
if method == "direct":
657+
expected = numpy.correlate(an, vn, **numpy_kwargs)
658+
# For 'direct' method we can use standard validation
659+
assert_dtype_allclose(result, expected, factor=30)
660+
else:
661+
rtol = 1e-3
662+
atol = 1e-10
663+
664+
if rdtype == dpnp.float64 or rdtype == dpnp.complex128:
665+
rtol = 1e-6
666+
atol = 1e-12
667+
elif rdtype == dpnp.bool:
668+
result = result.astype(dpnp.int32)
669+
rdtype = result.dtype
670+
671+
expected = expected.astype(rdtype)
672+
673+
diff = numpy.abs(result.asnumpy() - expected)
674+
invalid = diff > atol + rtol * numpy.abs(expected)
675+
676+
# When using the 'fft' method, we might encounter outliers.
677+
# This usually happens when the resulting array contains values close to zero.
678+
# For these outliers, the relative error can be significant.
679+
# We can tolerate a few such outliers.
680+
max_outliers = 8 if expected.size > 1 else 0
681+
if invalid.sum() > max_outliers:
682+
assert_dtype_allclose(result, expected, factor=1000)
683+
606684
def test_correlate_mode_error(self):
607685
a = dpnp.arange(5)
608686
v = dpnp.arange(3)
@@ -643,7 +721,7 @@ def test_correlate_different_sizes(self, size):
643721
vd = dpnp.array(v)
644722

645723
expected = numpy.correlate(a, v)
646-
result = dpnp.correlate(ad, vd)
724+
result = dpnp.correlate(ad, vd, method="direct")
647725

648726
assert_dtype_allclose(result, expected, factor=20)
649727

@@ -654,6 +732,13 @@ def test_correlate_another_sycl_queue(self):
654732
with pytest.raises(ValueError):
655733
dpnp.correlate(a, v)
656734

735+
def test_correlate_unkown_method(self):
736+
a = dpnp.arange(5)
737+
v = dpnp.arange(3)
738+
739+
with pytest.raises(ValueError):
740+
dpnp.correlate(a, v, method="unknown")
741+
657742

658743
class TestCov:
659744
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)