Skip to content

Commit 9fd03f1

Browse files
Implement common_gesvd_checks
1 parent bbb46bd commit 9fd03f1

File tree

3 files changed

+135
-185
lines changed

3 files changed

+135
-185
lines changed

dpnp/backend/extensions/lapack/gesvd.cpp

Lines changed: 6 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,13 @@
2626
#include <pybind11/pybind11.h>
2727

2828
// dpctl tensor headers
29-
#include "utils/memory_overlap.hpp"
3029
#include "utils/type_utils.hpp"
3130

3231
#include "common_helpers.hpp"
3332
#include "gesvd.hpp"
3433
#include "gesvd_common_utils.hpp"
3534
#include "types_matrix.hpp"
3635

37-
#include "dpnp_utils.hpp"
38-
3936
namespace dpnp::extensions::lapack
4037
{
4138
namespace mkl_lapack = oneapi::mkl::lapack;
@@ -154,103 +151,19 @@ std::pair<sycl::event, sycl::event>
154151
dpctl::tensor::usm_ndarray out_vt,
155152
const std::vector<sycl::event> &depends)
156153
{
157-
const int a_array_nd = a_array.get_ndim();
158-
const int out_u_array_nd = out_u.get_ndim();
159-
const int out_s_array_nd = out_s.get_ndim();
160-
const int out_vt_array_nd = out_vt.get_ndim();
161-
162-
if (a_array_nd != 2) {
163-
throw py::value_error(
164-
"The input array has ndim=" + std::to_string(a_array_nd) +
165-
", but a 2-dimensional array is expected.");
166-
}
154+
constexpr int expected_a_u_vt_ndim = 2;
155+
constexpr int expected_s_ndim = 1;
167156

168-
if (out_s_array_nd != 1) {
169-
throw py::value_error("The output array of singular values has ndim=" +
170-
std::to_string(out_s_array_nd) +
171-
", but a 1-dimensional array is expected.");
172-
}
173-
174-
if (jobu_val == 'N' && jobvt_val == 'N') {
175-
if (out_u_array_nd != 0) {
176-
throw py::value_error(
177-
"The output array of the left singular vectors has ndim=" +
178-
std::to_string(out_u_array_nd) +
179-
", but it is not used and should have ndim=0.");
180-
}
181-
if (out_vt_array_nd != 0) {
182-
throw py::value_error(
183-
"The output array of the right singular vectors has ndim=" +
184-
std::to_string(out_vt_array_nd) +
185-
", but it is not used and should have ndim=0.");
186-
}
187-
}
188-
else {
189-
if (out_u_array_nd != 2) {
190-
throw py::value_error(
191-
"The output array of the left singular vectors has ndim=" +
192-
std::to_string(out_u_array_nd) +
193-
", but a 2-dimensional array is expected.");
194-
}
195-
if (out_vt_array_nd != 2) {
196-
throw py::value_error(
197-
"The output array of the right singular vectors has ndim=" +
198-
std::to_string(out_vt_array_nd) +
199-
", but a 2-dimensional array is expected.");
200-
}
201-
}
157+
gesvd_utils::common_gesvd_checks(exec_q, a_array, out_s, out_u, out_vt,
158+
jobu_val, jobvt_val, expected_a_u_vt_ndim,
159+
expected_s_ndim);
202160

203-
// check compatibility of execution queue and allocation queue
204-
if (!dpctl::utils::queues_are_compatible(
205-
exec_q, {a_array.get_queue(), out_s.get_queue(), out_u.get_queue(),
206-
out_vt.get_queue()}))
207-
{
208-
throw std::runtime_error(
209-
"USM allocations are not compatible with the execution queue.");
210-
}
211-
212-
auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
213-
if (overlap(a_array, out_s) || overlap(a_array, out_u) ||
214-
overlap(a_array, out_vt) || overlap(out_s, out_u) ||
215-
overlap(out_s, out_vt) || overlap(out_u, out_vt))
216-
{
217-
throw py::value_error("Arrays have overlapping segments of memory");
218-
}
219-
220-
bool is_a_array_f_contig = a_array.is_f_contiguous();
221-
if (!is_a_array_f_contig) {
222-
throw py::value_error("The input array must be F-contiguous");
223-
}
224-
225-
bool is_out_u_array_f_contig = out_u.is_f_contiguous();
226-
bool is_out_vt_array_f_contig = out_vt.is_f_contiguous();
227-
228-
if (!is_out_u_array_f_contig || !is_out_vt_array_f_contig) {
229-
throw py::value_error("The output arrays of the left and right "
230-
"singular vectors must be F-contiguous");
231-
}
232-
233-
bool is_out_s_array_c_contig = out_s.is_c_contiguous();
234-
bool is_out_s_array_f_contig = out_s.is_f_contiguous();
235-
236-
if (!is_out_s_array_c_contig || !is_out_s_array_f_contig) {
237-
throw py::value_error("The output array of singular values "
238-
"must be contiguous");
239-
}
161+
// TODO: check non_zero shape
240162

241163
auto array_types = dpctl_td_ns::usm_ndarray_types();
242164
int a_array_type_id =
243165
array_types.typenum_to_lookup_id(a_array.get_typenum());
244-
int out_u_type_id = array_types.typenum_to_lookup_id(out_u.get_typenum());
245166
int out_s_type_id = array_types.typenum_to_lookup_id(out_s.get_typenum());
246-
int out_vt_type_id = array_types.typenum_to_lookup_id(out_vt.get_typenum());
247-
248-
if (a_array_type_id != out_u_type_id || a_array_type_id != out_vt_type_id) {
249-
throw py::type_error(
250-
"Input array, output left singular vectors array, "
251-
"and outpuy right singular vectors array must have "
252-
"the same data type");
253-
}
254167

255168
gesvd_impl_fn_ptr_t gesvd_fn =
256169
gesvd_dispatch_table[a_array_type_id][out_s_type_id];

dpnp/backend/extensions/lapack/gesvd_batch.cpp

Lines changed: 6 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,13 @@
2626
#include <pybind11/pybind11.h>
2727

2828
// dpctl tensor headers
29-
#include "utils/memory_overlap.hpp"
3029
#include "utils/type_utils.hpp"
3130

3231
#include "common_helpers.hpp"
3332
#include "gesvd.hpp"
3433
#include "gesvd_common_utils.hpp"
3534
#include "types_matrix.hpp"
3635

37-
#include "dpnp_utils.hpp"
38-
3936
namespace dpnp::extensions::lapack
4037
{
4138
namespace mkl_lapack = oneapi::mkl::lapack;
@@ -214,102 +211,19 @@ std::pair<sycl::event, sycl::event>
214211
dpctl::tensor::usm_ndarray out_vt,
215212
const std::vector<sycl::event> &depends)
216213
{
217-
const int a_array_nd = a_array.get_ndim();
218-
const int out_u_array_nd = out_u.get_ndim();
219-
const int out_s_array_nd = out_s.get_ndim();
220-
const int out_vt_array_nd = out_vt.get_ndim();
221-
222-
if (a_array_nd != 3) {
223-
throw py::value_error(
224-
"The input array has ndim=" + std::to_string(a_array_nd) +
225-
", but a 3-dimensional array is expected.");
226-
}
214+
constexpr int expected_a_u_vt_ndim = 3;
215+
constexpr int expected_s_ndim = 2;
227216

228-
if (out_s_array_nd != 2) {
229-
throw py::value_error("The output array of singular values has ndim=" +
230-
std::to_string(out_s_array_nd) +
231-
", but a 2-dimensional array is expected.");
232-
}
233-
234-
if (jobu_val == 'N' && jobvt_val == 'N') {
235-
if (out_u_array_nd != 0) {
236-
throw py::value_error(
237-
"The output array of the left singular vectors has ndim=" +
238-
std::to_string(out_u_array_nd) +
239-
", but it is not used and should have ndim=0.");
240-
}
241-
if (out_vt_array_nd != 0) {
242-
throw py::value_error(
243-
"The output array of the right singular vectors has ndim=" +
244-
std::to_string(out_vt_array_nd) +
245-
", but it is not used and should have ndim=0.");
246-
}
247-
}
248-
else {
249-
if (out_u_array_nd != 3) {
250-
throw py::value_error(
251-
"The output array of the left singular vectors has ndim=" +
252-
std::to_string(out_u_array_nd) +
253-
", but a 3-dimensional array is expected.");
254-
}
255-
if (out_vt_array_nd != 3) {
256-
throw py::value_error(
257-
"The output array of the right singular vectors has ndim=" +
258-
std::to_string(out_vt_array_nd) +
259-
", but a 3-dimensional array is expected.");
260-
}
261-
}
217+
gesvd_utils::common_gesvd_checks(exec_q, a_array, out_s, out_u, out_vt,
218+
jobu_val, jobvt_val, expected_a_u_vt_ndim,
219+
expected_s_ndim);
262220

263-
// check compatibility of execution queue and allocation queue
264-
if (!dpctl::utils::queues_are_compatible(
265-
exec_q, {a_array.get_queue(), out_s.get_queue(), out_u.get_queue(),
266-
out_vt.get_queue()}))
267-
{
268-
throw std::runtime_error(
269-
"USM allocations are not compatible with the execution queue.");
270-
}
271-
272-
auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
273-
if (overlap(a_array, out_s) || overlap(a_array, out_u) ||
274-
overlap(a_array, out_vt) || overlap(out_s, out_u) ||
275-
overlap(out_s, out_vt) || overlap(out_u, out_vt))
276-
{
277-
throw py::value_error("Arrays have overlapping segments of memory");
278-
}
279-
280-
bool is_a_array_f_contig = a_array.is_f_contiguous();
281-
if (!is_a_array_f_contig) {
282-
throw py::value_error("The input array must be F-contiguous");
283-
}
284-
285-
bool is_out_u_array_f_contig = out_u.is_f_contiguous();
286-
bool is_out_vt_array_f_contig = out_vt.is_f_contiguous();
287-
288-
if (!is_out_u_array_f_contig || !is_out_vt_array_f_contig) {
289-
throw py::value_error("The output arrays of the left and right "
290-
"singular vectors must be F-contiguous");
291-
}
292-
293-
bool is_out_s_array_c_contig = out_s.is_c_contiguous();
294-
295-
if (!is_out_s_array_c_contig) {
296-
throw py::value_error("The output array of singular values "
297-
"must be C-contiguous");
298-
}
221+
// TODO: check non_zero shape
299222

300223
auto array_types = dpctl_td_ns::usm_ndarray_types();
301224
int a_array_type_id =
302225
array_types.typenum_to_lookup_id(a_array.get_typenum());
303-
int out_u_type_id = array_types.typenum_to_lookup_id(out_u.get_typenum());
304226
int out_s_type_id = array_types.typenum_to_lookup_id(out_s.get_typenum());
305-
int out_vt_type_id = array_types.typenum_to_lookup_id(out_vt.get_typenum());
306-
307-
if (a_array_type_id != out_u_type_id || a_array_type_id != out_vt_type_id) {
308-
throw py::type_error(
309-
"Input array, output left singular vectors array, "
310-
"and outpuy right singular vectors array must have "
311-
"the same data type");
312-
}
313227

314228
gesvd_batch_impl_fn_ptr_t gesvd_batch_fn =
315229
gesvd_batch_dispatch_table[a_array_type_id][out_s_type_id];

dpnp/backend/extensions/lapack/gesvd_common_utils.hpp

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,17 @@
2525

2626
#pragma once
2727
#include <oneapi/mkl.hpp>
28+
#include <pybind11/pybind11.h>
29+
30+
// dpctl tensor headers
31+
#include "utils/memory_overlap.hpp"
32+
#include "utils/output_validation.hpp"
33+
#include "utils/type_dispatch.hpp"
2834

2935
namespace dpnp::extensions::lapack::gesvd_utils
3036
{
37+
namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
38+
namespace py = pybind11;
3139

3240
// Converts a given character code (ord) to the corresponding
3341
// oneapi::mkl::jobsvd enumeration value
@@ -47,6 +55,121 @@ inline oneapi::mkl::jobsvd process_job(std::int8_t job_val)
4755
}
4856
}
4957

58+
inline void common_gesvd_checks(sycl::queue &exec_q,
59+
dpctl::tensor::usm_ndarray a_array,
60+
dpctl::tensor::usm_ndarray out_s,
61+
dpctl::tensor::usm_ndarray out_u,
62+
dpctl::tensor::usm_ndarray out_vt,
63+
const std::int8_t jobu_val,
64+
const std::int8_t jobvt_val,
65+
const int expected_a_u_vt_ndim,
66+
const int expected_s_ndim)
67+
{
68+
const int a_array_nd = a_array.get_ndim();
69+
const int out_u_array_nd = out_u.get_ndim();
70+
const int out_s_array_nd = out_s.get_ndim();
71+
const int out_vt_array_nd = out_vt.get_ndim();
72+
73+
if (a_array_nd != expected_a_u_vt_ndim) {
74+
throw py::value_error(
75+
"The input array has ndim=" + std::to_string(a_array_nd) +
76+
", but a " + std::to_string(expected_a_u_vt_ndim) +
77+
"-dimensional array is expected.");
78+
}
79+
80+
if (out_s_array_nd != expected_s_ndim) {
81+
throw py::value_error("The output array of singular values has ndim=" +
82+
std::to_string(out_s_array_nd) + ", but a " +
83+
std::to_string(expected_s_ndim) +
84+
"-dimensional array is expected.");
85+
}
86+
87+
if (jobu_val == 'N' && jobvt_val == 'N') {
88+
if (out_u_array_nd != 0) {
89+
throw py::value_error(
90+
"The output array of the left singular vectors has ndim=" +
91+
std::to_string(out_u_array_nd) +
92+
", but it is not used and should have ndim=0.");
93+
}
94+
if (out_vt_array_nd != 0) {
95+
throw py::value_error(
96+
"The output array of the right singular vectors has ndim=" +
97+
std::to_string(out_vt_array_nd) +
98+
", but it is not used and should have ndim=0.");
99+
}
100+
}
101+
else {
102+
if (out_u_array_nd != expected_a_u_vt_ndim) {
103+
throw py::value_error(
104+
"The output array of the left singular vectors has ndim=" +
105+
std::to_string(out_u_array_nd) + ", but a " +
106+
std::to_string(expected_a_u_vt_ndim) +
107+
"-dimensional array is expected.");
108+
}
109+
if (out_vt_array_nd != expected_a_u_vt_ndim) {
110+
throw py::value_error(
111+
"The output array of the right singular vectors has ndim=" +
112+
std::to_string(out_vt_array_nd) + ", but a " +
113+
std::to_string(expected_a_u_vt_ndim) +
114+
"-dimensional array is expected.");
115+
}
116+
}
117+
118+
// check compatibility of execution queue and allocation queue
119+
if (!dpctl::utils::queues_are_compatible(exec_q,
120+
{a_array, out_s, out_u, out_vt}))
121+
{
122+
throw py::value_error(
123+
"Execution queue is not compatible with allocation queues.");
124+
}
125+
126+
auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
127+
if (overlap(a_array, out_s) || overlap(a_array, out_u) ||
128+
overlap(a_array, out_vt) || overlap(out_s, out_u) ||
129+
overlap(out_s, out_vt) || overlap(out_u, out_vt))
130+
{
131+
throw py::value_error("Arrays have overlapping segments of memory");
132+
}
133+
134+
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(a_array);
135+
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(out_s);
136+
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(out_u);
137+
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(out_vt);
138+
139+
bool is_a_array_f_contig = a_array.is_f_contiguous();
140+
if (!is_a_array_f_contig) {
141+
throw py::value_error("The input array must be F-contiguous");
142+
}
143+
144+
bool is_out_u_array_f_contig = out_u.is_f_contiguous();
145+
bool is_out_vt_array_f_contig = out_vt.is_f_contiguous();
146+
147+
if (!is_out_u_array_f_contig || !is_out_vt_array_f_contig) {
148+
throw py::value_error("The output arrays of the left and right "
149+
"singular vectors must be F-contiguous");
150+
}
151+
152+
bool is_out_s_array_c_contig = out_s.is_c_contiguous();
153+
154+
if (!is_out_s_array_c_contig) {
155+
throw py::value_error("The output array of singular values "
156+
"must be C-contiguous");
157+
}
158+
159+
auto array_types = dpctl_td_ns::usm_ndarray_types();
160+
int a_array_type_id =
161+
array_types.typenum_to_lookup_id(a_array.get_typenum());
162+
int out_u_type_id = array_types.typenum_to_lookup_id(out_u.get_typenum());
163+
int out_vt_type_id = array_types.typenum_to_lookup_id(out_vt.get_typenum());
164+
165+
if (a_array_type_id != out_u_type_id || a_array_type_id != out_vt_type_id) {
166+
throw py::type_error(
167+
"Input array, output left singular vectors array, "
168+
"and outpuy right singular vectors array must have "
169+
"the same data type");
170+
}
171+
}
172+
50173
inline void handle_lapack_exc(std::int64_t scratchpad_size,
51174
const oneapi::mkl::lapack::exception &e,
52175
std::stringstream &error_msg)

0 commit comments

Comments
 (0)