Skip to content

Commit 395f051

Browse files
support ifft with real input (#1150)
Co-authored-by: Alexander-Makaryev <[email protected]>
1 parent f8b0ce8 commit 395f051

File tree

2 files changed

+54
-51
lines changed

2 files changed

+54
-51
lines changed

dpnp/backend/kernels/dpnp_krnl_fft.cpp

Lines changed: 51 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -181,8 +181,7 @@ void dpnp_fft_fft_mathlib_cmplx_to_cmplx_c(DPCTLSyclQueueRef q_ref,
181181
const size_t result_size,
182182
_Descriptor_type& desc,
183183
size_t inverse,
184-
double backward_scale,
185-
double forward_scale)
184+
const size_t norm)
186185
{
187186
if (!shape_size)
188187
{
@@ -201,6 +200,21 @@ void dpnp_fft_fft_mathlib_cmplx_to_cmplx_c(DPCTLSyclQueueRef q_ref,
201200

202201
const size_t shift = input_shape[shape_size - 1];
203202

203+
double backward_scale = 1.;
204+
double forward_scale = 1.;
205+
206+
if (norm == 0) { // norm = "backward"
207+
backward_scale = 1. / shift;
208+
} else if (norm == 1) { // norm = "forward"
209+
forward_scale = 1. / shift;
210+
} else { // norm = "ortho"
211+
if (inverse) {
212+
backward_scale = 1. / sqrt(shift);
213+
} else {
214+
forward_scale = 1. / sqrt(shift);
215+
}
216+
}
217+
204218
desc.set_value(mkl_dft::config_param::BACKWARD_SCALE, backward_scale);
205219
desc.set_value(mkl_dft::config_param::FORWARD_SCALE, forward_scale);
206220
// enum value from math library C interface
@@ -238,8 +252,7 @@ void dpnp_fft_fft_mathlib_real_to_cmplx_c(DPCTLSyclQueueRef q_ref,
238252
const size_t result_size,
239253
_Descriptor_type& desc,
240254
size_t inverse,
241-
double backward_scale,
242-
double forward_scale,
255+
const size_t norm,
243256
const size_t real)
244257
{
245258
if (!shape_size)
@@ -258,7 +271,26 @@ void dpnp_fft_fft_mathlib_real_to_cmplx_c(DPCTLSyclQueueRef q_ref,
258271
std::accumulate(input_shape, input_shape + shape_size - 1, 1, std::multiplies<shape_elem_type>());
259272

260273
const size_t input_shift = input_shape[shape_size - 1];
261-
const size_t result_shift = result_shape[shape_size - 1];;
274+
const size_t result_shift = result_shape[shape_size - 1];
275+
276+
double backward_scale = 1.;
277+
double forward_scale = 1.;
278+
279+
if (norm == 0) { // norm = "backward"
280+
if (inverse) {
281+
forward_scale = 1. / result_shift;
282+
} else {
283+
backward_scale = 1. / result_shift;
284+
}
285+
} else if (norm == 1) { // norm = "forward"
286+
if (inverse) {
287+
backward_scale = 1. / result_shift;
288+
} else {
289+
forward_scale = 1. / result_shift;
290+
}
291+
} else { // norm = "ortho"
292+
forward_scale = 1. / sqrt(result_shift);
293+
}
262294

263295
desc.set_value(mkl_dft::config_param::BACKWARD_SCALE, backward_scale);
264296
desc.set_value(mkl_dft::config_param::FORWARD_SCALE, forward_scale);
@@ -270,11 +302,7 @@ void dpnp_fft_fft_mathlib_real_to_cmplx_c(DPCTLSyclQueueRef q_ref,
270302
fft_events.reserve(n_iter);
271303

272304
for (size_t i = 0; i < n_iter; ++i) {
273-
if (inverse) {
274-
fft_events.push_back(mkl_dft::compute_backward(desc, array_1 + i * input_shift, result + i * result_shift * 2));
275-
} else {
276-
fft_events.push_back(mkl_dft::compute_forward(desc, array_1 + i * input_shift, result + i * result_shift * 2));
277-
}
305+
fft_events.push_back(mkl_dft::compute_forward(desc, array_1 + i * input_shift, result + i * result_shift * 2));
278306
}
279307

280308
sycl::event::wait(fft_events);
@@ -307,6 +335,11 @@ void dpnp_fft_fft_mathlib_real_to_cmplx_c(DPCTLSyclQueueRef q_ref,
307335
event = queue.submit(kernel_func);
308336
event.wait();
309337

338+
if (inverse) {
339+
event = oneapi::mkl::vm::conj(queue, result_size, reinterpret_cast<std::complex<_DataType_output>*>(result), reinterpret_cast<std::complex<_DataType_output>*>(result));
340+
event.wait();
341+
}
342+
310343
return;
311344
}
312345

@@ -337,21 +370,6 @@ DPCTLSyclEventRef dpnp_fft_fft_c(DPCTLSyclQueueRef q_ref,
337370

338371
size_t dim = input_shape[shape_size - 1];
339372

340-
double backward_scale = 1;
341-
double forward_scale = 1;
342-
343-
if (norm == 0) { // norm = "backward"
344-
backward_scale = 1. / dim;
345-
} else if (norm == 1) { // norm = "forward"
346-
forward_scale = 1. / dim;
347-
} else { // norm = "ortho"
348-
if (inverse) {
349-
backward_scale = 1. / sqrt(dim);
350-
} else {
351-
forward_scale = 1. / sqrt(dim);
352-
}
353-
}
354-
355373
if constexpr (std::is_same<_DataType_output, std::complex<float>>::value ||
356374
std::is_same<_DataType_output, std::complex<double>>::value)
357375
{
@@ -360,15 +378,15 @@ DPCTLSyclEventRef dpnp_fft_fft_c(DPCTLSyclQueueRef q_ref,
360378
{
361379
desc_dp_cmplx_t desc(dim);
362380
dpnp_fft_fft_mathlib_cmplx_to_cmplx_c<_DataType_input, _DataType_output, desc_dp_cmplx_t>(
363-
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, backward_scale, forward_scale);
381+
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, norm);
364382
}
365383
/* complex-to-complex, single precision */
366384
else if constexpr (std::is_same<_DataType_input, std::complex<float>>::value &&
367385
std::is_same<_DataType_output, std::complex<float>>::value)
368386
{
369387
desc_sp_cmplx_t desc(dim);
370388
dpnp_fft_fft_mathlib_cmplx_to_cmplx_c<_DataType_input, _DataType_output, desc_sp_cmplx_t>(
371-
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, backward_scale, forward_scale);
389+
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, norm);
372390
}
373391
/* real-to-complex, double precision */
374392
else if constexpr (std::is_same<_DataType_input, double>::value &&
@@ -377,15 +395,15 @@ DPCTLSyclEventRef dpnp_fft_fft_c(DPCTLSyclQueueRef q_ref,
377395
desc_dp_real_t desc(dim);
378396

379397
dpnp_fft_fft_mathlib_real_to_cmplx_c<_DataType_input, double, desc_dp_real_t>(
380-
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, backward_scale, forward_scale, 0);
398+
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, norm, 0);
381399
}
382400
/* real-to-complex, single precision */
383401
else if constexpr (std::is_same<_DataType_input, float>::value &&
384402
std::is_same<_DataType_output, std::complex<float>>::value)
385403
{
386404
desc_sp_real_t desc(dim); // try: 2 * result_size
387405
dpnp_fft_fft_mathlib_real_to_cmplx_c<_DataType_input, float, desc_sp_real_t>(
388-
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, backward_scale, forward_scale, 0);
406+
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, norm, 0);
389407
}
390408
else if constexpr (std::is_same<_DataType_input, int32_t>::value ||
391409
std::is_same<_DataType_input, int64_t>::value)
@@ -402,7 +420,7 @@ DPCTLSyclEventRef dpnp_fft_fft_c(DPCTLSyclQueueRef q_ref,
402420

403421
desc_dp_real_t desc(dim);
404422
dpnp_fft_fft_mathlib_real_to_cmplx_c<double, double, desc_dp_real_t>(
405-
q_ref, array1_copy, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, backward_scale, forward_scale, 0);
423+
q_ref, array1_copy, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, norm, 0);
406424

407425
dpnp_memory_free_c(q_ref, array1_copy);
408426
dpnp_memory_free_c(q_ref, copy_strides);
@@ -506,19 +524,6 @@ DPCTLSyclEventRef dpnp_fft_rfft_c(DPCTLSyclQueueRef q_ref,
506524

507525
size_t dim = input_shape[shape_size - 1];
508526

509-
double backward_scale = 1;
510-
double forward_scale = 1;
511-
if (norm == 0) { // norm = "backward"
512-
backward_scale = 1. / dim;
513-
} else if (norm == 1) { // norm = "forward"
514-
forward_scale = 1. / dim;
515-
} else { // norm = "ortho"
516-
if (inverse) {
517-
backward_scale = 1. / sqrt(dim);
518-
} else {
519-
forward_scale = 1. / sqrt(dim);
520-
}
521-
}
522527

523528
if constexpr (std::is_same<_DataType_output, std::complex<float>>::value ||
524529
std::is_same<_DataType_output, std::complex<double>>::value)
@@ -529,15 +534,15 @@ DPCTLSyclEventRef dpnp_fft_rfft_c(DPCTLSyclQueueRef q_ref,
529534
desc_dp_real_t desc(dim);
530535

531536
dpnp_fft_fft_mathlib_real_to_cmplx_c<_DataType_input, double, desc_dp_real_t>(
532-
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, backward_scale, forward_scale, 1);
537+
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, norm, 1);
533538
}
534539
/* real-to-complex, single precision */
535540
else if constexpr (std::is_same<_DataType_input, float>::value &&
536541
std::is_same<_DataType_output, std::complex<float>>::value)
537542
{
538543
desc_sp_real_t desc(dim); // try: 2 * result_size
539544
dpnp_fft_fft_mathlib_real_to_cmplx_c<_DataType_input, float, desc_sp_real_t>(
540-
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, backward_scale, forward_scale, 1);
545+
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, norm, 1);
541546
}
542547
else if constexpr (std::is_same<_DataType_input, int32_t>::value ||
543548
std::is_same<_DataType_input, int64_t>::value)
@@ -554,7 +559,7 @@ DPCTLSyclEventRef dpnp_fft_rfft_c(DPCTLSyclQueueRef q_ref,
554559

555560
desc_dp_real_t desc(dim);
556561
dpnp_fft_fft_mathlib_real_to_cmplx_c<double, double, desc_dp_real_t>(
557-
q_ref, array1_copy, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, backward_scale, forward_scale, 1);
562+
q_ref, array1_copy, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, norm, 1);
558563

559564
dpnp_memory_free_c(q_ref, array1_copy);
560565
dpnp_memory_free_c(q_ref, copy_strides);

dpnp/fft/dpnp_iface_fft.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -297,9 +297,9 @@ def ifft(x1, n=None, axis=-1, norm=None):
297297
298298
Limitations
299299
-----------
300-
Parameter ``norm`` is unsupported.
301-
Parameter ``x1`` supports ``dpnp.int32``, ``dpnp.int64``, ``dpnp.float32``, ``dpnp.float64`` and
302-
``dpnp.complex128`` datatypes only.
300+
Parameter ``axis`` is supported with its default value.
301+
Parameter ``x1`` supports ``dpnp.int32``, ``dpnp.int64``, ``dpnp.float32``, ``dpnp.float64``,
302+
``dpnp.complex64``and ``dpnp.complex128`` datatypes only.
303303
304304
For full documentation refer to :obj:`numpy.fft.ifft`.
305305
@@ -325,8 +325,6 @@ def ifft(x1, n=None, axis=-1, norm=None):
325325
pass # let fallback to handle exception
326326
elif n is not None:
327327
pass
328-
elif x1_desc.dtype not in (numpy.complex128, numpy.complex64):
329-
pass
330328
else:
331329
output_boundarie = input_boundarie
332330
return dpnp_fft(x1_desc, input_boundarie, output_boundarie, axis_param, True, norm_.value).get_pyobj()

0 commit comments

Comments
 (0)