Skip to content

Implement gesv_batch via gesv call #1877

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 50 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
4c7c8c2
Init work
vlad-perevezentsev Jun 6, 2024
8e2cb23
First working version with transpose and C contig
vlad-perevezentsev Jun 7, 2024
67fa435
Second working version with moveaxis, transpose and F contig
vlad-perevezentsev Jun 7, 2024
4f5abec
Add more shape checks
vlad-perevezentsev Jun 11, 2024
0cb2808
Pass sycl::queue by reference for gesv/gesv_batch
vlad-perevezentsev Jun 11, 2024
bfa37d4
qwe
vlad-perevezentsev Jun 11, 2024
4a44292
Update _batched_solve implementation
vlad-perevezentsev Jun 12, 2024
df4774e
Remove old impl in _batched_solve
vlad-perevezentsev Jun 12, 2024
8dbe3c4
Merge master into impl_gesv_batch_via_gesv
vlad-perevezentsev Jun 12, 2024
8fb2af3
Use py::gil_scoped_release before gesv call
vlad-perevezentsev Jun 12, 2024
ddcf9fe
Remove junk files
vlad-perevezentsev Jun 12, 2024
262794f
Move gesv_batch to gesv_batch.cpp
vlad-perevezentsev Jun 13, 2024
3a7b8ca
Improve gesv_batch with independent linear streams
vlad-perevezentsev Jun 13, 2024
2016a8c
Extend checks for gesv/gesv_batch
vlad-perevezentsev Jun 13, 2024
2c42290
Update comment
vlad-perevezentsev Jun 13, 2024
e030da8
junk files
vlad-perevezentsev Jun 14, 2024
3f99ae5
Merge master into impl_gesv_batch_via_gesv
vlad-perevezentsev Jun 17, 2024
a0a683b
Merge master into impl_gesv_batch_via_gesv
vlad-perevezentsev Jul 11, 2024
5a48f33
Add common_gesv_checks
vlad-perevezentsev Jul 12, 2024
924fee7
Release GIL in gesv_batch_impl
vlad-perevezentsev Jul 12, 2024
2b15e6c
Remove junk file
vlad-perevezentsev Jul 12, 2024
5a1cab6
Merge master into impl_gesv_batch_via_gesv
vlad-perevezentsev Jul 12, 2024
b5c3062
Merge master into impl_gesv_batch_via_gesv
vlad-perevezentsev Jul 15, 2024
afca803
Remove junk files
vlad-perevezentsev Jul 16, 2024
ed99888
Merge master into impl_gesv_batch_via_gesv
vlad-perevezentsev Jul 16, 2024
0c97aff
Merge master into impl_gesv_batch_via_gesv
vlad-perevezentsev Jul 19, 2024
1b275ea
Merge master into impl_gesv_batch_via_gesv
vlad-perevezentsev Jul 26, 2024
e5b53a1
Remove host_task_events from gesv
vlad-perevezentsev Jul 26, 2024
d5adbd6
Use check_zeros_shape in gesv and gesv_batch
vlad-perevezentsev Jul 26, 2024
5b2780c
Add additional checks for gesv_impl
vlad-perevezentsev Jul 26, 2024
d4547d4
Move alloc_scratchpad to common_helpers.hpp
vlad-perevezentsev Jul 26, 2024
6759164
Use helper::alloc_scratchpad in gesv_batch_impl
vlad-perevezentsev Jul 26, 2024
f37ec43
Remove current_scratch_gesv check
vlad-perevezentsev Jul 26, 2024
adc17ba
Remove lda, ldb pass to gesv_batch_impl, gesv_impl
vlad-perevezentsev Jul 26, 2024
77ba0e2
Use const and constexpr in gesv/gesv_batch
vlad-perevezentsev Jul 26, 2024
9bf94b5
Applied review comments
vlad-perevezentsev Jul 29, 2024
b81893c
Use dpnp.reshape in _batched_solve
vlad-perevezentsev Jul 29, 2024
f8d68ef
Implement alloc_ipiv in common_helpers.hpp
vlad-perevezentsev Jul 29, 2024
fc6c7fa
Add gesv_common_utils.hpp
vlad-perevezentsev Jul 29, 2024
75079d2
Implement handle_lapack_exc function
vlad-perevezentsev Jul 29, 2024
6e82632
Merge master into impl_gesv_batch_via_gesv
vlad-perevezentsev Jul 29, 2024
7e0f384
Use try/catch for scratchpad/ipiv allocation
vlad-perevezentsev Jul 29, 2024
f5ee368
Update alloc_scratchpad/alloc_ipiv
vlad-perevezentsev Jul 29, 2024
eb8c3a0
gesv_scratchpad_size can be 0
vlad-perevezentsev Jul 30, 2024
3c8cda6
Implement help functions alloc_ipiv/alloc_scratchpad
vlad-perevezentsev Jul 30, 2024
3f4d672
Merge master into impl_gesv_batch_via_gesv
vlad-perevezentsev Jul 30, 2024
e56e07e
Merge master into impl_gesv_batch_via_gesv
vlad-perevezentsev Aug 2, 2024
629b97a
Reuse alloc_scratchpad/ipiv in batch versions
vlad-perevezentsev Aug 2, 2024
a9cc253
Merge master into impl_gesv_batch_via_gesv
vlad-perevezentsev Aug 6, 2024
3786ca2
Merge master into impl_gesv_batch_via_gesv
vlad-perevezentsev Aug 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 164 additions & 8 deletions dpnp/backend/extensions/lapack/gesv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ namespace mkl_lapack = oneapi::mkl::lapack;
namespace py = pybind11;
namespace type_utils = dpctl::tensor::type_utils;

typedef sycl::event (*gesv_impl_fn_ptr_t)(sycl::queue,
typedef sycl::event (*gesv_impl_fn_ptr_t)(sycl::queue &,
const std::int64_t,
const std::int64_t,
char *,
Expand All @@ -61,7 +61,7 @@ typedef sycl::event (*gesv_impl_fn_ptr_t)(sycl::queue,
static gesv_impl_fn_ptr_t gesv_dispatch_vector[dpctl_td_ns::num_types];

template <typename T>
static sycl::event gesv_impl(sycl::queue exec_q,
static sycl::event gesv_impl(sycl::queue &exec_q,
const std::int64_t n,
const std::int64_t nrhs,
char *in_a,
Expand Down Expand Up @@ -176,7 +176,7 @@ static sycl::event gesv_impl(sycl::queue exec_q,
}

std::pair<sycl::event, sycl::event>
gesv(sycl::queue exec_q,
gesv(sycl::queue &exec_q,
dpctl::tensor::usm_ndarray coeff_matrix,
dpctl::tensor::usm_ndarray dependent_vals,
const std::vector<sycl::event> &depends)
Expand Down Expand Up @@ -212,26 +212,26 @@ std::pair<sycl::event, sycl::event>
{coeff_matrix, dependent_vals}))
{
throw py::value_error(
"Execution queue is not compatible with allocation queues");
"Execution queue is not compatible with allocation queues.");
}

auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
if (overlap(coeff_matrix, dependent_vals)) {
throw py::value_error(
"The arrays of coefficients and dependent variables "
"are overlapping segments of memory");
"are overlapping segments of memory.");
}

bool is_coeff_matrix_f_contig = coeff_matrix.is_f_contiguous();
if (!is_coeff_matrix_f_contig) {
throw py::value_error("The coefficient matrix "
"must be F-contiguous");
"must be F-contiguous.");
}

bool is_dependent_vals_f_contig = dependent_vals.is_f_contiguous();
if (!is_dependent_vals_f_contig) {
throw py::value_error("The array of dependent variables "
"must be F-contiguous");
"must be F-contiguous.");
}

auto array_types = dpctl_td_ns::usm_ndarray_types();
Expand All @@ -242,7 +242,7 @@ std::pair<sycl::event, sycl::event>

if (coeff_matrix_type_id != dependent_vals_type_id) {
throw py::value_error("The types of the coefficient matrix and "
"dependent variables are mismatched");
"dependent variables are mismatched.");
}

gesv_impl_fn_ptr_t gesv_fn = gesv_dispatch_vector[coeff_matrix_type_id];
Expand Down Expand Up @@ -273,6 +273,162 @@ std::pair<sycl::event, sycl::event>
return std::make_pair(args_ev, gesv_ev);
}

std::pair<sycl::event, sycl::event>
gesv_batch(sycl::queue &exec_q,
dpctl::tensor::usm_ndarray coeff_matrix,
dpctl::tensor::usm_ndarray dependent_vals,
const std::vector<sycl::event> &depends)
{
const int coeff_matrix_nd = coeff_matrix.get_ndim();
const int dependent_vals_nd = dependent_vals.get_ndim();

if (coeff_matrix_nd != 3) {
throw py::value_error("The coefficient matrix has ndim=" +
std::to_string(coeff_matrix_nd) +
", but a 3-dimensional array is expected.");
}

if (dependent_vals_nd < 2 || dependent_vals_nd > 3) {
throw py::value_error(
"The dependent values array has ndim=" +
std::to_string(dependent_vals_nd) +
", but a 2-dimensional or 3-dimensional array is expected.");
}

const py::ssize_t *coeff_matrix_shape = coeff_matrix.get_shape_raw();
const py::ssize_t *dependent_vals_shape = dependent_vals.get_shape_raw();

// The coeff_matrix and dependent_vals arrays must be F-contiguous arrays
// with the shapes (n,n,batch_size) and (n,nrhs,batch_size) or
// (n,batch_size) respectively
if (coeff_matrix_shape[0] != coeff_matrix_shape[1]) {
throw py::value_error("The coefficient matrix must be square,"
" but got a shape of (" +
std::to_string(coeff_matrix_shape[0]) + ", " +
std::to_string(coeff_matrix_shape[1]) + ").");
}

if (coeff_matrix_shape[0] != dependent_vals_shape[0]) {
throw py::value_error("The first dimension (n) of coeff_matrix and"
" dependent_vals must be the same, but got " +
std::to_string(coeff_matrix_shape[0]) + " and " +
std::to_string(dependent_vals_shape[0]) + ".");
}

if (dependent_vals_nd == 2) {
if (coeff_matrix_shape[2] != dependent_vals_shape[1]) {
throw py::value_error(
"The batch_size of "
" coeff_matrix and dependent_vals must be"
" the same, but got " +
std::to_string(coeff_matrix_shape[2]) + " and " +
std::to_string(dependent_vals_shape[1]) + ".");
}
}
else if (dependent_vals_nd == 3) {
if (coeff_matrix_shape[2] != dependent_vals_shape[2]) {
throw py::value_error(
"The batch_size of "
" coeff_matrix and dependent_vals must be"
" the same, but got " +
std::to_string(coeff_matrix_shape[2]) + " and " +
std::to_string(dependent_vals_shape[2]) + ".");
}
}

// check compatibility of execution queue and allocation queue
if (!dpctl::utils::queues_are_compatible(exec_q,
{coeff_matrix, dependent_vals}))
{
throw py::value_error(
"Execution queue is not compatible with allocation queues.");
}

auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
if (overlap(coeff_matrix, dependent_vals)) {
throw py::value_error(
"The arrays of coefficients and dependent variables "
"are overlapping segments of memory.");
}

bool is_coeff_matrix_f_contig = coeff_matrix.is_f_contiguous();
if (!is_coeff_matrix_f_contig) {
throw py::value_error("The coefficient matrix "
"must be F-contiguous.");
}

bool is_dependent_vals_f_contig = dependent_vals.is_f_contiguous();
if (!is_dependent_vals_f_contig) {
throw py::value_error("The array of dependent variables "
"must be F-contiguous.");
}

auto array_types = dpctl_td_ns::usm_ndarray_types();
int coeff_matrix_type_id =
array_types.typenum_to_lookup_id(coeff_matrix.get_typenum());
int dependent_vals_type_id =
array_types.typenum_to_lookup_id(dependent_vals.get_typenum());

if (coeff_matrix_type_id != dependent_vals_type_id) {
throw py::value_error("The types of the coefficient matrix and "
"dependent variables are mismatched.");
}

gesv_impl_fn_ptr_t gesv_fn = gesv_dispatch_vector[coeff_matrix_type_id];
if (gesv_fn == nullptr) {
throw py::value_error(
"No gesv implementation defined for the provided type "
"of the coefficient matrix.");
}

char *coeff_matrix_data = coeff_matrix.get_data();
char *dependent_vals_data = dependent_vals.get_data();

const std::int64_t batch_size = coeff_matrix_shape[2];
const std::int64_t n = coeff_matrix_shape[1];
const std::int64_t nrhs =
(dependent_vals_nd > 2) ? dependent_vals_shape[1] : 1;

const std::int64_t lda = std::max<size_t>(1UL, n);
const std::int64_t ldb = std::max<size_t>(1UL, n);

int coeff_matrix_elemsize = coeff_matrix.get_elemsize();
int dependent_vals_elemsize = dependent_vals.get_elemsize();

std::vector<sycl::event> host_task_events;
std::vector<sycl::event> gesv_task_events;

host_task_events.reserve(batch_size);
gesv_task_events.reserve(batch_size);

{
// Release GIL to allow other Python threads to run during the loop
// as the operations in the loop do not require GIL
py::gil_scoped_release release;

for (std::int64_t i = 0; i < batch_size; ++i) {
char *coeff_matrix_batch =
coeff_matrix_data + i * n * n * coeff_matrix_elemsize;
char *dependent_vals_batch =
dependent_vals_data + i * n * nrhs * dependent_vals_elemsize;

sycl::event gesv_ev =
gesv_fn(exec_q, n, nrhs, coeff_matrix_batch, lda,
dependent_vals_batch, ldb, host_task_events, depends);

gesv_task_events.push_back(gesv_ev);
}
}

sycl::event combine_ev = exec_q.submit(
[&](sycl::handler &cgh) { cgh.depends_on(gesv_task_events); });

sycl::event args_ev = dpctl::utils::keep_args_alive(
exec_q, {coeff_matrix, dependent_vals}, host_task_events);

return std::make_pair(args_ev, combine_ev);
}

template <typename fnT, typename T>
struct GesvContigFactory
{
Expand Down
8 changes: 7 additions & 1 deletion dpnp/backend/extensions/lapack/gesv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,17 @@ namespace ext
namespace lapack
{
extern std::pair<sycl::event, sycl::event>
gesv(sycl::queue exec_q,
gesv(sycl::queue &exec_q,
dpctl::tensor::usm_ndarray coeff_matrix,
dpctl::tensor::usm_ndarray dependent_vals,
const std::vector<sycl::event> &depends);

extern std::pair<sycl::event, sycl::event>
gesv_batch(sycl::queue &exec_q,
dpctl::tensor::usm_ndarray coeff_matrix,
dpctl::tensor::usm_ndarray dependent_vals,
const std::vector<sycl::event> &depends);

extern void init_gesv_dispatch_vector(void);
} // namespace lapack
} // namespace ext
Expand Down
7 changes: 7 additions & 0 deletions dpnp/backend/extensions/lapack/lapack_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,13 @@ PYBIND11_MODULE(_lapack_impl, m)
py::arg("sycl_queue"), py::arg("coeff_matrix"),
py::arg("dependent_vals"), py::arg("depends") = py::list());

m.def("_gesv_batch", &lapack_ext::gesv_batch,
"Call `gesv` from OneMKL LAPACK library to return "
"the batch solution of a system of linear equations with "
"a square coefficient matrix A and multiple dependent variables",
py::arg("sycl_queue"), py::arg("coeff_matrix"),
py::arg("dependent_vals"), py::arg("depends") = py::list());

m.def("_gesvd", &lapack_ext::gesvd,
"Call `gesvd` from OneMKL LAPACK library to return "
"the singular value decomposition of a general rectangular matrix",
Expand Down
Loading
Loading