Skip to content

Commit 0976da8

Browse files
authored
Merge fc76991 into d353299
2 parents d353299 + fc76991 commit 0976da8

File tree

9 files changed

+937
-252
lines changed

9 files changed

+937
-252
lines changed

.github/workflows/conda-package.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ env:
5252
test_umath.py
5353
test_usm_type.py
5454
third_party/cupy/core_tests
55+
third_party/cupy/fft_tests
5556
third_party/cupy/creation_tests
5657
third_party/cupy/indexing_tests/test_indexing.py
5758
third_party/cupy/lib_tests

dpnp/fft/dpnp_iface_fft.py

Lines changed: 381 additions & 137 deletions
Large diffs are not rendered by default.

dpnp/fft/dpnp_utils_fft.py

Lines changed: 188 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,16 @@
3737
# pylint: disable=c-extension-no-member
3838
# pylint: disable=no-name-in-module
3939

40+
from collections.abc import Sequence
41+
4042
import dpctl
4143
import dpctl.tensor._tensor_impl as ti
4244
import dpctl.utils as dpu
4345
import numpy
44-
from dpctl.tensor._numpy_helper import normalize_axis_index
46+
from dpctl.tensor._numpy_helper import (
47+
normalize_axis_index,
48+
normalize_axis_tuple,
49+
)
4550
from dpctl.utils import ExecutionPlacementError
4651

4752
import dpnp
@@ -55,6 +60,7 @@
5560

5661
__all__ = [
5762
"dpnp_fft",
63+
"dpnp_fftn",
5864
]
5965

6066

@@ -66,6 +72,37 @@ def _check_norm(norm):
6672
)
6773

6874

75+
# TODO: c2r keyword is place holder for irfftn
76+
def _cook_nd_args(a, s=None, axes=None, c2r=False):
77+
if s is None:
78+
shapeless = True
79+
if axes is None:
80+
s = list(a.shape)
81+
else:
82+
s = numpy.take(a.shape, axes)
83+
else:
84+
shapeless = False
85+
86+
for s_i in s:
87+
if s_i is not None and s_i < 1 and s_i != -1:
88+
raise ValueError(
89+
f"Invalid number of FFT data points ({s_i}) specified."
90+
)
91+
92+
if axes is None:
93+
axes = list(range(-len(s), 0))
94+
95+
if len(s) != len(axes):
96+
raise ValueError("Shape and axes have different lengths.")
97+
98+
s = list(s)
99+
if c2r and shapeless:
100+
s[-1] = (a.shape[axes[-1]] - 1) * 2
101+
# use the whole input array along axis `i` if `s[i] == -1`
102+
s = [a.shape[_a] if _s == -1 else _s for _s, _a in zip(s, axes)]
103+
return s, axes
104+
105+
69106
def _commit_descriptor(a, in_place, c2c, a_strides, index, axes):
70107
"""Commit the FFT descriptor for the input array."""
71108

@@ -205,6 +242,63 @@ def _copy_array(x, complex_input):
205242
return x, copy_flag
206243

207244

245+
def _extract_axes_chunk(a, chunk_size=3):
246+
"""
247+
Classify input into a list of list with each list containing
248+
only unique values and its length is at most `chunk_size`.
249+
250+
Parameters
251+
----------
252+
a : list, tuple
253+
Input.
254+
chunk_size : int
255+
Maximum number of elements in each chunk.
256+
257+
Return
258+
------
259+
out : list of lists
260+
List of lists with each list containing only unique values
261+
and its length is at most `chunk_size`.
262+
The final list is returned in reverse order.
263+
264+
Examples
265+
--------
266+
>>> axes = (0, 1, 2, 3, 4)
267+
>>> _extract_axes_chunk(axes, chunk_size=3)
268+
[[2, 3, 4], [0, 1]]
269+
270+
>>> axes = (0, 1, 2, 3, 4, 4)
271+
>>> _extract_axes_chunk(axes, chunk_size=3)
272+
[[4], [2, 3, 4], [0, 1]]
273+
274+
"""
275+
276+
chunks = []
277+
current_chunk = []
278+
seen_elements = set()
279+
280+
for elem in a:
281+
if elem in seen_elements:
282+
# If element is already seen, start a new chunk
283+
chunks.append(current_chunk)
284+
current_chunk = [elem]
285+
seen_elements = {elem}
286+
else:
287+
current_chunk.append(elem)
288+
seen_elements.add(elem)
289+
290+
if len(current_chunk) == chunk_size:
291+
chunks.append(current_chunk)
292+
current_chunk = []
293+
seen_elements = set()
294+
295+
# Add the last chunk if it's not empty
296+
if current_chunk:
297+
chunks.append(current_chunk)
298+
299+
return chunks[::-1]
300+
301+
208302
def _fft(a, norm, out, forward, in_place, c2c, axes=None):
209303
"""Calculates FFT of the input array along the specified axes."""
210304

@@ -239,7 +333,11 @@ def _fft(a, norm, out, forward, in_place, c2c, axes=None):
239333

240334
def _scale_result(res, a_shape, norm, forward, index):
241335
"""Scale the result of the FFT according to `norm`."""
242-
scale = numpy.prod(a_shape[index:], dtype=res.real.dtype)
336+
if res.dtype in [dpnp.float32, dpnp.complex64]:
337+
dtype = dpnp.float32
338+
else:
339+
dtype = dpnp.float64
340+
scale = numpy.prod(a_shape[index:], dtype=dtype)
243341
norm_factor = 1
244342
if norm == "ortho":
245343
norm_factor = numpy.sqrt(scale)
@@ -329,9 +427,33 @@ def _validate_out_keyword(a, out, axis, c2r, r2c):
329427
raise TypeError("output array should have complex data type.")
330428

331429

430+
def _validate_s_axes(a, s, axes):
431+
if axes is not None:
432+
# validate axes is a sequence and
433+
# each axis is an integer within the range
434+
normalize_axis_tuple(list(set(axes)), a.ndim, "axes")
435+
436+
if s is not None:
437+
raise_error = False
438+
if isinstance(s, Sequence):
439+
if any(not isinstance(s_i, int) for s_i in s):
440+
raise_error = True
441+
else:
442+
raise_error = True
443+
444+
if raise_error:
445+
raise TypeError("`s` must be `None` or a sequence of integers.")
446+
447+
if axes is None:
448+
raise ValueError(
449+
"`axes` should not be `None` if `s` is not `None`."
450+
)
451+
452+
332453
def dpnp_fft(a, forward, real, n=None, axis=-1, norm=None, out=None):
333454
"""Calculates 1-D FFT of the input array along axis"""
334455

456+
_check_norm(norm)
335457
a_ndim = a.ndim
336458
if a_ndim == 0:
337459
raise ValueError("Input array must be at least 1D")
@@ -378,3 +500,67 @@ def dpnp_fft(a, forward, real, n=None, axis=-1, norm=None, out=None):
378500
c2c=c2c,
379501
axes=axis,
380502
)
503+
504+
505+
def dpnp_fftn(a, forward, s=None, axes=None, norm=None, out=None):
506+
"""Calculates N-D FFT of the input array along axes"""
507+
508+
_check_norm(norm)
509+
if isinstance(axes, (list, tuple)) and len(axes) == 0:
510+
return a
511+
512+
if a.ndim == 0:
513+
if axes is not None:
514+
raise IndexError(
515+
"Input array is 0-dimensional while axis is not `None`."
516+
)
517+
518+
return a
519+
520+
_validate_s_axes(a, s, axes)
521+
s, axes = _cook_nd_args(a, s, axes)
522+
a = _truncate_or_pad(a, s, axes)
523+
# TODO: None, False, False are place holder for future development of
524+
# rfft2, irfft2, rfftn, irfftn
525+
_validate_out_keyword(a, out, None, False, False)
526+
# TODO: True is place holder for future development of
527+
# rfft2, irfft2, rfftn, irfftn
528+
a, in_place = _copy_array(a, True)
529+
530+
if a.size == 0:
531+
return dpnp.get_result_array(a, out=out, casting="same_kind")
532+
533+
len_axes = len(axes)
534+
# OneMKL supports up to 3-dimensional FFT on GPU
535+
# repeated axis in OneMKL FFT is not allowed
536+
if len_axes > 3 or len(set(axes)) < len_axes:
537+
axes_chunk = _extract_axes_chunk(axes, chunk_size=3)
538+
for chunk in axes_chunk:
539+
a = _fft(
540+
a,
541+
norm=norm,
542+
out=out,
543+
forward=forward,
544+
in_place=in_place,
545+
# TODO: c2c=True is place holder for future development of
546+
# rfft2, irfft2, rfftn, irfftn
547+
c2c=True,
548+
axes=chunk,
549+
)
550+
return a
551+
552+
if a.ndim == len_axes:
553+
# non-batch FFT
554+
axes = None
555+
556+
return _fft(
557+
a,
558+
norm=norm,
559+
out=out,
560+
forward=forward,
561+
in_place=in_place,
562+
# TODO: c2c=True is place holder for future development of
563+
# rfft2, irfft2, rfftn, irfftn
564+
c2c=True,
565+
axes=axes,
566+
)

tests/skipped_tests.tbl

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,30 +9,6 @@ tests/test_random.py::TestPermutationsTestShuffle::test_shuffle1[lambda x: (dpnp
99
tests/test_random.py::TestPermutationsTestShuffle::test_shuffle1[lambda x: dpnp.asarray([(i, i) for i in x], [("a", object), ("b", dpnp.int32)])]]
1010
tests/test_random.py::TestPermutationsTestShuffle::test_shuffle1[lambda x: dpnp.asarray(x).astype(dpnp.int8)]
1111

12-
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_1_{axes=None, norm=None, s=(1, None), shape=(3, 4)}::test_fft2
13-
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_7_{axes=(), norm=None, s=None, shape=(3, 4)}::test_fft2
14-
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_7_{axes=(), norm=None, s=None, shape=(3, 4)}::test_ifft2
15-
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_9_{axes=None, norm=None, s=(1, 4, None), shape=(2, 3, 4)}::test_fft2
16-
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_15_{axes=(), norm=None, s=None, shape=(2, 3, 4)}::test_fft2
17-
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_15_{axes=(), norm=None, s=None, shape=(2, 3, 4)}::test_ifft2
18-
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_16_{axes=(0, 1, 2), norm='ortho', s=(2, 3), shape=(2, 3, 4)}::test_fft2
19-
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_16_{axes=(0, 1, 2), norm='ortho', s=(2, 3), shape=(2, 3, 4)}::test_ifft2
20-
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_18_{axes=None, norm=None, s=None, shape=(0, 5)}::test_fft2
21-
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_19_{axes=None, norm=None, s=None, shape=(2, 0, 5)}::test_fft2
22-
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_20_{axes=None, norm=None, s=None, shape=(0, 0, 5)}::test_fft2
23-
24-
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_1_{axes=None, norm=None, s=(1, None), shape=(3, 4)}::test_fftn
25-
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_7_{axes=(), norm=None, s=None, shape=(3, 4)}::test_fftn
26-
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_7_{axes=(), norm=None, s=None, shape=(3, 4)}::test_ifftn
27-
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_17_{axes=(), norm='ortho', s=None, shape=(2, 3, 4)}::test_fftn
28-
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_17_{axes=(), norm='ortho', s=None, shape=(2, 3, 4)}::test_ifftn
29-
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_18_{axes=(0, 1, 2), norm='ortho', s=(2, 3), shape=(2, 3, 4)}::test_fftn
30-
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_18_{axes=(0, 1, 2), norm='ortho', s=(2, 3), shape=(2, 3, 4)}::test_ifftn
31-
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_10_{axes=None, norm=None, s=(1, 4, None), shape=(2, 3, 4)}::test_fftn
32-
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_21_{axes=None, norm=None, s=None, shape=(0, 5)}::test_fftn
33-
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_22_{axes=None, norm=None, s=None, shape=(2, 0, 5)}::test_fftn
34-
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_23_{axes=None, norm=None, s=None, shape=(0, 0, 5)}::test_fftn
35-
3612
tests/third_party/intel/test_zero_copy_test1.py::test_dpnp_interaction_with_dpctl_memory
3713

3814
tests/test_umath.py::test_umaths[('divmod', 'ii')]

tests/skipped_tests_gpu.tbl

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -111,30 +111,6 @@ tests/third_party/cupy/core_tests/test_ndarray_reduction.py::TestCubReduction_pa
111111
tests/third_party/cupy/core_tests/test_ndarray_reduction.py::TestCubReduction_param_7_{order='F', shape=(10, 20, 30, 40)}::test_cub_max
112112
tests/third_party/cupy/core_tests/test_ndarray_reduction.py::TestCubReduction_param_7_{order='F', shape=(10, 20, 30, 40)}::test_cub_min
113113

114-
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_1_{axes=None, norm=None, s=(1, None), shape=(3, 4)}::test_fft2
115-
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_7_{axes=(), norm=None, s=None, shape=(3, 4)}::test_fft2
116-
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_7_{axes=(), norm=None, s=None, shape=(3, 4)}::test_ifft2
117-
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_9_{axes=None, norm=None, s=(1, 4, None), shape=(2, 3, 4)}::test_fft2
118-
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_15_{axes=(), norm=None, s=None, shape=(2, 3, 4)}::test_fft2
119-
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_15_{axes=(), norm=None, s=None, shape=(2, 3, 4)}::test_ifft2
120-
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_16_{axes=(0, 1, 2), norm='ortho', s=(2, 3), shape=(2, 3, 4)}::test_fft2
121-
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_16_{axes=(0, 1, 2), norm='ortho', s=(2, 3), shape=(2, 3, 4)}::test_ifft2
122-
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_18_{axes=None, norm=None, s=None, shape=(0, 5)}::test_fft2
123-
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_19_{axes=None, norm=None, s=None, shape=(2, 0, 5)}::test_fft2
124-
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_20_{axes=None, norm=None, s=None, shape=(0, 0, 5)}::test_fft2
125-
126-
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_1_{axes=None, norm=None, s=(1, None), shape=(3, 4)}::test_fftn
127-
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_7_{axes=(), norm=None, s=None, shape=(3, 4)}::test_fftn
128-
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_7_{axes=(), norm=None, s=None, shape=(3, 4)}::test_ifftn
129-
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_17_{axes=(), norm='ortho', s=None, shape=(2, 3, 4)}::test_fftn
130-
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_17_{axes=(), norm='ortho', s=None, shape=(2, 3, 4)}::test_ifftn
131-
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_18_{axes=(0, 1, 2), norm='ortho', s=(2, 3), shape=(2, 3, 4)}::test_fftn
132-
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_18_{axes=(0, 1, 2), norm='ortho', s=(2, 3), shape=(2, 3, 4)}::test_ifftn
133-
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_10_{axes=None, norm=None, s=(1, 4, None), shape=(2, 3, 4)}::test_fftn
134-
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_21_{axes=None, norm=None, s=None, shape=(0, 5)}::test_fftn
135-
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_22_{axes=None, norm=None, s=None, shape=(2, 0, 5)}::test_fftn
136-
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_23_{axes=None, norm=None, s=None, shape=(0, 0, 5)}::test_fftn
137-
138114
tests/third_party/cupy/indexing_tests/test_generate.py::TestAxisConcatenator::test_AxisConcatenator_init1
139115
tests/third_party/cupy/indexing_tests/test_generate.py::TestAxisConcatenator::test_len
140116
tests/third_party/cupy/indexing_tests/test_generate.py::TestC_::test_c_1

0 commit comments

Comments
 (0)