Skip to content

Commit 3c70c7e

Browse files
committed
fix rfft
1 parent 8afa83c commit 3c70c7e

File tree

5 files changed

+92
-57
lines changed

5 files changed

+92
-57
lines changed

dpnp/backend/include/dpnp_iface_fft.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
* @param[in] input_boundarie Limit number of elements for @ref axis.
6363
* @param[in] inverse Using inverse algorithm.
6464
* @param[in] norm Normalization mode. 0 - backward, 1 - forward, 2 - ortho.
65+
* @param[in] real Real mode, 1 if type of array1_in is real and size of result_out is input_size/ 2 + 1, else 0.
6566
* @param[in] dep_event_vec_ref Reference to vector of SYCL events.
6667
*/
6768
template <typename _DataType>
@@ -75,5 +76,6 @@ INP_DLLEXPORT void dpnp_fft_fft_c(DPCTLSyclQueueRef q_ref,
7576
long input_boundarie,
7677
size_t inverse,
7778
const size_t norm,
79+
const size_t real,
7880
const DPCTLEventVectorRef dep_event_vec_ref);
7981
#endif // BACKEND_IFACE_FFT_H

dpnp/backend/kernels/dpnp_krnl_fft.cpp

Lines changed: 53 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,9 @@ void dpnp_fft_fft_mathlib_cmplx_to_cmplx_c(DPCTLSyclQueueRef q_ref,
175175
const void* array1_in,
176176
void* result_out,
177177
const shape_elem_type* input_shape,
178+
const shape_elem_type* result_shape,
178179
const size_t shape_size,
180+
const size_t input_size,
179181
const size_t result_size,
180182
_Descriptor_type& desc,
181183
const size_t norm)
@@ -187,7 +189,7 @@ void dpnp_fft_fft_mathlib_cmplx_to_cmplx_c(DPCTLSyclQueueRef q_ref,
187189

188190
sycl::queue queue = *(reinterpret_cast<sycl::queue*>(q_ref));
189191

190-
DPNPC_ptr_adapter<_DataType_input> input1_ptr(q_ref, array1_in, result_size);
192+
DPNPC_ptr_adapter<_DataType_input> input1_ptr(q_ref, array1_in, input_size);
191193
DPNPC_ptr_adapter<_DataType_output> result_ptr(q_ref, result_out, result_size);
192194
_DataType_input* array_1 = input1_ptr.get_ptr();
193195
_DataType_output* result = result_ptr.get_ptr();
@@ -227,72 +229,81 @@ void dpnp_fft_fft_mathlib_real_to_cmplx_c(DPCTLSyclQueueRef q_ref,
227229
const void* array1_in,
228230
void* result_out,
229231
const shape_elem_type* input_shape,
232+
const shape_elem_type* result_shape,
230233
const size_t shape_size,
234+
const size_t input_size,
231235
const size_t result_size,
232236
_Descriptor_type& desc,
233-
const size_t norm)
237+
const size_t norm,
238+
const size_t real)
234239
{
235240
if (!shape_size)
236241
{
237242
return;
238243
}
239244

240-
DPNPC_ptr_adapter<_DataType_input> input1_ptr(q_ref, array1_in, result_size);
245+
DPNPC_ptr_adapter<_DataType_input> input1_ptr(q_ref, array1_in, input_size);
241246
DPNPC_ptr_adapter<_DataType_output> result_ptr(q_ref, result_out, result_size * 2, true, true);
242247
_DataType_input* array_1 = input1_ptr.get_ptr();
243248
_DataType_output* result = result_ptr.get_ptr();
244249

250+
sycl::queue q = *(reinterpret_cast<sycl::queue*>(q_ref));
251+
245252
const size_t n_iter =
246253
std::accumulate(input_shape, input_shape + shape_size - 1, 1, std::multiplies<shape_elem_type>());
247254

248-
const size_t shift = input_shape[shape_size - 1];
255+
const size_t input_shift = input_shape[shape_size - 1];
256+
const size_t result_shift = result_shape[shape_size - 1];;
249257

250258
double forward_scale = 1.0;
251-
double backward_scale = 1.0 / shift;
259+
double backward_scale = 1.0 / input_shift;
252260

253261
desc.set_value(mkl_dft::config_param::BACKWARD_SCALE, backward_scale);
254262
desc.set_value(mkl_dft::config_param::FORWARD_SCALE, forward_scale);
255263

256-
desc.commit(DPNP_QUEUE);
264+
desc.commit(q);
257265

258266
std::vector<sycl::event> fft_events;
259267
fft_events.reserve(n_iter);
260268

261269
for (size_t i = 0; i < n_iter; ++i) {
262-
fft_events.push_back(mkl_dft::compute_forward(desc, array_1 + i * shift, result + i * shift * 2));
270+
fft_events.push_back(mkl_dft::compute_forward(desc, array_1 + i * input_shift, result + i * result_shift * 2));
263271
}
264272

265273
sycl::event::wait(fft_events);
266274

267-
size_t n_conj = shift % 2 == 0 ? shift / 2 - 1 : shift / 2;
275+
if (!real) {
268276

269-
sycl::event event;
277+
size_t n_conj = result_shift % 2 == 0 ? result_shift / 2 - 1 : result_shift / 2;
270278

271-
sycl::range<2> gws(n_iter, n_conj);
279+
sycl::event event;
272280

273-
auto kernel_parallel_for_func = [=](sycl::id<2> global_id) {
274-
size_t i = global_id[0];
275-
{
276-
size_t j = global_id[1];
281+
sycl::range<2> gws(n_iter, n_conj);
282+
283+
auto kernel_parallel_for_func = [=](sycl::id<2> global_id) {
284+
size_t i = global_id[0];
277285
{
278-
*(reinterpret_cast<std::complex<_DataType_output>*>(result) + shift * (i + 1) - (j + 1)) = std::conj(*(reinterpret_cast<std::complex<_DataType_output>*>(result) + shift * i + (j + 1)));
286+
size_t j = global_id[1];
287+
{
288+
*(reinterpret_cast<std::complex<_DataType_output>*>(result) + result_shift * (i + 1) - (j + 1)) = std::conj(*(reinterpret_cast<std::complex<_DataType_output>*>(result) + result_shift * i + (j + 1)));
289+
}
279290
}
280-
}
281-
};
291+
};
282292

283-
auto kernel_func = [&](sycl::handler& cgh) {
284-
cgh.parallel_for<class dpnp_fft_fft_mathlib_real_to_cmplx_c_kernel<_DataType_input, _DataType_output, _Descriptor_type>>(
285-
gws, kernel_parallel_for_func);
286-
};
293+
auto kernel_func = [&](sycl::handler& cgh) {
294+
cgh.parallel_for<class dpnp_fft_fft_mathlib_real_to_cmplx_c_kernel<_DataType_input, _DataType_output, _Descriptor_type>>(
295+
gws, kernel_parallel_for_func);
296+
};
287297

288-
event = DPNP_QUEUE.submit(kernel_func);
289-
event.wait();
298+
event = q.submit(kernel_func);
299+
event.wait();
300+
}
290301

291302
return;
292303
}
293304

294305
template <typename _DataType_input, typename _DataType_output>
295-
void dpnp_fft_fft_c(DPCTLSyclQueueRef q_ref,
306+
DPCTLSyclEventRef dpnp_fft_fft_c(DPCTLSyclQueueRef q_ref,
296307
const void* array1_in,
297308
void* result_out,
298309
const shape_elem_type* input_shape,
@@ -302,10 +313,9 @@ void dpnp_fft_fft_c(DPCTLSyclQueueRef q_ref,
302313
long input_boundarie,
303314
size_t inverse,
304315
const size_t norm,
316+
const size_t real,
305317
const DPCTLEventVectorRef dep_event_vec_ref)
306318
{
307-
(void)dep_event_vec_ref;
308-
309319
DPCTLSyclEventRef event_ref = nullptr;
310320

311321
if (!shape_size || !array1_in || !result_out)
@@ -317,8 +327,6 @@ void dpnp_fft_fft_c(DPCTLSyclQueueRef q_ref,
317327
std::accumulate(result_shape, result_shape + shape_size, 1, std::multiplies<shape_elem_type>());
318328
const size_t input_size =
319329
std::accumulate(input_shape, input_shape + shape_size, 1, std::multiplies<shape_elem_type>());
320-
321-
sycl::queue q = *(reinterpret_cast<sycl::queue*>(q_ref));
322330

323331
size_t dim = input_shape[shape_size - 1];
324332

@@ -330,15 +338,15 @@ void dpnp_fft_fft_c(DPCTLSyclQueueRef q_ref,
330338
{
331339
desc_dp_cmplx_t desc(dim);
332340
dpnp_fft_fft_mathlib_cmplx_to_cmplx_c<_DataType_input, _DataType_output, desc_dp_cmplx_t>(
333-
q_ref, array1_in, result_out, input_shape, shape_size, result_size, desc, norm);
341+
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, norm);
334342
}
335343
/* complex-to-complex, single precision */
336344
else if constexpr (std::is_same<_DataType_input, std::complex<float>>::value &&
337345
std::is_same<_DataType_output, std::complex<float>>::value)
338346
{
339347
desc_sp_cmplx_t desc(dim);
340348
dpnp_fft_fft_mathlib_cmplx_to_cmplx_c<_DataType_input, _DataType_output, desc_sp_cmplx_t>(
341-
q_ref, array1_in, result_out, input_shape, shape_size, result_size, desc, norm);
349+
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, norm);
342350
}
343351
/* real-to-complex, double precision */
344352
else if constexpr (std::is_same<_DataType_input, double>::value &&
@@ -347,36 +355,36 @@ void dpnp_fft_fft_c(DPCTLSyclQueueRef q_ref,
347355
desc_dp_real_t desc(dim);
348356

349357
dpnp_fft_fft_mathlib_real_to_cmplx_c<_DataType_input, double, desc_dp_real_t>(
350-
q_ref, array1_in, result_out, input_shape, shape_size, result_size, desc, norm);
358+
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, norm, real);
351359
}
352360
/* real-to-complex, single precision */
353361
else if constexpr (std::is_same<_DataType_input, float>::value &&
354362
std::is_same<_DataType_output, std::complex<float>>::value)
355363
{
356364
desc_sp_real_t desc(dim); // try: 2 * result_size
357365
dpnp_fft_fft_mathlib_real_to_cmplx_c<_DataType_input, float, desc_sp_real_t>(
358-
q_ref, array1_in, result_out, input_shape, shape_size, result_size, desc, norm);
366+
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, norm, real);
359367
}
360368
else if constexpr (std::is_same<_DataType_input, int32_t>::value ||
361369
std::is_same<_DataType_input, int64_t>::value)
362370
{
363371
double* array1_copy = reinterpret_cast<double*>(dpnp_memory_alloc_c(input_size * sizeof(double)));
364372

365-
shape_elem_type* copy_strides = reinterpret_cast<shape_elem_type*>(dpnp_memory_alloc_c(sizeof(shape_elem_type)));
373+
shape_elem_type* copy_strides = reinterpret_cast<shape_elem_type*>(dpnp_memory_alloc_c(q_ref, sizeof(shape_elem_type)));
366374
*copy_strides = 1;
367-
shape_elem_type* copy_shape = reinterpret_cast<shape_elem_type*>(dpnp_memory_alloc_c(sizeof(shape_elem_type)));
375+
shape_elem_type* copy_shape = reinterpret_cast<shape_elem_type*>(dpnp_memory_alloc_c(q_ref, sizeof(shape_elem_type)));
368376
*copy_shape = input_size;
369377
shape_elem_type copy_shape_size = 1;
370-
dpnp_copyto_c<_DataType_input, double>(array1_copy, input_size, copy_shape_size, copy_shape, copy_strides,
371-
array1_in, input_size, copy_shape_size, copy_shape, copy_strides, NULL);
378+
dpnp_copyto_c<_DataType_input, double>(q_ref, array1_copy, input_size, copy_shape_size, copy_shape, copy_strides,
379+
array1_in, input_size, copy_shape_size, copy_shape, copy_strides, NULL, dep_event_vec_ref);
372380

373381
desc_dp_real_t desc(dim);
374382
dpnp_fft_fft_mathlib_real_to_cmplx_c<double, double, desc_dp_real_t>(
375-
array1_copy, result_out, input_shape, shape_size, result_size, desc, norm);
383+
q_ref, array1_copy, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, norm, real);
376384

377-
dpnp_memory_free_c(array1_copy);
378-
dpnp_memory_free_c(copy_strides);
379-
dpnp_memory_free_c(copy_shape);
385+
dpnp_memory_free_c(q_ref, array1_copy);
386+
dpnp_memory_free_c(q_ref, copy_strides);
387+
dpnp_memory_free_c(q_ref, copy_shape);
380388
}
381389
else
382390
{
@@ -406,7 +414,8 @@ void dpnp_fft_fft_c(const void* array1_in,
406414
long axis,
407415
long input_boundarie,
408416
size_t inverse,
409-
const size_t norm)
417+
const size_t norm,
418+
const size_t real)
410419
{
411420
DPCTLSyclQueueRef q_ref = reinterpret_cast<DPCTLSyclQueueRef>(&DPNP_QUEUE);
412421
DPCTLEventVectorRef dep_event_vec_ref = nullptr;
@@ -420,6 +429,7 @@ void dpnp_fft_fft_c(const void* array1_in,
420429
input_boundarie,
421430
inverse,
422431
norm,
432+
real,
423433
dep_event_vec_ref);
424434
DPCTLEvent_WaitAndThrow(event_ref);
425435
}
@@ -433,6 +443,7 @@ void (*dpnp_fft_fft_default_c)(const void*,
433443
long,
434444
long,
435445
size_t,
446+
const size_t,
436447
const size_t) = dpnp_fft_fft_c<_DataType_input, _DataType_output>;
437448

438449
template <typename _DataType_input, typename _DataType_output>
@@ -446,6 +457,7 @@ DPCTLSyclEventRef (*dpnp_fft_fft_ext_c)(DPCTLSyclQueueRef,
446457
long,
447458
size_t,
448459
const size_t,
460+
const size_t,
449461
const DPCTLEventVectorRef) = dpnp_fft_fft_c<_DataType_input, _DataType_output>;
450462

451463
void func_map_init_fft_func(func_map_t& fmap)

dpnp/fft/dpnp_algo_fft.pyx

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,16 @@ __all__ = [
4242
]
4343

4444
ctypedef void(*fptr_dpnp_fft_fft_t)(void *, void * , shape_elem_type * , shape_elem_type * ,
45-
size_t, long, long, size_t, size_t)
45+
size_t, long, long, size_t, size_t, size_t)
4646

4747

4848
cpdef utils.dpnp_descriptor dpnp_fft(utils.dpnp_descriptor input,
4949
size_t input_boundarie,
5050
size_t output_boundarie,
5151
long axis,
5252
size_t inverse,
53-
size_t norm):
53+
size_t norm,
54+
size_t real):
5455

5556
cdef shape_type_c input_shape = input.shape
5657
cdef shape_type_c output_shape = input_shape
@@ -70,6 +71,6 @@ cpdef utils.dpnp_descriptor dpnp_fft(utils.dpnp_descriptor input,
7071
cdef fptr_dpnp_fft_fft_t func = <fptr_dpnp_fft_fft_t > kernel_data.ptr
7172
# call FPTR function
7273
func(input.get_data(), result.get_data(), input_shape.data(),
73-
output_shape.data(), input_shape.size(), axis_norm, input_boundarie, inverse, norm)
74+
output_shape.data(), input_shape.size(), axis_norm, input_boundarie, inverse, norm, real)
7475

7576
return result

dpnp/fft/dpnp_iface_fft.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -120,13 +120,13 @@ def fft(x1, n=None, axis=-1, norm=None):
120120
pass # let fallback to handle exception
121121
elif norm is not None:
122122
pass
123-
elif axis != -1:
123+
elif n is not None:
124124
pass
125-
elif x1_desc.dtype not in (numpy.complex128, numpy.complex64):
125+
elif axis != -1:
126126
pass
127127
else:
128128
output_boundarie = input_boundarie
129-
return dpnp_fft(x1_desc, input_boundarie, output_boundarie, axis_param, False, norm_.value).get_pyobj()
129+
return dpnp_fft(x1_desc, input_boundarie, output_boundarie, axis_param, False, norm_.value, 0).get_pyobj()
130130
return call_origin(numpy.fft.fft, x1, n, axis, norm)
131131

132132

@@ -246,7 +246,7 @@ def fftshift(x1, axes=None):
246246
if x1_desc.size < 1:
247247
pass # let fallback to handle exception
248248
else:
249-
return dpnp_fft(x1_desc, input_boundarie, output_boundarie, axis_param, False, norm_.value).get_pyobj()
249+
return dpnp_fft(x1_desc, input_boundarie, output_boundarie, axis_param, False, norm_.value, 0).get_pyobj()
250250

251251
return call_origin(numpy.fft.fftshift, x1, axes)
252252

@@ -288,7 +288,7 @@ def hfft(x1, n=None, axis=-1, norm=None):
288288
else:
289289
output_boundarie = input_boundarie
290290

291-
return dpnp_fft(x1_desc, input_boundarie, output_boundarie, axis_param, False, norm_.value).get_pyobj()
291+
return dpnp_fft(x1_desc, input_boundarie, output_boundarie, axis_param, False, norm_.value, 0).get_pyobj()
292292

293293
return call_origin(numpy.fft.hfft, x1, n, axis, norm)
294294

@@ -327,10 +327,12 @@ def ifft(x1, n=None, axis=-1, norm=None):
327327
pass # let fallback to handle exception
328328
elif norm is not None:
329329
pass
330+
elif n is not None:
331+
pass
330332
else:
331333
output_boundarie = input_boundarie
332334

333-
return dpnp_fft(x1_desc, input_boundarie, output_boundarie, axis_param, True, norm_.value).get_pyobj()
335+
return dpnp_fft(x1_desc, input_boundarie, output_boundarie, axis_param, True, norm_.value, 0).get_pyobj()
334336

335337
return call_origin(numpy.fft.ifft, x1, n, axis, norm)
336338

@@ -388,7 +390,7 @@ def ifftshift(x1, axes=None):
388390
if x1_desc.size < 1:
389391
pass # let fallback to handle exception
390392
else:
391-
return dpnp_fft(x1_desc, input_boundarie, output_boundarie, axis_param, False, norm_.value).get_pyobj()
393+
return dpnp_fft(x1_desc, input_boundarie, output_boundarie, axis_param, False, norm_.value, 0).get_pyobj()
392394

393395
return call_origin(numpy.fft.ifftshift, x1, axes)
394396

@@ -476,10 +478,12 @@ def ihfft(x1, n=None, axis=-1, norm=None):
476478
pass # let fallback to handle exception
477479
elif norm is not None:
478480
pass
481+
elif n is not None:
482+
pass
479483
else:
480484
output_boundarie = input_boundarie
481485

482-
return dpnp_fft(x1_desc, input_boundarie, output_boundarie, axis_param, False, norm_.value).get_pyobj()
486+
return dpnp_fft(x1_desc, input_boundarie, output_boundarie, axis_param, False, norm_.value, 0).get_pyobj()
483487

484488
return call_origin(numpy.fft.ihfft, x1, n, axis, norm)
485489

@@ -518,10 +522,12 @@ def irfft(x1, n=None, axis=-1, norm=None):
518522
pass # let fallback to handle exception
519523
elif norm is not None:
520524
pass
525+
elif n is not None:
526+
pass
521527
else:
522528
output_boundarie = 2 * (input_boundarie - 1)
523529

524-
result = dpnp_fft(x1_desc, input_boundarie, output_boundarie, axis_param, True, norm_.value).get_pyobj()
530+
result = dpnp_fft(x1_desc, input_boundarie, output_boundarie, axis_param, True, norm_.value, 1).get_pyobj()
525531
# TODO tmp = utils.create_output_array(result_shape, result_c_type, out)
526532
# tmp = dparray(result.shape, dtype=dpnp.float64)
527533
# for it in range(tmp.size):
@@ -638,16 +644,17 @@ def rfft(x1, n=None, axis=-1, norm=None):
638644
pass # let fallback to handle exception
639645
elif input_boundarie < 1:
640646
pass # let fallback to handle exception
647+
elif axis != -1:
648+
pass
641649
elif norm is not None:
642650
pass
643-
elif x1_desc.ndim > 1:
651+
elif n is not None:
644652
pass
645-
elif x1_desc.dtype not in (numpy.complex128, numpy.complex64):
653+
elif x1_desc.dtype in (numpy.complex128, numpy.complex64):
646654
pass
647655
else:
648656
output_boundarie = input_boundarie // 2 + 1 # rfft specific requirenment
649-
650-
return dpnp_fft(x1_desc, input_boundarie, output_boundarie, axis_param, False, norm_.value).get_pyobj()
657+
return dpnp_fft(x1_desc, input_boundarie, output_boundarie, axis_param, False, norm_.value, 1).get_pyobj()
651658

652659
return call_origin(numpy.fft.rfft, x1, n, axis, norm)
653660

0 commit comments

Comments
 (0)