37
37
# pylint: disable=c-extension-no-member
38
38
# pylint: disable=no-name-in-module
39
39
40
+ from collections .abc import Sequence
41
+
40
42
import dpctl
41
43
import dpctl .tensor ._tensor_impl as ti
42
44
import dpctl .utils as dpu
43
45
import numpy
44
- from dpctl .tensor ._numpy_helper import normalize_axis_index
46
+ from dpctl .tensor ._numpy_helper import (
47
+ normalize_axis_index ,
48
+ normalize_axis_tuple ,
49
+ )
45
50
from dpctl .utils import ExecutionPlacementError
46
51
47
52
import dpnp
55
60
56
61
__all__ = [
57
62
"dpnp_fft" ,
63
+ "dpnp_fftn" ,
58
64
]
59
65
60
66
@@ -66,6 +72,37 @@ def _check_norm(norm):
66
72
)
67
73
68
74
75
+ # TODO: c2r keyword is place holder for irfftn
76
+ def _cook_nd_args (a , s = None , axes = None , c2r = False ):
77
+ if s is None :
78
+ shapeless = True
79
+ if axes is None :
80
+ s = list (a .shape )
81
+ else :
82
+ s = numpy .take (a .shape , axes )
83
+ else :
84
+ shapeless = False
85
+
86
+ for s_i in s :
87
+ if s_i is not None and s_i < 1 and s_i != - 1 :
88
+ raise ValueError (
89
+ f"Invalid number of FFT data points ({ s_i } ) specified."
90
+ )
91
+
92
+ if axes is None :
93
+ axes = list (range (- len (s ), 0 ))
94
+
95
+ if len (s ) != len (axes ):
96
+ raise ValueError ("Shape and axes have different lengths." )
97
+
98
+ s = list (s )
99
+ if c2r and shapeless :
100
+ s [- 1 ] = (a .shape [axes [- 1 ]] - 1 ) * 2
101
+ # use the whole input array along axis `i` if `s[i] == -1`
102
+ s = [a .shape [_a ] if _s == - 1 else _s for _s , _a in zip (s , axes )]
103
+ return s , axes
104
+
105
+
69
106
def _commit_descriptor (a , in_place , c2c , a_strides , index , axes ):
70
107
"""Commit the FFT descriptor for the input array."""
71
108
@@ -205,6 +242,63 @@ def _copy_array(x, complex_input):
205
242
return x , copy_flag
206
243
207
244
245
+ def _extract_axes_chunk (a , chunk_size = 3 ):
246
+ """
247
+ Classify input into a list of list with each list containing
248
+ only unique values and its length is at most `chunk_size`.
249
+
250
+ Parameters
251
+ ----------
252
+ a : list, tuple
253
+ Input.
254
+ chunk_size : int
255
+ Maximum number of elements in each chunk.
256
+
257
+ Return
258
+ ------
259
+ out : list of lists
260
+ List of lists with each list containing only unique values
261
+ and its length is at most `chunk_size`.
262
+ The final list is returned in reverse order.
263
+
264
+ Examples
265
+ --------
266
+ >>> axes = (0, 1, 2, 3, 4)
267
+ >>> _extract_axes_chunk(axes, chunk_size=3)
268
+ [[2, 3, 4], [0, 1]]
269
+
270
+ >>> axes = (0, 1, 2, 3, 4, 4)
271
+ >>> _extract_axes_chunk(axes, chunk_size=3)
272
+ [[4], [2, 3, 4], [0, 1]]
273
+
274
+ """
275
+
276
+ chunks = []
277
+ current_chunk = []
278
+ seen_elements = set ()
279
+
280
+ for elem in a :
281
+ if elem in seen_elements :
282
+ # If element is already seen, start a new chunk
283
+ chunks .append (current_chunk )
284
+ current_chunk = [elem ]
285
+ seen_elements = {elem }
286
+ else :
287
+ current_chunk .append (elem )
288
+ seen_elements .add (elem )
289
+
290
+ if len (current_chunk ) == chunk_size :
291
+ chunks .append (current_chunk )
292
+ current_chunk = []
293
+ seen_elements = set ()
294
+
295
+ # Add the last chunk if it's not empty
296
+ if current_chunk :
297
+ chunks .append (current_chunk )
298
+
299
+ return chunks [::- 1 ]
300
+
301
+
208
302
def _fft (a , norm , out , forward , in_place , c2c , axes = None ):
209
303
"""Calculates FFT of the input array along the specified axes."""
210
304
@@ -239,7 +333,11 @@ def _fft(a, norm, out, forward, in_place, c2c, axes=None):
239
333
240
334
def _scale_result (res , a_shape , norm , forward , index ):
241
335
"""Scale the result of the FFT according to `norm`."""
242
- scale = numpy .prod (a_shape [index :], dtype = res .real .dtype )
336
+ if res .dtype in [dpnp .float32 , dpnp .complex64 ]:
337
+ dtype = dpnp .float32
338
+ else :
339
+ dtype = dpnp .float64
340
+ scale = numpy .prod (a_shape [index :], dtype = dtype )
243
341
norm_factor = 1
244
342
if norm == "ortho" :
245
343
norm_factor = numpy .sqrt (scale )
@@ -329,9 +427,33 @@ def _validate_out_keyword(a, out, axis, c2r, r2c):
329
427
raise TypeError ("output array should have complex data type." )
330
428
331
429
430
+ def _validate_s_axes (a , s , axes ):
431
+ if axes is not None :
432
+ # validate axes is a sequence and
433
+ # each axis is an integer within the range
434
+ normalize_axis_tuple (list (set (axes )), a .ndim , "axes" )
435
+
436
+ if s is not None :
437
+ raise_error = False
438
+ if isinstance (s , Sequence ):
439
+ if any (not isinstance (s_i , int ) for s_i in s ):
440
+ raise_error = True
441
+ else :
442
+ raise_error = True
443
+
444
+ if raise_error :
445
+ raise TypeError ("`s` must be `None` or a sequence of integers." )
446
+
447
+ if axes is None :
448
+ raise ValueError (
449
+ "`axes` should not be `None` if `s` is not `None`."
450
+ )
451
+
452
+
332
453
def dpnp_fft (a , forward , real , n = None , axis = - 1 , norm = None , out = None ):
333
454
"""Calculates 1-D FFT of the input array along axis"""
334
455
456
+ _check_norm (norm )
335
457
a_ndim = a .ndim
336
458
if a_ndim == 0 :
337
459
raise ValueError ("Input array must be at least 1D" )
@@ -378,3 +500,67 @@ def dpnp_fft(a, forward, real, n=None, axis=-1, norm=None, out=None):
378
500
c2c = c2c ,
379
501
axes = axis ,
380
502
)
503
+
504
+
505
+ def dpnp_fftn (a , forward , s = None , axes = None , norm = None , out = None ):
506
+ """Calculates N-D FFT of the input array along axes"""
507
+
508
+ _check_norm (norm )
509
+ if isinstance (axes , (list , tuple )) and len (axes ) == 0 :
510
+ return a
511
+
512
+ if a .ndim == 0 :
513
+ if axes is not None :
514
+ raise IndexError (
515
+ "Input array is 0-dimensional while axis is not `None`."
516
+ )
517
+
518
+ return a
519
+
520
+ _validate_s_axes (a , s , axes )
521
+ s , axes = _cook_nd_args (a , s , axes )
522
+ a = _truncate_or_pad (a , s , axes )
523
+ # TODO: None, False, False are place holder for future development of
524
+ # rfft2, irfft2, rfftn, irfftn
525
+ _validate_out_keyword (a , out , None , False , False )
526
+ # TODO: True is place holder for future development of
527
+ # rfft2, irfft2, rfftn, irfftn
528
+ a , in_place = _copy_array (a , True )
529
+
530
+ if a .size == 0 :
531
+ return dpnp .get_result_array (a , out = out , casting = "same_kind" )
532
+
533
+ len_axes = len (axes )
534
+ # OneMKL supports up to 3-dimensional FFT on GPU
535
+ # repeated axis in OneMKL FFT is not allowed
536
+ if len_axes > 3 or len (set (axes )) < len_axes :
537
+ axes_chunk = _extract_axes_chunk (axes , chunk_size = 3 )
538
+ for chunk in axes_chunk :
539
+ a = _fft (
540
+ a ,
541
+ norm = norm ,
542
+ out = out ,
543
+ forward = forward ,
544
+ in_place = in_place ,
545
+ # TODO: c2c=True is place holder for future development of
546
+ # rfft2, irfft2, rfftn, irfftn
547
+ c2c = True ,
548
+ axes = chunk ,
549
+ )
550
+ return a
551
+
552
+ if a .ndim == len_axes :
553
+ # non-batch FFT
554
+ axes = None
555
+
556
+ return _fft (
557
+ a ,
558
+ norm = norm ,
559
+ out = out ,
560
+ forward = forward ,
561
+ in_place = in_place ,
562
+ # TODO: c2c=True is place holder for future development of
563
+ # rfft2, irfft2, rfftn, irfftn
564
+ c2c = True ,
565
+ axes = axes ,
566
+ )
0 commit comments