Skip to content

Commit 00f4830

Browse files
committed
implement in-place fft
1 parent 89f2d42 commit 00f4830

File tree

9 files changed

+380
-77
lines changed

9 files changed

+380
-77
lines changed

dpnp/backend/extensions/fft/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@
2727
set(python_module_name _fft_impl)
2828
set(_module_src
2929
${CMAKE_CURRENT_SOURCE_DIR}/fft_py.cpp
30-
${CMAKE_CURRENT_SOURCE_DIR}/c2c.cpp
30+
${CMAKE_CURRENT_SOURCE_DIR}/c2c_in.cpp
31+
${CMAKE_CURRENT_SOURCE_DIR}/c2c_out.cpp
3132
)
3233

3334
pybind11_add_module(${python_module_name} MODULE ${_module_src})

dpnp/backend/extensions/fft/c2c.hpp

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -222,12 +222,20 @@ class ComplexDescriptorWrapper
222222
std::unique_ptr<sycl::queue> queue_ptr_;
223223
};
224224

225+
// forward declaration
225226
template <mkl_dft::precision prec>
226227
std::pair<sycl::event, sycl::event>
227-
compute_fft(ComplexDescriptorWrapper<prec> &descr,
228-
const dpctl::tensor::usm_ndarray &in,
229-
const dpctl::tensor::usm_ndarray &out,
230-
const bool is_forward,
231-
const std::vector<sycl::event> &depends);
228+
compute_fft_out_of_place(ComplexDescriptorWrapper<prec> &descr,
229+
const dpctl::tensor::usm_ndarray &in,
230+
const dpctl::tensor::usm_ndarray &out,
231+
const bool is_forward,
232+
const std::vector<sycl::event> &depends);
233+
234+
template <mkl_dft::precision prec>
235+
std::pair<sycl::event, sycl::event>
236+
compute_fft_in_place(ComplexDescriptorWrapper<prec> &descr,
237+
const dpctl::tensor::usm_ndarray &in_out,
238+
const bool is_forward,
239+
const std::vector<sycl::event> &depends);
232240

233241
} // namespace dpnp::extensions::fft
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2024, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#include <oneapi/mkl.hpp>
27+
#include <sycl/sycl.hpp>
28+
29+
#include <dpctl4pybind11.hpp>
30+
31+
#include "c2c.hpp"
32+
#include "fft_utils.hpp"
33+
// dpctl tensor headers
34+
#include "utils/output_validation.hpp"
35+
36+
namespace dpnp::extensions::fft
37+
{
38+
namespace mkl_dft = oneapi::mkl::dft;
39+
namespace py = pybind11;
40+
41+
// in-place FFT computation
42+
template <mkl_dft::precision prec>
43+
std::pair<sycl::event, sycl::event>
44+
compute_fft_in_place(ComplexDescriptorWrapper<prec> &descr,
45+
const dpctl::tensor::usm_ndarray &in_out,
46+
const bool is_forward,
47+
const std::vector<sycl::event> &depends)
48+
{
49+
bool committed = descr.is_committed();
50+
if (!committed) {
51+
throw py::value_error("Descriptor is not committed");
52+
}
53+
54+
const bool in_place = descr.get_in_place();
55+
if (!in_place) {
56+
throw py::value_error("Descriptor is defined for out-of-place FFT "
57+
"while this function is set "
58+
"to compute in-place FFT.");
59+
}
60+
61+
sycl::queue exec_q = descr.get_queue();
62+
if (!dpctl::utils::queues_are_compatible(exec_q, {in_out.get_queue()})) {
63+
throw py::value_error(
64+
"USM allocations are not compatible with the execution queue.");
65+
}
66+
67+
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(in_out);
68+
69+
using ScaleT = typename ScaleType<prec>::value_type;
70+
std::complex<ScaleT> *in_out_ptr = in_out.get_data<std::complex<ScaleT>>();
71+
72+
sycl::event fft_event = {};
73+
std::stringstream error_msg;
74+
bool is_exception_caught = false;
75+
76+
try {
77+
if (is_forward) {
78+
fft_event = oneapi::mkl::dft::compute_forward(
79+
descr.get_descriptor(), in_out_ptr, depends);
80+
}
81+
else {
82+
fft_event = oneapi::mkl::dft::compute_backward(
83+
descr.get_descriptor(), in_out_ptr, depends);
84+
}
85+
} catch (oneapi::mkl::exception const &e) {
86+
error_msg
87+
<< "Unexpected MKL exception caught during FFT() call:\nreason: "
88+
<< e.what();
89+
is_exception_caught = true;
90+
} catch (sycl::exception const &e) {
91+
error_msg << "Unexpected SYCL exception caught during FFT() call:\n"
92+
<< e.what();
93+
is_exception_caught = true;
94+
}
95+
if (is_exception_caught) {
96+
throw std::runtime_error(error_msg.str());
97+
}
98+
99+
sycl::event args_ev =
100+
dpctl::utils::keep_args_alive(exec_q, {in_out}, {fft_event});
101+
102+
return std::make_pair(fft_event, args_ev);
103+
}
104+
105+
// Explicit instantiations
106+
template std::pair<sycl::event, sycl::event> compute_fft_in_place(
107+
ComplexDescriptorWrapper<mkl_dft::precision::SINGLE> &descr,
108+
const dpctl::tensor::usm_ndarray &in_out,
109+
const bool is_forward,
110+
const std::vector<sycl::event> &depends);
111+
112+
template std::pair<sycl::event, sycl::event> compute_fft_in_place(
113+
ComplexDescriptorWrapper<mkl_dft::precision::DOUBLE> &descr,
114+
const dpctl::tensor::usm_ndarray &in_out,
115+
const bool is_forward,
116+
const std::vector<sycl::event> &depends);
117+
118+
} // namespace dpnp::extensions::fft

dpnp/backend/extensions/fft/c2c.cpp renamed to dpnp/backend/extensions/fft/c2c_out.cpp

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,14 @@ namespace dpnp::extensions::fft
3939
namespace mkl_dft = oneapi::mkl::dft;
4040
namespace py = pybind11;
4141

42+
// out-of-place FFT computation
4243
template <mkl_dft::precision prec>
4344
std::pair<sycl::event, sycl::event>
44-
compute_fft(ComplexDescriptorWrapper<prec> &descr,
45-
const dpctl::tensor::usm_ndarray &in,
46-
const dpctl::tensor::usm_ndarray &out,
47-
const bool is_forward,
48-
const std::vector<sycl::event> &depends)
45+
compute_fft_out_of_place(ComplexDescriptorWrapper<prec> &descr,
46+
const dpctl::tensor::usm_ndarray &in,
47+
const dpctl::tensor::usm_ndarray &out,
48+
const bool is_forward,
49+
const std::vector<sycl::event> &depends)
4950
{
5051
bool committed = descr.is_committed();
5152
if (!committed) {
@@ -129,18 +130,18 @@ std::pair<sycl::event, sycl::event>
129130
}
130131

131132
// Explicit instantiations
132-
template std::pair<sycl::event, sycl::event>
133-
compute_fft(ComplexDescriptorWrapper<mkl_dft::precision::SINGLE> &descr,
134-
const dpctl::tensor::usm_ndarray &in,
135-
const dpctl::tensor::usm_ndarray &out,
136-
const bool is_forward,
137-
const std::vector<sycl::event> &depends);
138-
139-
template std::pair<sycl::event, sycl::event>
140-
compute_fft(ComplexDescriptorWrapper<mkl_dft::precision::DOUBLE> &descr,
141-
const dpctl::tensor::usm_ndarray &in,
142-
const dpctl::tensor::usm_ndarray &out,
143-
const bool is_forward,
144-
const std::vector<sycl::event> &depends);
133+
template std::pair<sycl::event, sycl::event> compute_fft_out_of_place(
134+
ComplexDescriptorWrapper<mkl_dft::precision::SINGLE> &descr,
135+
const dpctl::tensor::usm_ndarray &in,
136+
const dpctl::tensor::usm_ndarray &out,
137+
const bool is_forward,
138+
const std::vector<sycl::event> &depends);
139+
140+
template std::pair<sycl::event, sycl::event> compute_fft_out_of_place(
141+
ComplexDescriptorWrapper<mkl_dft::precision::DOUBLE> &descr,
142+
const dpctl::tensor::usm_ndarray &in,
143+
const dpctl::tensor::usm_ndarray &out,
144+
const bool is_forward,
145+
const std::vector<sycl::event> &depends);
145146

146147
} // namespace dpnp::extensions::fft

dpnp/backend/extensions/fft/fft_py.cpp

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,17 +67,33 @@ PYBIND11_MODULE(_fft_impl, m)
6767
{
6868
constexpr mkl_dft::precision single_prec = mkl_dft::precision::SINGLE;
6969
register_complex_descriptor<single_prec>(m, "Complex64Descriptor");
70-
m.def("compute_fft", &fft_ns::compute_fft<single_prec>,
71-
"Compute forward/backward fft using OneMKL DFT library for complex "
72-
"float data types.",
70+
71+
m.def("compute_fft_out_of_place",
72+
&fft_ns::compute_fft_out_of_place<single_prec>,
73+
"Compute out-of-place fft using OneMKL DFT library for complex64 "
74+
"data types.",
7375
py::arg("descriptor"), py::arg("input"), py::arg("output"),
7476
py::arg("is_forward"), py::arg("depends") = py::list());
7577

78+
m.def("compute_fft_in_place", &fft_ns::compute_fft_in_place<single_prec>,
79+
"Compute in-place fft using OneMKL DFT library for complex64 data "
80+
"types.",
81+
py::arg("descriptor"), py::arg("input-output"), py::arg("is_forward"),
82+
py::arg("depends") = py::list());
83+
7684
constexpr mkl_dft::precision double_prec = mkl_dft::precision::DOUBLE;
7785
register_complex_descriptor<double_prec>(m, "Complex128Descriptor");
78-
m.def("compute_fft", &fft_ns::compute_fft<double_prec>,
79-
"Compute forward/backward fft using OneMKL DFT library for complex "
80-
"double data types.",
86+
87+
m.def("compute_fft_out_of_place",
88+
&fft_ns::compute_fft_out_of_place<double_prec>,
89+
"Compute out-of-place fft using OneMKL DFT library for complex128 "
90+
"data types.",
8191
py::arg("descriptor"), py::arg("input"), py::arg("output"),
8292
py::arg("is_forward"), py::arg("depends") = py::list());
93+
94+
m.def("compute_fft_in_place", &fft_ns::compute_fft_in_place<double_prec>,
95+
"Compute in-place fft using OneMKL DFT library for complex128 data "
96+
"types.",
97+
py::arg("descriptor"), py::arg("input-output"), py::arg("is_forward"),
98+
py::arg("depends") = py::list());
8399
}

dpnp/dpnp_iface.py

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
- The functions parameters check
3737
3838
"""
39-
39+
# pylint: disable=protected-access
4040

4141
import os
4242

@@ -53,6 +53,7 @@
5353
from dpnp.random import *
5454

5555
__all__ = [
56+
"are_same_logical_tensors",
5657
"array_equal",
5758
"asnumpy",
5859
"astype",
@@ -127,6 +128,67 @@
127128
__all__ += __all__trigonometric
128129

129130

131+
def are_same_logical_tensors(ar1, ar2):
132+
"""
133+
Check if two arrays are logical views into the same memory.
134+
135+
Parameters
136+
----------
137+
ar1 : {dpnp_array, usm_ndarray}
138+
First input array.
139+
ar2 : {dpnp_array, usm_ndarray}
140+
Second input array.
141+
142+
Returns
143+
-------
144+
out : bool
145+
``True`` if two arrays are logical views into the same memory,
146+
``False`` otherwise.
147+
148+
Examples
149+
--------
150+
>>> import dpnp as np
151+
>>> a = np.array([1, 2, 3])
152+
>>> b = a[:]
153+
>>> a is b
154+
False
155+
>>> np.are_same_logical_tensors(a, b)
156+
True
157+
>>> b[0] = 0
158+
>>> a
159+
array([0, 2, 3])
160+
161+
>>> c = a.copy()
162+
>>> np.are_same_logical_tensors(a, c)
163+
False
164+
165+
"""
166+
check_supported_arrays_type(ar1, ar2)
167+
# Same ndim
168+
nd1 = ar1.ndim
169+
if nd1 != ar2.ndim:
170+
return False
171+
172+
# Same dtype
173+
if ar1.dtype != ar2.dtype:
174+
return False
175+
176+
# Same pointer
177+
if ar1.get_array()._pointer != ar2.get_array()._pointer:
178+
return False
179+
180+
# Same shape
181+
if ar1.shape != ar2.shape:
182+
return False
183+
184+
# Same strides
185+
if ar1.strides != ar2.strides:
186+
return False
187+
188+
# All checks passed: arrays are logical views into the same memory
189+
return True
190+
191+
130192
def array_equal(a1, a2, equal_nan=False):
131193
"""
132194
True if two arrays have the same shape and elements, False otherwise.
@@ -585,7 +647,7 @@ def get_result_array(a, out=None, casting="safe"):
585647
if out is None:
586648
return a
587649

588-
if a is out:
650+
if a is out or dpnp.are_same_logical_tensors(a, out):
589651
return out
590652

591653
dpnp.check_supported_arrays_type(out)

dpnp/fft/dpnp_iface_fft.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def fft(a, n=None, axis=-1, norm=None, out=None):
165165
"""
166166

167167
dpnp.check_supported_arrays_type(a)
168-
return dpnp_fft(a, is_forward=True, n=n, axis=axis, norm=norm, out=out)
168+
return dpnp_fft(a, forward=True, n=n, axis=axis, norm=norm, out=out)
169169

170170

171171
def fft2(x, s=None, axes=(-2, -1), norm=None):
@@ -425,7 +425,7 @@ def ifft(a, n=None, axis=-1, norm=None, out=None):
425425
"""
426426

427427
dpnp.check_supported_arrays_type(a)
428-
return dpnp_fft(a, is_forward=False, n=n, axis=axis, norm=norm, out=out)
428+
return dpnp_fft(a, forward=False, n=n, axis=axis, norm=norm, out=out)
429429

430430

431431
def ifft2(x, s=None, axes=(-2, -1), norm=None):

0 commit comments

Comments
 (0)