|
37 | 37 | # pylint: disable=c-extension-no-member
|
38 | 38 | # pylint: disable=no-name-in-module
|
39 | 39 |
|
| 40 | +from collections.abc import Sequence |
| 41 | + |
40 | 42 | import dpctl
|
41 | 43 | import dpctl.tensor._tensor_impl as ti
|
42 | 44 | import dpctl.utils as dpu
|
43 | 45 | import numpy
|
44 |
| -from dpctl.tensor._numpy_helper import normalize_axis_index |
| 46 | +from dpctl.tensor._numpy_helper import ( |
| 47 | + normalize_axis_index, |
| 48 | + normalize_axis_tuple, |
| 49 | +) |
45 | 50 | from dpctl.utils import ExecutionPlacementError
|
46 | 51 |
|
47 | 52 | import dpnp
|
|
55 | 60 |
|
56 | 61 | __all__ = [
|
57 | 62 | "dpnp_fft",
|
| 63 | + "dpnp_fftn", |
58 | 64 | ]
|
59 | 65 |
|
60 | 66 |
|
@@ -160,6 +166,37 @@ def _compute_result(dsc, a, out, forward, c2c, a_strides):
|
160 | 166 | return result
|
161 | 167 |
|
162 | 168 |
|
| 169 | +# TODO: c2r keyword is place holder for irfftn |
| 170 | +def _cook_nd_args(a, s=None, axes=None, c2r=False): |
| 171 | + if s is None: |
| 172 | + shapeless = True |
| 173 | + if axes is None: |
| 174 | + s = list(a.shape) |
| 175 | + else: |
| 176 | + s = numpy.take(a.shape, axes) |
| 177 | + else: |
| 178 | + shapeless = False |
| 179 | + |
| 180 | + for s_i in s: |
| 181 | + if s_i is not None and s_i < 1 and s_i != -1: |
| 182 | + raise ValueError( |
| 183 | + f"Invalid number of FFT data points ({s_i}) specified." |
| 184 | + ) |
| 185 | + |
| 186 | + if axes is None: |
| 187 | + axes = list(range(-len(s), 0)) |
| 188 | + |
| 189 | + if len(s) != len(axes): |
| 190 | + raise ValueError("Shape and axes have different lengths.") |
| 191 | + |
| 192 | + s = list(s) |
| 193 | + if c2r and shapeless: |
| 194 | + s[-1] = (a.shape[axes[-1]] - 1) * 2 |
| 195 | + # use the whole input array along axis `i` if `s[i] == -1` |
| 196 | + s = [a.shape[_a] if _s == -1 else _s for _s, _a in zip(s, axes)] |
| 197 | + return s, axes |
| 198 | + |
| 199 | + |
163 | 200 | def _copy_array(x, complex_input):
|
164 | 201 | """
|
165 | 202 | Creating a C-contiguous copy of input array if input array has a negative
|
@@ -205,6 +242,63 @@ def _copy_array(x, complex_input):
|
205 | 242 | return x, copy_flag
|
206 | 243 |
|
207 | 244 |
|
| 245 | +def _extract_axes_chunk(a, chunk_size=3): |
| 246 | + """ |
| 247 | + Classify input into a list of list with each list containing |
| 248 | + only unique values and its length is at most `chunk_size`. |
| 249 | +
|
| 250 | + Parameters |
| 251 | + ---------- |
| 252 | + a : list, tuple |
| 253 | + Input. |
| 254 | + chunk_size : int |
| 255 | + Maximum number of elements in each chunk. |
| 256 | +
|
| 257 | + Return |
| 258 | + ------ |
| 259 | + out : list of lists |
| 260 | + List of lists with each list containing only unique values |
| 261 | + and its length is at most `chunk_size`. |
| 262 | + The final list is returned in reverse order. |
| 263 | +
|
| 264 | + Examples |
| 265 | + -------- |
| 266 | + >>> axes = (0, 1, 2, 3, 4) |
| 267 | + >>> _extract_axes_chunk(axes, chunk_size=3) |
| 268 | + [[2, 3, 4], [0, 1]] |
| 269 | +
|
| 270 | + >>> axes = (0, 1, 2, 3, 4, 4) |
| 271 | + >>> _extract_axes_chunk(axes, chunk_size=3) |
| 272 | + [[4], [2, 3, 4], [0, 1]] |
| 273 | +
|
| 274 | + """ |
| 275 | + |
| 276 | + chunks = [] |
| 277 | + current_chunk = [] |
| 278 | + seen_elements = set() |
| 279 | + |
| 280 | + for elem in a: |
| 281 | + if elem in seen_elements: |
| 282 | + # If element is already seen, start a new chunk |
| 283 | + chunks.append(current_chunk) |
| 284 | + current_chunk = [elem] |
| 285 | + seen_elements = {elem} |
| 286 | + else: |
| 287 | + current_chunk.append(elem) |
| 288 | + seen_elements.add(elem) |
| 289 | + |
| 290 | + if len(current_chunk) == chunk_size: |
| 291 | + chunks.append(current_chunk) |
| 292 | + current_chunk = [] |
| 293 | + seen_elements = set() |
| 294 | + |
| 295 | + # Add the last chunk if it's not empty |
| 296 | + if current_chunk: |
| 297 | + chunks.append(current_chunk) |
| 298 | + |
| 299 | + return chunks[::-1] |
| 300 | + |
| 301 | + |
208 | 302 | def _fft(a, norm, out, forward, in_place, c2c, axes=None):
|
209 | 303 | """Calculates FFT of the input array along the specified axes."""
|
210 | 304 |
|
@@ -239,7 +333,11 @@ def _fft(a, norm, out, forward, in_place, c2c, axes=None):
|
239 | 333 |
|
240 | 334 | def _scale_result(res, a_shape, norm, forward, index):
|
241 | 335 | """Scale the result of the FFT according to `norm`."""
|
242 |
| - scale = numpy.prod(a_shape[index:], dtype=res.real.dtype) |
| 336 | + if res.dtype in [dpnp.float32, dpnp.complex64]: |
| 337 | + dtype = dpnp.float32 |
| 338 | + else: |
| 339 | + dtype = dpnp.float64 |
| 340 | + scale = numpy.prod(a_shape[index:], dtype=dtype) |
243 | 341 | norm_factor = 1
|
244 | 342 | if norm == "ortho":
|
245 | 343 | norm_factor = numpy.sqrt(scale)
|
@@ -329,9 +427,33 @@ def _validate_out_keyword(a, out, axis, c2r, r2c):
|
329 | 427 | raise TypeError("output array should have complex data type.")
|
330 | 428 |
|
331 | 429 |
|
| 430 | +def _validate_s_axes(a, s, axes): |
| 431 | + if axes is not None: |
| 432 | + # validate axes is a sequence and |
| 433 | + # each axis is an integer within the range |
| 434 | + normalize_axis_tuple(list(set(axes)), a.ndim, "axes") |
| 435 | + |
| 436 | + if s is not None: |
| 437 | + raise_error = False |
| 438 | + if isinstance(s, Sequence): |
| 439 | + if any(not isinstance(s_i, int) for s_i in s): |
| 440 | + raise_error = True |
| 441 | + else: |
| 442 | + raise_error = True |
| 443 | + |
| 444 | + if raise_error: |
| 445 | + raise TypeError("`s` must be `None` or a sequence of integers.") |
| 446 | + |
| 447 | + if axes is None: |
| 448 | + raise ValueError( |
| 449 | + "`axes` should not be `None` if `s` is not `None`." |
| 450 | + ) |
| 451 | + |
| 452 | + |
332 | 453 | def dpnp_fft(a, forward, real, n=None, axis=-1, norm=None, out=None):
|
333 | 454 | """Calculates 1-D FFT of the input array along axis"""
|
334 | 455 |
|
| 456 | + _check_norm(norm) |
335 | 457 | a_ndim = a.ndim
|
336 | 458 | if a_ndim == 0:
|
337 | 459 | raise ValueError("Input array must be at least 1D")
|
@@ -378,3 +500,67 @@ def dpnp_fft(a, forward, real, n=None, axis=-1, norm=None, out=None):
|
378 | 500 | c2c=c2c,
|
379 | 501 | axes=axis,
|
380 | 502 | )
|
| 503 | + |
| 504 | + |
| 505 | +def dpnp_fftn(a, forward, s=None, axes=None, norm=None, out=None): |
| 506 | + """Calculates N-D FFT of the input array along axes""" |
| 507 | + |
| 508 | + _check_norm(norm) |
| 509 | + if isinstance(axes, (list, tuple)) and len(axes) == 0: |
| 510 | + return a |
| 511 | + |
| 512 | + if a.ndim == 0: |
| 513 | + if axes is not None: |
| 514 | + raise IndexError( |
| 515 | + "Input array is 0-dimensional while axis is not `None`." |
| 516 | + ) |
| 517 | + |
| 518 | + return a |
| 519 | + |
| 520 | + _validate_s_axes(a, s, axes) |
| 521 | + s, axes = _cook_nd_args(a, s, axes) |
| 522 | + a = _truncate_or_pad(a, s, axes) |
| 523 | + # TODO: None, False, False are place holder for future development of |
| 524 | + # rfft2, irfft2, rfftn, irfftn |
| 525 | + _validate_out_keyword(a, out, None, False, False) |
| 526 | + # TODO: True is place holder for future development of |
| 527 | + # rfft2, irfft2, rfftn, irfftn |
| 528 | + a, in_place = _copy_array(a, True) |
| 529 | + |
| 530 | + if a.size == 0: |
| 531 | + return dpnp.get_result_array(a, out=out, casting="same_kind") |
| 532 | + |
| 533 | + len_axes = len(axes) |
| 534 | + # OneMKL supports up to 3-dimensional FFT on GPU |
| 535 | + # repeated axis in OneMKL FFT is not allowed |
| 536 | + if len_axes > 3 or len(set(axes)) < len_axes: |
| 537 | + axes_chunk = _extract_axes_chunk(axes, chunk_size=3) |
| 538 | + for chunk in axes_chunk: |
| 539 | + a = _fft( |
| 540 | + a, |
| 541 | + norm=norm, |
| 542 | + out=out, |
| 543 | + forward=forward, |
| 544 | + in_place=in_place, |
| 545 | + # TODO: c2c=True is place holder for future development of |
| 546 | + # rfft2, irfft2, rfftn, irfftn |
| 547 | + c2c=True, |
| 548 | + axes=chunk, |
| 549 | + ) |
| 550 | + return a |
| 551 | + |
| 552 | + if a.ndim == len_axes: |
| 553 | + # non-batch FFT |
| 554 | + axes = None |
| 555 | + |
| 556 | + return _fft( |
| 557 | + a, |
| 558 | + norm=norm, |
| 559 | + out=out, |
| 560 | + forward=forward, |
| 561 | + in_place=in_place, |
| 562 | + # TODO: c2c=True is place holder for future development of |
| 563 | + # rfft2, irfft2, rfftn, irfftn |
| 564 | + c2c=True, |
| 565 | + axes=axes, |
| 566 | + ) |
0 commit comments