Skip to content

Commit 2ed3501

Browse files
committed
update backend structure
1 parent 12ad420 commit 2ed3501

File tree

4 files changed

+201
-104
lines changed

4 files changed

+201
-104
lines changed

dpnp/backend/extensions/fft/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
set(python_module_name _fft_impl)
2828
set(_module_src
2929
${CMAKE_CURRENT_SOURCE_DIR}/fft_py.cpp
30+
${CMAKE_CURRENT_SOURCE_DIR}/c2c.cpp
3031
)
3132

3233
pybind11_add_module(${python_module_name} MODULE ${_module_src})

dpnp/backend/extensions/fft/c2c.cpp

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2024, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#include <oneapi/mkl.hpp>
27+
#include <sycl/sycl.hpp>
28+
29+
#include <dpctl4pybind11.hpp>
30+
31+
#include "c2c.hpp"
32+
#include "fft_utils.hpp"
33+
// dpctl tensor headers
34+
#include "utils/memory_overlap.hpp"
35+
#include "utils/output_validation.hpp"
36+
37+
namespace dpnp::extensions::fft
38+
{
39+
namespace mkl_dft = oneapi::mkl::dft;
40+
namespace py = pybind11;
41+
42+
template <mkl_dft::precision prec>
43+
std::pair<sycl::event, sycl::event>
44+
compute_fft(ComplexDescriptorWrapper<prec> &descr,
45+
const dpctl::tensor::usm_ndarray &in,
46+
const dpctl::tensor::usm_ndarray &out,
47+
const bool is_forward,
48+
const std::vector<sycl::event> &depends)
49+
{
50+
// TODO: activate in MKL=2024.2
51+
// bool committed = descr.is_committed();
52+
// if (!committed) {
53+
// throw py::value_error("Descriptor is not committed");
54+
//}
55+
56+
const bool in_place = descr.get_in_place();
57+
if (in_place) {
58+
throw py::value_error(
59+
"Descriptor is defined for in-place FFT while this function is set "
60+
"to compute out-of-place FFT.");
61+
}
62+
63+
const int in_nd = in.get_ndim();
64+
const int out_nd = out.get_ndim();
65+
if ((in_nd != out_nd)) {
66+
throw py::value_error(
67+
"The input and output arrays must have the same dimension.");
68+
}
69+
70+
auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
71+
if (overlap(in, out)) {
72+
throw py::value_error("The input and output arrays are overlapping "
73+
"segments of memory");
74+
}
75+
76+
sycl::queue exec_q = descr.get_queue();
77+
if (!dpctl::utils::queues_are_compatible(exec_q,
78+
{in.get_queue(), out.get_queue()}))
79+
{
80+
throw py::value_error(
81+
"USM allocations are not compatible with the execution queue.");
82+
}
83+
84+
py::ssize_t in_size = in.get_size();
85+
py::ssize_t out_size = out.get_size();
86+
if (in_size != out_size) {
87+
throw py::value_error("The size of the input vector must be "
88+
"equal to the size of the output vector.");
89+
}
90+
91+
size_t src_nelems = in_size;
92+
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(out);
93+
dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(out, src_nelems);
94+
95+
using ScaleT = typename ScaleType<prec>::value_type;
96+
std::complex<ScaleT> *in_ptr = in.get_data<std::complex<ScaleT>>();
97+
std::complex<ScaleT> *out_ptr = out.get_data<std::complex<ScaleT>>();
98+
99+
sycl::event fft_event = {};
100+
std::stringstream error_msg;
101+
bool is_exception_caught = false;
102+
103+
try {
104+
if (is_forward) {
105+
fft_event = oneapi::mkl::dft::compute_forward(
106+
descr.get_descriptor(), in_ptr, out_ptr, depends);
107+
}
108+
else {
109+
fft_event = oneapi::mkl::dft::compute_backward(
110+
descr.get_descriptor(), in_ptr, out_ptr, depends);
111+
}
112+
} catch (oneapi::mkl::exception const &e) {
113+
error_msg
114+
<< "Unexpected MKL exception caught during FFT() call:\nreason: "
115+
<< e.what();
116+
is_exception_caught = true;
117+
} catch (sycl::exception const &e) {
118+
error_msg << "Unexpected SYCL exception caught during FFT() call:\n"
119+
<< e.what();
120+
is_exception_caught = true;
121+
}
122+
if (is_exception_caught) {
123+
throw std::runtime_error(error_msg.str());
124+
}
125+
126+
sycl::event args_ev =
127+
dpctl::utils::keep_args_alive(exec_q, {in, out}, {fft_event});
128+
129+
return std::make_pair(fft_event, args_ev);
130+
}
131+
132+
// Explicit instantiations
133+
template std::pair<sycl::event, sycl::event>
134+
compute_fft(ComplexDescriptorWrapper<mkl_dft::precision::SINGLE> &descr,
135+
const dpctl::tensor::usm_ndarray &in,
136+
const dpctl::tensor::usm_ndarray &out,
137+
const bool is_forward,
138+
const std::vector<sycl::event> &depends);
139+
140+
template std::pair<sycl::event, sycl::event>
141+
compute_fft(ComplexDescriptorWrapper<mkl_dft::precision::DOUBLE> &descr,
142+
const dpctl::tensor::usm_ndarray &in,
143+
const dpctl::tensor::usm_ndarray &out,
144+
const bool is_forward,
145+
const std::vector<sycl::event> &depends);
146+
147+
} // namespace dpnp::extensions::fft

dpnp/backend/extensions/fft/c2c.hpp

Lines changed: 1 addition & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -30,33 +30,11 @@
3030

3131
#include <dpctl4pybind11.hpp>
3232

33-
// dpctl tensor headers
34-
#include "utils/memory_overlap.hpp"
35-
#include "utils/output_validation.hpp"
36-
3733
namespace dpnp::extensions::fft
3834
{
3935
namespace mkl_dft = oneapi::mkl::dft;
4036
namespace py = pybind11;
4137

42-
template <mkl_dft::precision prec>
43-
struct ScaleType
44-
{
45-
using value_type = void;
46-
};
47-
48-
template <>
49-
struct ScaleType<mkl_dft::precision::SINGLE>
50-
{
51-
using value_type = float;
52-
};
53-
54-
template <>
55-
struct ScaleType<mkl_dft::precision::DOUBLE>
56-
{
57-
using value_type = double;
58-
};
59-
6038
template <mkl_dft::precision prec>
6139
class ComplexDescriptorWrapper
6240
{
@@ -254,87 +232,6 @@ std::pair<sycl::event, sycl::event>
254232
const dpctl::tensor::usm_ndarray &in,
255233
const dpctl::tensor::usm_ndarray &out,
256234
const bool is_forward,
257-
const std::vector<sycl::event> &depends)
258-
{
259-
// TODO: activate in MKL=2024.2
260-
// bool committed = descr.is_committed();
261-
// if (!committed) {
262-
// throw py::value_error("Descriptor is not committed");
263-
//}
264-
265-
const bool in_place = descr.get_in_place();
266-
if (in_place) {
267-
throw py::value_error(
268-
"Descriptor is defined for in-place FFT while this function is set "
269-
"to compute out-of-place FFT.");
270-
}
271-
272-
const int in_nd = in.get_ndim();
273-
const int out_nd = out.get_ndim();
274-
if ((in_nd != out_nd)) {
275-
throw py::value_error(
276-
"The input and output arrays must have the same dimension.");
277-
}
278-
279-
auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
280-
if (overlap(in, out)) {
281-
throw py::value_error("The input and output arrays are overlapping "
282-
"segments of memory");
283-
}
284-
285-
sycl::queue exec_q = descr.get_queue();
286-
if (!dpctl::utils::queues_are_compatible(exec_q,
287-
{in.get_queue(), out.get_queue()}))
288-
{
289-
throw py::value_error(
290-
"USM allocations are not compatible with the execution queue.");
291-
}
292-
293-
py::ssize_t in_size = in.get_size();
294-
py::ssize_t out_size = out.get_size();
295-
if (in_size != out_size) {
296-
throw py::value_error("The size of the input vector must be "
297-
"equal to the size of the output vector.");
298-
}
299-
300-
size_t src_nelems = in_size;
301-
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(out);
302-
dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(out, src_nelems);
303-
304-
using ScaleT = typename ScaleType<prec>::value_type;
305-
std::complex<ScaleT> *in_ptr = in.get_data<std::complex<ScaleT>>();
306-
std::complex<ScaleT> *out_ptr = out.get_data<std::complex<ScaleT>>();
307-
308-
sycl::event fft_event = {};
309-
std::stringstream error_msg;
310-
bool is_exception_caught = false;
311-
312-
try {
313-
if (is_forward) {
314-
fft_event = oneapi::mkl::dft::compute_forward(
315-
descr.get_descriptor(), in_ptr, out_ptr, depends);
316-
}
317-
else {
318-
fft_event = oneapi::mkl::dft::compute_backward(
319-
descr.get_descriptor(), in_ptr, out_ptr, depends);
320-
}
321-
} catch (oneapi::mkl::exception const &e) {
322-
error_msg
323-
<< "Unexpected MKL exception caught during FFT() call:\nreason: "
324-
<< e.what();
325-
is_exception_caught = true;
326-
} catch (sycl::exception const &e) {
327-
error_msg << "Unexpected SYCL exception caught during FFT() call:\n"
328-
<< e.what();
329-
is_exception_caught = true;
330-
}
331-
if (is_exception_caught) {
332-
throw std::runtime_error(error_msg.str());
333-
}
334-
335-
sycl::event args_ev =
336-
dpctl::utils::keep_args_alive(exec_q, {in, out}, {fft_event});
235+
const std::vector<sycl::event> &depends);
337236

338-
return std::make_pair(fft_event, args_ev);
339-
}
340237
} // namespace dpnp::extensions::fft
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2024, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#pragma once
27+
28+
#include <oneapi/mkl.hpp>
29+
30+
namespace dpnp::extensions::fft
31+
{
32+
namespace mkl_dft = oneapi::mkl::dft;
33+
34+
template <mkl_dft::precision prec>
35+
struct ScaleType
36+
{
37+
using value_type = void;
38+
};
39+
40+
template <>
41+
struct ScaleType<mkl_dft::precision::SINGLE>
42+
{
43+
using value_type = float;
44+
};
45+
46+
template <>
47+
struct ScaleType<mkl_dft::precision::DOUBLE>
48+
{
49+
using value_type = double;
50+
};
51+
52+
} // namespace dpnp::extensions::fft

0 commit comments

Comments
 (0)