Skip to content

Commit 0234590

Browse files
authored
add ifft support (#1142)
* add ifft support
1 parent 6874366 commit 0234590

File tree

3 files changed

+79
-30
lines changed

3 files changed

+79
-30
lines changed

dpnp/backend/kernels/dpnp_krnl_fft.cpp

Lines changed: 54 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,9 @@ void dpnp_fft_fft_mathlib_cmplx_to_cmplx_c(DPCTLSyclQueueRef q_ref,
180180
const size_t input_size,
181181
const size_t result_size,
182182
_Descriptor_type& desc,
183-
const size_t norm)
183+
size_t inverse,
184+
double backward_scale,
185+
double forward_scale)
184186
{
185187
if (!shape_size)
186188
{
@@ -199,9 +201,6 @@ void dpnp_fft_fft_mathlib_cmplx_to_cmplx_c(DPCTLSyclQueueRef q_ref,
199201

200202
const size_t shift = input_shape[shape_size - 1];
201203

202-
double forward_scale = 1.0;
203-
double backward_scale = 1.0 / shift;
204-
205204
desc.set_value(mkl_dft::config_param::BACKWARD_SCALE, backward_scale);
206205
desc.set_value(mkl_dft::config_param::FORWARD_SCALE, forward_scale);
207206
// enum value from math library C interface
@@ -213,7 +212,11 @@ void dpnp_fft_fft_mathlib_cmplx_to_cmplx_c(DPCTLSyclQueueRef q_ref,
213212
fft_events.reserve(n_iter);
214213

215214
for (size_t i = 0; i < n_iter; ++i) {
216-
fft_events.push_back(mkl_dft::compute_forward(desc, array_1 + i * shift, result + i * shift));
215+
if (inverse) {
216+
fft_events.push_back(mkl_dft::compute_backward(desc, array_1 + i * shift, result + i * shift));
217+
} else {
218+
fft_events.push_back(mkl_dft::compute_forward(desc, array_1 + i * shift, result + i * shift));
219+
}
217220
}
218221

219222
sycl::event::wait(fft_events);
@@ -234,7 +237,9 @@ void dpnp_fft_fft_mathlib_real_to_cmplx_c(DPCTLSyclQueueRef q_ref,
234237
const size_t input_size,
235238
const size_t result_size,
236239
_Descriptor_type& desc,
237-
const size_t norm,
240+
size_t inverse,
241+
double backward_scale,
242+
double forward_scale,
238243
const size_t real)
239244
{
240245
if (!shape_size)
@@ -255,19 +260,21 @@ void dpnp_fft_fft_mathlib_real_to_cmplx_c(DPCTLSyclQueueRef q_ref,
255260
const size_t input_shift = input_shape[shape_size - 1];
256261
const size_t result_shift = result_shape[shape_size - 1];;
257262

258-
double forward_scale = 1.0;
259-
double backward_scale = 1.0 / input_shift;
260-
261263
desc.set_value(mkl_dft::config_param::BACKWARD_SCALE, backward_scale);
262264
desc.set_value(mkl_dft::config_param::FORWARD_SCALE, forward_scale);
265+
desc.set_value(mkl_dft::config_param::PLACEMENT, DFTI_NOT_INPLACE);
263266

264267
desc.commit(queue);
265268

266269
std::vector<sycl::event> fft_events;
267270
fft_events.reserve(n_iter);
268271

269272
for (size_t i = 0; i < n_iter; ++i) {
270-
fft_events.push_back(mkl_dft::compute_forward(desc, array_1 + i * input_shift, result + i * result_shift * 2));
273+
if (inverse) {
274+
fft_events.push_back(mkl_dft::compute_backward(desc, array_1 + i * input_shift, result + i * result_shift * 2));
275+
} else {
276+
fft_events.push_back(mkl_dft::compute_forward(desc, array_1 + i * input_shift, result + i * result_shift * 2));
277+
}
271278
}
272279

273280
sycl::event::wait(fft_events);
@@ -330,6 +337,21 @@ DPCTLSyclEventRef dpnp_fft_fft_c(DPCTLSyclQueueRef q_ref,
330337

331338
size_t dim = input_shape[shape_size - 1];
332339

340+
double backward_scale = 1;
341+
double forward_scale = 1;
342+
343+
if (norm == 0) { // norm = "backward"
344+
backward_scale = 1. / dim;
345+
} else if (norm == 1) { // norm = "forward"
346+
forward_scale = 1. / dim;
347+
} else { // norm = "ortho"
348+
if (inverse) {
349+
backward_scale = 1. / sqrt(dim);
350+
} else {
351+
forward_scale = 1. / sqrt(dim);
352+
}
353+
}
354+
333355
if constexpr (std::is_same<_DataType_output, std::complex<float>>::value ||
334356
std::is_same<_DataType_output, std::complex<double>>::value)
335357
{
@@ -338,15 +360,15 @@ DPCTLSyclEventRef dpnp_fft_fft_c(DPCTLSyclQueueRef q_ref,
338360
{
339361
desc_dp_cmplx_t desc(dim);
340362
dpnp_fft_fft_mathlib_cmplx_to_cmplx_c<_DataType_input, _DataType_output, desc_dp_cmplx_t>(
341-
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, norm);
363+
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, backward_scale, forward_scale);
342364
}
343365
/* complex-to-complex, single precision */
344366
else if constexpr (std::is_same<_DataType_input, std::complex<float>>::value &&
345367
std::is_same<_DataType_output, std::complex<float>>::value)
346368
{
347369
desc_sp_cmplx_t desc(dim);
348370
dpnp_fft_fft_mathlib_cmplx_to_cmplx_c<_DataType_input, _DataType_output, desc_sp_cmplx_t>(
349-
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, norm);
371+
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, backward_scale, forward_scale);
350372
}
351373
/* real-to-complex, double precision */
352374
else if constexpr (std::is_same<_DataType_input, double>::value &&
@@ -355,15 +377,15 @@ DPCTLSyclEventRef dpnp_fft_fft_c(DPCTLSyclQueueRef q_ref,
355377
desc_dp_real_t desc(dim);
356378

357379
dpnp_fft_fft_mathlib_real_to_cmplx_c<_DataType_input, double, desc_dp_real_t>(
358-
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, norm, 0);
380+
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, backward_scale, forward_scale, 0);
359381
}
360382
/* real-to-complex, single precision */
361383
else if constexpr (std::is_same<_DataType_input, float>::value &&
362384
std::is_same<_DataType_output, std::complex<float>>::value)
363385
{
364386
desc_sp_real_t desc(dim); // try: 2 * result_size
365387
dpnp_fft_fft_mathlib_real_to_cmplx_c<_DataType_input, float, desc_sp_real_t>(
366-
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, norm, 0);
388+
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, backward_scale, forward_scale, 0);
367389
}
368390
else if constexpr (std::is_same<_DataType_input, int32_t>::value ||
369391
std::is_same<_DataType_input, int64_t>::value)
@@ -380,7 +402,7 @@ DPCTLSyclEventRef dpnp_fft_fft_c(DPCTLSyclQueueRef q_ref,
380402

381403
desc_dp_real_t desc(dim);
382404
dpnp_fft_fft_mathlib_real_to_cmplx_c<double, double, desc_dp_real_t>(
383-
q_ref, array1_copy, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, norm, 0);
405+
q_ref, array1_copy, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, backward_scale, forward_scale, 0);
384406

385407
dpnp_memory_free_c(q_ref, array1_copy);
386408
dpnp_memory_free_c(q_ref, copy_strides);
@@ -484,6 +506,20 @@ DPCTLSyclEventRef dpnp_fft_rfft_c(DPCTLSyclQueueRef q_ref,
484506

485507
size_t dim = input_shape[shape_size - 1];
486508

509+
double backward_scale = 1;
510+
double forward_scale = 1;
511+
if (norm == 0) { // norm = "backward"
512+
backward_scale = 1. / dim;
513+
} else if (norm == 1) { // norm = "forward"
514+
forward_scale = 1. / dim;
515+
} else { // norm = "ortho"
516+
if (inverse) {
517+
backward_scale = 1. / sqrt(dim);
518+
} else {
519+
forward_scale = 1. / sqrt(dim);
520+
}
521+
}
522+
487523
if constexpr (std::is_same<_DataType_output, std::complex<float>>::value ||
488524
std::is_same<_DataType_output, std::complex<double>>::value)
489525
{
@@ -493,15 +529,15 @@ DPCTLSyclEventRef dpnp_fft_rfft_c(DPCTLSyclQueueRef q_ref,
493529
desc_dp_real_t desc(dim);
494530

495531
dpnp_fft_fft_mathlib_real_to_cmplx_c<_DataType_input, double, desc_dp_real_t>(
496-
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, norm, 1l);
532+
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, backward_scale, forward_scale, 1);
497533
}
498534
/* real-to-complex, single precision */
499535
else if constexpr (std::is_same<_DataType_input, float>::value &&
500536
std::is_same<_DataType_output, std::complex<float>>::value)
501537
{
502538
desc_sp_real_t desc(dim); // try: 2 * result_size
503539
dpnp_fft_fft_mathlib_real_to_cmplx_c<_DataType_input, float, desc_sp_real_t>(
504-
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, norm, 1);
540+
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, backward_scale, forward_scale, 1);
505541
}
506542
else if constexpr (std::is_same<_DataType_input, int32_t>::value ||
507543
std::is_same<_DataType_input, int64_t>::value)
@@ -518,7 +554,7 @@ DPCTLSyclEventRef dpnp_fft_rfft_c(DPCTLSyclQueueRef q_ref,
518554

519555
desc_dp_real_t desc(dim);
520556
dpnp_fft_fft_mathlib_real_to_cmplx_c<double, double, desc_dp_real_t>(
521-
q_ref, array1_copy, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, norm, 1);
557+
q_ref, array1_copy, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, backward_scale, forward_scale, 1);
522558

523559
dpnp_memory_free_c(q_ref, array1_copy);
524560
dpnp_memory_free_c(q_ref, copy_strides);

dpnp/fft/dpnp_iface_fft.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,6 @@ def fft(x1, n=None, axis=-1, norm=None):
118118
pass # let fallback to handle exception
119119
elif input_boundarie < 1:
120120
pass # let fallback to handle exception
121-
elif norm is not None:
122-
pass
123121
elif n is not None:
124122
pass
125123
elif axis != -1:
@@ -308,7 +306,7 @@ def ifft(x1, n=None, axis=-1, norm=None):
308306
"""
309307

310308
x1_desc = dpnp.get_dpnp_descriptor(x1)
311-
if x1_desc and 0:
309+
if x1_desc:
312310
norm_ = get_validated_norm(norm)
313311

314312
if axis is None:
@@ -325,13 +323,12 @@ def ifft(x1, n=None, axis=-1, norm=None):
325323
pass # let fallback to handle exception
326324
elif input_boundarie < 1:
327325
pass # let fallback to handle exception
328-
elif norm is not None:
329-
pass
330326
elif n is not None:
331327
pass
328+
elif x1_desc.dtype not in (numpy.complex128, numpy.complex64):
329+
pass
332330
else:
333331
output_boundarie = input_boundarie
334-
335332
return dpnp_fft(x1_desc, input_boundarie, output_boundarie, axis_param, True, norm_.value).get_pyobj()
336333

337334
return call_origin(numpy.fft.ifft, x1, n, axis, norm)

tests/test_fft.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,29 +6,45 @@
66

77

88
@pytest.mark.parametrize("type", ['complex128', 'complex64', 'float32', 'float64', 'int32', 'int64'])
9-
def test_fft(type):
9+
@pytest.mark.parametrize("norm", [None, 'forward', 'ortho'])
10+
def test_fft(type, norm):
1011
# 1 dim array
1112
data = numpy.arange(100, dtype=numpy.dtype(type))
1213
# TODO:
1314
# doesn't work correct with `complex64` (not supported)
1415
# dpnp_data = dpnp.arange(100, dtype=dpnp.dtype(type))
1516
dpnp_data = dpnp.array(data)
1617

17-
np_res = numpy.fft.fft(data)
18-
dpnp_res = dpnp.asnumpy(dpnp.fft.fft(dpnp_data))
18+
np_res = numpy.fft.fft(data, norm=norm)
19+
dpnp_res = dpnp.asnumpy(dpnp.fft.fft(dpnp_data, norm=norm))
1920

2021
numpy.testing.assert_allclose(dpnp_res, np_res, rtol=1e-4, atol=1e-7)
2122
assert dpnp_res.dtype == np_res.dtype
2223

2324

2425
@pytest.mark.parametrize("type", ['complex128', 'complex64', 'float32', 'float64', 'int32', 'int64'])
2526
@pytest.mark.parametrize("shape", [(8, 8), (4, 16), (4, 4, 4), (2, 4, 4, 2)])
26-
def test_fft_ndim(type, shape):
27+
@pytest.mark.parametrize("norm", [None, 'forward', 'ortho'])
28+
def test_fft_ndim(type, shape, norm):
2729
np_data = numpy.arange(64, dtype=numpy.dtype(type)).reshape(shape)
2830
dpnp_data = dpnp.arange(64, dtype=numpy.dtype(type)).reshape(shape)
2931

30-
np_res = numpy.fft.fft(np_data)
31-
dpnp_res = dpnp.fft.fft(dpnp_data)
32+
np_res = numpy.fft.fft(np_data, norm=norm)
33+
dpnp_res = dpnp.fft.fft(dpnp_data, norm=norm)
34+
35+
numpy.testing.assert_allclose(dpnp_res, np_res, rtol=1e-4, atol=1e-7)
36+
assert dpnp_res.dtype == np_res.dtype
37+
38+
39+
@pytest.mark.parametrize("type", ['complex128', 'complex64', 'float32', 'float64', 'int32', 'int64'])
40+
@pytest.mark.parametrize("shape", [(64,), (8, 8), (4, 16), (4, 4, 4), (2, 4, 4, 2)])
41+
@pytest.mark.parametrize("norm", [None, 'forward', 'ortho'])
42+
def test_fft_ifft(type, shape, norm):
43+
np_data = numpy.arange(64, dtype=numpy.dtype(type)).reshape(shape)
44+
dpnp_data = dpnp.arange(64, dtype=numpy.dtype(type)).reshape(shape)
45+
46+
np_res = numpy.fft.ifft(np_data, norm=norm)
47+
dpnp_res = dpnp.fft.ifft(dpnp_data, norm=norm)
3248

3349
numpy.testing.assert_allclose(dpnp_res, np_res, rtol=1e-4, atol=1e-7)
3450
assert dpnp_res.dtype == np_res.dtype

0 commit comments

Comments
 (0)