Skip to content

add ifft support #1142

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 54 additions & 18 deletions dpnp/backend/kernels/dpnp_krnl_fft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,9 @@ void dpnp_fft_fft_mathlib_cmplx_to_cmplx_c(DPCTLSyclQueueRef q_ref,
const size_t input_size,
const size_t result_size,
_Descriptor_type& desc,
const size_t norm)
size_t inverse,
double backward_scale,
double forward_scale)
{
if (!shape_size)
{
Expand All @@ -199,9 +201,6 @@ void dpnp_fft_fft_mathlib_cmplx_to_cmplx_c(DPCTLSyclQueueRef q_ref,

const size_t shift = input_shape[shape_size - 1];

double forward_scale = 1.0;
double backward_scale = 1.0 / shift;

desc.set_value(mkl_dft::config_param::BACKWARD_SCALE, backward_scale);
desc.set_value(mkl_dft::config_param::FORWARD_SCALE, forward_scale);
// enum value from math library C interface
Expand All @@ -213,7 +212,11 @@ void dpnp_fft_fft_mathlib_cmplx_to_cmplx_c(DPCTLSyclQueueRef q_ref,
fft_events.reserve(n_iter);

for (size_t i = 0; i < n_iter; ++i) {
fft_events.push_back(mkl_dft::compute_forward(desc, array_1 + i * shift, result + i * shift));
if (inverse) {
fft_events.push_back(mkl_dft::compute_backward(desc, array_1 + i * shift, result + i * shift));
} else {
fft_events.push_back(mkl_dft::compute_forward(desc, array_1 + i * shift, result + i * shift));
}
}

sycl::event::wait(fft_events);
Expand All @@ -234,7 +237,9 @@ void dpnp_fft_fft_mathlib_real_to_cmplx_c(DPCTLSyclQueueRef q_ref,
const size_t input_size,
const size_t result_size,
_Descriptor_type& desc,
const size_t norm,
size_t inverse,
double backward_scale,
double forward_scale,
const size_t real)
{
if (!shape_size)
Expand All @@ -255,19 +260,21 @@ void dpnp_fft_fft_mathlib_real_to_cmplx_c(DPCTLSyclQueueRef q_ref,
const size_t input_shift = input_shape[shape_size - 1];
const size_t result_shift = result_shape[shape_size - 1];;

double forward_scale = 1.0;
double backward_scale = 1.0 / input_shift;

desc.set_value(mkl_dft::config_param::BACKWARD_SCALE, backward_scale);
desc.set_value(mkl_dft::config_param::FORWARD_SCALE, forward_scale);
desc.set_value(mkl_dft::config_param::PLACEMENT, DFTI_NOT_INPLACE);

desc.commit(queue);

std::vector<sycl::event> fft_events;
fft_events.reserve(n_iter);

for (size_t i = 0; i < n_iter; ++i) {
fft_events.push_back(mkl_dft::compute_forward(desc, array_1 + i * input_shift, result + i * result_shift * 2));
if (inverse) {
fft_events.push_back(mkl_dft::compute_backward(desc, array_1 + i * input_shift, result + i * result_shift * 2));
} else {
fft_events.push_back(mkl_dft::compute_forward(desc, array_1 + i * input_shift, result + i * result_shift * 2));
}
}

sycl::event::wait(fft_events);
Expand Down Expand Up @@ -330,6 +337,21 @@ DPCTLSyclEventRef dpnp_fft_fft_c(DPCTLSyclQueueRef q_ref,

size_t dim = input_shape[shape_size - 1];

double backward_scale = 1;
double forward_scale = 1;

if (norm == 0) { // norm = "backward"
backward_scale = 1. / dim;
} else if (norm == 1) { // norm = "forward"
forward_scale = 1. / dim;
} else { // norm = "ortho"
if (inverse) {
backward_scale = 1. / sqrt(dim);
} else {
forward_scale = 1. / sqrt(dim);
}
}

if constexpr (std::is_same<_DataType_output, std::complex<float>>::value ||
std::is_same<_DataType_output, std::complex<double>>::value)
{
Expand All @@ -338,15 +360,15 @@ DPCTLSyclEventRef dpnp_fft_fft_c(DPCTLSyclQueueRef q_ref,
{
desc_dp_cmplx_t desc(dim);
dpnp_fft_fft_mathlib_cmplx_to_cmplx_c<_DataType_input, _DataType_output, desc_dp_cmplx_t>(
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, norm);
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, backward_scale, forward_scale);
}
/* complex-to-complex, single precision */
else if constexpr (std::is_same<_DataType_input, std::complex<float>>::value &&
std::is_same<_DataType_output, std::complex<float>>::value)
{
desc_sp_cmplx_t desc(dim);
dpnp_fft_fft_mathlib_cmplx_to_cmplx_c<_DataType_input, _DataType_output, desc_sp_cmplx_t>(
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, norm);
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, backward_scale, forward_scale);
}
/* real-to-complex, double precision */
else if constexpr (std::is_same<_DataType_input, double>::value &&
Expand All @@ -355,15 +377,15 @@ DPCTLSyclEventRef dpnp_fft_fft_c(DPCTLSyclQueueRef q_ref,
desc_dp_real_t desc(dim);

dpnp_fft_fft_mathlib_real_to_cmplx_c<_DataType_input, double, desc_dp_real_t>(
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, norm, 0);
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, backward_scale, forward_scale, 0);
}
/* real-to-complex, single precision */
else if constexpr (std::is_same<_DataType_input, float>::value &&
std::is_same<_DataType_output, std::complex<float>>::value)
{
desc_sp_real_t desc(dim); // try: 2 * result_size
dpnp_fft_fft_mathlib_real_to_cmplx_c<_DataType_input, float, desc_sp_real_t>(
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, norm, 0);
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, backward_scale, forward_scale, 0);
}
else if constexpr (std::is_same<_DataType_input, int32_t>::value ||
std::is_same<_DataType_input, int64_t>::value)
Expand All @@ -380,7 +402,7 @@ DPCTLSyclEventRef dpnp_fft_fft_c(DPCTLSyclQueueRef q_ref,

desc_dp_real_t desc(dim);
dpnp_fft_fft_mathlib_real_to_cmplx_c<double, double, desc_dp_real_t>(
q_ref, array1_copy, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, norm, 0);
q_ref, array1_copy, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, backward_scale, forward_scale, 0);

dpnp_memory_free_c(q_ref, array1_copy);
dpnp_memory_free_c(q_ref, copy_strides);
Expand Down Expand Up @@ -484,6 +506,20 @@ DPCTLSyclEventRef dpnp_fft_rfft_c(DPCTLSyclQueueRef q_ref,

size_t dim = input_shape[shape_size - 1];

double backward_scale = 1;
double forward_scale = 1;
if (norm == 0) { // norm = "backward"
backward_scale = 1. / dim;
} else if (norm == 1) { // norm = "forward"
forward_scale = 1. / dim;
} else { // norm = "ortho"
if (inverse) {
backward_scale = 1. / sqrt(dim);
} else {
forward_scale = 1. / sqrt(dim);
}
}

if constexpr (std::is_same<_DataType_output, std::complex<float>>::value ||
std::is_same<_DataType_output, std::complex<double>>::value)
{
Expand All @@ -493,15 +529,15 @@ DPCTLSyclEventRef dpnp_fft_rfft_c(DPCTLSyclQueueRef q_ref,
desc_dp_real_t desc(dim);

dpnp_fft_fft_mathlib_real_to_cmplx_c<_DataType_input, double, desc_dp_real_t>(
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, norm, 1l);
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, backward_scale, forward_scale, 1);
}
/* real-to-complex, single precision */
else if constexpr (std::is_same<_DataType_input, float>::value &&
std::is_same<_DataType_output, std::complex<float>>::value)
{
desc_sp_real_t desc(dim); // try: 2 * result_size
dpnp_fft_fft_mathlib_real_to_cmplx_c<_DataType_input, float, desc_sp_real_t>(
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, norm, 1);
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, backward_scale, forward_scale, 1);
}
else if constexpr (std::is_same<_DataType_input, int32_t>::value ||
std::is_same<_DataType_input, int64_t>::value)
Expand All @@ -518,7 +554,7 @@ DPCTLSyclEventRef dpnp_fft_rfft_c(DPCTLSyclQueueRef q_ref,

desc_dp_real_t desc(dim);
dpnp_fft_fft_mathlib_real_to_cmplx_c<double, double, desc_dp_real_t>(
q_ref, array1_copy, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, norm, 1);
q_ref, array1_copy, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, backward_scale, forward_scale, 1);

dpnp_memory_free_c(q_ref, array1_copy);
dpnp_memory_free_c(q_ref, copy_strides);
Expand Down
9 changes: 3 additions & 6 deletions dpnp/fft/dpnp_iface_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,6 @@ def fft(x1, n=None, axis=-1, norm=None):
pass # let fallback to handle exception
elif input_boundarie < 1:
pass # let fallback to handle exception
elif norm is not None:
pass
elif n is not None:
pass
elif axis != -1:
Expand Down Expand Up @@ -308,7 +306,7 @@ def ifft(x1, n=None, axis=-1, norm=None):
"""

x1_desc = dpnp.get_dpnp_descriptor(x1)
if x1_desc and 0:
if x1_desc:
norm_ = get_validated_norm(norm)

if axis is None:
Expand All @@ -325,13 +323,12 @@ def ifft(x1, n=None, axis=-1, norm=None):
pass # let fallback to handle exception
elif input_boundarie < 1:
pass # let fallback to handle exception
elif norm is not None:
pass
elif n is not None:
pass
elif x1_desc.dtype not in (numpy.complex128, numpy.complex64):
pass
else:
output_boundarie = input_boundarie

return dpnp_fft(x1_desc, input_boundarie, output_boundarie, axis_param, True, norm_.value).get_pyobj()

return call_origin(numpy.fft.ifft, x1, n, axis, norm)
Expand Down
28 changes: 22 additions & 6 deletions tests/test_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,45 @@


@pytest.mark.parametrize("type", ['complex128', 'complex64', 'float32', 'float64', 'int32', 'int64'])
def test_fft(type):
@pytest.mark.parametrize("norm", [None, 'forward', 'ortho'])
def test_fft(type, norm):
# 1 dim array
data = numpy.arange(100, dtype=numpy.dtype(type))
# TODO:
# doesn't work correct with `complex64` (not supported)
# dpnp_data = dpnp.arange(100, dtype=dpnp.dtype(type))
dpnp_data = dpnp.array(data)

np_res = numpy.fft.fft(data)
dpnp_res = dpnp.asnumpy(dpnp.fft.fft(dpnp_data))
np_res = numpy.fft.fft(data, norm=norm)
dpnp_res = dpnp.asnumpy(dpnp.fft.fft(dpnp_data, norm=norm))

numpy.testing.assert_allclose(dpnp_res, np_res, rtol=1e-4, atol=1e-7)
assert dpnp_res.dtype == np_res.dtype


@pytest.mark.parametrize("type", ['complex128', 'complex64', 'float32', 'float64', 'int32', 'int64'])
@pytest.mark.parametrize("shape", [(8, 8), (4, 16), (4, 4, 4), (2, 4, 4, 2)])
def test_fft_ndim(type, shape):
@pytest.mark.parametrize("norm", [None, 'forward', 'ortho'])
def test_fft_ndim(type, shape, norm):
np_data = numpy.arange(64, dtype=numpy.dtype(type)).reshape(shape)
dpnp_data = dpnp.arange(64, dtype=numpy.dtype(type)).reshape(shape)

np_res = numpy.fft.fft(np_data)
dpnp_res = dpnp.fft.fft(dpnp_data)
np_res = numpy.fft.fft(np_data, norm=norm)
dpnp_res = dpnp.fft.fft(dpnp_data, norm=norm)

numpy.testing.assert_allclose(dpnp_res, np_res, rtol=1e-4, atol=1e-7)
assert dpnp_res.dtype == np_res.dtype


@pytest.mark.parametrize("type", ['complex128', 'complex64', 'float32', 'float64', 'int32', 'int64'])
@pytest.mark.parametrize("shape", [(64,), (8, 8), (4, 16), (4, 4, 4), (2, 4, 4, 2)])
@pytest.mark.parametrize("norm", [None, 'forward', 'ortho'])
def test_fft_ifft(type, shape, norm):
np_data = numpy.arange(64, dtype=numpy.dtype(type)).reshape(shape)
dpnp_data = dpnp.arange(64, dtype=numpy.dtype(type)).reshape(shape)

np_res = numpy.fft.ifft(np_data, norm=norm)
dpnp_res = dpnp.fft.ifft(dpnp_data, norm=norm)

numpy.testing.assert_allclose(dpnp_res, np_res, rtol=1e-4, atol=1e-7)
assert dpnp_res.dtype == np_res.dtype
Expand Down