Skip to content

Commit 2a2cdfa

Browse files
Fix RuntimeError raising in dpnp.linalg.solve() (#1763)
* Correct parameter calculation for gesv * Use getrf and getrs MKL funcs in dpnp_solve for 2d array * Add test to cover SAT-6701 case * Extend test_solve in test_sycl_queue.py * Address remarks * Support as F-contiguous for _getrs --------- Co-authored-by: Anton <[email protected]>
1 parent 6f056ca commit 2a2cdfa

File tree

9 files changed

+509
-34
lines changed

9 files changed

+509
-34
lines changed

dpnp/backend/extensions/lapack/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ set(_module_src
3434
${CMAKE_CURRENT_SOURCE_DIR}/getrf.cpp
3535
${CMAKE_CURRENT_SOURCE_DIR}/getrf_batch.cpp
3636
${CMAKE_CURRENT_SOURCE_DIR}/getri_batch.cpp
37+
${CMAKE_CURRENT_SOURCE_DIR}/getrs.cpp
3738
${CMAKE_CURRENT_SOURCE_DIR}/heevd.cpp
3839
${CMAKE_CURRENT_SOURCE_DIR}/orgqr.cpp
3940
${CMAKE_CURRENT_SOURCE_DIR}/orgqr_batch.cpp

dpnp/backend/extensions/lapack/gesv.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -93,15 +93,18 @@ static sycl::event gesv_impl(sycl::queue exec_q,
9393

9494
gesv_event = mkl_lapack::gesv(
9595
exec_q,
96-
n, // The order of the matrix A (0 ≤ n).
97-
nrhs, // The number of right-hand sides B (0 ≤ nrhs).
96+
n, // The order of the square matrix A
97+
// and the number of rows in matrix B (0 ≤ n).
98+
nrhs, // The number of right-hand sides,
99+
// i.e., the number of columns in matrix B (0 ≤ nrhs).
98100
a, // Pointer to the square coefficient matrix A (n x n).
99101
lda, // The leading dimension of a, must be at least max(1, n).
100102
ipiv, // The pivot indices that define the permutation matrix P;
101103
// row i of the matrix was interchanged with row ipiv(i),
102104
// must be at least max(1, n).
103105
b, // Pointer to the right hand side matrix B (n x nrhs).
104-
ldb, // The leading dimension of b, must be at least max(1, n).
106+
ldb, // The leading dimension of matrix B,
107+
// must be at least max(1, n).
105108
scratchpad, // Pointer to scratchpad memory to be used by MKL
106109
// routine for storing intermediate results.
107110
scratchpad_size, depends);
@@ -252,13 +255,12 @@ std::pair<sycl::event, sycl::event>
252255
char *coeff_matrix_data = coeff_matrix.get_data();
253256
char *dependent_vals_data = dependent_vals.get_data();
254257

255-
const std::int64_t n = coeff_matrix_shape[0];
256-
const std::int64_t m = dependent_vals_shape[0];
258+
const std::int64_t n = dependent_vals_shape[0];
257259
const std::int64_t nrhs =
258260
(dependent_vals_nd > 1) ? dependent_vals_shape[1] : 1;
259261

260262
const std::int64_t lda = std::max<size_t>(1UL, n);
261-
const std::int64_t ldb = std::max<size_t>(1UL, m);
263+
const std::int64_t ldb = std::max<size_t>(1UL, n);
262264

263265
std::vector<sycl::event> host_task_events;
264266
sycl::event gesv_ev =
Lines changed: 314 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,314 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2024, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#include <pybind11/pybind11.h>
27+
28+
// dpctl tensor headers
29+
#include "utils/memory_overlap.hpp"
30+
#include "utils/type_utils.hpp"
31+
32+
#include "getrs.hpp"
33+
#include "linalg_exceptions.hpp"
34+
#include "types_matrix.hpp"
35+
36+
#include "dpnp_utils.hpp"
37+
38+
namespace dpnp
39+
{
40+
namespace backend
41+
{
42+
namespace ext
43+
{
44+
namespace lapack
45+
{
46+
namespace mkl_lapack = oneapi::mkl::lapack;
47+
namespace py = pybind11;
48+
namespace type_utils = dpctl::tensor::type_utils;
49+
50+
typedef sycl::event (*getrs_impl_fn_ptr_t)(sycl::queue,
51+
oneapi::mkl::transpose,
52+
const std::int64_t,
53+
const std::int64_t,
54+
char *,
55+
std::int64_t,
56+
std::int64_t *,
57+
char *,
58+
std::int64_t,
59+
std::vector<sycl::event> &,
60+
const std::vector<sycl::event> &);
61+
62+
static getrs_impl_fn_ptr_t getrs_dispatch_vector[dpctl_td_ns::num_types];
63+
64+
template <typename T>
65+
static sycl::event getrs_impl(sycl::queue exec_q,
66+
oneapi::mkl::transpose trans,
67+
const std::int64_t n,
68+
const std::int64_t nrhs,
69+
char *in_a,
70+
std::int64_t lda,
71+
std::int64_t *ipiv,
72+
char *in_b,
73+
std::int64_t ldb,
74+
std::vector<sycl::event> &host_task_events,
75+
const std::vector<sycl::event> &depends)
76+
{
77+
type_utils::validate_type_for_device<T>(exec_q);
78+
79+
T *a = reinterpret_cast<T *>(in_a);
80+
T *b = reinterpret_cast<T *>(in_b);
81+
82+
const std::int64_t scratchpad_size =
83+
mkl_lapack::getrs_scratchpad_size<T>(exec_q, trans, n, nrhs, lda, ldb);
84+
T *scratchpad = nullptr;
85+
86+
std::stringstream error_msg;
87+
std::int64_t info = 0;
88+
bool is_exception_caught = false;
89+
90+
sycl::event getrs_event;
91+
try {
92+
scratchpad = sycl::malloc_device<T>(scratchpad_size, exec_q);
93+
94+
getrs_event = mkl_lapack::getrs(
95+
exec_q,
96+
trans, // Specifies the operation: whether or not to transpose
97+
// matrix A. Can be 'N' for no transpose, 'T' for transpose,
98+
// and 'C' for conjugate transpose.
99+
n, // The order of the square matrix A
100+
// and the number of rows in matrix B (0 ≤ n).
101+
// It must be a non-negative integer.
102+
nrhs, // The number of right-hand sides,
103+
// i.e., the number of columns in matrix B (0 ≤ nrhs).
104+
a, // Pointer to the square matrix A (n x n).
105+
lda, // The leading dimension of matrix A, must be at least max(1,
106+
// n). It must be at least max(1, n).
107+
ipiv, // Pointer to the output array of pivot indices that were used
108+
// during factorization (n, ).
109+
b, // Pointer to the matrix B of right-hand sides (ldb, nrhs).
110+
ldb, // The leading dimension of matrix B, must be at least max(1,
111+
// n).
112+
scratchpad, // Pointer to scratchpad memory to be used by MKL
113+
// routine for storing intermediate results.
114+
scratchpad_size, depends);
115+
} catch (mkl_lapack::exception const &e) {
116+
is_exception_caught = true;
117+
info = e.info();
118+
119+
if (info < 0) {
120+
error_msg << "Parameter number " << -info
121+
<< " had an illegal value.";
122+
}
123+
else if (info == scratchpad_size && e.detail() != 0) {
124+
error_msg
125+
<< "Insufficient scratchpad size. Required size is at least "
126+
<< e.detail();
127+
}
128+
else if (info > 0) {
129+
is_exception_caught = false;
130+
if (scratchpad != nullptr) {
131+
sycl::free(scratchpad, exec_q);
132+
}
133+
throw LinAlgError("The solve could not be completed.");
134+
}
135+
else {
136+
error_msg << "Unexpected MKL exception caught during getrs() "
137+
"call:\nreason: "
138+
<< e.what() << "\ninfo: " << e.info();
139+
}
140+
} catch (sycl::exception const &e) {
141+
is_exception_caught = true;
142+
error_msg << "Unexpected SYCL exception caught during getrs() call:\n"
143+
<< e.what();
144+
}
145+
146+
if (is_exception_caught) // an unexpected error occurs
147+
{
148+
if (scratchpad != nullptr) {
149+
sycl::free(scratchpad, exec_q);
150+
}
151+
152+
throw std::runtime_error(error_msg.str());
153+
}
154+
155+
sycl::event clean_up_event = exec_q.submit([&](sycl::handler &cgh) {
156+
cgh.depends_on(getrs_event);
157+
auto ctx = exec_q.get_context();
158+
cgh.host_task([ctx, scratchpad]() { sycl::free(scratchpad, ctx); });
159+
});
160+
host_task_events.push_back(clean_up_event);
161+
return getrs_event;
162+
}
163+
164+
std::pair<sycl::event, sycl::event>
165+
getrs(sycl::queue exec_q,
166+
dpctl::tensor::usm_ndarray a_array,
167+
dpctl::tensor::usm_ndarray ipiv_array,
168+
dpctl::tensor::usm_ndarray b_array,
169+
const std::vector<sycl::event> &depends)
170+
{
171+
const int a_array_nd = a_array.get_ndim();
172+
const int b_array_nd = b_array.get_ndim();
173+
const int ipiv_array_nd = ipiv_array.get_ndim();
174+
175+
if (a_array_nd != 2) {
176+
throw py::value_error(
177+
"The LU-factorized array has ndim=" + std::to_string(a_array_nd) +
178+
", but a 2-dimensional array is expected.");
179+
}
180+
if (b_array_nd > 2) {
181+
throw py::value_error(
182+
"The right-hand sides array has ndim=" +
183+
std::to_string(b_array_nd) +
184+
", but a 1-dimensional or a 2-dimensional array is expected.");
185+
}
186+
if (ipiv_array_nd != 1) {
187+
throw py::value_error("The array of pivot indices has ndim=" +
188+
std::to_string(ipiv_array_nd) +
189+
", but a 1-dimensional array is expected.");
190+
}
191+
192+
const py::ssize_t *a_array_shape = a_array.get_shape_raw();
193+
const py::ssize_t *b_array_shape = b_array.get_shape_raw();
194+
195+
if (a_array_shape[0] != a_array_shape[1]) {
196+
throw py::value_error("The LU-factorized array must be square,"
197+
" but got a shape of (" +
198+
std::to_string(a_array_shape[0]) + ", " +
199+
std::to_string(a_array_shape[1]) + ").");
200+
}
201+
202+
// check compatibility of execution queue and allocation queue
203+
if (!dpctl::utils::queues_are_compatible(exec_q,
204+
{a_array, b_array, ipiv_array}))
205+
{
206+
throw py::value_error(
207+
"Execution queue is not compatible with allocation queues");
208+
}
209+
210+
auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
211+
if (overlap(a_array, b_array)) {
212+
throw py::value_error("The LU-factorized and right-hand sides arrays "
213+
"are overlapping segments of memory");
214+
}
215+
216+
bool is_a_array_c_contig = a_array.is_c_contiguous();
217+
bool is_a_array_f_contig = a_array.is_f_contiguous();
218+
bool is_b_array_f_contig = b_array.is_f_contiguous();
219+
bool is_ipiv_array_c_contig = ipiv_array.is_c_contiguous();
220+
bool is_ipiv_array_f_contig = ipiv_array.is_f_contiguous();
221+
if (!is_a_array_c_contig && !is_a_array_f_contig) {
222+
throw py::value_error("The LU-factorized array "
223+
"must be either C-contiguous "
224+
"or F-contiguous");
225+
}
226+
if (!is_b_array_f_contig) {
227+
throw py::value_error("The right-hand sides array "
228+
"must be F-contiguous");
229+
}
230+
if (!is_ipiv_array_c_contig || !is_ipiv_array_f_contig) {
231+
throw py::value_error("The array of pivot indices "
232+
"must be contiguous");
233+
}
234+
235+
auto array_types = dpctl_td_ns::usm_ndarray_types();
236+
int a_array_type_id =
237+
array_types.typenum_to_lookup_id(a_array.get_typenum());
238+
int b_array_type_id =
239+
array_types.typenum_to_lookup_id(b_array.get_typenum());
240+
241+
if (a_array_type_id != b_array_type_id) {
242+
throw py::value_error("The types of the LU-factorized and "
243+
"right-hand sides arrays are mismatched");
244+
}
245+
246+
getrs_impl_fn_ptr_t getrs_fn = getrs_dispatch_vector[a_array_type_id];
247+
if (getrs_fn == nullptr) {
248+
throw py::value_error(
249+
"No getrs implementation defined for the provided type "
250+
"of the input matrix.");
251+
}
252+
253+
auto ipiv_types = dpctl_td_ns::usm_ndarray_types();
254+
int ipiv_array_type_id =
255+
ipiv_types.typenum_to_lookup_id(ipiv_array.get_typenum());
256+
257+
if (ipiv_array_type_id != static_cast<int>(dpctl_td_ns::typenum_t::INT64)) {
258+
throw py::value_error("The type of 'ipiv_array' must be int64.");
259+
}
260+
261+
const std::int64_t n = a_array_shape[0];
262+
const std::int64_t nrhs = (b_array_nd > 1) ? b_array_shape[1] : 1;
263+
264+
const std::int64_t lda = std::max<size_t>(1UL, n);
265+
const std::int64_t ldb = std::max<size_t>(1UL, n);
266+
267+
// Use transpose::T if the LU-factorized array is passed as C-contiguous.
268+
// For F-contiguous we use transpose::N.
269+
oneapi::mkl::transpose trans = is_a_array_c_contig
270+
? oneapi::mkl::transpose::T
271+
: oneapi::mkl::transpose::N;
272+
273+
char *a_array_data = a_array.get_data();
274+
char *b_array_data = b_array.get_data();
275+
char *ipiv_array_data = ipiv_array.get_data();
276+
277+
std::int64_t *ipiv = reinterpret_cast<std::int64_t *>(ipiv_array_data);
278+
279+
std::vector<sycl::event> host_task_events;
280+
sycl::event getrs_ev =
281+
getrs_fn(exec_q, trans, n, nrhs, a_array_data, lda, ipiv, b_array_data,
282+
ldb, host_task_events, depends);
283+
284+
sycl::event args_ev = dpctl::utils::keep_args_alive(
285+
exec_q, {a_array, b_array, ipiv_array}, host_task_events);
286+
287+
return std::make_pair(args_ev, getrs_ev);
288+
}
289+
290+
template <typename fnT, typename T>
291+
struct GetrsContigFactory
292+
{
293+
fnT get()
294+
{
295+
if constexpr (types::GetrsTypePairSupportFactory<T>::is_defined) {
296+
return getrs_impl<T>;
297+
}
298+
else {
299+
return nullptr;
300+
}
301+
}
302+
};
303+
304+
void init_getrs_dispatch_vector(void)
305+
{
306+
dpctl_td_ns::DispatchVectorBuilder<getrs_impl_fn_ptr_t, GetrsContigFactory,
307+
dpctl_td_ns::num_types>
308+
contig;
309+
contig.populate_dispatch_vector(getrs_dispatch_vector);
310+
}
311+
} // namespace lapack
312+
} // namespace ext
313+
} // namespace backend
314+
} // namespace dpnp

0 commit comments

Comments
 (0)