41
41
import dpctl .tensor as dpt
42
42
import dpctl .tensor ._tensor_impl as ti
43
43
import numpy
44
+ from dpctl .utils import ExecutionPlacementError
44
45
from numpy .core .numeric import normalize_axis_index
45
46
46
47
import dpnp
47
48
import dpnp .backend .extensions .fft ._fft_impl as fi
49
+ from dpnp .dpnp_utils .dpnp_utils_linearalgebra import (
50
+ _standardize_strides_to_nonzero ,
51
+ )
48
52
49
53
from ..dpnp_array import dpnp_array
50
54
from ..dpnp_utils import map_dtype_to_device
@@ -62,22 +66,12 @@ def _check_norm(norm):
62
66
)
63
67
64
68
65
- def _fft (a , norm , is_forward , hev_list , dev_list , axes = None ):
66
- """Calculates FFT of the input array along the specified axes."""
67
-
68
- index = 0
69
- if axes is not None : # batch_fft
70
- len_axes = 1 if isinstance (axes , int ) else len (axes )
71
- local_axes = numpy .arange (- len_axes , 0 )
72
- a = dpnp .moveaxis (a , axes , local_axes )
73
- a_shape_orig = a .shape
74
- local_shape = (- 1 ,) + a .shape [- len_axes :]
75
- a = dpnp .reshape (a , local_shape )
76
- index = 1
77
-
78
- shape = a .shape [index :]
79
- strides = (0 ,) + a .strides [index :]
69
+ def _commit_descriptor (a , a_strides , index , axes ):
70
+ """Commit the FFT descriptor for the input array."""
80
71
72
+ a_shape = a .shape
73
+ shape = a_shape [index :]
74
+ strides = (0 ,) + a_strides [index :]
81
75
if a .dtype == dpnp .complex64 :
82
76
dsc = fi .Complex64Descriptor (shape )
83
77
else :
@@ -87,28 +81,95 @@ def _fft(a, norm, is_forward, hev_list, dev_list, axes=None):
87
81
dsc .bwd_strides = dsc .fwd_strides
88
82
dsc .transform_in_place = False
89
83
if axes is not None : # batch_fft
90
- dsc .fwd_distance = a . strides [0 ]
84
+ dsc .fwd_distance = a_strides [0 ]
91
85
dsc .bwd_distance = dsc .fwd_distance
92
- dsc .number_of_transforms = numpy .prod (a . shape [0 ])
86
+ dsc .number_of_transforms = numpy .prod (a_shape [0 ])
93
87
dsc .commit (a .sycl_queue )
94
88
95
- # TODO: replace with dpnp.empty_like when its bug is fixed
96
- # and it returns arrays with the same stride as input array
97
- res = dpt .usm_ndarray (
98
- a .shape ,
99
- dtype = a .dtype ,
100
- buffer = a .usm_type ,
101
- strides = a .strides ,
102
- offset = 0 ,
103
- buffer_ctor_kwargs = {"queue" : a .sycl_queue },
104
- )
105
- fft_event , _ = fi .compute_fft (dsc , a .get_array (), res , is_forward , dev_list )
89
+ return dsc
90
+
91
+
92
+ def _compute_result (dsc , a , out , is_forward , a_strides , hev_list , dev_list ):
93
+ """Compute the result of the FFT."""
94
+
95
+ a_usm = a .get_array ()
96
+ if (
97
+ out is not None
98
+ and out .strides == a_strides
99
+ and not ti ._array_overlap (a_usm , out .get_array ())
100
+ ):
101
+ res_usm = out .get_array ()
102
+ else :
103
+ # Result array that is used in OneMKL must have the exact same
104
+ # stride as input array
105
+ res_usm = dpt .usm_ndarray (
106
+ a .shape ,
107
+ dtype = a .dtype ,
108
+ buffer = a .usm_type ,
109
+ strides = a_strides ,
110
+ offset = 0 ,
111
+ buffer_ctor_kwargs = {"queue" : a .sycl_queue },
112
+ )
113
+ fft_event , _ = fi .compute_fft (dsc , a_usm , res_usm , is_forward , dev_list )
106
114
hev_list .append (fft_event )
107
115
dpctl .SyclEvent .wait_for (hev_list )
108
116
109
- res = dpnp_array ._create_from_usm_ndarray (res )
117
+ res = dpnp_array ._create_from_usm_ndarray (res_usm )
118
+
119
+ return res
120
+
121
+
122
+ def _copy_array (x , dep_events , host_events ):
123
+ """
124
+ Creating a C-contiguous copy of input array if input array has a negative
125
+ stride or it does not have a complex data types.
126
+ """
127
+ dtype = x .dtype
128
+ copy_flag = False
129
+ if numpy .min (x .strides ) < 0 :
130
+ # negative stride is not allowed in OneMKL FFT
131
+ copy_flag = True
132
+ elif not dpnp .issubdtype (dtype , dpnp .complexfloating ):
133
+ # if input is not complex, convert to complex
134
+ copy_flag = True
135
+ if dtype == dpnp .float32 :
136
+ dtype = dpnp .complex64
137
+ else :
138
+ dtype = map_dtype_to_device (dpnp .complex128 , x .sycl_device )
139
+
140
+ if copy_flag :
141
+ x_copy = dpnp .empty_like (x , dtype = dtype , order = "C" )
142
+ ht_copy_ev , copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
143
+ src = dpnp .get_usm_ndarray (x ),
144
+ dst = x_copy .get_array (),
145
+ sycl_queue = x .sycl_queue ,
146
+ )
147
+ dep_events .append (copy_ev )
148
+ host_events .append (ht_copy_ev )
149
+ return x_copy
150
+ return x
151
+
152
+
153
+ def _fft (a , norm , out , is_forward , hev_list , dev_list , axes = None ):
154
+ """Calculates FFT of the input array along the specified axes."""
110
155
111
- scale = numpy .prod (shape , dtype = a .real .dtype )
156
+ index = 0
157
+ if axes is not None : # batch_fft
158
+ len_axes = 1 if isinstance (axes , int ) else len (axes )
159
+ local_axes = numpy .arange (- len_axes , 0 )
160
+ a = dpnp .moveaxis (a , axes , local_axes )
161
+ a_shape_orig = a .shape
162
+ local_shape = (- 1 ,) + a_shape_orig [- len_axes :]
163
+ a = dpnp .reshape (a , local_shape )
164
+ index = 1
165
+
166
+ a_strides = _standardize_strides_to_nonzero (a .strides , a .shape )
167
+ dsc = _commit_descriptor (a , a_strides , index , axes )
168
+ res = _compute_result (
169
+ dsc , a , out , is_forward , a_strides , hev_list , dev_list
170
+ )
171
+
172
+ scale = numpy .prod (a .shape [index :], dtype = a .real .dtype )
112
173
norm_factor = 1
113
174
if norm == "ortho" :
114
175
norm_factor = numpy .sqrt (scale )
@@ -121,14 +182,16 @@ def _fft(a, norm, is_forward, hev_list, dev_list, axes=None):
121
182
if axes is not None : # batch_fft
122
183
res = dpnp .reshape (res , a_shape_orig )
123
184
res = dpnp .moveaxis (res , local_axes , axes )
124
- return res
185
+
186
+ result = dpnp .get_result_array (res , out = out , casting = "same_kind" )
187
+ if not (result .flags .c_contiguous or result .flags .f_contiguous ):
188
+ result = dpnp .ascontiguousarray (result )
189
+ return result
125
190
126
191
127
- def _truncate_or_pad (a , shape , axes ):
192
+ def _truncate_or_pad (a , shape , axes , copy_ht_ev , copy_dp_ev ):
128
193
"""Truncating or zero-padding the input array along the specified axes."""
129
194
130
- copy_ht_ev = []
131
- copy_dp_ev = []
132
195
shape = (shape ,) if isinstance (shape , int ) else shape
133
196
axes = (axes ,) if isinstance (axes , int ) else axes
134
197
@@ -146,58 +209,77 @@ def _truncate_or_pad(a, shape, axes):
146
209
exec_q = a .sycl_queue
147
210
index [axis ] = slice (0 , a_shape [axis ]) # orig shape
148
211
a_shape [axis ] = s # modified shape
212
+ order = "F" if a .flags .f_contiguous else "C"
149
213
z = dpnp .zeros (
150
214
a_shape ,
151
215
dtype = a .dtype ,
216
+ order = order ,
152
217
usm_type = a .usm_type ,
153
218
sycl_queue = exec_q ,
154
219
)
155
220
ht_ev , dp_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
156
221
src = a .get_array (),
157
222
dst = z .get_array ()[tuple (index )],
158
223
sycl_queue = exec_q ,
224
+ depends = copy_dp_ev ,
159
225
)
160
226
copy_ht_ev .append (ht_ev )
161
227
copy_dp_ev .append (dp_ev )
162
228
a = z
163
229
164
- return a , copy_ht_ev , copy_dp_ev
230
+ return a
231
+
232
+
233
+ def _validate_out_keyword (a , out ):
234
+ """Validate out keyword argument."""
235
+ if out is not None :
236
+ dpnp .check_supported_arrays_type (out )
237
+ if (
238
+ dpctl .utils .get_execution_queue ((a .sycl_queue , out .sycl_queue ))
239
+ is None
240
+ ):
241
+ raise ExecutionPlacementError (
242
+ "Input and output allocation queues are not compatible"
243
+ )
244
+
245
+ if out .shape != a .shape :
246
+ raise ValueError ("output array has incorrect shape." )
247
+
248
+ if not dpnp .issubdtype (out .dtype , dpnp .complexfloating ):
249
+ raise TypeError ("output array has incorrect data type." )
165
250
166
251
167
- def dpnp_fft (a , is_forward , n = None , axis = - 1 , norm = None ):
252
+ def dpnp_fft (a , is_forward , n = None , axis = - 1 , norm = None , out = None ):
168
253
"""Calculates 1-D FFT of the input array along axis"""
169
254
170
255
_check_norm (norm )
171
- if not dpnp .issubdtype (a .dtype , dpnp .complexfloating ):
172
- if a .dtype == dpnp .float32 :
173
- dtype = dpnp .complex64
174
- else :
175
- dtype = map_dtype_to_device (dpnp .complex128 , a .sycl_device )
176
- a = dpnp .astype (a , dtype , copy = False )
256
+ a_ndim = a .ndim
257
+ copy_ht_ev = []
258
+ copy_dp_ev = []
259
+ a = _copy_array (a , copy_ht_ev , copy_dp_ev )
177
260
178
- if a . ndim == 0 :
261
+ if a_ndim == 0 :
179
262
raise ValueError ("Input array must be at least 1D" )
180
263
181
- axis = normalize_axis_index (axis , a . ndim )
264
+ axis = normalize_axis_index (axis , a_ndim )
182
265
if n is None :
183
266
n = a .shape [axis ]
184
267
if not isinstance (n , int ):
185
268
raise TypeError ("`n` should be None or an integer" )
186
269
if n < 1 :
187
270
raise ValueError (f"Invalid number of FFT data points ({ n } ) specified" )
188
271
189
- a , copy_ht_ev , copy_dp_ev = _truncate_or_pad (a , n , axis )
272
+ a = _truncate_or_pad (a , n , axis , copy_ht_ev , copy_dp_ev )
273
+ _validate_out_keyword (a , out )
274
+
190
275
if a .size == 0 :
191
- if a .shape [axis ] == 0 :
192
- raise ValueError (
193
- f"Invalid number of FFT data points ({ 0 } ) specified."
194
- )
195
276
return a
196
277
197
- if a . ndim == 1 :
278
+ if a_ndim == 1 :
198
279
return _fft (
199
280
a ,
200
281
norm = norm ,
282
+ out = out ,
201
283
is_forward = is_forward ,
202
284
hev_list = copy_ht_ev ,
203
285
dev_list = copy_dp_ev ,
@@ -206,6 +288,7 @@ def dpnp_fft(a, is_forward, n=None, axis=-1, norm=None):
206
288
return _fft (
207
289
a ,
208
290
norm = norm ,
291
+ out = out ,
209
292
is_forward = is_forward ,
210
293
axes = axis ,
211
294
hev_list = copy_ht_ev ,
0 commit comments