Skip to content

Commit 6d0fd93

Browse files
committed
update when both s and axes are given
1 parent 2be4c03 commit 6d0fd93

File tree

3 files changed

+117
-57
lines changed

3 files changed

+117
-57
lines changed

dpnp/fft/dpnp_utils_fft.py

Lines changed: 72 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -242,61 +242,78 @@ def _copy_array(x, complex_input):
242242
return x, copy_flag
243243

244244

245-
def _extract_axes_chunk(a, chunk_size=3):
245+
def _extract_axes_chunk(a, s, chunk_size=3):
246246
"""
247-
Classify input into a list of list with each list containing
248-
only unique values and its length is at most `chunk_size`.
247+
Classify the first input into a list of lists with each list containing
248+
only unique values in reverse order and its length is at most `chunk_size`.
249+
The second input is also classified into a list of lists with each list
250+
containing the corresponding values of the first input.
249251
250252
Parameters
251253
----------
252-
a : list, tuple
253-
Input.
254+
a : list or tuple of ints
255+
The first input.
256+
s : list or tuple of ints
257+
The second input.
254258
chunk_size : int
255259
Maximum number of elements in each chunk.
256260
257261
Return
258262
------
259-
out : list of lists
260-
List of lists with each list containing only unique values
261-
and its length is at most `chunk_size`.
262-
The final list is returned in reverse order.
263+
out : a tuple of two lists
264+
The first element of output is a list of lists with each list
265+
containing only unique values in revere order and its length is
266+
at most `chunk_size`.
267+
The second element of output is a list of lists with each list
268+
containing the corresponding values of the first input.
263269
264270
Examples
265271
--------
266272
>>> axes = (0, 1, 2, 3, 4)
267-
>>> _extract_axes_chunk(axes, chunk_size=3)
268-
[[2, 3, 4], [0, 1]]
273+
>>> shape = (7, 8, 10, 9, 5)
274+
>>> _extract_axes_chunk(axes, shape, chunk_size=3)
275+
([[4, 3], [2, 1, 0]], [[5, 9], [10, 8, 7]])
269276
270-
>>> axes = (0, 1, 2, 3, 4, 4)
271-
>>> _extract_axes_chunk(axes, chunk_size=3)
272-
[[4], [2, 3, 4], [0, 1]]
277+
>>> axes = (1, 0, 3, 2, 4, 4)
278+
>>> shape = (7, 8, 10, 5, 7, 6)
279+
>>> _extract_axes_chunk(axes, shape, chunk_size=3)
280+
([[4], [4, 2], [3, 0, 1]], [[6], [7, 5], [10, 8, 7]])
273281
274282
"""
275283

276-
chunks = []
277-
current_chunk = []
284+
a_chunks = []
285+
a_current_chunk = []
278286
seen_elements = set()
279287

280-
for elem in a:
281-
if elem in seen_elements:
288+
s_chunks = []
289+
s_current_chunk = []
290+
291+
for a_elem, s_elem in zip(a, s):
292+
if a_elem in seen_elements:
282293
# If element is already seen, start a new chunk
283-
chunks.append(current_chunk)
284-
current_chunk = [elem]
285-
seen_elements = {elem}
294+
a_chunks.append(a_current_chunk[::-1])
295+
s_chunks.append(s_current_chunk[::-1])
296+
a_current_chunk = [a_elem]
297+
s_current_chunk = [s_elem]
298+
seen_elements = {a_elem}
286299
else:
287-
current_chunk.append(elem)
288-
seen_elements.add(elem)
289-
290-
if len(current_chunk) == chunk_size:
291-
chunks.append(current_chunk)
292-
current_chunk = []
300+
a_current_chunk.append(a_elem)
301+
s_current_chunk.append(s_elem)
302+
seen_elements.add(a_elem)
303+
304+
if len(a_current_chunk) == chunk_size:
305+
a_chunks.append(a_current_chunk[::-1])
306+
s_chunks.append(s_current_chunk[::-1])
307+
a_current_chunk = []
308+
s_current_chunk = []
293309
seen_elements = set()
294310

295311
# Add the last chunk if it's not empty
296-
if current_chunk:
297-
chunks.append(current_chunk)
312+
if a_current_chunk:
313+
a_chunks.append(a_current_chunk[::-1])
314+
s_chunks.append(s_current_chunk[::-1])
298315

299-
return chunks[::-1]
316+
return a_chunks[::-1], s_chunks[::-1]
300317

301318

302319
def _fft(a, norm, out, forward, in_place, c2c, axes=None):
@@ -392,7 +409,7 @@ def _truncate_or_pad(a, shape, axes):
392409
return a
393410

394411

395-
def _validate_out_keyword(a, out, axis, c2r, r2c):
412+
def _validate_out_keyword(a, out, s, axes, c2r, r2c):
396413
"""Validate out keyword argument."""
397414
if out is not None:
398415
dpnp.check_supported_arrays_type(out)
@@ -404,16 +421,18 @@ def _validate_out_keyword(a, out, axis, c2r, r2c):
404421
"Input and output allocation queues are not compatible"
405422
)
406423

407-
# validate out shape
408-
expected_shape = a.shape
424+
# validate out shape against the final shape,
425+
# intermediate shapes may vary
426+
expected_shape = list(a.shape)
427+
for s_i, axis in zip(s[::-1], axes[::-1]):
428+
expected_shape[axis] = s_i
409429
if r2c:
410-
expected_shape = list(a.shape)
411-
expected_shape[axis] = a.shape[axis] // 2 + 1
412-
expected_shape = tuple(expected_shape)
413-
if out.shape != expected_shape:
430+
expected_shape[axes[-1]] = expected_shape[axes[-1]] // 2 + 1
431+
432+
if out.shape != tuple(expected_shape):
414433
raise ValueError(
415434
"output array has incorrect shape, expected "
416-
f"{expected_shape}, got {out.shape}."
435+
f"{tuple(expected_shape)}, got {out.shape}."
417436
)
418437

419438
# validate out data type
@@ -477,7 +496,7 @@ def dpnp_fft(a, forward, real, n=None, axis=-1, norm=None, out=None):
477496

478497
_check_norm(norm)
479498
a = _truncate_or_pad(a, n, axis)
480-
_validate_out_keyword(a, out, axis, c2r, r2c)
499+
_validate_out_keyword(a, out, (n,), (axis,), c2r, r2c)
481500
# if input array is copied, in-place FFT can be used
482501
a, in_place = _copy_array(a, c2c or c2r)
483502
if not in_place and out is not None:
@@ -519,36 +538,40 @@ def dpnp_fftn(a, forward, s=None, axes=None, norm=None, out=None):
519538

520539
_validate_s_axes(a, s, axes)
521540
s, axes = _cook_nd_args(a, s, axes)
522-
a = _truncate_or_pad(a, s, axes)
523-
# TODO: None, False, False are place holder for future development of
541+
# TODO: False and False are place holder for future development of
524542
# rfft2, irfft2, rfftn, irfftn
525-
_validate_out_keyword(a, out, None, False, False)
543+
_validate_out_keyword(a, out, s, axes, False, False)
526544
# TODO: True is place holder for future development of
527545
# rfft2, irfft2, rfftn, irfftn
528546
a, in_place = _copy_array(a, True)
529547

530-
if a.size == 0:
531-
return dpnp.get_result_array(a, out=out, casting="same_kind")
532-
533548
len_axes = len(axes)
534549
# OneMKL supports up to 3-dimensional FFT on GPU
535550
# repeated axis in OneMKL FFT is not allowed
536551
if len_axes > 3 or len(set(axes)) < len_axes:
537-
axes_chunk = _extract_axes_chunk(axes, chunk_size=3)
538-
for chunk in axes_chunk:
552+
axes_chunk, shape_chunk = _extract_axes_chunk(axes, s, chunk_size=3)
553+
for s_chunk, a_chunk in zip(shape_chunk, axes_chunk):
554+
a = _truncate_or_pad(a, shape=s_chunk, axes=a_chunk)
555+
if out is not None and out.shape == a.shape:
556+
tmp_out = out
557+
else:
558+
tmp_out = None
539559
a = _fft(
540560
a,
541561
norm=norm,
542-
out=out,
562+
out=tmp_out,
543563
forward=forward,
544564
in_place=in_place,
545565
# TODO: c2c=True is place holder for future development of
546566
# rfft2, irfft2, rfftn, irfftn
547567
c2c=True,
548-
axes=chunk,
568+
axes=a_chunk,
549569
)
550570
return a
551571

572+
a = _truncate_or_pad(a, s, axes)
573+
if a.size == 0:
574+
return dpnp.get_result_array(a, out=out, casting="same_kind")
552575
if a.ndim == len_axes:
553576
# non-batch FFT
554577
axes = None

tests/test_fft.py

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,7 @@ def test_fftn_repeated_axes(self, axes):
486486
assert_dtype_allclose(iresult, iexpected, check_only_type_kind=True)
487487

488488
@pytest.mark.parametrize("axes", [(2, 3, 3, 2), (0, 0, 3, 3)])
489-
@pytest.mark.parametrize("s", [(5, 4, 3, 3), (7, 8, 10, 7)])
489+
@pytest.mark.parametrize("s", [(5, 4, 3, 3), (7, 8, 10, 9)])
490490
def test_fftn_repeated_axes_with_s(self, axes, s):
491491
x1 = numpy.random.uniform(-10, 10, 120)
492492
x2 = numpy.random.uniform(-10, 10, 120)
@@ -495,19 +495,56 @@ def test_fftn_repeated_axes_with_s(self, axes, s):
495495
)
496496
a = dpnp.asarray(a_np)
497497

498-
result = dpnp.fft.fftn(a, axes=axes)
498+
result = dpnp.fft.fftn(a, s=s, axes=axes)
499499
# Intel® NumPy ignores repeated axes, handle it one by one
500500
expected = a_np
501-
for ii in axes:
502-
expected = numpy.fft.fft(expected, axis=ii)
501+
for jj, ii in zip(s[::-1], axes[::-1]):
502+
expected = numpy.fft.fft(expected, n=jj, axis=ii)
503503
assert_dtype_allclose(result, expected, check_only_type_kind=True)
504504

505-
iresult = dpnp.fft.ifftn(result, axes=axes)
505+
iresult = dpnp.fft.ifftn(result, s=s, axes=axes)
506506
iexpected = expected
507-
for ii in axes:
508-
iexpected = numpy.fft.ifft(iexpected, axis=ii)
507+
for jj, ii in zip(s[::-1], axes[::-1]):
508+
iexpected = numpy.fft.ifft(iexpected, n=jj, axis=ii)
509509
assert_dtype_allclose(iresult, iexpected, check_only_type_kind=True)
510510

511+
@pytest.mark.parametrize("axes", [(0, 1, 2, 3), (1, 2, 1, 2), (2, 2, 2, 3)])
512+
@pytest.mark.parametrize("s", [(2, 3, 4, 5), (5, 4, 7, 8), (2, 5, 1, 2)])
513+
def test_fftn_out(self, axes, s):
514+
x1 = numpy.random.uniform(-10, 10, 120)
515+
x2 = numpy.random.uniform(-10, 10, 120)
516+
a_np = numpy.array(x1 + 1j * x2, dtype=numpy.complex64).reshape(
517+
2, 3, 4, 5
518+
)
519+
a = dpnp.asarray(a_np)
520+
521+
out_shape = list(a.shape)
522+
for s_i, axis in zip(s[::-1], axes[::-1]):
523+
out_shape[axis] = s_i
524+
result = dpnp.empty(out_shape, dtype=a.dtype)
525+
dpnp.fft.fftn(a, out=result, s=s, axes=axes)
526+
# Intel® NumPy ignores repeated axes, handle it one by one
527+
expected = a_np
528+
for jj, ii in zip(s[::-1], axes[::-1]):
529+
expected = numpy.fft.fft(expected, n=jj, axis=ii)
530+
assert_dtype_allclose(result, expected, check_only_type_kind=True)
531+
532+
iresult = dpnp.empty(out_shape, dtype=a.dtype)
533+
dpnp.fft.ifftn(result, out=iresult, s=s, axes=axes)
534+
iexpected = expected
535+
for jj, ii in zip(s[::-1], axes[::-1]):
536+
iexpected = numpy.fft.ifft(iexpected, n=jj, axis=ii)
537+
assert_dtype_allclose(iresult, iexpected, check_only_type_kind=True)
538+
539+
def test_negative_s(self):
540+
# stock NumPy 2.0, if s is -1, the whole input is used (no padding/trimming).
541+
a_np = numpy.empty((3, 4, 5), dtype=numpy.complex64)
542+
a = dpnp.array(a_np)
543+
544+
result = dpnp.fft.fftn(a, s=(-1, -1), axes=(0, 2))
545+
expected = numpy.fft.fftn(a_np, s=(3, 5), axes=(0, 2))
546+
assert_dtype_allclose(result, expected, check_only_type_kind=True)
547+
511548
def test_fftn_empty_array(self):
512549
a_np = numpy.empty((10, 0, 4), dtype=numpy.complex64)
513550
a = dpnp.array(a_np)

tests/third_party/cupy/fft_tests/test_fft.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ def test_ifft2(self, xp, dtype, order):
216216
{"shape": (3, 4), "s": (1, 5), "axes": (0, 1)},
217217
{"shape": (3, 4), "s": None, "axes": (-2, -1)},
218218
{"shape": (3, 4), "s": None, "axes": (-1, -2)},
219-
{"shape": (3, 4), "s": None, "axes": [-1, -2]},
219+
{"shape": (3, 4), "s": None, "axes": (-1, -2)},
220220
# {"shape": (3, 4), "s": None, "axes": (0,)}, # mkl_fft gh-109
221221
# {"shape": (3, 4), "s": None, "axes": ()}, # mkl_fft gh-108
222222
{"shape": (3, 4), "s": None, "axes": None},

0 commit comments

Comments
 (0)