26
26
#include < pybind11/pybind11.h>
27
27
28
28
// dpctl tensor headers
29
- #include " utils/memory_overlap.hpp"
30
29
#include " utils/type_utils.hpp"
31
30
31
+ #include " common_helpers.hpp"
32
32
#include " gesvd.hpp"
33
+ #include " gesvd_common_utils.hpp"
33
34
#include " types_matrix.hpp"
34
35
35
- #include " dpnp_utils.hpp"
36
-
37
36
namespace dpnp ::extensions::lapack
38
37
{
39
38
namespace mkl_lapack = oneapi::mkl::lapack;
40
39
namespace py = pybind11;
41
40
namespace type_utils = dpctl::tensor::type_utils;
42
41
43
- typedef sycl::event (*gesvd_impl_fn_ptr_t )(sycl::queue,
42
+ typedef sycl::event (*gesvd_impl_fn_ptr_t )(sycl::queue & ,
44
43
const oneapi::mkl::jobsvd,
45
44
const oneapi::mkl::jobsvd,
46
45
const std::int64_t ,
@@ -58,26 +57,8 @@ typedef sycl::event (*gesvd_impl_fn_ptr_t)(sycl::queue,
58
57
static gesvd_impl_fn_ptr_t gesvd_dispatch_table[dpctl_td_ns::num_types]
59
58
[dpctl_td_ns::num_types];
60
59
61
- // Converts a given character code (ord) to the corresponding
62
- // oneapi::mkl::jobsvd enumeration value
63
- static oneapi::mkl::jobsvd process_job (std::int8_t job_val)
64
- {
65
- switch (job_val) {
66
- case ' A' :
67
- return oneapi::mkl::jobsvd::vectors;
68
- case ' S' :
69
- return oneapi::mkl::jobsvd::somevec;
70
- case ' O' :
71
- return oneapi::mkl::jobsvd::vectorsina;
72
- case ' N' :
73
- return oneapi::mkl::jobsvd::novec;
74
- default :
75
- throw std::invalid_argument (" Unknown value for job" );
76
- }
77
- }
78
-
79
60
template <typename T, typename RealT>
80
- static sycl::event gesvd_impl (sycl::queue exec_q,
61
+ static sycl::event gesvd_impl (sycl::queue & exec_q,
81
62
const oneapi::mkl::jobsvd jobu,
82
63
const oneapi::mkl::jobsvd jobvt,
83
64
const std::int64_t m,
@@ -102,16 +83,14 @@ static sycl::event gesvd_impl(sycl::queue exec_q,
102
83
103
84
const std::int64_t scratchpad_size = mkl_lapack::gesvd_scratchpad_size<T>(
104
85
exec_q, jobu, jobvt, m, n, lda, ldu, ldvt);
105
- T *scratchpad = nullptr ;
86
+
87
+ T *scratchpad = helper::alloc_scratchpad<T>(scratchpad_size, exec_q);
106
88
107
89
std::stringstream error_msg;
108
- std::int64_t info = 0 ;
109
90
bool is_exception_caught = false ;
110
91
111
92
sycl::event gesvd_event;
112
93
try {
113
- scratchpad = sycl::malloc_device<T>(scratchpad_size, exec_q);
114
-
115
94
gesvd_event = mkl_lapack::gesvd (
116
95
exec_q,
117
96
jobu, // Character specifying how to compute the matrix U:
@@ -138,26 +117,7 @@ static sycl::event gesvd_impl(sycl::queue exec_q,
138
117
scratchpad_size, depends);
139
118
} catch (mkl_lapack::exception const &e) {
140
119
is_exception_caught = true ;
141
- info = e.info ();
142
- if (info < 0 ) {
143
- error_msg << " Parameter number " << -info
144
- << " had an illegal value." ;
145
- }
146
- else if (info == scratchpad_size && e.detail () != 0 ) {
147
- error_msg
148
- << " Insufficient scratchpad size. Required size is at least "
149
- << e.detail ();
150
- }
151
- else if (info > 0 ) {
152
- error_msg << " The algorithm computing SVD failed to converge; "
153
- << info << " off-diagonal elements of an intermediate "
154
- << " bidiagonal form did not converge to zero.\n " ;
155
- }
156
- else {
157
- error_msg << " Unexpected MKL exception caught during gesvd() "
158
- " call:\n reason: "
159
- << e.what () << " \n info: " << e.info ();
160
- }
120
+ gesvd_utils::handle_lapack_exc (scratchpad_size, e, error_msg);
161
121
} catch (sycl::exception const &e) {
162
122
is_exception_caught = true ;
163
123
error_msg << " Unexpected SYCL exception caught during gesvd() call:\n "
@@ -182,7 +142,7 @@ static sycl::event gesvd_impl(sycl::queue exec_q,
182
142
}
183
143
184
144
std::pair<sycl::event, sycl::event>
185
- gesvd (sycl::queue exec_q,
145
+ gesvd (sycl::queue & exec_q,
186
146
const std::int8_t jobu_val,
187
147
const std::int8_t jobvt_val,
188
148
dpctl::tensor::usm_ndarray a_array,
@@ -191,103 +151,26 @@ std::pair<sycl::event, sycl::event>
191
151
dpctl::tensor::usm_ndarray out_vt,
192
152
const std::vector<sycl::event> &depends)
193
153
{
194
- const int a_array_nd = a_array.get_ndim ();
195
- const int out_u_array_nd = out_u.get_ndim ();
196
- const int out_s_array_nd = out_s.get_ndim ();
197
- const int out_vt_array_nd = out_vt.get_ndim ();
198
-
199
- if (a_array_nd != 2 ) {
200
- throw py::value_error (
201
- " The input array has ndim=" + std::to_string (a_array_nd) +
202
- " , but a 2-dimensional array is expected." );
203
- }
204
-
205
- if (out_s_array_nd != 1 ) {
206
- throw py::value_error (" The output array of singular values has ndim=" +
207
- std::to_string (out_s_array_nd) +
208
- " , but a 1-dimensional array is expected." );
209
- }
210
-
211
- if (jobu_val == ' N' && jobvt_val == ' N' ) {
212
- if (out_u_array_nd != 0 ) {
213
- throw py::value_error (
214
- " The output array of the left singular vectors has ndim=" +
215
- std::to_string (out_u_array_nd) +
216
- " , but it is not used and should have ndim=0." );
217
- }
218
- if (out_vt_array_nd != 0 ) {
219
- throw py::value_error (
220
- " The output array of the right singular vectors has ndim=" +
221
- std::to_string (out_vt_array_nd) +
222
- " , but it is not used and should have ndim=0." );
223
- }
224
- }
225
- else {
226
- if (out_u_array_nd != 2 ) {
227
- throw py::value_error (
228
- " The output array of the left singular vectors has ndim=" +
229
- std::to_string (out_u_array_nd) +
230
- " , but a 2-dimensional array is expected." );
231
- }
232
- if (out_vt_array_nd != 2 ) {
233
- throw py::value_error (
234
- " The output array of the right singular vectors has ndim=" +
235
- std::to_string (out_vt_array_nd) +
236
- " , but a 2-dimensional array is expected." );
237
- }
238
- }
239
-
240
- // check compatibility of execution queue and allocation queue
241
- if (!dpctl::utils::queues_are_compatible (
242
- exec_q, {a_array.get_queue (), out_s.get_queue (), out_u.get_queue (),
243
- out_vt.get_queue ()}))
244
- {
245
- throw std::runtime_error (
246
- " USM allocations are not compatible with the execution queue." );
247
- }
248
-
249
- auto const &overlap = dpctl::tensor::overlap::MemoryOverlap ();
250
- if (overlap (a_array, out_s) || overlap (a_array, out_u) ||
251
- overlap (a_array, out_vt) || overlap (out_s, out_u) ||
252
- overlap (out_s, out_vt) || overlap (out_u, out_vt))
253
- {
254
- throw py::value_error (" Arrays have overlapping segments of memory" );
255
- }
256
-
257
- bool is_a_array_f_contig = a_array.is_f_contiguous ();
258
- if (!is_a_array_f_contig) {
259
- throw py::value_error (" The input array must be F-contiguous" );
260
- }
261
-
262
- bool is_out_u_array_f_contig = out_u.is_f_contiguous ();
263
- bool is_out_vt_array_f_contig = out_vt.is_f_contiguous ();
264
-
265
- if (!is_out_u_array_f_contig || !is_out_vt_array_f_contig) {
266
- throw py::value_error (" The output arrays of the left and right "
267
- " singular vectors must be F-contiguous" );
268
- }
269
-
270
- bool is_out_s_array_c_contig = out_s.is_c_contiguous ();
271
- bool is_out_s_array_f_contig = out_s.is_f_contiguous ();
272
-
273
- if (!is_out_s_array_c_contig || !is_out_s_array_f_contig) {
274
- throw py::value_error (" The output array of singular values "
275
- " must be contiguous" );
276
- }
154
+ constexpr int expected_a_u_vt_ndim = 2 ;
155
+ constexpr int expected_s_ndim = 1 ;
156
+
157
+ gesvd_utils::common_gesvd_checks (exec_q, a_array, out_s, out_u, out_vt,
158
+ jobu_val, jobvt_val, expected_a_u_vt_ndim,
159
+ expected_s_ndim);
160
+
161
+ // // Ensure `m` and 'n' are non-zero, otherwise return empty
162
+ // // events
163
+ // if (gesvd_utils::check_zeros_shape_gesvd(a_array, out_s, out_u, out_vt,
164
+ // jobu_val, jobvt_val))
165
+ // {
166
+ // // nothing to do
167
+ // return std::make_pair(sycl::event(), sycl::event());
168
+ // }
277
169
278
170
auto array_types = dpctl_td_ns::usm_ndarray_types ();
279
171
int a_array_type_id =
280
172
array_types.typenum_to_lookup_id (a_array.get_typenum ());
281
- int out_u_type_id = array_types.typenum_to_lookup_id (out_u.get_typenum ());
282
173
int out_s_type_id = array_types.typenum_to_lookup_id (out_s.get_typenum ());
283
- int out_vt_type_id = array_types.typenum_to_lookup_id (out_vt.get_typenum ());
284
-
285
- if (a_array_type_id != out_u_type_id || a_array_type_id != out_vt_type_id) {
286
- throw py::type_error (
287
- " Input array, output left singular vectors array, "
288
- " and outpuy right singular vectors array must have "
289
- " the same data type" );
290
- }
291
174
292
175
gesvd_impl_fn_ptr_t gesvd_fn =
293
176
gesvd_dispatch_table[a_array_type_id][out_s_type_id];
@@ -311,8 +194,8 @@ std::pair<sycl::event, sycl::event>
311
194
const std::int64_t ldvt =
312
195
std::max<std::size_t >(1UL , jobvt_val == ' S' ? (m > n ? n : m) : n);
313
196
314
- const oneapi::mkl::jobsvd jobu = process_job (jobu_val);
315
- const oneapi::mkl::jobsvd jobvt = process_job (jobvt_val);
197
+ const oneapi::mkl::jobsvd jobu = gesvd_utils:: process_job (jobu_val);
198
+ const oneapi::mkl::jobsvd jobvt = gesvd_utils:: process_job (jobvt_val);
316
199
317
200
std::vector<sycl::event> host_task_events;
318
201
sycl::event gesvd_ev =
0 commit comments