-
Notifications
You must be signed in to change notification settings - Fork 22
Update dpnp.linalg.svd() function #1604
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
78 commits
Select commit
Hold shift + click to select a range
d464ba3
Draft commit of dpnp.linalg.svd impl
vlad-perevezentsev 48d9b61
Pass empty arrays if compute_uv=False
vlad-perevezentsev f6b4c1f
Remove unnecessary comments
vlad-perevezentsev 9a4d62e
Add logic for the input array n < m
vlad-perevezentsev a942dcf
Small changes
vlad-perevezentsev 54febf6
Add a new cupy test_decomposition
vlad-perevezentsev b292afb
Merge master into impl_svd
vlad-perevezentsev 4da97be
Rename gesvd input parameters
vlad-perevezentsev 303f23a
Correspondence of passed parameters to gesvd signature
vlad-perevezentsev 80d5e40
Correct initialization of result variables in dpnp_svd
vlad-perevezentsev e386f3d
Update test_decomposition
vlad-perevezentsev da6d0c9
Add implementation of _dpnp_svd_batch
vlad-perevezentsev d59ac0c
Add test_decomposition to the scope of public CI
vlad-perevezentsev 02270f8
Use mkl_lapack
vlad-perevezentsev 8d01e4f
Improve error handling for mkl_lapack::gesvd function
vlad-perevezentsev 6490d8d
Declate detail variable
vlad-perevezentsev ca427c6
Use a_usm_type and a_sycl_queue variables
vlad-perevezentsev 327e8df
Check a.ndim < 2
vlad-perevezentsev 0a99b84
Add additional checks for gesvd function
vlad-perevezentsev d03405e
Merge master into impl_svd
vlad-perevezentsev 25bfbb5
Remove old dpnp_svd backend
vlad-perevezentsev f123ece
Refresh test_svd in test_linalg
vlad-perevezentsev a96fc7e
Merge master into impl_svd
vlad-perevezentsev 7cf7d31
Add detailed comments for gesvd arguments
vlad-perevezentsev 3146428
gesvd returns pair of events and uses dpctl.utils.keep_args_alive
vlad-perevezentsev 5997a3a
Keep a lexicographical order
vlad-perevezentsev d1ce945
Update docstrings for svd
vlad-perevezentsev d4e5389
Merge master into impl_svd
vlad-perevezentsev 468c9c8
Merge master into impl_svd
vlad-perevezentsev 1a6f4c0
Add test_svd to test_usm_type
vlad-perevezentsev c08959d
Merge master into impl_svd
vlad-perevezentsev e43b9a3
Remove all TODOs
vlad-perevezentsev 96f243c
Add a new impl to get s_type
vlad-perevezentsev 87825f5
Address remarks
vlad-perevezentsev 9765e35
Add a description for _stacked_identity
vlad-perevezentsev 61257e2
Move internal funcs up
vlad-perevezentsev 61d1ed7
Simplify dpnp_svd_batch
vlad-perevezentsev 33c4f5e
Update tests for dpnp.linalg.svd
vlad-perevezentsev 0be3132
Add hermitian argument support
vlad-perevezentsev 6b63eea
Add test_svd_hermitian
vlad-perevezentsev 12a5cb5
Update svd docstrings
vlad-perevezentsev 3845613
Tune tolerance
vlad-perevezentsev d12d49d
Merge master into impl_svd
vlad-perevezentsev c0e0462
Update test_svd_errors
vlad-perevezentsev ddb4d43
Update _common_type and _common_inexact_type
vlad-perevezentsev 8f4b557
Address the remarks
vlad-perevezentsev 6ba9483
Remove passing n and m parameteres to _gesvd
vlad-perevezentsev 6008ce4
Simplify results return logic for dpnp_svd_batch
vlad-perevezentsev 16369a0
Update condition and random files in cupy/testing to use fix_random a…
vlad-perevezentsev 937d2ac
Merge master into impl_svd
vlad-perevezentsev 9235112
Rename cupy/testing/condition.py to .../_condition.py
vlad-perevezentsev 8815da2
Use self._tol in TestSvd
vlad-perevezentsev f28e9ae
Add TODO for check_decomposition
vlad-perevezentsev 2118d3e
Update gesvd error handler
vlad-perevezentsev c89817a
Merge master into impl_svd
vlad-perevezentsev 1fcfb5e
Merge master into impl_svd
vlad-perevezentsev 347fbe9
dpnp_svd works with F contiguous arrays
vlad-perevezentsev 03d04d7
Merge master into impl_svd
vlad-perevezentsev dd74a36
Add additional checks for output arrays
vlad-perevezentsev 5a4721f
Tune atol for TestSvd
vlad-perevezentsev 31e8bbb
Impl parallel calculation in dpnp_svd_batch
vlad-perevezentsev 876842a
Skip using @_condition.repeat in cupy tests
vlad-perevezentsev 9d5afde
Resolve conflicts
vlad-perevezentsev 1f66d0a
Merge master into impl_svd
vlad-perevezentsev 63f257f
Add additional checks for output arrays
vlad-perevezentsev 8c2a8ff
Update docstrings for svd
vlad-perevezentsev 7a27d4c
Use dpctl.SyclEvent.wait_for in dpnp_svd_batch
vlad-perevezentsev ac4283a
Address remarks
vlad-perevezentsev f64e6b2
Merge master into impl_svd
vlad-perevezentsev 647befa
Add TODO : matching the order of returned arrays
vlad-perevezentsev eeb272b
Merge master into impl_svd
vlad-perevezentsev 405530c
Merge master into impl_svd
vlad-perevezentsev 036e1df
Skip cupy tests on windows
vlad-perevezentsev 0fef6ce
Rename condition to _condition
vlad-perevezentsev 7d7a01c
Tune tol in TestSvd
vlad-perevezentsev 3c1c3be
Merge master into impl_svd
vlad-perevezentsev 6344767
Set setUpClass to skip cupy tests on cpu
vlad-perevezentsev cee13dc
Merge master into impl_svd
vlad-perevezentsev File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,359 @@ | ||
//***************************************************************************** | ||
// Copyright (c) 2023, Intel Corporation | ||
// All rights reserved. | ||
// | ||
// Redistribution and use in source and binary forms, with or without | ||
// modification, are permitted provided that the following conditions are met: | ||
// - Redistributions of source code must retain the above copyright notice, | ||
// this list of conditions and the following disclaimer. | ||
// - Redistributions in binary form must reproduce the above copyright notice, | ||
// this list of conditions and the following disclaimer in the documentation | ||
// and/or other materials provided with the distribution. | ||
// | ||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | ||
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | ||
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE | ||
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR | ||
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF | ||
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS | ||
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN | ||
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) | ||
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF | ||
// THE POSSIBILITY OF SUCH DAMAGE. | ||
//***************************************************************************** | ||
|
||
#include <pybind11/pybind11.h> | ||
|
||
// dpctl tensor headers | ||
#include "utils/memory_overlap.hpp" | ||
#include "utils/type_utils.hpp" | ||
|
||
#include "gesvd.hpp" | ||
#include "types_matrix.hpp" | ||
|
||
#include "dpnp_utils.hpp" | ||
|
||
namespace dpnp | ||
{ | ||
namespace backend | ||
{ | ||
namespace ext | ||
{ | ||
namespace lapack | ||
{ | ||
namespace mkl_lapack = oneapi::mkl::lapack; | ||
namespace py = pybind11; | ||
namespace type_utils = dpctl::tensor::type_utils; | ||
|
||
typedef sycl::event (*gesvd_impl_fn_ptr_t)(sycl::queue, | ||
const oneapi::mkl::jobsvd, | ||
const oneapi::mkl::jobsvd, | ||
const std::int64_t, | ||
const std::int64_t, | ||
char *, | ||
const std::int64_t, | ||
char *, | ||
char *, | ||
const std::int64_t, | ||
char *, | ||
const std::int64_t, | ||
std::vector<sycl::event> &, | ||
const std::vector<sycl::event> &); | ||
|
||
static gesvd_impl_fn_ptr_t gesvd_dispatch_table[dpctl_td_ns::num_types] | ||
[dpctl_td_ns::num_types]; | ||
|
||
// Converts a given character code (ord) to the corresponding | ||
// oneapi::mkl::jobsvd enumeration value | ||
static oneapi::mkl::jobsvd process_job(std::int8_t job_val) | ||
{ | ||
switch (job_val) { | ||
case 'A': | ||
return oneapi::mkl::jobsvd::vectors; | ||
case 'S': | ||
return oneapi::mkl::jobsvd::somevec; | ||
case 'O': | ||
return oneapi::mkl::jobsvd::vectorsina; | ||
case 'N': | ||
return oneapi::mkl::jobsvd::novec; | ||
default: | ||
throw std::invalid_argument("Unknown value for job"); | ||
} | ||
} | ||
|
||
template <typename T, typename RealT> | ||
static sycl::event gesvd_impl(sycl::queue exec_q, | ||
const oneapi::mkl::jobsvd jobu, | ||
const oneapi::mkl::jobsvd jobvt, | ||
const std::int64_t m, | ||
const std::int64_t n, | ||
char *in_a, | ||
const std::int64_t lda, | ||
char *out_s, | ||
char *out_u, | ||
const std::int64_t ldu, | ||
char *out_vt, | ||
const std::int64_t ldvt, | ||
std::vector<sycl::event> &host_task_events, | ||
const std::vector<sycl::event> &depends) | ||
{ | ||
type_utils::validate_type_for_device<T>(exec_q); | ||
type_utils::validate_type_for_device<RealT>(exec_q); | ||
|
||
T *a = reinterpret_cast<T *>(in_a); | ||
RealT *s = reinterpret_cast<RealT *>(out_s); | ||
T *u = reinterpret_cast<T *>(out_u); | ||
T *vt = reinterpret_cast<T *>(out_vt); | ||
|
||
const std::int64_t scratchpad_size = mkl_lapack::gesvd_scratchpad_size<T>( | ||
exec_q, jobu, jobvt, m, n, lda, ldu, ldvt); | ||
T *scratchpad = nullptr; | ||
|
||
std::stringstream error_msg; | ||
std::int64_t info = 0; | ||
bool is_exception_caught = false; | ||
|
||
sycl::event gesvd_event; | ||
try { | ||
scratchpad = sycl::malloc_device<T>(scratchpad_size, exec_q); | ||
|
||
gesvd_event = mkl_lapack::gesvd( | ||
exec_q, | ||
jobu, // Character specifying how to compute the matrix U: | ||
// 'A' computes all columns of U, | ||
// 'S' computes the first min(m,n) columns of U, | ||
// 'O' overwrites A with the columns of U, | ||
// 'N' does not compute U. | ||
jobvt, // Character specifying how to compute the matrix VT: | ||
// 'A' computes all rows of VT, | ||
// 'S' computes the first min(m,n) rows of VT, | ||
// 'O' overwrites A with the rows of VT, | ||
// 'N' does not compute VT. | ||
m, // The number of rows in the input matrix A (0 <= m). | ||
n, // The number of columns in the input matrix A (0 <= n). | ||
a, // Pointer to the input matrix A of size (m x n). | ||
lda, // The leading dimension of A, must be at least max(1, m). | ||
s, // Pointer to the array containing the singular values. | ||
u, // Pointer to the matrix U in the singular value decomposition. | ||
ldu, // The leading dimension of U, must be at least max(1, m). | ||
vt, // Pointer to the matrix VT in the singular value decomposition. | ||
ldvt, // The leading dimension of VT, must be at least max(1, n). | ||
scratchpad, // Pointer to scratchpad memory to be used by MKL | ||
// routine for storing intermediate results. | ||
scratchpad_size, depends); | ||
} catch (mkl_lapack::exception const &e) { | ||
is_exception_caught = true; | ||
info = e.info(); | ||
if (info < 0) { | ||
error_msg << "Parameter number " << -info | ||
<< " had an illegal value."; | ||
} | ||
else if (info == scratchpad_size && e.detail() != 0) { | ||
error_msg | ||
<< "Insufficient scratchpad size. Required size is at least " | ||
<< e.detail(); | ||
} | ||
else if (info > 0) { | ||
error_msg << "The algorithm computing SVD failed to converge; " | ||
<< info << " off-diagonal elements of an intermediate " | ||
<< "bidiagonal form did not converge to zero.\n"; | ||
} | ||
else { | ||
error_msg << "Unexpected MKL exception caught during gesvd() " | ||
"call:\nreason: " | ||
<< e.what() << "\ninfo: " << e.info(); | ||
} | ||
} catch (sycl::exception const &e) { | ||
is_exception_caught = true; | ||
error_msg << "Unexpected SYCL exception caught during gesvd() call:\n" | ||
<< e.what(); | ||
} | ||
|
||
if (is_exception_caught) // an unexpected error occurs | ||
{ | ||
if (scratchpad != nullptr) { | ||
sycl::free(scratchpad, exec_q); | ||
} | ||
throw std::runtime_error(error_msg.str()); | ||
} | ||
|
||
sycl::event clean_up_event = exec_q.submit([&](sycl::handler &cgh) { | ||
cgh.depends_on(gesvd_event); | ||
auto ctx = exec_q.get_context(); | ||
cgh.host_task([ctx, scratchpad]() { sycl::free(scratchpad, ctx); }); | ||
}); | ||
host_task_events.push_back(clean_up_event); | ||
return gesvd_event; | ||
} | ||
|
||
std::pair<sycl::event, sycl::event> | ||
gesvd(sycl::queue exec_q, | ||
const std::int8_t jobu_val, | ||
const std::int8_t jobvt_val, | ||
dpctl::tensor::usm_ndarray a_array, | ||
dpctl::tensor::usm_ndarray out_s, | ||
dpctl::tensor::usm_ndarray out_u, | ||
dpctl::tensor::usm_ndarray out_vt, | ||
const std::vector<sycl::event> &depends) | ||
{ | ||
const int a_array_nd = a_array.get_ndim(); | ||
const int out_u_array_nd = out_u.get_ndim(); | ||
const int out_s_array_nd = out_s.get_ndim(); | ||
const int out_vt_array_nd = out_vt.get_ndim(); | ||
|
||
if (a_array_nd != 2) { | ||
vlad-perevezentsev marked this conversation as resolved.
Show resolved
Hide resolved
|
||
throw py::value_error( | ||
"The input array has ndim=" + std::to_string(a_array_nd) + | ||
", but a 2-dimensional array is expected."); | ||
} | ||
|
||
if (out_s_array_nd != 1) { | ||
throw py::value_error("The output array of singular values has ndim=" + | ||
std::to_string(out_s_array_nd) + | ||
", but a 1-dimensional array is expected."); | ||
} | ||
|
||
if (jobu_val == 'N' && jobvt_val == 'N') { | ||
if (out_u_array_nd != 0) { | ||
throw py::value_error( | ||
"The output array of the left singular vectors has ndim=" + | ||
std::to_string(out_u_array_nd) + | ||
", but it is not used and should have ndim=0."); | ||
} | ||
if (out_vt_array_nd != 0) { | ||
throw py::value_error( | ||
"The output array of the right singular vectors has ndim=" + | ||
std::to_string(out_vt_array_nd) + | ||
", but it is not used and should have ndim=0."); | ||
} | ||
} | ||
else { | ||
if (out_u_array_nd != 2) { | ||
throw py::value_error( | ||
"The output array of the left singular vectors has ndim=" + | ||
std::to_string(out_u_array_nd) + | ||
", but a 2-dimensional array is expected."); | ||
} | ||
if (out_vt_array_nd != 2) { | ||
throw py::value_error( | ||
"The output array of the right singular vectors has ndim=" + | ||
std::to_string(out_vt_array_nd) + | ||
", but a 2-dimensional array is expected."); | ||
} | ||
} | ||
|
||
// check compatibility of execution queue and allocation queue | ||
if (!dpctl::utils::queues_are_compatible( | ||
exec_q, {a_array.get_queue(), out_s.get_queue(), out_u.get_queue(), | ||
out_vt.get_queue()})) | ||
{ | ||
throw std::runtime_error( | ||
"USM allocations are not compatible with the execution queue."); | ||
} | ||
|
||
auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); | ||
if (overlap(a_array, out_s) || overlap(a_array, out_u) || | ||
overlap(a_array, out_vt) || overlap(out_s, out_u) || | ||
overlap(out_s, out_vt) || overlap(out_u, out_vt)) | ||
{ | ||
throw py::value_error("Arrays have overlapping segments of memory"); | ||
} | ||
|
||
bool is_a_array_f_contig = a_array.is_f_contiguous(); | ||
if (!is_a_array_f_contig) { | ||
throw py::value_error("The input array must be F-contiguous"); | ||
} | ||
|
||
bool is_out_u_array_f_contig = out_u.is_f_contiguous(); | ||
bool is_out_vt_array_f_contig = out_vt.is_f_contiguous(); | ||
|
||
if (!is_out_u_array_f_contig || !is_out_vt_array_f_contig) { | ||
throw py::value_error("The output arrays of the left and right " | ||
"singular vectors must be F-contiguous"); | ||
} | ||
|
||
bool is_out_s_array_c_contig = out_s.is_c_contiguous(); | ||
bool is_out_s_array_f_contig = out_s.is_f_contiguous(); | ||
|
||
if (!is_out_s_array_c_contig || !is_out_s_array_f_contig) { | ||
throw py::value_error("The output array of singular values " | ||
antonwolfy marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"must be contiguous"); | ||
} | ||
|
||
auto array_types = dpctl_td_ns::usm_ndarray_types(); | ||
int a_array_type_id = | ||
array_types.typenum_to_lookup_id(a_array.get_typenum()); | ||
int out_u_type_id = array_types.typenum_to_lookup_id(out_u.get_typenum()); | ||
int out_s_type_id = array_types.typenum_to_lookup_id(out_s.get_typenum()); | ||
int out_vt_type_id = array_types.typenum_to_lookup_id(out_vt.get_typenum()); | ||
|
||
if (a_array_type_id != out_u_type_id || a_array_type_id != out_vt_type_id) { | ||
throw py::type_error( | ||
"Input array, output left singular vectors array, " | ||
"and outpuy right singular vectors array must have " | ||
"the same data type"); | ||
} | ||
|
||
gesvd_impl_fn_ptr_t gesvd_fn = | ||
gesvd_dispatch_table[a_array_type_id][out_s_type_id]; | ||
if (gesvd_fn == nullptr) { | ||
throw py::value_error( | ||
"No gesvd implementation is defined for the given pair " | ||
"of array type and output singular values type."); | ||
} | ||
|
||
char *a_array_data = a_array.get_data(); | ||
char *out_s_data = out_s.get_data(); | ||
char *out_u_data = out_u.get_data(); | ||
char *out_vt_data = out_vt.get_data(); | ||
|
||
const py::ssize_t *a_array_shape = a_array.get_shape_raw(); | ||
const std::int64_t m = a_array_shape[0]; | ||
const std::int64_t n = a_array_shape[1]; | ||
|
||
const std::int64_t lda = std::max<size_t>(1UL, m); | ||
const std::int64_t ldu = std::max<size_t>(1UL, m); | ||
const std::int64_t ldvt = | ||
std::max<std::size_t>(1UL, jobvt_val == 'S' ? (m > n ? n : m) : n); | ||
|
||
const oneapi::mkl::jobsvd jobu = process_job(jobu_val); | ||
const oneapi::mkl::jobsvd jobvt = process_job(jobvt_val); | ||
|
||
std::vector<sycl::event> host_task_events; | ||
sycl::event gesvd_ev = | ||
gesvd_fn(exec_q, jobu, jobvt, m, n, a_array_data, lda, out_s_data, | ||
out_u_data, ldu, out_vt_data, ldvt, host_task_events, depends); | ||
|
||
sycl::event args_ev = dpctl::utils::keep_args_alive( | ||
exec_q, {a_array, out_s, out_u, out_vt}, host_task_events); | ||
|
||
return std::make_pair(args_ev, gesvd_ev); | ||
} | ||
|
||
template <typename fnT, typename T, typename RealT> | ||
struct GesvdContigFactory | ||
{ | ||
fnT get() | ||
{ | ||
if constexpr (types::GesvdTypePairSupportFactory<T, RealT>::is_defined) | ||
{ | ||
return gesvd_impl<T, RealT>; | ||
} | ||
else { | ||
return nullptr; | ||
} | ||
} | ||
}; | ||
|
||
void init_gesvd_dispatch_table(void) | ||
{ | ||
dpctl_td_ns::DispatchTableBuilder<gesvd_impl_fn_ptr_t, GesvdContigFactory, | ||
dpctl_td_ns::num_types> | ||
contig; | ||
contig.populate_dispatch_table(gesvd_dispatch_table); | ||
} | ||
} // namespace lapack | ||
} // namespace ext | ||
} // namespace backend | ||
} // namespace dpnp |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.