@@ -242,61 +242,78 @@ def _copy_array(x, complex_input):
242
242
return x , copy_flag
243
243
244
244
245
- def _extract_axes_chunk (a , chunk_size = 3 ):
245
+ def _extract_axes_chunk (a , s , chunk_size = 3 ):
246
246
"""
247
- Classify input into a list of list with each list containing
248
- only unique values and its length is at most `chunk_size`.
247
+ Classify the first input into a list of lists with each list containing
248
+ only unique values in reverse order and its length is at most `chunk_size`.
249
+ The second input is also classified into a list of lists with each list
250
+ containing the corresponding values of the first input.
249
251
250
252
Parameters
251
253
----------
252
- a : list, tuple
253
- Input.
254
+ a : list or tuple of ints
255
+ The first input.
256
+ s : list or tuple of ints
257
+ The second input.
254
258
chunk_size : int
255
259
Maximum number of elements in each chunk.
256
260
257
261
Return
258
262
------
259
- out : list of lists
260
- List of lists with each list containing only unique values
261
- and its length is at most `chunk_size`.
262
- The final list is returned in reverse order.
263
+ out : a tuple of two lists
264
+ The first element of output is a list of lists with each list
265
+ containing only unique values in revere order and its length is
266
+ at most `chunk_size`.
267
+ The second element of output is a list of lists with each list
268
+ containing the corresponding values of the first input.
263
269
264
270
Examples
265
271
--------
266
272
>>> axes = (0, 1, 2, 3, 4)
267
- >>> _extract_axes_chunk(axes, chunk_size=3)
268
- [[2, 3, 4], [0, 1]]
273
+ >>> shape = (7, 8, 10, 9, 5)
274
+ >>> _extract_axes_chunk(axes, shape, chunk_size=3)
275
+ ([[4, 3], [2, 1, 0]], [[5, 9], [10, 8, 7]])
269
276
270
- >>> axes = (0, 1, 2, 3, 4, 4)
271
- >>> _extract_axes_chunk(axes, chunk_size=3)
272
- [[4], [2, 3, 4], [0, 1]]
277
+ >>> axes = (1, 0, 3, 2, 4, 4)
278
+ >>> shape = (7, 8, 10, 5, 7, 6)
279
+ >>> _extract_axes_chunk(axes, shape, chunk_size=3)
280
+ ([[4], [4, 2], [3, 0, 1]], [[6], [7, 5], [10, 8, 7]])
273
281
274
282
"""
275
283
276
- chunks = []
277
- current_chunk = []
284
+ a_chunks = []
285
+ a_current_chunk = []
278
286
seen_elements = set ()
279
287
280
- for elem in a :
281
- if elem in seen_elements :
288
+ s_chunks = []
289
+ s_current_chunk = []
290
+
291
+ for a_elem , s_elem in zip (a , s ):
292
+ if a_elem in seen_elements :
282
293
# If element is already seen, start a new chunk
283
- chunks .append (current_chunk )
284
- current_chunk = [elem ]
285
- seen_elements = {elem }
294
+ a_chunks .append (a_current_chunk [::- 1 ])
295
+ s_chunks .append (s_current_chunk [::- 1 ])
296
+ a_current_chunk = [a_elem ]
297
+ s_current_chunk = [s_elem ]
298
+ seen_elements = {a_elem }
286
299
else :
287
- current_chunk .append (elem )
288
- seen_elements .add (elem )
289
-
290
- if len (current_chunk ) == chunk_size :
291
- chunks .append (current_chunk )
292
- current_chunk = []
300
+ a_current_chunk .append (a_elem )
301
+ s_current_chunk .append (s_elem )
302
+ seen_elements .add (a_elem )
303
+
304
+ if len (a_current_chunk ) == chunk_size :
305
+ a_chunks .append (a_current_chunk [::- 1 ])
306
+ s_chunks .append (s_current_chunk [::- 1 ])
307
+ a_current_chunk = []
308
+ s_current_chunk = []
293
309
seen_elements = set ()
294
310
295
311
# Add the last chunk if it's not empty
296
- if current_chunk :
297
- chunks .append (current_chunk )
312
+ if a_current_chunk :
313
+ a_chunks .append (a_current_chunk [::- 1 ])
314
+ s_chunks .append (s_current_chunk [::- 1 ])
298
315
299
- return chunks [::- 1 ]
316
+ return a_chunks [:: - 1 ], s_chunks [::- 1 ]
300
317
301
318
302
319
def _fft (a , norm , out , forward , in_place , c2c , axes = None ):
@@ -392,7 +409,7 @@ def _truncate_or_pad(a, shape, axes):
392
409
return a
393
410
394
411
395
- def _validate_out_keyword (a , out , axis , c2r , r2c ):
412
+ def _validate_out_keyword (a , out , s , axes , c2r , r2c ):
396
413
"""Validate out keyword argument."""
397
414
if out is not None :
398
415
dpnp .check_supported_arrays_type (out )
@@ -404,16 +421,18 @@ def _validate_out_keyword(a, out, axis, c2r, r2c):
404
421
"Input and output allocation queues are not compatible"
405
422
)
406
423
407
- # validate out shape
408
- expected_shape = a .shape
424
+ # validate out shape against the final shape,
425
+ # intermediate shapes may vary
426
+ expected_shape = list (a .shape )
427
+ for s_i , axis in zip (s [::- 1 ], axes [::- 1 ]):
428
+ expected_shape [axis ] = s_i
409
429
if r2c :
410
- expected_shape = list (a .shape )
411
- expected_shape [axis ] = a .shape [axis ] // 2 + 1
412
- expected_shape = tuple (expected_shape )
413
- if out .shape != expected_shape :
430
+ expected_shape [axes [- 1 ]] = expected_shape [axes [- 1 ]] // 2 + 1
431
+
432
+ if out .shape != tuple (expected_shape ):
414
433
raise ValueError (
415
434
"output array has incorrect shape, expected "
416
- f"{ expected_shape } , got { out .shape } ."
435
+ f"{ tuple ( expected_shape ) } , got { out .shape } ."
417
436
)
418
437
419
438
# validate out data type
@@ -477,7 +496,7 @@ def dpnp_fft(a, forward, real, n=None, axis=-1, norm=None, out=None):
477
496
478
497
_check_norm (norm )
479
498
a = _truncate_or_pad (a , n , axis )
480
- _validate_out_keyword (a , out , axis , c2r , r2c )
499
+ _validate_out_keyword (a , out , ( n ,), ( axis ,) , c2r , r2c )
481
500
# if input array is copied, in-place FFT can be used
482
501
a , in_place = _copy_array (a , c2c or c2r )
483
502
if not in_place and out is not None :
@@ -519,36 +538,40 @@ def dpnp_fftn(a, forward, s=None, axes=None, norm=None, out=None):
519
538
520
539
_validate_s_axes (a , s , axes )
521
540
s , axes = _cook_nd_args (a , s , axes )
522
- a = _truncate_or_pad (a , s , axes )
523
- # TODO: None, False, False are place holder for future development of
541
+ # TODO: False and False are place holder for future development of
524
542
# rfft2, irfft2, rfftn, irfftn
525
- _validate_out_keyword (a , out , None , False , False )
543
+ _validate_out_keyword (a , out , s , axes , False , False )
526
544
# TODO: True is place holder for future development of
527
545
# rfft2, irfft2, rfftn, irfftn
528
546
a , in_place = _copy_array (a , True )
529
547
530
- if a .size == 0 :
531
- return dpnp .get_result_array (a , out = out , casting = "same_kind" )
532
-
533
548
len_axes = len (axes )
534
549
# OneMKL supports up to 3-dimensional FFT on GPU
535
550
# repeated axis in OneMKL FFT is not allowed
536
551
if len_axes > 3 or len (set (axes )) < len_axes :
537
- axes_chunk = _extract_axes_chunk (axes , chunk_size = 3 )
538
- for chunk in axes_chunk :
552
+ axes_chunk , shape_chunk = _extract_axes_chunk (axes , s , chunk_size = 3 )
553
+ for s_chunk , a_chunk in zip (shape_chunk , axes_chunk ):
554
+ a = _truncate_or_pad (a , shape = s_chunk , axes = a_chunk )
555
+ if out is not None and out .shape == a .shape :
556
+ tmp_out = out
557
+ else :
558
+ tmp_out = None
539
559
a = _fft (
540
560
a ,
541
561
norm = norm ,
542
- out = out ,
562
+ out = tmp_out ,
543
563
forward = forward ,
544
564
in_place = in_place ,
545
565
# TODO: c2c=True is place holder for future development of
546
566
# rfft2, irfft2, rfftn, irfftn
547
567
c2c = True ,
548
- axes = chunk ,
568
+ axes = a_chunk ,
549
569
)
550
570
return a
551
571
572
+ a = _truncate_or_pad (a , s , axes )
573
+ if a .size == 0 :
574
+ return dpnp .get_result_array (a , out = out , casting = "same_kind" )
552
575
if a .ndim == len_axes :
553
576
# non-batch FFT
554
577
axes = None
0 commit comments