37
37
38
38
"""
39
39
40
+ import math
41
+
40
42
import dpctl .tensor as dpt
41
43
import dpctl .utils as dpu
42
44
import numpy
55
57
from .dpnp_utils .dpnp_utils_reduction import dpnp_wrap_reduction_call
56
58
from .dpnp_utils .dpnp_utils_statistics import dpnp_cov , dpnp_median
57
59
60
+ min_ = min # pylint: disable=used-before-assignment
61
+
58
62
__all__ = [
59
63
"amax" ,
60
64
"amin" ,
@@ -451,16 +455,55 @@ def _get_padding(a_size, v_size, mode):
451
455
return l_pad , r_pad
452
456
453
457
454
- def _run_native_sliding_dot_product1d (a , v , l_pad , r_pad ):
458
+ def _choose_conv_method (a , v , rdtype ):
459
+ assert a .size >= v .size
460
+ if rdtype == dpnp .bool :
461
+ return "direct"
462
+
463
+ if v .size < 10 ** 4 or a .size < 10 ** 4 :
464
+ return "direct"
465
+
466
+ if dpnp .issubdtype (rdtype , dpnp .integer ):
467
+ max_a = int (dpnp .max (dpnp .abs (a )))
468
+ sum_v = int (dpnp .sum (dpnp .abs (v )))
469
+ max_value = int (max_a * sum_v )
470
+
471
+ default_float = dpnp .default_float_type (a .sycl_device )
472
+ if max_value > 2 ** numpy .finfo (default_float ).nmant - 1 :
473
+ return "direct"
474
+
475
+ if dpnp .issubdtype (rdtype , dpnp .number ):
476
+ return "fft"
477
+
478
+ raise ValueError (f"Unsupported dtype: { rdtype } " )
479
+
480
+
481
+ def _run_native_sliding_dot_product1d (a , v , l_pad , r_pad , rdtype ):
455
482
queue = a .sycl_queue
483
+ device = a .sycl_device
456
484
457
- usm_type = dpu .get_coerced_usm_type ([a .usm_type , v .usm_type ])
458
- out_size = l_pad + r_pad + a .size - v .size + 1
485
+ supported_types = statistics_ext .sliding_dot_product1d_dtypes ()
486
+ supported_dtype = to_supported_dtypes (rdtype , supported_types , device )
487
+
488
+ if supported_dtype is None :
489
+ raise ValueError (
490
+ f"function does not support input types "
491
+ f"({ a .dtype .name } , { v .dtype .name } ), "
492
+ "and the inputs could not be coerced to any "
493
+ f"supported types. List of supported types: "
494
+ f"{ [st .name for st in supported_types ]} "
495
+ )
496
+
497
+ a_casted = dpnp .asarray (a , dtype = supported_dtype , order = "C" )
498
+ v_casted = dpnp .asarray (v , dtype = supported_dtype , order = "C" )
499
+
500
+ usm_type = dpu .get_coerced_usm_type ([a_casted .usm_type , v_casted .usm_type ])
501
+ out_size = l_pad + r_pad + a_casted .size - v_casted .size + 1
459
502
# out type is the same as input type
460
- out = dpnp .empty_like (a , shape = out_size , usm_type = usm_type )
503
+ out = dpnp .empty_like (a_casted , shape = out_size , usm_type = usm_type )
461
504
462
- a_usm = dpnp .get_usm_ndarray (a )
463
- v_usm = dpnp .get_usm_ndarray (v )
505
+ a_usm = dpnp .get_usm_ndarray (a_casted )
506
+ v_usm = dpnp .get_usm_ndarray (v_casted )
464
507
out_usm = dpnp .get_usm_ndarray (out )
465
508
466
509
_manager = dpu .SequentialOrderManager [queue ]
@@ -478,7 +521,30 @@ def _run_native_sliding_dot_product1d(a, v, l_pad, r_pad):
478
521
return out
479
522
480
523
481
- def correlate (a , v , mode = "valid" ):
524
+ def _convolve_fft (a , v , l_pad , r_pad , rtype ):
525
+ assert a .size >= v .size
526
+ assert l_pad < v .size
527
+
528
+ # +1 is needed to avoid circular convolution
529
+ padded_size = a .size + r_pad + 1
530
+ fft_size = 2 ** int (math .ceil (math .log2 (padded_size )))
531
+
532
+ af = dpnp .fft .fft (a , fft_size ) # pylint: disable=no-member
533
+ vf = dpnp .fft .fft (v , fft_size ) # pylint: disable=no-member
534
+
535
+ r = dpnp .fft .ifft (af * vf ) # pylint: disable=no-member
536
+ if dpnp .issubdtype (rtype , dpnp .floating ):
537
+ r = r .real
538
+ elif dpnp .issubdtype (rtype , dpnp .integer ) or rtype == dpnp .bool :
539
+ r = r .real .round ()
540
+
541
+ start = v .size - 1 - l_pad
542
+ end = padded_size - 1
543
+
544
+ return r [start :end ]
545
+
546
+
547
+ def correlate (a , v , mode = "valid" , method = "auto" ):
482
548
r"""
483
549
Cross-correlation of two 1-dimensional sequences.
484
550
@@ -503,10 +569,24 @@ def correlate(a, v, mode="valid"):
503
569
is ``"valid"``, unlike :obj:`dpnp.convolve`, which uses ``"full"``.
504
570
505
571
Default: ``"valid"``.
572
+ method : {"auto", "direct", "fft"}, optional
573
+ `"direct"`: The correlation is determined directly from sums.
574
+
575
+ `"fft"`: The Fourier Transform is used to perform the calculations.
576
+ This method is faster for long sequences but can have accuracy issues.
577
+
578
+ `"auto"`: Automatically chooses direct or Fourier method based on
579
+ an estimate of which is faster.
580
+
581
+ Note: Use of the FFT convolution on input containing NAN or INF
582
+ will lead to the entire output being NAN or INF.
583
+ Use method='direct' when your input contains NAN or INF values.
584
+
585
+ Default: ``"auto"``.
506
586
507
587
Returns
508
588
-------
509
- out : dpnp.ndarray
589
+ out : { dpnp.ndarray}
510
590
Discrete cross-correlation of `a` and `v`.
511
591
512
592
Notes
@@ -570,20 +650,14 @@ def correlate(a, v, mode="valid"):
570
650
f"Received shapes: a.shape={ a .shape } , v.shape={ v .shape } "
571
651
)
572
652
573
- supported_types = statistics_ext .sliding_dot_product1d_dtypes ()
653
+ supported_methods = ["auto" , "direct" , "fft" ]
654
+ if method not in supported_methods :
655
+ raise ValueError (
656
+ f"Unknown method: { method } . Supported methods: { supported_methods } "
657
+ )
574
658
575
659
device = a .sycl_device
576
660
rdtype = result_type_for_device ([a .dtype , v .dtype ], device )
577
- supported_dtype = to_supported_dtypes (rdtype , supported_types , device )
578
-
579
- if supported_dtype is None :
580
- raise ValueError (
581
- f"function does not support input types "
582
- f"({ a .dtype .name } , { v .dtype .name } ), "
583
- "and the inputs could not be coerced to any "
584
- f"supported types. List of supported types: "
585
- f"{ [st .name for st in supported_types ]} "
586
- )
587
661
588
662
if dpnp .issubdtype (v .dtype , dpnp .complexfloating ):
589
663
v = dpnp .conj (v )
@@ -595,10 +669,15 @@ def correlate(a, v, mode="valid"):
595
669
596
670
l_pad , r_pad = _get_padding (a .size , v .size , mode )
597
671
598
- a_casted = dpnp . asarray ( a , dtype = supported_dtype , order = "C" )
599
- v_casted = dpnp . asarray ( v , dtype = supported_dtype , order = "C" )
672
+ if method == "auto" :
673
+ method = _choose_conv_method ( a , v , rdtype )
600
674
601
- r = _run_native_sliding_dot_product1d (a_casted , v_casted , l_pad , r_pad )
675
+ if method == "direct" :
676
+ r = _run_native_sliding_dot_product1d (a , v , l_pad , r_pad , rdtype )
677
+ elif method == "fft" :
678
+ r = _convolve_fft (a , v [::- 1 ], l_pad , r_pad , rdtype )
679
+ else :
680
+ raise ValueError (f"Unknown method: { method } " )
602
681
603
682
if revert :
604
683
r = r [::- 1 ]
0 commit comments