Skip to content

support real input for dpnp.fft.ifft #1150

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 51 additions & 46 deletions dpnp/backend/kernels/dpnp_krnl_fft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,7 @@ void dpnp_fft_fft_mathlib_cmplx_to_cmplx_c(DPCTLSyclQueueRef q_ref,
const size_t result_size,
_Descriptor_type& desc,
size_t inverse,
double backward_scale,
double forward_scale)
const size_t norm)
{
if (!shape_size)
{
Expand All @@ -201,6 +200,21 @@ void dpnp_fft_fft_mathlib_cmplx_to_cmplx_c(DPCTLSyclQueueRef q_ref,

const size_t shift = input_shape[shape_size - 1];

double backward_scale = 1.;
double forward_scale = 1.;

if (norm == 0) { // norm = "backward"
backward_scale = 1. / shift;
} else if (norm == 1) { // norm = "forward"
forward_scale = 1. / shift;
} else { // norm = "ortho"
if (inverse) {
backward_scale = 1. / sqrt(shift);
} else {
forward_scale = 1. / sqrt(shift);
}
}

desc.set_value(mkl_dft::config_param::BACKWARD_SCALE, backward_scale);
desc.set_value(mkl_dft::config_param::FORWARD_SCALE, forward_scale);
// enum value from math library C interface
Expand Down Expand Up @@ -238,8 +252,7 @@ void dpnp_fft_fft_mathlib_real_to_cmplx_c(DPCTLSyclQueueRef q_ref,
const size_t result_size,
_Descriptor_type& desc,
size_t inverse,
double backward_scale,
double forward_scale,
const size_t norm,
const size_t real)
{
if (!shape_size)
Expand All @@ -258,7 +271,26 @@ void dpnp_fft_fft_mathlib_real_to_cmplx_c(DPCTLSyclQueueRef q_ref,
std::accumulate(input_shape, input_shape + shape_size - 1, 1, std::multiplies<shape_elem_type>());

const size_t input_shift = input_shape[shape_size - 1];
const size_t result_shift = result_shape[shape_size - 1];;
const size_t result_shift = result_shape[shape_size - 1];

double backward_scale = 1.;
double forward_scale = 1.;

if (norm == 0) { // norm = "backward"
if (inverse) {
forward_scale = 1. / result_shift;
} else {
backward_scale = 1. / result_shift;
}
} else if (norm == 1) { // norm = "forward"
if (inverse) {
backward_scale = 1. / result_shift;
} else {
forward_scale = 1. / result_shift;
}
} else { // norm = "ortho"
forward_scale = 1. / sqrt(result_shift);
}

desc.set_value(mkl_dft::config_param::BACKWARD_SCALE, backward_scale);
desc.set_value(mkl_dft::config_param::FORWARD_SCALE, forward_scale);
Expand All @@ -270,11 +302,7 @@ void dpnp_fft_fft_mathlib_real_to_cmplx_c(DPCTLSyclQueueRef q_ref,
fft_events.reserve(n_iter);

for (size_t i = 0; i < n_iter; ++i) {
if (inverse) {
fft_events.push_back(mkl_dft::compute_backward(desc, array_1 + i * input_shift, result + i * result_shift * 2));
} else {
fft_events.push_back(mkl_dft::compute_forward(desc, array_1 + i * input_shift, result + i * result_shift * 2));
}
fft_events.push_back(mkl_dft::compute_forward(desc, array_1 + i * input_shift, result + i * result_shift * 2));
}

sycl::event::wait(fft_events);
Expand Down Expand Up @@ -307,6 +335,11 @@ void dpnp_fft_fft_mathlib_real_to_cmplx_c(DPCTLSyclQueueRef q_ref,
event = queue.submit(kernel_func);
event.wait();

if (inverse) {
event = oneapi::mkl::vm::conj(queue, result_size, reinterpret_cast<std::complex<_DataType_output>*>(result), reinterpret_cast<std::complex<_DataType_output>*>(result));
event.wait();
}

return;
}

Expand Down Expand Up @@ -337,21 +370,6 @@ DPCTLSyclEventRef dpnp_fft_fft_c(DPCTLSyclQueueRef q_ref,

size_t dim = input_shape[shape_size - 1];

double backward_scale = 1;
double forward_scale = 1;

if (norm == 0) { // norm = "backward"
backward_scale = 1. / dim;
} else if (norm == 1) { // norm = "forward"
forward_scale = 1. / dim;
} else { // norm = "ortho"
if (inverse) {
backward_scale = 1. / sqrt(dim);
} else {
forward_scale = 1. / sqrt(dim);
}
}

if constexpr (std::is_same<_DataType_output, std::complex<float>>::value ||
std::is_same<_DataType_output, std::complex<double>>::value)
{
Expand All @@ -360,15 +378,15 @@ DPCTLSyclEventRef dpnp_fft_fft_c(DPCTLSyclQueueRef q_ref,
{
desc_dp_cmplx_t desc(dim);
dpnp_fft_fft_mathlib_cmplx_to_cmplx_c<_DataType_input, _DataType_output, desc_dp_cmplx_t>(
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, backward_scale, forward_scale);
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, norm);
}
/* complex-to-complex, single precision */
else if constexpr (std::is_same<_DataType_input, std::complex<float>>::value &&
std::is_same<_DataType_output, std::complex<float>>::value)
{
desc_sp_cmplx_t desc(dim);
dpnp_fft_fft_mathlib_cmplx_to_cmplx_c<_DataType_input, _DataType_output, desc_sp_cmplx_t>(
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, backward_scale, forward_scale);
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, norm);
}
/* real-to-complex, double precision */
else if constexpr (std::is_same<_DataType_input, double>::value &&
Expand All @@ -377,15 +395,15 @@ DPCTLSyclEventRef dpnp_fft_fft_c(DPCTLSyclQueueRef q_ref,
desc_dp_real_t desc(dim);

dpnp_fft_fft_mathlib_real_to_cmplx_c<_DataType_input, double, desc_dp_real_t>(
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, backward_scale, forward_scale, 0);
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, norm, 0);
}
/* real-to-complex, single precision */
else if constexpr (std::is_same<_DataType_input, float>::value &&
std::is_same<_DataType_output, std::complex<float>>::value)
{
desc_sp_real_t desc(dim); // try: 2 * result_size
dpnp_fft_fft_mathlib_real_to_cmplx_c<_DataType_input, float, desc_sp_real_t>(
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, backward_scale, forward_scale, 0);
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, norm, 0);
}
else if constexpr (std::is_same<_DataType_input, int32_t>::value ||
std::is_same<_DataType_input, int64_t>::value)
Expand All @@ -402,7 +420,7 @@ DPCTLSyclEventRef dpnp_fft_fft_c(DPCTLSyclQueueRef q_ref,

desc_dp_real_t desc(dim);
dpnp_fft_fft_mathlib_real_to_cmplx_c<double, double, desc_dp_real_t>(
q_ref, array1_copy, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, backward_scale, forward_scale, 0);
q_ref, array1_copy, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, norm, 0);

dpnp_memory_free_c(q_ref, array1_copy);
dpnp_memory_free_c(q_ref, copy_strides);
Expand Down Expand Up @@ -506,19 +524,6 @@ DPCTLSyclEventRef dpnp_fft_rfft_c(DPCTLSyclQueueRef q_ref,

size_t dim = input_shape[shape_size - 1];

double backward_scale = 1;
double forward_scale = 1;
if (norm == 0) { // norm = "backward"
backward_scale = 1. / dim;
} else if (norm == 1) { // norm = "forward"
forward_scale = 1. / dim;
} else { // norm = "ortho"
if (inverse) {
backward_scale = 1. / sqrt(dim);
} else {
forward_scale = 1. / sqrt(dim);
}
}

if constexpr (std::is_same<_DataType_output, std::complex<float>>::value ||
std::is_same<_DataType_output, std::complex<double>>::value)
Expand All @@ -529,15 +534,15 @@ DPCTLSyclEventRef dpnp_fft_rfft_c(DPCTLSyclQueueRef q_ref,
desc_dp_real_t desc(dim);

dpnp_fft_fft_mathlib_real_to_cmplx_c<_DataType_input, double, desc_dp_real_t>(
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, backward_scale, forward_scale, 1);
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, norm, 1);
}
/* real-to-complex, single precision */
else if constexpr (std::is_same<_DataType_input, float>::value &&
std::is_same<_DataType_output, std::complex<float>>::value)
{
desc_sp_real_t desc(dim); // try: 2 * result_size
dpnp_fft_fft_mathlib_real_to_cmplx_c<_DataType_input, float, desc_sp_real_t>(
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, backward_scale, forward_scale, 1);
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, norm, 1);
}
else if constexpr (std::is_same<_DataType_input, int32_t>::value ||
std::is_same<_DataType_input, int64_t>::value)
Expand All @@ -554,7 +559,7 @@ DPCTLSyclEventRef dpnp_fft_rfft_c(DPCTLSyclQueueRef q_ref,

desc_dp_real_t desc(dim);
dpnp_fft_fft_mathlib_real_to_cmplx_c<double, double, desc_dp_real_t>(
q_ref, array1_copy, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, backward_scale, forward_scale, 1);
q_ref, array1_copy, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, norm, 1);

dpnp_memory_free_c(q_ref, array1_copy);
dpnp_memory_free_c(q_ref, copy_strides);
Expand Down
8 changes: 3 additions & 5 deletions dpnp/fft/dpnp_iface_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,9 +297,9 @@ def ifft(x1, n=None, axis=-1, norm=None):

Limitations
-----------
Parameter ``norm`` is unsupported.
Parameter ``x1`` supports ``dpnp.int32``, ``dpnp.int64``, ``dpnp.float32``, ``dpnp.float64`` and
``dpnp.complex128`` datatypes only.
Parameter ``axis`` is supported with its default value.
Parameter ``x1`` supports ``dpnp.int32``, ``dpnp.int64``, ``dpnp.float32``, ``dpnp.float64``,
``dpnp.complex64``and ``dpnp.complex128`` datatypes only.

For full documentation refer to :obj:`numpy.fft.ifft`.

Expand All @@ -325,8 +325,6 @@ def ifft(x1, n=None, axis=-1, norm=None):
pass # let fallback to handle exception
elif n is not None:
pass
elif x1_desc.dtype not in (numpy.complex128, numpy.complex64):
pass
else:
output_boundarie = input_boundarie
return dpnp_fft(x1_desc, input_boundarie, output_boundarie, axis_param, True, norm_.value).get_pyobj()
Expand Down