Skip to content

Commit 56ae34a

Browse files
committed
Add dtype to fftfreq/rfftfreq
1 parent 395c896 commit 56ae34a

File tree

2 files changed

+52
-16
lines changed

2 files changed

+52
-16
lines changed

dpnp/fft/dpnp_iface_fft.py

Lines changed: 36 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from .dpnp_utils_fft import (
4242
dpnp_fft,
4343
dpnp_fftn,
44+
dpnp_fillfreq,
4445
)
4546

4647
__all__ = [
@@ -257,7 +258,9 @@ def fft2(a, s=None, axes=(-2, -1), norm=None, out=None):
257258
)
258259

259260

260-
def fftfreq(n, d=1.0, device=None, usm_type=None, sycl_queue=None):
261+
def fftfreq(
262+
n, /, *, d=1.0, dtype=None, device=None, usm_type=None, sycl_queue=None
263+
):
261264
"""
262265
Return the Discrete Fourier Transform sample frequencies.
263266
@@ -279,6 +282,12 @@ def fftfreq(n, d=1.0, device=None, usm_type=None, sycl_queue=None):
279282
d : scalar, optional
280283
Sample spacing (inverse of the sampling rate).
281284
Default: ``1.0``.
285+
dtype : {None, str, dtype object}, optional
286+
The output array data type. Must be a real-valued floating-point data
287+
type. If `dtype` is ``None``, the output array data type must be the
288+
default real-valued floating-point data type.
289+
290+
Default: ``None``.
282291
device : {None, string, SyclDevice, SyclQueue, Device}, optional
283292
An array API concept of device where the output array is created.
284293
`device` can be ``None``, a oneAPI filter selector string, an instance
@@ -342,23 +351,19 @@ def fftfreq(n, d=1.0, device=None, usm_type=None, sycl_queue=None):
342351
if not dpnp.isscalar(d):
343352
raise ValueError("`d` should be an scalar")
344353

345-
cfd_kwarg = {
346-
"device": device,
347-
"usm_type": usm_type,
348-
"sycl_queue": sycl_queue,
349-
}
354+
if dtype and not dpnp.issubdtype(dtype, dpnp.floating):
355+
raise ValueError(
356+
"dtype must a real-valued floating-point data type, "
357+
f"but got {dtype}"
358+
)
350359

351360
val = 1.0 / (n * d)
352-
results = dpnp.empty(n, dtype=dpnp.intp, **cfd_kwarg)
361+
results = dpnp.empty(
362+
n, dtype=dtype, device=device, usm_type=usm_type, sycl_queue=sycl_queue
363+
)
353364

354365
m = (n - 1) // 2 + 1
355-
p1 = dpnp.arange(0, m, dtype=dpnp.intp, **cfd_kwarg)
356-
357-
results[:m] = p1
358-
p2 = dpnp.arange(m - n, 0, dtype=dpnp.intp, **cfd_kwarg)
359-
360-
results[m:] = p2
361-
return results * val
366+
return dpnp_fillfreq(results, m, n, val)
362367

363368

364369
def fftn(a, s=None, axes=None, norm=None, out=None):
@@ -1507,7 +1512,9 @@ def rfft2(a, s=None, axes=(-2, -1), norm=None, out=None):
15071512
)
15081513

15091514

1510-
def rfftfreq(n, d=1.0, device=None, usm_type=None, sycl_queue=None):
1515+
def rfftfreq(
1516+
n, /, *, d=1.0, dtype=None, device=None, usm_type=None, sycl_queue=None
1517+
):
15111518
"""
15121519
Return the Discrete Fourier Transform sample frequencies
15131520
(for usage with :obj:`dpnp.fft.rfft`, :obj:`dpnp.fft.irfft`).
@@ -1533,6 +1540,12 @@ def rfftfreq(n, d=1.0, device=None, usm_type=None, sycl_queue=None):
15331540
d : scalar, optional
15341541
Sample spacing (inverse of the sampling rate).
15351542
Default: ``1.0``.
1543+
dtype : {None, str, dtype object}, optional
1544+
The output array data type. Must be a real-valued floating-point data
1545+
type. If `dtype` is ``None``, the output array data type must be the
1546+
default real-valued floating-point data type.
1547+
1548+
Default: ``None``.
15361549
device : {None, string, SyclDevice, SyclQueue, Device}, optional
15371550
An array API concept of device where the output array is created.
15381551
`device` can be ``None``, a oneAPI filter selector string, an instance
@@ -1598,12 +1611,19 @@ def rfftfreq(n, d=1.0, device=None, usm_type=None, sycl_queue=None):
15981611
raise ValueError("`n` should be an integer")
15991612
if not dpnp.isscalar(d):
16001613
raise ValueError("`d` should be an scalar")
1614+
1615+
if dtype and not dpnp.issubdtype(dtype, dpnp.floating):
1616+
raise ValueError(
1617+
"dtype must a real-valued floating-point data type, "
1618+
f"but got {dtype}"
1619+
)
1620+
16011621
val = 1.0 / (n * d)
16021622
m = n // 2 + 1
16031623
results = dpnp.arange(
16041624
0,
16051625
m,
1606-
dtype=dpnp.intp,
1626+
dtype=dtype,
16071627
device=device,
16081628
usm_type=usm_type,
16091629
sycl_queue=sycl_queue,

dpnp/fft/dpnp_utils_fft.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
__all__ = [
6161
"dpnp_fft",
6262
"dpnp_fftn",
63+
"dpnp_fillfreq",
6364
]
6465

6566

@@ -698,3 +699,18 @@ def dpnp_fftn(a, forward, real, s=None, axes=None, norm=None, out=None):
698699
return _complex_nd_fft(
699700
a, s, norm, out, forward, in_place, c2c, axes, a.ndim != len_axes
700701
)
702+
703+
704+
def dpnp_fillfreq(a, m, n, val):
705+
"""Fill an array with the sample frequencies"""
706+
707+
exec_q = a.sycl_queue
708+
_manager = dpctl.utils.SequentialOrderManager[exec_q]
709+
710+
# it's assumed there is no dependent events to fill the array
711+
ht_lin_ev, lin_ev = ti._linspace_step(0, m, a[:m].get_array(), exec_q)
712+
_manager.add_event_pair(ht_lin_ev, lin_ev)
713+
714+
ht_lin_ev, lin_ev = ti._linspace_step(m - n, 0, a[m:].get_array(), exec_q)
715+
_manager.add_event_pair(ht_lin_ev, lin_ev)
716+
return a * val

0 commit comments

Comments
 (0)