Skip to content

Update dpnp.linalg.svd() function #1604

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 78 commits into from
Feb 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
78 commits
Select commit Hold shift + click to select a range
d464ba3
Draft commit of dpnp.linalg.svd impl
vlad-perevezentsev Oct 18, 2023
48d9b61
Pass empty arrays if compute_uv=False
vlad-perevezentsev Oct 18, 2023
f6b4c1f
Remove unnecessary comments
vlad-perevezentsev Oct 18, 2023
9a4d62e
Add logic for the input array n < m
vlad-perevezentsev Oct 19, 2023
a942dcf
Small changes
vlad-perevezentsev Oct 19, 2023
54febf6
Add a new cupy test_decomposition
vlad-perevezentsev Oct 19, 2023
b292afb
Merge master into impl_svd
vlad-perevezentsev Oct 19, 2023
4da97be
Rename gesvd input parameters
vlad-perevezentsev Oct 20, 2023
303f23a
Correspondence of passed parameters to gesvd signature
vlad-perevezentsev Oct 20, 2023
80d5e40
Correct initialization of result variables in dpnp_svd
vlad-perevezentsev Oct 20, 2023
e386f3d
Update test_decomposition
vlad-perevezentsev Oct 20, 2023
da6d0c9
Add implementation of _dpnp_svd_batch
vlad-perevezentsev Oct 20, 2023
d59ac0c
Add test_decomposition to the scope of public CI
vlad-perevezentsev Oct 20, 2023
02270f8
Use mkl_lapack
vlad-perevezentsev Oct 24, 2023
8d01e4f
Improve error handling for mkl_lapack::gesvd function
vlad-perevezentsev Oct 24, 2023
6490d8d
Declate detail variable
vlad-perevezentsev Oct 24, 2023
ca427c6
Use a_usm_type and a_sycl_queue variables
vlad-perevezentsev Oct 24, 2023
327e8df
Check a.ndim < 2
vlad-perevezentsev Oct 24, 2023
0a99b84
Add additional checks for gesvd function
vlad-perevezentsev Oct 24, 2023
d03405e
Merge master into impl_svd
vlad-perevezentsev Oct 24, 2023
25bfbb5
Remove old dpnp_svd backend
vlad-perevezentsev Oct 24, 2023
f123ece
Refresh test_svd in test_linalg
vlad-perevezentsev Oct 24, 2023
a96fc7e
Merge master into impl_svd
vlad-perevezentsev Nov 2, 2023
7cf7d31
Add detailed comments for gesvd arguments
vlad-perevezentsev Nov 2, 2023
3146428
gesvd returns pair of events and uses dpctl.utils.keep_args_alive
vlad-perevezentsev Nov 2, 2023
5997a3a
Keep a lexicographical order
vlad-perevezentsev Nov 2, 2023
d1ce945
Update docstrings for svd
vlad-perevezentsev Nov 2, 2023
d4e5389
Merge master into impl_svd
vlad-perevezentsev Nov 3, 2023
468c9c8
Merge master into impl_svd
vlad-perevezentsev Nov 22, 2023
1a6f4c0
Add test_svd to test_usm_type
vlad-perevezentsev Nov 22, 2023
c08959d
Merge master into impl_svd
vlad-perevezentsev Dec 18, 2023
e43b9a3
Remove all TODOs
vlad-perevezentsev Dec 18, 2023
96f243c
Add a new impl to get s_type
vlad-perevezentsev Dec 18, 2023
87825f5
Address remarks
vlad-perevezentsev Dec 18, 2023
9765e35
Add a description for _stacked_identity
vlad-perevezentsev Dec 19, 2023
61257e2
Move internal funcs up
vlad-perevezentsev Dec 19, 2023
61d1ed7
Simplify dpnp_svd_batch
vlad-perevezentsev Dec 19, 2023
33c4f5e
Update tests for dpnp.linalg.svd
vlad-perevezentsev Dec 19, 2023
0be3132
Add hermitian argument support
vlad-perevezentsev Dec 20, 2023
6b63eea
Add test_svd_hermitian
vlad-perevezentsev Dec 20, 2023
12a5cb5
Update svd docstrings
vlad-perevezentsev Dec 20, 2023
3845613
Tune tolerance
vlad-perevezentsev Dec 20, 2023
d12d49d
Merge master into impl_svd
vlad-perevezentsev Dec 20, 2023
c0e0462
Update test_svd_errors
vlad-perevezentsev Dec 20, 2023
ddb4d43
Update _common_type and _common_inexact_type
vlad-perevezentsev Dec 22, 2023
8f4b557
Address the remarks
vlad-perevezentsev Dec 22, 2023
6ba9483
Remove passing n and m parameteres to _gesvd
vlad-perevezentsev Jan 10, 2024
6008ce4
Simplify results return logic for dpnp_svd_batch
vlad-perevezentsev Jan 10, 2024
16369a0
Update condition and random files in cupy/testing to use fix_random a…
vlad-perevezentsev Jan 10, 2024
937d2ac
Merge master into impl_svd
vlad-perevezentsev Jan 10, 2024
9235112
Rename cupy/testing/condition.py to .../_condition.py
vlad-perevezentsev Jan 10, 2024
8815da2
Use self._tol in TestSvd
vlad-perevezentsev Jan 11, 2024
f28e9ae
Add TODO for check_decomposition
vlad-perevezentsev Jan 11, 2024
2118d3e
Update gesvd error handler
vlad-perevezentsev Jan 17, 2024
c89817a
Merge master into impl_svd
vlad-perevezentsev Jan 17, 2024
1fcfb5e
Merge master into impl_svd
vlad-perevezentsev Jan 19, 2024
347fbe9
dpnp_svd works with F contiguous arrays
vlad-perevezentsev Jan 22, 2024
03d04d7
Merge master into impl_svd
vlad-perevezentsev Jan 22, 2024
dd74a36
Add additional checks for output arrays
vlad-perevezentsev Jan 22, 2024
5a4721f
Tune atol for TestSvd
vlad-perevezentsev Jan 22, 2024
31e8bbb
Impl parallel calculation in dpnp_svd_batch
vlad-perevezentsev Jan 22, 2024
876842a
Skip using @_condition.repeat in cupy tests
vlad-perevezentsev Jan 22, 2024
9d5afde
Resolve conflicts
vlad-perevezentsev Jan 22, 2024
1f66d0a
Merge master into impl_svd
vlad-perevezentsev Jan 22, 2024
63f257f
Add additional checks for output arrays
vlad-perevezentsev Jan 23, 2024
8c2a8ff
Update docstrings for svd
vlad-perevezentsev Jan 23, 2024
7a27d4c
Use dpctl.SyclEvent.wait_for in dpnp_svd_batch
vlad-perevezentsev Jan 23, 2024
ac4283a
Address remarks
vlad-perevezentsev Jan 24, 2024
f64e6b2
Merge master into impl_svd
vlad-perevezentsev Jan 24, 2024
647befa
Add TODO : matching the order of returned arrays
vlad-perevezentsev Jan 24, 2024
eeb272b
Merge master into impl_svd
vlad-perevezentsev Jan 26, 2024
405530c
Merge master into impl_svd
vlad-perevezentsev Feb 1, 2024
036e1df
Skip cupy tests on windows
vlad-perevezentsev Feb 1, 2024
0fef6ce
Rename condition to _condition
vlad-perevezentsev Feb 1, 2024
7d7a01c
Tune tol in TestSvd
vlad-perevezentsev Feb 1, 2024
3c1c3be
Merge master into impl_svd
vlad-perevezentsev Feb 1, 2024
6344767
Set setUpClass to skip cupy tests on cpu
vlad-perevezentsev Feb 2, 2024
cee13dc
Merge master into impl_svd
vlad-perevezentsev Feb 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dpnp/backend/extensions/lapack/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ set(python_module_name _lapack_impl)
set(_module_src
${CMAKE_CURRENT_SOURCE_DIR}/lapack_py.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gesv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gesvd.cpp
${CMAKE_CURRENT_SOURCE_DIR}/getrf.cpp
${CMAKE_CURRENT_SOURCE_DIR}/getrf_batch.cpp
${CMAKE_CURRENT_SOURCE_DIR}/getri_batch.cpp
Expand Down
359 changes: 359 additions & 0 deletions dpnp/backend/extensions/lapack/gesvd.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,359 @@
//*****************************************************************************
// Copyright (c) 2023, Intel Corporation
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
// - Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
// - Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
// THE POSSIBILITY OF SUCH DAMAGE.
//*****************************************************************************

#include <pybind11/pybind11.h>

// dpctl tensor headers
#include "utils/memory_overlap.hpp"
#include "utils/type_utils.hpp"

#include "gesvd.hpp"
#include "types_matrix.hpp"

#include "dpnp_utils.hpp"

namespace dpnp
{
namespace backend
{
namespace ext
{
namespace lapack
{
namespace mkl_lapack = oneapi::mkl::lapack;
namespace py = pybind11;
namespace type_utils = dpctl::tensor::type_utils;

typedef sycl::event (*gesvd_impl_fn_ptr_t)(sycl::queue,
const oneapi::mkl::jobsvd,
const oneapi::mkl::jobsvd,
const std::int64_t,
const std::int64_t,
char *,
const std::int64_t,
char *,
char *,
const std::int64_t,
char *,
const std::int64_t,
std::vector<sycl::event> &,
const std::vector<sycl::event> &);

static gesvd_impl_fn_ptr_t gesvd_dispatch_table[dpctl_td_ns::num_types]
[dpctl_td_ns::num_types];

// Converts a given character code (ord) to the corresponding
// oneapi::mkl::jobsvd enumeration value
static oneapi::mkl::jobsvd process_job(std::int8_t job_val)
{
switch (job_val) {
case 'A':
return oneapi::mkl::jobsvd::vectors;
case 'S':
return oneapi::mkl::jobsvd::somevec;
case 'O':
return oneapi::mkl::jobsvd::vectorsina;
case 'N':
return oneapi::mkl::jobsvd::novec;
default:
throw std::invalid_argument("Unknown value for job");
}
}

template <typename T, typename RealT>
static sycl::event gesvd_impl(sycl::queue exec_q,
const oneapi::mkl::jobsvd jobu,
const oneapi::mkl::jobsvd jobvt,
const std::int64_t m,
const std::int64_t n,
char *in_a,
const std::int64_t lda,
char *out_s,
char *out_u,
const std::int64_t ldu,
char *out_vt,
const std::int64_t ldvt,
std::vector<sycl::event> &host_task_events,
const std::vector<sycl::event> &depends)
{
type_utils::validate_type_for_device<T>(exec_q);
type_utils::validate_type_for_device<RealT>(exec_q);

T *a = reinterpret_cast<T *>(in_a);
RealT *s = reinterpret_cast<RealT *>(out_s);
T *u = reinterpret_cast<T *>(out_u);
T *vt = reinterpret_cast<T *>(out_vt);

const std::int64_t scratchpad_size = mkl_lapack::gesvd_scratchpad_size<T>(
exec_q, jobu, jobvt, m, n, lda, ldu, ldvt);
T *scratchpad = nullptr;

std::stringstream error_msg;
std::int64_t info = 0;
bool is_exception_caught = false;

sycl::event gesvd_event;
try {
scratchpad = sycl::malloc_device<T>(scratchpad_size, exec_q);

gesvd_event = mkl_lapack::gesvd(
exec_q,
jobu, // Character specifying how to compute the matrix U:
// 'A' computes all columns of U,
// 'S' computes the first min(m,n) columns of U,
// 'O' overwrites A with the columns of U,
// 'N' does not compute U.
jobvt, // Character specifying how to compute the matrix VT:
// 'A' computes all rows of VT,
// 'S' computes the first min(m,n) rows of VT,
// 'O' overwrites A with the rows of VT,
// 'N' does not compute VT.
m, // The number of rows in the input matrix A (0 <= m).
n, // The number of columns in the input matrix A (0 <= n).
a, // Pointer to the input matrix A of size (m x n).
lda, // The leading dimension of A, must be at least max(1, m).
s, // Pointer to the array containing the singular values.
u, // Pointer to the matrix U in the singular value decomposition.
ldu, // The leading dimension of U, must be at least max(1, m).
vt, // Pointer to the matrix VT in the singular value decomposition.
ldvt, // The leading dimension of VT, must be at least max(1, n).
scratchpad, // Pointer to scratchpad memory to be used by MKL
// routine for storing intermediate results.
scratchpad_size, depends);
} catch (mkl_lapack::exception const &e) {
is_exception_caught = true;
info = e.info();
if (info < 0) {
error_msg << "Parameter number " << -info
<< " had an illegal value.";
}
else if (info == scratchpad_size && e.detail() != 0) {
error_msg
<< "Insufficient scratchpad size. Required size is at least "
<< e.detail();
}
else if (info > 0) {
error_msg << "The algorithm computing SVD failed to converge; "
<< info << " off-diagonal elements of an intermediate "
<< "bidiagonal form did not converge to zero.\n";
}
else {
error_msg << "Unexpected MKL exception caught during gesvd() "
"call:\nreason: "
<< e.what() << "\ninfo: " << e.info();
}
} catch (sycl::exception const &e) {
is_exception_caught = true;
error_msg << "Unexpected SYCL exception caught during gesvd() call:\n"
<< e.what();
}

if (is_exception_caught) // an unexpected error occurs
{
if (scratchpad != nullptr) {
sycl::free(scratchpad, exec_q);
}
throw std::runtime_error(error_msg.str());
}

sycl::event clean_up_event = exec_q.submit([&](sycl::handler &cgh) {
cgh.depends_on(gesvd_event);
auto ctx = exec_q.get_context();
cgh.host_task([ctx, scratchpad]() { sycl::free(scratchpad, ctx); });
});
host_task_events.push_back(clean_up_event);
return gesvd_event;
}

std::pair<sycl::event, sycl::event>
gesvd(sycl::queue exec_q,
const std::int8_t jobu_val,
const std::int8_t jobvt_val,
dpctl::tensor::usm_ndarray a_array,
dpctl::tensor::usm_ndarray out_s,
dpctl::tensor::usm_ndarray out_u,
dpctl::tensor::usm_ndarray out_vt,
const std::vector<sycl::event> &depends)
{
const int a_array_nd = a_array.get_ndim();
const int out_u_array_nd = out_u.get_ndim();
const int out_s_array_nd = out_s.get_ndim();
const int out_vt_array_nd = out_vt.get_ndim();

if (a_array_nd != 2) {
throw py::value_error(
"The input array has ndim=" + std::to_string(a_array_nd) +
", but a 2-dimensional array is expected.");
}

if (out_s_array_nd != 1) {
throw py::value_error("The output array of singular values has ndim=" +
std::to_string(out_s_array_nd) +
", but a 1-dimensional array is expected.");
}

if (jobu_val == 'N' && jobvt_val == 'N') {
if (out_u_array_nd != 0) {
throw py::value_error(
"The output array of the left singular vectors has ndim=" +
std::to_string(out_u_array_nd) +
", but it is not used and should have ndim=0.");
}
if (out_vt_array_nd != 0) {
throw py::value_error(
"The output array of the right singular vectors has ndim=" +
std::to_string(out_vt_array_nd) +
", but it is not used and should have ndim=0.");
}
}
else {
if (out_u_array_nd != 2) {
throw py::value_error(
"The output array of the left singular vectors has ndim=" +
std::to_string(out_u_array_nd) +
", but a 2-dimensional array is expected.");
}
if (out_vt_array_nd != 2) {
throw py::value_error(
"The output array of the right singular vectors has ndim=" +
std::to_string(out_vt_array_nd) +
", but a 2-dimensional array is expected.");
}
}

// check compatibility of execution queue and allocation queue
if (!dpctl::utils::queues_are_compatible(
exec_q, {a_array.get_queue(), out_s.get_queue(), out_u.get_queue(),
out_vt.get_queue()}))
{
throw std::runtime_error(
"USM allocations are not compatible with the execution queue.");
}

auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
if (overlap(a_array, out_s) || overlap(a_array, out_u) ||
overlap(a_array, out_vt) || overlap(out_s, out_u) ||
overlap(out_s, out_vt) || overlap(out_u, out_vt))
{
throw py::value_error("Arrays have overlapping segments of memory");
}

bool is_a_array_f_contig = a_array.is_f_contiguous();
if (!is_a_array_f_contig) {
throw py::value_error("The input array must be F-contiguous");
}

bool is_out_u_array_f_contig = out_u.is_f_contiguous();
bool is_out_vt_array_f_contig = out_vt.is_f_contiguous();

if (!is_out_u_array_f_contig || !is_out_vt_array_f_contig) {
throw py::value_error("The output arrays of the left and right "
"singular vectors must be F-contiguous");
}

bool is_out_s_array_c_contig = out_s.is_c_contiguous();
bool is_out_s_array_f_contig = out_s.is_f_contiguous();

if (!is_out_s_array_c_contig || !is_out_s_array_f_contig) {
throw py::value_error("The output array of singular values "
"must be contiguous");
}

auto array_types = dpctl_td_ns::usm_ndarray_types();
int a_array_type_id =
array_types.typenum_to_lookup_id(a_array.get_typenum());
int out_u_type_id = array_types.typenum_to_lookup_id(out_u.get_typenum());
int out_s_type_id = array_types.typenum_to_lookup_id(out_s.get_typenum());
int out_vt_type_id = array_types.typenum_to_lookup_id(out_vt.get_typenum());

if (a_array_type_id != out_u_type_id || a_array_type_id != out_vt_type_id) {
throw py::type_error(
"Input array, output left singular vectors array, "
"and outpuy right singular vectors array must have "
"the same data type");
}

gesvd_impl_fn_ptr_t gesvd_fn =
gesvd_dispatch_table[a_array_type_id][out_s_type_id];
if (gesvd_fn == nullptr) {
throw py::value_error(
"No gesvd implementation is defined for the given pair "
"of array type and output singular values type.");
}

char *a_array_data = a_array.get_data();
char *out_s_data = out_s.get_data();
char *out_u_data = out_u.get_data();
char *out_vt_data = out_vt.get_data();

const py::ssize_t *a_array_shape = a_array.get_shape_raw();
const std::int64_t m = a_array_shape[0];
const std::int64_t n = a_array_shape[1];

const std::int64_t lda = std::max<size_t>(1UL, m);
const std::int64_t ldu = std::max<size_t>(1UL, m);
const std::int64_t ldvt =
std::max<std::size_t>(1UL, jobvt_val == 'S' ? (m > n ? n : m) : n);

const oneapi::mkl::jobsvd jobu = process_job(jobu_val);
const oneapi::mkl::jobsvd jobvt = process_job(jobvt_val);

std::vector<sycl::event> host_task_events;
sycl::event gesvd_ev =
gesvd_fn(exec_q, jobu, jobvt, m, n, a_array_data, lda, out_s_data,
out_u_data, ldu, out_vt_data, ldvt, host_task_events, depends);

sycl::event args_ev = dpctl::utils::keep_args_alive(
exec_q, {a_array, out_s, out_u, out_vt}, host_task_events);

return std::make_pair(args_ev, gesvd_ev);
}

template <typename fnT, typename T, typename RealT>
struct GesvdContigFactory
{
fnT get()
{
if constexpr (types::GesvdTypePairSupportFactory<T, RealT>::is_defined)
{
return gesvd_impl<T, RealT>;
}
else {
return nullptr;
}
}
};

void init_gesvd_dispatch_table(void)
{
dpctl_td_ns::DispatchTableBuilder<gesvd_impl_fn_ptr_t, GesvdContigFactory,
dpctl_td_ns::num_types>
contig;
contig.populate_dispatch_table(gesvd_dispatch_table);
}
} // namespace lapack
} // namespace ext
} // namespace backend
} // namespace dpnp
Loading