Skip to content

Commit 1e86753

Browse files
Update dpnp.linalg.qr() function (#1673)
* Impl dpnp.linalg.qr for 2d array * Add cupy tests for dpnp.linalg.qr * Add batch implementation of dpnp.linalg.qr * Remove an old impl of dpnp_qr * Update test_qr in test_sycl_queue * Add test_qr in test_usm_type * Use _real_type for _orgqr * Use _real_type for _orgqr_batch * Update dpnp tests for dpnp.linalg.qr * Pass scratchpad_size to the error message test * Add additional checks * Extend error handler for mkl batch funcs * Add ungqr mkl extension to support complex dtype * Update tau array size check for orgqr * Add ungqr_batch mkl extension to support complex dtype * Add arrays type check * Fix test_det_singular_matrix * Expand tests for dpnp.linalg.qr with complex types * Update examples * Remove astype for output arrays * Use empty_like instead of empty * Use ht_list_ev with dpctl.SyclEvent.wait_for * Add _triu_inplace func * Use copy_usm for a_t array overwritten by geqrf/geqrf_batch --------- Co-authored-by: Anton <[email protected]>
1 parent d45bb24 commit 1e86753

22 files changed

+2767
-246
lines changed

dpnp/backend/extensions/lapack/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,21 @@
2727
set(python_module_name _lapack_impl)
2828
set(_module_src
2929
${CMAKE_CURRENT_SOURCE_DIR}/lapack_py.cpp
30+
${CMAKE_CURRENT_SOURCE_DIR}/geqrf.cpp
31+
${CMAKE_CURRENT_SOURCE_DIR}/geqrf_batch.cpp
3032
${CMAKE_CURRENT_SOURCE_DIR}/gesv.cpp
3133
${CMAKE_CURRENT_SOURCE_DIR}/gesvd.cpp
3234
${CMAKE_CURRENT_SOURCE_DIR}/getrf.cpp
3335
${CMAKE_CURRENT_SOURCE_DIR}/getrf_batch.cpp
3436
${CMAKE_CURRENT_SOURCE_DIR}/getri_batch.cpp
3537
${CMAKE_CURRENT_SOURCE_DIR}/heevd.cpp
38+
${CMAKE_CURRENT_SOURCE_DIR}/orgqr.cpp
39+
${CMAKE_CURRENT_SOURCE_DIR}/orgqr_batch.cpp
3640
${CMAKE_CURRENT_SOURCE_DIR}/potrf.cpp
3741
${CMAKE_CURRENT_SOURCE_DIR}/potrf_batch.cpp
3842
${CMAKE_CURRENT_SOURCE_DIR}/syevd.cpp
43+
${CMAKE_CURRENT_SOURCE_DIR}/ungqr.cpp
44+
${CMAKE_CURRENT_SOURCE_DIR}/ungqr_batch.cpp
3945
)
4046

4147
pybind11_add_module(${python_module_name} MODULE ${_module_src})
Lines changed: 262 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,262 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2024, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#include <pybind11/pybind11.h>
27+
28+
// dpctl tensor headers
29+
#include "utils/memory_overlap.hpp"
30+
#include "utils/type_utils.hpp"
31+
32+
#include "geqrf.hpp"
33+
#include "types_matrix.hpp"
34+
35+
#include "dpnp_utils.hpp"
36+
37+
namespace dpnp
38+
{
39+
namespace backend
40+
{
41+
namespace ext
42+
{
43+
namespace lapack
44+
{
45+
namespace mkl_lapack = oneapi::mkl::lapack;
46+
namespace py = pybind11;
47+
namespace type_utils = dpctl::tensor::type_utils;
48+
49+
typedef sycl::event (*geqrf_impl_fn_ptr_t)(sycl::queue,
50+
const std::int64_t,
51+
const std::int64_t,
52+
char *,
53+
std::int64_t,
54+
char *,
55+
std::vector<sycl::event> &,
56+
const std::vector<sycl::event> &);
57+
58+
static geqrf_impl_fn_ptr_t geqrf_dispatch_vector[dpctl_td_ns::num_types];
59+
60+
template <typename T>
61+
static sycl::event geqrf_impl(sycl::queue exec_q,
62+
const std::int64_t m,
63+
const std::int64_t n,
64+
char *in_a,
65+
std::int64_t lda,
66+
char *in_tau,
67+
std::vector<sycl::event> &host_task_events,
68+
const std::vector<sycl::event> &depends)
69+
{
70+
type_utils::validate_type_for_device<T>(exec_q);
71+
72+
T *a = reinterpret_cast<T *>(in_a);
73+
T *tau = reinterpret_cast<T *>(in_tau);
74+
75+
const std::int64_t scratchpad_size =
76+
mkl_lapack::geqrf_scratchpad_size<T>(exec_q, m, n, lda);
77+
T *scratchpad = nullptr;
78+
79+
std::stringstream error_msg;
80+
std::int64_t info = 0;
81+
bool is_exception_caught = false;
82+
83+
sycl::event geqrf_event;
84+
try {
85+
scratchpad = sycl::malloc_device<T>(scratchpad_size, exec_q);
86+
87+
geqrf_event = mkl_lapack::geqrf(
88+
exec_q,
89+
m, // The number of rows in the matrix; (0 ≤ m).
90+
n, // The number of columns in the matrix; (0 ≤ n).
91+
a, // Pointer to the m-by-n matrix.
92+
lda, // The leading dimension of `a`; (1 ≤ m).
93+
tau, // Pointer to the array of scalar factors of the
94+
// elementary reflectors.
95+
scratchpad, // Pointer to scratchpad memory to be used by MKL
96+
// routine for storing intermediate results.
97+
scratchpad_size, depends);
98+
} catch (mkl_lapack::exception const &e) {
99+
is_exception_caught = true;
100+
info = e.info();
101+
102+
if (info < 0) {
103+
error_msg << "Parameter number " << -info
104+
<< " had an illegal value.";
105+
}
106+
else if (info == scratchpad_size && e.detail() != 0) {
107+
error_msg
108+
<< "Insufficient scratchpad size. Required size is at least "
109+
<< e.detail() << ", but current size is " << scratchpad_size
110+
<< ".";
111+
}
112+
else {
113+
error_msg << "Unexpected MKL exception caught during geqrf() "
114+
"call:\nreason: "
115+
<< e.what() << "\ninfo: " << info;
116+
}
117+
} catch (sycl::exception const &e) {
118+
is_exception_caught = true;
119+
error_msg << "Unexpected SYCL exception caught during geqrf() call:\n"
120+
<< e.what();
121+
}
122+
123+
if (is_exception_caught) // an unexpected error occurs
124+
{
125+
if (scratchpad != nullptr) {
126+
sycl::free(scratchpad, exec_q);
127+
}
128+
throw std::runtime_error(error_msg.str());
129+
}
130+
131+
sycl::event clean_up_event = exec_q.submit([&](sycl::handler &cgh) {
132+
cgh.depends_on(geqrf_event);
133+
auto ctx = exec_q.get_context();
134+
cgh.host_task([ctx, scratchpad]() { sycl::free(scratchpad, ctx); });
135+
});
136+
host_task_events.push_back(clean_up_event);
137+
138+
return geqrf_event;
139+
}
140+
141+
std::pair<sycl::event, sycl::event>
142+
geqrf(sycl::queue q,
143+
dpctl::tensor::usm_ndarray a_array,
144+
dpctl::tensor::usm_ndarray tau_array,
145+
const std::vector<sycl::event> &depends)
146+
{
147+
const int a_array_nd = a_array.get_ndim();
148+
const int tau_array_nd = tau_array.get_ndim();
149+
150+
if (a_array_nd != 2) {
151+
throw py::value_error(
152+
"The input array has ndim=" + std::to_string(a_array_nd) +
153+
", but a 2-dimensional array is expected.");
154+
}
155+
156+
if (tau_array_nd != 1) {
157+
throw py::value_error("The array of Householder scalars has ndim=" +
158+
std::to_string(tau_array_nd) +
159+
", but a 1-dimensional array is expected.");
160+
}
161+
162+
// check compatibility of execution queue and allocation queue
163+
if (!dpctl::utils::queues_are_compatible(q, {a_array, tau_array})) {
164+
throw py::value_error(
165+
"Execution queue is not compatible with allocation queues");
166+
}
167+
168+
auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
169+
if (overlap(a_array, tau_array)) {
170+
throw py::value_error(
171+
"The input array and the array of Householder scalars "
172+
"are overlapping segments of memory");
173+
}
174+
175+
bool is_a_array_c_contig = a_array.is_c_contiguous();
176+
if (!is_a_array_c_contig) {
177+
throw py::value_error("The input array "
178+
"must be C-contiguous");
179+
}
180+
181+
bool is_tau_array_c_contig = tau_array.is_c_contiguous();
182+
bool is_tau_array_f_contig = tau_array.is_f_contiguous();
183+
184+
if (!is_tau_array_c_contig || !is_tau_array_f_contig) {
185+
throw py::value_error("The array of Householder scalars "
186+
"must be contiguous");
187+
}
188+
189+
auto array_types = dpctl_td_ns::usm_ndarray_types();
190+
int a_array_type_id =
191+
array_types.typenum_to_lookup_id(a_array.get_typenum());
192+
int tau_array_type_id =
193+
array_types.typenum_to_lookup_id(tau_array.get_typenum());
194+
195+
if (a_array_type_id != tau_array_type_id) {
196+
throw py::value_error(
197+
"The types of the input array and "
198+
"the array of Householder scalars are mismatched");
199+
}
200+
201+
geqrf_impl_fn_ptr_t geqrf_fn = geqrf_dispatch_vector[a_array_type_id];
202+
if (geqrf_fn == nullptr) {
203+
throw py::value_error(
204+
"No geqrf implementation defined for the provided type "
205+
"of the input matrix.");
206+
}
207+
208+
char *a_array_data = a_array.get_data();
209+
char *tau_array_data = tau_array.get_data();
210+
211+
const py::ssize_t *a_array_shape = a_array.get_shape_raw();
212+
213+
// The input array is transponded
214+
// Change the order of getting m, n
215+
const std::int64_t m = a_array_shape[1];
216+
const std::int64_t n = a_array_shape[0];
217+
const std::int64_t lda = std::max<size_t>(1UL, m);
218+
219+
const size_t tau_array_size = tau_array.get_size();
220+
const size_t min_m_n = std::max<size_t>(1UL, std::min<size_t>(m, n));
221+
222+
if (tau_array_size != min_m_n) {
223+
throw py::value_error("The array of Householder scalars has size=" +
224+
std::to_string(tau_array_size) + ", but a size=" +
225+
std::to_string(min_m_n) + " array is expected.");
226+
}
227+
228+
std::vector<sycl::event> host_task_events;
229+
sycl::event geqrf_ev = geqrf_fn(q, m, n, a_array_data, lda, tau_array_data,
230+
host_task_events, depends);
231+
232+
sycl::event args_ev = dpctl::utils::keep_args_alive(q, {a_array, tau_array},
233+
host_task_events);
234+
235+
return std::make_pair(args_ev, geqrf_ev);
236+
}
237+
238+
template <typename fnT, typename T>
239+
struct GeqrfContigFactory
240+
{
241+
fnT get()
242+
{
243+
if constexpr (types::GeqrfTypePairSupportFactory<T>::is_defined) {
244+
return geqrf_impl<T>;
245+
}
246+
else {
247+
return nullptr;
248+
}
249+
}
250+
};
251+
252+
void init_geqrf_dispatch_vector(void)
253+
{
254+
dpctl_td_ns::DispatchVectorBuilder<geqrf_impl_fn_ptr_t, GeqrfContigFactory,
255+
dpctl_td_ns::num_types>
256+
contig;
257+
contig.populate_dispatch_vector(geqrf_dispatch_vector);
258+
}
259+
} // namespace lapack
260+
} // namespace ext
261+
} // namespace backend
262+
} // namespace dpnp
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2024, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#pragma once
27+
28+
#include <CL/sycl.hpp>
29+
#include <oneapi/mkl.hpp>
30+
31+
#include <dpctl4pybind11.hpp>
32+
33+
namespace dpnp
34+
{
35+
namespace backend
36+
{
37+
namespace ext
38+
{
39+
namespace lapack
40+
{
41+
extern std::pair<sycl::event, sycl::event>
42+
geqrf(sycl::queue exec_q,
43+
dpctl::tensor::usm_ndarray a_array,
44+
dpctl::tensor::usm_ndarray tau_array,
45+
const std::vector<sycl::event> &depends = {});
46+
47+
extern std::pair<sycl::event, sycl::event>
48+
geqrf_batch(sycl::queue exec_q,
49+
dpctl::tensor::usm_ndarray a_array,
50+
dpctl::tensor::usm_ndarray tau_array,
51+
std::int64_t m,
52+
std::int64_t n,
53+
std::int64_t stride_a,
54+
std::int64_t stride_tau,
55+
std::int64_t batch_size,
56+
const std::vector<sycl::event> &depends = {});
57+
58+
extern void init_geqrf_batch_dispatch_vector(void);
59+
extern void init_geqrf_dispatch_vector(void);
60+
} // namespace lapack
61+
} // namespace ext
62+
} // namespace backend
63+
} // namespace dpnp

0 commit comments

Comments
 (0)