Skip to content

Commit 3b4b22f

Browse files
authored
Merge e6f75cb into 4550d18
2 parents 4550d18 + e6f75cb commit 3b4b22f

File tree

9 files changed

+1006
-261
lines changed

9 files changed

+1006
-261
lines changed

.github/workflows/conda-package.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ env:
5252
test_umath.py
5353
test_usm_type.py
5454
third_party/cupy/core_tests
55+
third_party/cupy/fft_tests
5556
third_party/cupy/creation_tests
5657
third_party/cupy/indexing_tests/test_indexing.py
5758
third_party/cupy/lib_tests

dpnp/fft/dpnp_iface_fft.py

Lines changed: 381 additions & 137 deletions
Large diffs are not rendered by default.

dpnp/fft/dpnp_utils_fft.py

Lines changed: 220 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,16 @@
3636
# pylint: disable=protected-access
3737
# pylint: disable=no-name-in-module
3838

39+
from collections.abc import Sequence
40+
3941
import dpctl
4042
import dpctl.tensor._tensor_impl as ti
4143
import dpctl.utils as dpu
4244
import numpy
43-
from dpctl.tensor._numpy_helper import normalize_axis_index
45+
from dpctl.tensor._numpy_helper import (
46+
normalize_axis_index,
47+
normalize_axis_tuple,
48+
)
4449
from dpctl.utils import ExecutionPlacementError
4550

4651
import dpnp
@@ -54,6 +59,7 @@
5459

5560
__all__ = [
5661
"dpnp_fft",
62+
"dpnp_fftn",
5763
]
5864

5965

@@ -159,6 +165,37 @@ def _compute_result(dsc, a, out, forward, c2c, a_strides):
159165
return result
160166

161167

168+
# TODO: c2r keyword is place holder for irfftn
169+
def _cook_nd_args(a, s=None, axes=None, c2r=False):
170+
if s is None:
171+
shapeless = True
172+
if axes is None:
173+
s = list(a.shape)
174+
else:
175+
s = numpy.take(a.shape, axes)
176+
else:
177+
shapeless = False
178+
179+
for s_i in s:
180+
if s_i is not None and s_i < 1 and s_i != -1:
181+
raise ValueError(
182+
f"Invalid number of FFT data points ({s_i}) specified."
183+
)
184+
185+
if axes is None:
186+
axes = list(range(-len(s), 0))
187+
188+
if len(s) != len(axes):
189+
raise ValueError("Shape and axes have different lengths.")
190+
191+
s = list(s)
192+
if c2r and shapeless:
193+
s[-1] = (a.shape[axes[-1]] - 1) * 2
194+
# use the whole input array along axis `i` if `s[i] == -1`
195+
s = [a.shape[_a] if _s == -1 else _s for _s, _a in zip(s, axes)]
196+
return s, axes
197+
198+
162199
def _copy_array(x, complex_input):
163200
"""
164201
Creating a C-contiguous copy of input array if input array has a negative
@@ -204,6 +241,80 @@ def _copy_array(x, complex_input):
204241
return x, copy_flag
205242

206243

244+
def _extract_axes_chunk(a, s, chunk_size=3):
245+
"""
246+
Classify the first input into a list of lists with each list containing
247+
only unique values in reverse order and its length is at most `chunk_size`.
248+
The second input is also classified into a list of lists with each list
249+
containing the corresponding values of the first input.
250+
251+
Parameters
252+
----------
253+
a : list or tuple of ints
254+
The first input.
255+
s : list or tuple of ints
256+
The second input.
257+
chunk_size : int
258+
Maximum number of elements in each chunk.
259+
260+
Return
261+
------
262+
out : a tuple of two lists
263+
The first element of output is a list of lists with each list
264+
containing only unique values in revere order and its length is
265+
at most `chunk_size`.
266+
The second element of output is a list of lists with each list
267+
containing the corresponding values of the first input.
268+
269+
Examples
270+
--------
271+
>>> axes = (0, 1, 2, 3, 4)
272+
>>> shape = (7, 8, 10, 9, 5)
273+
>>> _extract_axes_chunk(axes, shape, chunk_size=3)
274+
([[4, 3], [2, 1, 0]], [[5, 9], [10, 8, 7]])
275+
276+
>>> axes = (1, 0, 3, 2, 4, 4)
277+
>>> shape = (7, 8, 10, 5, 7, 6)
278+
>>> _extract_axes_chunk(axes, shape, chunk_size=3)
279+
([[4], [4, 2], [3, 0, 1]], [[6], [7, 5], [10, 8, 7]])
280+
281+
"""
282+
283+
a_chunks = []
284+
a_current_chunk = []
285+
seen_elements = set()
286+
287+
s_chunks = []
288+
s_current_chunk = []
289+
290+
for a_elem, s_elem in zip(a, s):
291+
if a_elem in seen_elements:
292+
# If element is already seen, start a new chunk
293+
a_chunks.append(a_current_chunk[::-1])
294+
s_chunks.append(s_current_chunk[::-1])
295+
a_current_chunk = [a_elem]
296+
s_current_chunk = [s_elem]
297+
seen_elements = {a_elem}
298+
else:
299+
a_current_chunk.append(a_elem)
300+
s_current_chunk.append(s_elem)
301+
seen_elements.add(a_elem)
302+
303+
if len(a_current_chunk) == chunk_size:
304+
a_chunks.append(a_current_chunk[::-1])
305+
s_chunks.append(s_current_chunk[::-1])
306+
a_current_chunk = []
307+
s_current_chunk = []
308+
seen_elements = set()
309+
310+
# Add the last chunk if it's not empty
311+
if a_current_chunk:
312+
a_chunks.append(a_current_chunk[::-1])
313+
s_chunks.append(s_current_chunk[::-1])
314+
315+
return a_chunks[::-1], s_chunks[::-1]
316+
317+
207318
def _fft(a, norm, out, forward, in_place, c2c, axes=None):
208319
"""Calculates FFT of the input array along the specified axes."""
209320

@@ -238,7 +349,11 @@ def _fft(a, norm, out, forward, in_place, c2c, axes=None):
238349

239350
def _scale_result(res, a_shape, norm, forward, index):
240351
"""Scale the result of the FFT according to `norm`."""
241-
scale = numpy.prod(a_shape[index:], dtype=res.real.dtype)
352+
if res.dtype in [dpnp.float32, dpnp.complex64]:
353+
dtype = dpnp.float32
354+
else:
355+
dtype = dpnp.float64
356+
scale = numpy.prod(a_shape[index:], dtype=dtype)
242357
norm_factor = 1
243358
if norm == "ortho":
244359
norm_factor = numpy.sqrt(scale)
@@ -293,7 +408,7 @@ def _truncate_or_pad(a, shape, axes):
293408
return a
294409

295410

296-
def _validate_out_keyword(a, out, axis, c2r, r2c):
411+
def _validate_out_keyword(a, out, s, axes, c2r, r2c):
297412
"""Validate out keyword argument."""
298413
if out is not None:
299414
dpnp.check_supported_arrays_type(out)
@@ -305,16 +420,18 @@ def _validate_out_keyword(a, out, axis, c2r, r2c):
305420
"Input and output allocation queues are not compatible"
306421
)
307422

308-
# validate out shape
309-
expected_shape = a.shape
423+
# validate out shape against the final shape,
424+
# intermediate shapes may vary
425+
expected_shape = list(a.shape)
426+
for s_i, axis in zip(s[::-1], axes[::-1]):
427+
expected_shape[axis] = s_i
310428
if r2c:
311-
expected_shape = list(a.shape)
312-
expected_shape[axis] = a.shape[axis] // 2 + 1
313-
expected_shape = tuple(expected_shape)
314-
if out.shape != expected_shape:
429+
expected_shape[axes[-1]] = expected_shape[axes[-1]] // 2 + 1
430+
431+
if out.shape != tuple(expected_shape):
315432
raise ValueError(
316433
"output array has incorrect shape, expected "
317-
f"{expected_shape}, got {out.shape}."
434+
f"{tuple(expected_shape)}, got {out.shape}."
318435
)
319436

320437
# validate out data type
@@ -328,9 +445,33 @@ def _validate_out_keyword(a, out, axis, c2r, r2c):
328445
raise TypeError("output array should have complex data type.")
329446

330447

448+
def _validate_s_axes(a, s, axes):
449+
if axes is not None:
450+
# validate axes is a sequence and
451+
# each axis is an integer within the range
452+
normalize_axis_tuple(list(set(axes)), a.ndim, "axes")
453+
454+
if s is not None:
455+
raise_error = False
456+
if isinstance(s, Sequence):
457+
if any(not isinstance(s_i, int) for s_i in s):
458+
raise_error = True
459+
else:
460+
raise_error = True
461+
462+
if raise_error:
463+
raise TypeError("`s` must be `None` or a sequence of integers.")
464+
465+
if axes is None:
466+
raise ValueError(
467+
"`axes` should not be `None` if `s` is not `None`."
468+
)
469+
470+
331471
def dpnp_fft(a, forward, real, n=None, axis=-1, norm=None, out=None):
332472
"""Calculates 1-D FFT of the input array along axis"""
333473

474+
_check_norm(norm)
334475
a_ndim = a.ndim
335476
if a_ndim == 0:
336477
raise ValueError("Input array must be at least 1D")
@@ -354,7 +495,7 @@ def dpnp_fft(a, forward, real, n=None, axis=-1, norm=None, out=None):
354495

355496
_check_norm(norm)
356497
a = _truncate_or_pad(a, n, axis)
357-
_validate_out_keyword(a, out, axis, c2r, r2c)
498+
_validate_out_keyword(a, out, (n,), (axis,), c2r, r2c)
358499
# if input array is copied, in-place FFT can be used
359500
a, in_place = _copy_array(a, c2c or c2r)
360501
if not in_place and out is not None:
@@ -377,3 +518,71 @@ def dpnp_fft(a, forward, real, n=None, axis=-1, norm=None, out=None):
377518
c2c=c2c,
378519
axes=axis,
379520
)
521+
522+
523+
def dpnp_fftn(a, forward, s=None, axes=None, norm=None, out=None):
524+
"""Calculates N-D FFT of the input array along axes"""
525+
526+
_check_norm(norm)
527+
if isinstance(axes, (list, tuple)) and len(axes) == 0:
528+
return a
529+
530+
if a.ndim == 0:
531+
if axes is not None:
532+
raise IndexError(
533+
"Input array is 0-dimensional while axis is not `None`."
534+
)
535+
536+
return a
537+
538+
_validate_s_axes(a, s, axes)
539+
s, axes = _cook_nd_args(a, s, axes)
540+
# TODO: False and False are place holder for future development of
541+
# rfft2, irfft2, rfftn, irfftn
542+
_validate_out_keyword(a, out, s, axes, False, False)
543+
# TODO: True is place holder for future development of
544+
# rfft2, irfft2, rfftn, irfftn
545+
a, in_place = _copy_array(a, True)
546+
547+
len_axes = len(axes)
548+
# OneMKL supports up to 3-dimensional FFT on GPU
549+
# repeated axis in OneMKL FFT is not allowed
550+
if len_axes > 3 or len(set(axes)) < len_axes:
551+
axes_chunk, shape_chunk = _extract_axes_chunk(axes, s, chunk_size=3)
552+
for s_chunk, a_chunk in zip(shape_chunk, axes_chunk):
553+
a = _truncate_or_pad(a, shape=s_chunk, axes=a_chunk)
554+
if out is not None and out.shape == a.shape:
555+
tmp_out = out
556+
else:
557+
tmp_out = None
558+
a = _fft(
559+
a,
560+
norm=norm,
561+
out=tmp_out,
562+
forward=forward,
563+
in_place=in_place,
564+
# TODO: c2c=True is place holder for future development of
565+
# rfft2, irfft2, rfftn, irfftn
566+
c2c=True,
567+
axes=a_chunk,
568+
)
569+
return a
570+
571+
a = _truncate_or_pad(a, s, axes)
572+
if a.size == 0:
573+
return dpnp.get_result_array(a, out=out, casting="same_kind")
574+
if a.ndim == len_axes:
575+
# non-batch FFT
576+
axes = None
577+
578+
return _fft(
579+
a,
580+
norm=norm,
581+
out=out,
582+
forward=forward,
583+
in_place=in_place,
584+
# TODO: c2c=True is place holder for future development of
585+
# rfft2, irfft2, rfftn, irfftn
586+
c2c=True,
587+
axes=axes,
588+
)

tests/skipped_tests.tbl

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,30 +9,6 @@ tests/test_random.py::TestPermutationsTestShuffle::test_shuffle1[lambda x: (dpnp
99
tests/test_random.py::TestPermutationsTestShuffle::test_shuffle1[lambda x: dpnp.asarray([(i, i) for i in x], [("a", object), ("b", dpnp.int32)])]]
1010
tests/test_random.py::TestPermutationsTestShuffle::test_shuffle1[lambda x: dpnp.asarray(x).astype(dpnp.int8)]
1111

12-
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_1_{axes=None, norm=None, s=(1, None), shape=(3, 4)}::test_fft2
13-
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_7_{axes=(), norm=None, s=None, shape=(3, 4)}::test_fft2
14-
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_7_{axes=(), norm=None, s=None, shape=(3, 4)}::test_ifft2
15-
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_9_{axes=None, norm=None, s=(1, 4, None), shape=(2, 3, 4)}::test_fft2
16-
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_15_{axes=(), norm=None, s=None, shape=(2, 3, 4)}::test_fft2
17-
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_15_{axes=(), norm=None, s=None, shape=(2, 3, 4)}::test_ifft2
18-
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_16_{axes=(0, 1, 2), norm='ortho', s=(2, 3), shape=(2, 3, 4)}::test_fft2
19-
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_16_{axes=(0, 1, 2), norm='ortho', s=(2, 3), shape=(2, 3, 4)}::test_ifft2
20-
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_18_{axes=None, norm=None, s=None, shape=(0, 5)}::test_fft2
21-
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_19_{axes=None, norm=None, s=None, shape=(2, 0, 5)}::test_fft2
22-
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_20_{axes=None, norm=None, s=None, shape=(0, 0, 5)}::test_fft2
23-
24-
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_1_{axes=None, norm=None, s=(1, None), shape=(3, 4)}::test_fftn
25-
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_7_{axes=(), norm=None, s=None, shape=(3, 4)}::test_fftn
26-
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_7_{axes=(), norm=None, s=None, shape=(3, 4)}::test_ifftn
27-
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_17_{axes=(), norm='ortho', s=None, shape=(2, 3, 4)}::test_fftn
28-
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_17_{axes=(), norm='ortho', s=None, shape=(2, 3, 4)}::test_ifftn
29-
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_18_{axes=(0, 1, 2), norm='ortho', s=(2, 3), shape=(2, 3, 4)}::test_fftn
30-
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_18_{axes=(0, 1, 2), norm='ortho', s=(2, 3), shape=(2, 3, 4)}::test_ifftn
31-
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_10_{axes=None, norm=None, s=(1, 4, None), shape=(2, 3, 4)}::test_fftn
32-
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_21_{axes=None, norm=None, s=None, shape=(0, 5)}::test_fftn
33-
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_22_{axes=None, norm=None, s=None, shape=(2, 0, 5)}::test_fftn
34-
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_23_{axes=None, norm=None, s=None, shape=(0, 0, 5)}::test_fftn
35-
3612
tests/third_party/intel/test_zero_copy_test1.py::test_dpnp_interaction_with_dpctl_memory
3713

3814
tests/test_umath.py::test_umaths[('divmod', 'ii')]

tests/skipped_tests_gpu.tbl

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -110,30 +110,6 @@ tests/third_party/cupy/core_tests/test_ndarray_reduction.py::TestCubReduction_pa
110110
tests/third_party/cupy/core_tests/test_ndarray_reduction.py::TestCubReduction_param_7_{order='F', shape=(10, 20, 30, 40)}::test_cub_max
111111
tests/third_party/cupy/core_tests/test_ndarray_reduction.py::TestCubReduction_param_7_{order='F', shape=(10, 20, 30, 40)}::test_cub_min
112112

113-
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_1_{axes=None, norm=None, s=(1, None), shape=(3, 4)}::test_fft2
114-
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_7_{axes=(), norm=None, s=None, shape=(3, 4)}::test_fft2
115-
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_7_{axes=(), norm=None, s=None, shape=(3, 4)}::test_ifft2
116-
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_9_{axes=None, norm=None, s=(1, 4, None), shape=(2, 3, 4)}::test_fft2
117-
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_15_{axes=(), norm=None, s=None, shape=(2, 3, 4)}::test_fft2
118-
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_15_{axes=(), norm=None, s=None, shape=(2, 3, 4)}::test_ifft2
119-
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_16_{axes=(0, 1, 2), norm='ortho', s=(2, 3), shape=(2, 3, 4)}::test_fft2
120-
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_16_{axes=(0, 1, 2), norm='ortho', s=(2, 3), shape=(2, 3, 4)}::test_ifft2
121-
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_18_{axes=None, norm=None, s=None, shape=(0, 5)}::test_fft2
122-
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_19_{axes=None, norm=None, s=None, shape=(2, 0, 5)}::test_fft2
123-
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_20_{axes=None, norm=None, s=None, shape=(0, 0, 5)}::test_fft2
124-
125-
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_1_{axes=None, norm=None, s=(1, None), shape=(3, 4)}::test_fftn
126-
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_7_{axes=(), norm=None, s=None, shape=(3, 4)}::test_fftn
127-
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_7_{axes=(), norm=None, s=None, shape=(3, 4)}::test_ifftn
128-
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_17_{axes=(), norm='ortho', s=None, shape=(2, 3, 4)}::test_fftn
129-
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_17_{axes=(), norm='ortho', s=None, shape=(2, 3, 4)}::test_ifftn
130-
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_18_{axes=(0, 1, 2), norm='ortho', s=(2, 3), shape=(2, 3, 4)}::test_fftn
131-
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_18_{axes=(0, 1, 2), norm='ortho', s=(2, 3), shape=(2, 3, 4)}::test_ifftn
132-
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_10_{axes=None, norm=None, s=(1, 4, None), shape=(2, 3, 4)}::test_fftn
133-
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_21_{axes=None, norm=None, s=None, shape=(0, 5)}::test_fftn
134-
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_22_{axes=None, norm=None, s=None, shape=(2, 0, 5)}::test_fftn
135-
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_23_{axes=None, norm=None, s=None, shape=(0, 0, 5)}::test_fftn
136-
137113
tests/third_party/cupy/indexing_tests/test_generate.py::TestAxisConcatenator::test_AxisConcatenator_init1
138114
tests/third_party/cupy/indexing_tests/test_generate.py::TestAxisConcatenator::test_len
139115
tests/third_party/cupy/indexing_tests/test_generate.py::TestC_::test_c_1

0 commit comments

Comments
 (0)