36
36
# pylint: disable=protected-access
37
37
# pylint: disable=no-name-in-module
38
38
39
+ from collections .abc import Sequence
40
+
39
41
import dpctl
40
42
import dpctl .tensor ._tensor_impl as ti
41
43
import dpctl .utils as dpu
42
44
import numpy
43
- from dpctl .tensor ._numpy_helper import normalize_axis_index
45
+ from dpctl .tensor ._numpy_helper import (
46
+ normalize_axis_index ,
47
+ normalize_axis_tuple ,
48
+ )
44
49
from dpctl .utils import ExecutionPlacementError
45
50
46
51
import dpnp
54
59
55
60
__all__ = [
56
61
"dpnp_fft" ,
62
+ "dpnp_fftn" ,
57
63
]
58
64
59
65
@@ -159,6 +165,37 @@ def _compute_result(dsc, a, out, forward, c2c, a_strides):
159
165
return result
160
166
161
167
168
+ # TODO: c2r keyword is place holder for irfftn
169
+ def _cook_nd_args (a , s = None , axes = None , c2r = False ):
170
+ if s is None :
171
+ shapeless = True
172
+ if axes is None :
173
+ s = list (a .shape )
174
+ else :
175
+ s = numpy .take (a .shape , axes )
176
+ else :
177
+ shapeless = False
178
+
179
+ for s_i in s :
180
+ if s_i is not None and s_i < 1 and s_i != - 1 :
181
+ raise ValueError (
182
+ f"Invalid number of FFT data points ({ s_i } ) specified."
183
+ )
184
+
185
+ if axes is None :
186
+ axes = list (range (- len (s ), 0 ))
187
+
188
+ if len (s ) != len (axes ):
189
+ raise ValueError ("Shape and axes have different lengths." )
190
+
191
+ s = list (s )
192
+ if c2r and shapeless :
193
+ s [- 1 ] = (a .shape [axes [- 1 ]] - 1 ) * 2
194
+ # use the whole input array along axis `i` if `s[i] == -1`
195
+ s = [a .shape [_a ] if _s == - 1 else _s for _s , _a in zip (s , axes )]
196
+ return s , axes
197
+
198
+
162
199
def _copy_array (x , complex_input ):
163
200
"""
164
201
Creating a C-contiguous copy of input array if input array has a negative
@@ -204,6 +241,80 @@ def _copy_array(x, complex_input):
204
241
return x , copy_flag
205
242
206
243
244
+ def _extract_axes_chunk (a , s , chunk_size = 3 ):
245
+ """
246
+ Classify the first input into a list of lists with each list containing
247
+ only unique values in reverse order and its length is at most `chunk_size`.
248
+ The second input is also classified into a list of lists with each list
249
+ containing the corresponding values of the first input.
250
+
251
+ Parameters
252
+ ----------
253
+ a : list or tuple of ints
254
+ The first input.
255
+ s : list or tuple of ints
256
+ The second input.
257
+ chunk_size : int
258
+ Maximum number of elements in each chunk.
259
+
260
+ Return
261
+ ------
262
+ out : a tuple of two lists
263
+ The first element of output is a list of lists with each list
264
+ containing only unique values in revere order and its length is
265
+ at most `chunk_size`.
266
+ The second element of output is a list of lists with each list
267
+ containing the corresponding values of the first input.
268
+
269
+ Examples
270
+ --------
271
+ >>> axes = (0, 1, 2, 3, 4)
272
+ >>> shape = (7, 8, 10, 9, 5)
273
+ >>> _extract_axes_chunk(axes, shape, chunk_size=3)
274
+ ([[4, 3], [2, 1, 0]], [[5, 9], [10, 8, 7]])
275
+
276
+ >>> axes = (1, 0, 3, 2, 4, 4)
277
+ >>> shape = (7, 8, 10, 5, 7, 6)
278
+ >>> _extract_axes_chunk(axes, shape, chunk_size=3)
279
+ ([[4], [4, 2], [3, 0, 1]], [[6], [7, 5], [10, 8, 7]])
280
+
281
+ """
282
+
283
+ a_chunks = []
284
+ a_current_chunk = []
285
+ seen_elements = set ()
286
+
287
+ s_chunks = []
288
+ s_current_chunk = []
289
+
290
+ for a_elem , s_elem in zip (a , s ):
291
+ if a_elem in seen_elements :
292
+ # If element is already seen, start a new chunk
293
+ a_chunks .append (a_current_chunk [::- 1 ])
294
+ s_chunks .append (s_current_chunk [::- 1 ])
295
+ a_current_chunk = [a_elem ]
296
+ s_current_chunk = [s_elem ]
297
+ seen_elements = {a_elem }
298
+ else :
299
+ a_current_chunk .append (a_elem )
300
+ s_current_chunk .append (s_elem )
301
+ seen_elements .add (a_elem )
302
+
303
+ if len (a_current_chunk ) == chunk_size :
304
+ a_chunks .append (a_current_chunk [::- 1 ])
305
+ s_chunks .append (s_current_chunk [::- 1 ])
306
+ a_current_chunk = []
307
+ s_current_chunk = []
308
+ seen_elements = set ()
309
+
310
+ # Add the last chunk if it's not empty
311
+ if a_current_chunk :
312
+ a_chunks .append (a_current_chunk [::- 1 ])
313
+ s_chunks .append (s_current_chunk [::- 1 ])
314
+
315
+ return a_chunks [::- 1 ], s_chunks [::- 1 ]
316
+
317
+
207
318
def _fft (a , norm , out , forward , in_place , c2c , axes = None ):
208
319
"""Calculates FFT of the input array along the specified axes."""
209
320
@@ -238,7 +349,11 @@ def _fft(a, norm, out, forward, in_place, c2c, axes=None):
238
349
239
350
def _scale_result (res , a_shape , norm , forward , index ):
240
351
"""Scale the result of the FFT according to `norm`."""
241
- scale = numpy .prod (a_shape [index :], dtype = res .real .dtype )
352
+ if res .dtype in [dpnp .float32 , dpnp .complex64 ]:
353
+ dtype = dpnp .float32
354
+ else :
355
+ dtype = dpnp .float64
356
+ scale = numpy .prod (a_shape [index :], dtype = dtype )
242
357
norm_factor = 1
243
358
if norm == "ortho" :
244
359
norm_factor = numpy .sqrt (scale )
@@ -293,7 +408,7 @@ def _truncate_or_pad(a, shape, axes):
293
408
return a
294
409
295
410
296
- def _validate_out_keyword (a , out , axis , c2r , r2c ):
411
+ def _validate_out_keyword (a , out , s , axes , c2r , r2c ):
297
412
"""Validate out keyword argument."""
298
413
if out is not None :
299
414
dpnp .check_supported_arrays_type (out )
@@ -305,16 +420,18 @@ def _validate_out_keyword(a, out, axis, c2r, r2c):
305
420
"Input and output allocation queues are not compatible"
306
421
)
307
422
308
- # validate out shape
309
- expected_shape = a .shape
423
+ # validate out shape against the final shape,
424
+ # intermediate shapes may vary
425
+ expected_shape = list (a .shape )
426
+ for s_i , axis in zip (s [::- 1 ], axes [::- 1 ]):
427
+ expected_shape [axis ] = s_i
310
428
if r2c :
311
- expected_shape = list (a .shape )
312
- expected_shape [axis ] = a .shape [axis ] // 2 + 1
313
- expected_shape = tuple (expected_shape )
314
- if out .shape != expected_shape :
429
+ expected_shape [axes [- 1 ]] = expected_shape [axes [- 1 ]] // 2 + 1
430
+
431
+ if out .shape != tuple (expected_shape ):
315
432
raise ValueError (
316
433
"output array has incorrect shape, expected "
317
- f"{ expected_shape } , got { out .shape } ."
434
+ f"{ tuple ( expected_shape ) } , got { out .shape } ."
318
435
)
319
436
320
437
# validate out data type
@@ -328,9 +445,33 @@ def _validate_out_keyword(a, out, axis, c2r, r2c):
328
445
raise TypeError ("output array should have complex data type." )
329
446
330
447
448
+ def _validate_s_axes (a , s , axes ):
449
+ if axes is not None :
450
+ # validate axes is a sequence and
451
+ # each axis is an integer within the range
452
+ normalize_axis_tuple (list (set (axes )), a .ndim , "axes" )
453
+
454
+ if s is not None :
455
+ raise_error = False
456
+ if isinstance (s , Sequence ):
457
+ if any (not isinstance (s_i , int ) for s_i in s ):
458
+ raise_error = True
459
+ else :
460
+ raise_error = True
461
+
462
+ if raise_error :
463
+ raise TypeError ("`s` must be `None` or a sequence of integers." )
464
+
465
+ if axes is None :
466
+ raise ValueError (
467
+ "`axes` should not be `None` if `s` is not `None`."
468
+ )
469
+
470
+
331
471
def dpnp_fft (a , forward , real , n = None , axis = - 1 , norm = None , out = None ):
332
472
"""Calculates 1-D FFT of the input array along axis"""
333
473
474
+ _check_norm (norm )
334
475
a_ndim = a .ndim
335
476
if a_ndim == 0 :
336
477
raise ValueError ("Input array must be at least 1D" )
@@ -354,7 +495,7 @@ def dpnp_fft(a, forward, real, n=None, axis=-1, norm=None, out=None):
354
495
355
496
_check_norm (norm )
356
497
a = _truncate_or_pad (a , n , axis )
357
- _validate_out_keyword (a , out , axis , c2r , r2c )
498
+ _validate_out_keyword (a , out , ( n ,), ( axis ,) , c2r , r2c )
358
499
# if input array is copied, in-place FFT can be used
359
500
a , in_place = _copy_array (a , c2c or c2r )
360
501
if not in_place and out is not None :
@@ -377,3 +518,71 @@ def dpnp_fft(a, forward, real, n=None, axis=-1, norm=None, out=None):
377
518
c2c = c2c ,
378
519
axes = axis ,
379
520
)
521
+
522
+
523
+ def dpnp_fftn (a , forward , s = None , axes = None , norm = None , out = None ):
524
+ """Calculates N-D FFT of the input array along axes"""
525
+
526
+ _check_norm (norm )
527
+ if isinstance (axes , (list , tuple )) and len (axes ) == 0 :
528
+ return a
529
+
530
+ if a .ndim == 0 :
531
+ if axes is not None :
532
+ raise IndexError (
533
+ "Input array is 0-dimensional while axis is not `None`."
534
+ )
535
+
536
+ return a
537
+
538
+ _validate_s_axes (a , s , axes )
539
+ s , axes = _cook_nd_args (a , s , axes )
540
+ # TODO: False and False are place holder for future development of
541
+ # rfft2, irfft2, rfftn, irfftn
542
+ _validate_out_keyword (a , out , s , axes , False , False )
543
+ # TODO: True is place holder for future development of
544
+ # rfft2, irfft2, rfftn, irfftn
545
+ a , in_place = _copy_array (a , True )
546
+
547
+ len_axes = len (axes )
548
+ # OneMKL supports up to 3-dimensional FFT on GPU
549
+ # repeated axis in OneMKL FFT is not allowed
550
+ if len_axes > 3 or len (set (axes )) < len_axes :
551
+ axes_chunk , shape_chunk = _extract_axes_chunk (axes , s , chunk_size = 3 )
552
+ for s_chunk , a_chunk in zip (shape_chunk , axes_chunk ):
553
+ a = _truncate_or_pad (a , shape = s_chunk , axes = a_chunk )
554
+ if out is not None and out .shape == a .shape :
555
+ tmp_out = out
556
+ else :
557
+ tmp_out = None
558
+ a = _fft (
559
+ a ,
560
+ norm = norm ,
561
+ out = tmp_out ,
562
+ forward = forward ,
563
+ in_place = in_place ,
564
+ # TODO: c2c=True is place holder for future development of
565
+ # rfft2, irfft2, rfftn, irfftn
566
+ c2c = True ,
567
+ axes = a_chunk ,
568
+ )
569
+ return a
570
+
571
+ a = _truncate_or_pad (a , s , axes )
572
+ if a .size == 0 :
573
+ return dpnp .get_result_array (a , out = out , casting = "same_kind" )
574
+ if a .ndim == len_axes :
575
+ # non-batch FFT
576
+ axes = None
577
+
578
+ return _fft (
579
+ a ,
580
+ norm = norm ,
581
+ out = out ,
582
+ forward = forward ,
583
+ in_place = in_place ,
584
+ # TODO: c2c=True is place holder for future development of
585
+ # rfft2, irfft2, rfftn, irfftn
586
+ c2c = True ,
587
+ axes = axes ,
588
+ )
0 commit comments