Skip to content

Commit ea048ac

Browse files
committed
add out keyword and fix an issue with negative stride
1 parent 9235a29 commit ea048ac

File tree

3 files changed

+248
-54
lines changed

3 files changed

+248
-54
lines changed

dpnp/fft/dpnp_iface_fft.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def get_validated_norm(norm):
101101
raise ValueError("Unknown norm value.")
102102

103103

104-
def fft(a, n=None, axis=-1, norm=None):
104+
def fft(a, n=None, axis=-1, norm=None, out=None):
105105
"""
106106
Compute the one-dimensional discrete Fourier Transform.
107107
@@ -126,6 +126,9 @@ def fft(a, n=None, axis=-1, norm=None):
126126
is scaled and with what normalization factor. ``None`` is an alias of
127127
the default option "backward".
128128
Default: "backward".
129+
out : dpnp.ndarray of complex dtype, optional
130+
If provided, the result will be placed in this array. It should be
131+
of the appropriate shape and dtype.
129132
130133
Returns
131134
-------
@@ -162,7 +165,7 @@ def fft(a, n=None, axis=-1, norm=None):
162165
"""
163166

164167
dpnp.check_supported_arrays_type(a)
165-
return dpnp_fft(a, is_forward=True, n=n, axis=axis, norm=norm)
168+
return dpnp_fft(a, is_forward=True, n=n, axis=axis, norm=norm, out=out)
166169

167170

168171
def fft2(x, s=None, axes=(-2, -1), norm=None):
@@ -362,7 +365,7 @@ def hfft(x, n=None, axis=-1, norm=None):
362365
return call_origin(numpy.fft.hfft, x, n, axis, norm)
363366

364367

365-
def ifft(a, n=None, axis=-1, norm=None):
368+
def ifft(a, n=None, axis=-1, norm=None, out=None):
366369
"""
367370
Compute the one-dimensional inverse discrete Fourier Transform.
368371
@@ -387,6 +390,9 @@ def ifft(a, n=None, axis=-1, norm=None):
387390
is scaled and with what normalization factor. ``None`` is an alias of
388391
the default option "backward".
389392
Default: "backward"
393+
out : dpnp.ndarray of complex dtype, optional
394+
If provided, the result will be placed in this array. It should be
395+
of the appropriate shape and dtype.
390396
391397
Returns
392398
-------
@@ -419,7 +425,7 @@ def ifft(a, n=None, axis=-1, norm=None):
419425
"""
420426

421427
dpnp.check_supported_arrays_type(a)
422-
return dpnp_fft(a, is_forward=False, n=n, axis=axis, norm=norm)
428+
return dpnp_fft(a, is_forward=False, n=n, axis=axis, norm=norm, out=out)
423429

424430

425431
def ifft2(x, s=None, axes=(-2, -1), norm=None):

dpnp/fft/dpnp_utils_fft.py

Lines changed: 133 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,14 @@
4141
import dpctl.tensor as dpt
4242
import dpctl.tensor._tensor_impl as ti
4343
import numpy
44+
from dpctl.utils import ExecutionPlacementError
4445
from numpy.core.numeric import normalize_axis_index
4546

4647
import dpnp
4748
import dpnp.backend.extensions.fft._fft_impl as fi
49+
from dpnp.dpnp_utils.dpnp_utils_linearalgebra import (
50+
_standardize_strides_to_nonzero,
51+
)
4852

4953
from ..dpnp_array import dpnp_array
5054
from ..dpnp_utils import map_dtype_to_device
@@ -62,22 +66,12 @@ def _check_norm(norm):
6266
)
6367

6468

65-
def _fft(a, norm, is_forward, hev_list, dev_list, axes=None):
66-
"""Calculates FFT of the input array along the specified axes."""
67-
68-
index = 0
69-
if axes is not None: # batch_fft
70-
len_axes = 1 if isinstance(axes, int) else len(axes)
71-
local_axes = numpy.arange(-len_axes, 0)
72-
a = dpnp.moveaxis(a, axes, local_axes)
73-
a_shape_orig = a.shape
74-
local_shape = (-1,) + a.shape[-len_axes:]
75-
a = dpnp.reshape(a, local_shape)
76-
index = 1
77-
78-
shape = a.shape[index:]
79-
strides = (0,) + a.strides[index:]
69+
def _commit_descriptor(a, a_strides, index, axes):
70+
"""Commit the FFT descriptor for the input array."""
8071

72+
a_shape = a.shape
73+
shape = a_shape[index:]
74+
strides = (0,) + a_strides[index:]
8175
if a.dtype == dpnp.complex64:
8276
dsc = fi.Complex64Descriptor(shape)
8377
else:
@@ -87,28 +81,95 @@ def _fft(a, norm, is_forward, hev_list, dev_list, axes=None):
8781
dsc.bwd_strides = dsc.fwd_strides
8882
dsc.transform_in_place = False
8983
if axes is not None: # batch_fft
90-
dsc.fwd_distance = a.strides[0]
84+
dsc.fwd_distance = a_strides[0]
9185
dsc.bwd_distance = dsc.fwd_distance
92-
dsc.number_of_transforms = numpy.prod(a.shape[0])
86+
dsc.number_of_transforms = numpy.prod(a_shape[0])
9387
dsc.commit(a.sycl_queue)
9488

95-
# TODO: replace with dpnp.empty_like when its bug is fixed
96-
# and it returns arrays with the same stride as input array
97-
res = dpt.usm_ndarray(
98-
a.shape,
99-
dtype=a.dtype,
100-
buffer=a.usm_type,
101-
strides=a.strides,
102-
offset=0,
103-
buffer_ctor_kwargs={"queue": a.sycl_queue},
104-
)
105-
fft_event, _ = fi.compute_fft(dsc, a.get_array(), res, is_forward, dev_list)
89+
return dsc
90+
91+
92+
def _compute_result(dsc, a, out, is_forward, a_strides, hev_list, dev_list):
93+
"""Compute the result of the FFT."""
94+
95+
a_usm = a.get_array()
96+
if (
97+
out is not None
98+
and out.strides == a_strides
99+
and not ti._array_overlap(a_usm, out.get_array())
100+
):
101+
res_usm = out.get_array()
102+
else:
103+
# Result array that is used in OneMKL must have the exact same
104+
# stride as input array
105+
res_usm = dpt.usm_ndarray(
106+
a.shape,
107+
dtype=a.dtype,
108+
buffer=a.usm_type,
109+
strides=a_strides,
110+
offset=0,
111+
buffer_ctor_kwargs={"queue": a.sycl_queue},
112+
)
113+
fft_event, _ = fi.compute_fft(dsc, a_usm, res_usm, is_forward, dev_list)
106114
hev_list.append(fft_event)
107115
dpctl.SyclEvent.wait_for(hev_list)
108116

109-
res = dpnp_array._create_from_usm_ndarray(res)
117+
res = dpnp_array._create_from_usm_ndarray(res_usm)
118+
119+
return res
120+
121+
122+
def _copy_array(x, dep_events, host_events):
123+
"""
124+
Creating a C-contiguous copy of input array if input array has a negative
125+
stride or it does not have a complex data types.
126+
"""
127+
dtype = x.dtype
128+
copy_flag = False
129+
if numpy.min(x.strides) < 0:
130+
# negative stride is not allowed in OneMKL FFT
131+
copy_flag = True
132+
elif not dpnp.issubdtype(dtype, dpnp.complexfloating):
133+
# if input is not complex, convert to complex
134+
copy_flag = True
135+
if dtype == dpnp.float32:
136+
dtype = dpnp.complex64
137+
else:
138+
dtype = map_dtype_to_device(dpnp.complex128, x.sycl_device)
139+
140+
if copy_flag:
141+
x_copy = dpnp.empty_like(x, dtype=dtype, order="C")
142+
ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
143+
src=dpnp.get_usm_ndarray(x),
144+
dst=x_copy.get_array(),
145+
sycl_queue=x.sycl_queue,
146+
)
147+
dep_events.append(copy_ev)
148+
host_events.append(ht_copy_ev)
149+
return x_copy
150+
return x
151+
152+
153+
def _fft(a, norm, out, is_forward, hev_list, dev_list, axes=None):
154+
"""Calculates FFT of the input array along the specified axes."""
110155

111-
scale = numpy.prod(shape, dtype=a.real.dtype)
156+
index = 0
157+
if axes is not None: # batch_fft
158+
len_axes = 1 if isinstance(axes, int) else len(axes)
159+
local_axes = numpy.arange(-len_axes, 0)
160+
a = dpnp.moveaxis(a, axes, local_axes)
161+
a_shape_orig = a.shape
162+
local_shape = (-1,) + a_shape_orig[-len_axes:]
163+
a = dpnp.reshape(a, local_shape)
164+
index = 1
165+
166+
a_strides = _standardize_strides_to_nonzero(a.strides, a.shape)
167+
dsc = _commit_descriptor(a, a_strides, index, axes)
168+
res = _compute_result(
169+
dsc, a, out, is_forward, a_strides, hev_list, dev_list
170+
)
171+
172+
scale = numpy.prod(a.shape[index:], dtype=a.real.dtype)
112173
norm_factor = 1
113174
if norm == "ortho":
114175
norm_factor = numpy.sqrt(scale)
@@ -121,14 +182,16 @@ def _fft(a, norm, is_forward, hev_list, dev_list, axes=None):
121182
if axes is not None: # batch_fft
122183
res = dpnp.reshape(res, a_shape_orig)
123184
res = dpnp.moveaxis(res, local_axes, axes)
124-
return res
185+
186+
result = dpnp.get_result_array(res, out=out, casting="same_kind")
187+
if not (result.flags.c_contiguous or result.flags.f_contiguous):
188+
result = dpnp.ascontiguousarray(result)
189+
return result
125190

126191

127-
def _truncate_or_pad(a, shape, axes):
192+
def _truncate_or_pad(a, shape, axes, copy_ht_ev, copy_dp_ev):
128193
"""Truncating or zero-padding the input array along the specified axes."""
129194

130-
copy_ht_ev = []
131-
copy_dp_ev = []
132195
shape = (shape,) if isinstance(shape, int) else shape
133196
axes = (axes,) if isinstance(axes, int) else axes
134197

@@ -146,58 +209,77 @@ def _truncate_or_pad(a, shape, axes):
146209
exec_q = a.sycl_queue
147210
index[axis] = slice(0, a_shape[axis]) # orig shape
148211
a_shape[axis] = s # modified shape
212+
order = "F" if a.flags.f_contiguous else "C"
149213
z = dpnp.zeros(
150214
a_shape,
151215
dtype=a.dtype,
216+
order=order,
152217
usm_type=a.usm_type,
153218
sycl_queue=exec_q,
154219
)
155220
ht_ev, dp_ev = ti._copy_usm_ndarray_into_usm_ndarray(
156221
src=a.get_array(),
157222
dst=z.get_array()[tuple(index)],
158223
sycl_queue=exec_q,
224+
depends=copy_dp_ev,
159225
)
160226
copy_ht_ev.append(ht_ev)
161227
copy_dp_ev.append(dp_ev)
162228
a = z
163229

164-
return a, copy_ht_ev, copy_dp_ev
230+
return a
231+
232+
233+
def _validate_out_keyword(a, out):
234+
"""Validate out keyword argument."""
235+
if out is not None:
236+
dpnp.check_supported_arrays_type(out)
237+
if (
238+
dpctl.utils.get_execution_queue((a.sycl_queue, out.sycl_queue))
239+
is None
240+
):
241+
raise ExecutionPlacementError(
242+
"Input and output allocation queues are not compatible"
243+
)
244+
245+
if out.shape != a.shape:
246+
raise ValueError("output array has incorrect shape.")
247+
248+
if not dpnp.issubdtype(out.dtype, dpnp.complexfloating):
249+
raise TypeError("output array has incorrect data type.")
165250

166251

167-
def dpnp_fft(a, is_forward, n=None, axis=-1, norm=None):
252+
def dpnp_fft(a, is_forward, n=None, axis=-1, norm=None, out=None):
168253
"""Calculates 1-D FFT of the input array along axis"""
169254

170255
_check_norm(norm)
171-
if not dpnp.issubdtype(a.dtype, dpnp.complexfloating):
172-
if a.dtype == dpnp.float32:
173-
dtype = dpnp.complex64
174-
else:
175-
dtype = map_dtype_to_device(dpnp.complex128, a.sycl_device)
176-
a = dpnp.astype(a, dtype, copy=False)
256+
a_ndim = a.ndim
257+
copy_ht_ev = []
258+
copy_dp_ev = []
259+
a = _copy_array(a, copy_ht_ev, copy_dp_ev)
177260

178-
if a.ndim == 0:
261+
if a_ndim == 0:
179262
raise ValueError("Input array must be at least 1D")
180263

181-
axis = normalize_axis_index(axis, a.ndim)
264+
axis = normalize_axis_index(axis, a_ndim)
182265
if n is None:
183266
n = a.shape[axis]
184267
if not isinstance(n, int):
185268
raise TypeError("`n` should be None or an integer")
186269
if n < 1:
187270
raise ValueError(f"Invalid number of FFT data points ({n}) specified")
188271

189-
a, copy_ht_ev, copy_dp_ev = _truncate_or_pad(a, n, axis)
272+
a = _truncate_or_pad(a, n, axis, copy_ht_ev, copy_dp_ev)
273+
_validate_out_keyword(a, out)
274+
190275
if a.size == 0:
191-
if a.shape[axis] == 0:
192-
raise ValueError(
193-
f"Invalid number of FFT data points ({0}) specified."
194-
)
195276
return a
196277

197-
if a.ndim == 1:
278+
if a_ndim == 1:
198279
return _fft(
199280
a,
200281
norm=norm,
282+
out=out,
201283
is_forward=is_forward,
202284
hev_list=copy_ht_ev,
203285
dev_list=copy_dp_ev,
@@ -206,6 +288,7 @@ def dpnp_fft(a, is_forward, n=None, axis=-1, norm=None):
206288
return _fft(
207289
a,
208290
norm=norm,
291+
out=out,
209292
is_forward=is_forward,
210293
axes=axis,
211294
hev_list=copy_ht_ev,

0 commit comments

Comments
 (0)