Skip to content

update BLAS extension routines #1884

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 30 additions & 27 deletions dpnp/backend/extensions/blas/blas_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,17 @@
#include "gemm.hpp"
#include "gemv.hpp"

namespace blas_ext = dpnp::backend::ext::blas;
namespace blas_ns = dpnp::extensions::blas;
namespace py = pybind11;
namespace dot_ext = blas_ext::dot;
using dot_ext::dot_impl_fn_ptr_t;
namespace dot_ns = blas_ns::dot;
using dot_ns::dot_impl_fn_ptr_t;

// populate dispatch vectors and tables
void init_dispatch_vectors_tables(void)
{
blas_ext::init_gemm_batch_dispatch_table();
blas_ext::init_gemm_dispatch_table();
blas_ext::init_gemv_dispatch_vector();
blas_ns::init_gemm_batch_dispatch_table();
blas_ns::init_gemm_dispatch_table();
blas_ns::init_gemv_dispatch_vector();
}

static dot_impl_fn_ptr_t dot_dispatch_vector[dpctl_td_ns::num_types];
Expand All @@ -62,14 +62,15 @@ PYBIND11_MODULE(_blas_impl, m)
using event_vecT = std::vector<sycl::event>;

{
dot_ext::init_dot_dispatch_vector<dot_impl_fn_ptr_t,
blas_ext::DotContigFactory>(
dot_ns::init_dot_dispatch_vector<dot_impl_fn_ptr_t,
blas_ns::DotContigFactory>(
dot_dispatch_vector);

auto dot_pyapi = [&](sycl::queue exec_q, arrayT src1, arrayT src2,
arrayT dst, const event_vecT &depends = {}) {
return dot_ext::dot_func(exec_q, src1, src2, dst, depends,
dot_dispatch_vector);
auto dot_pyapi = [&](sycl::queue &exec_q, const arrayT &src1,
const arrayT &src2, const arrayT &dst,
const event_vecT &depends = {}) {
return dot_ns::dot_func(exec_q, src1, src2, dst, depends,
dot_dispatch_vector);
};

m.def("_dot", dot_pyapi,
Expand All @@ -80,14 +81,15 @@ PYBIND11_MODULE(_blas_impl, m)
}

{
dot_ext::init_dot_dispatch_vector<dot_impl_fn_ptr_t,
blas_ext::DotcContigFactory>(
dot_ns::init_dot_dispatch_vector<dot_impl_fn_ptr_t,
blas_ns::DotcContigFactory>(
dotc_dispatch_vector);

auto dotc_pyapi = [&](sycl::queue exec_q, arrayT src1, arrayT src2,
arrayT dst, const event_vecT &depends = {}) {
return dot_ext::dot_func(exec_q, src1, src2, dst, depends,
dotc_dispatch_vector);
auto dotc_pyapi = [&](sycl::queue &exec_q, const arrayT &src1,
const arrayT &src2, const arrayT &dst,
const event_vecT &depends = {}) {
return dot_ns::dot_func(exec_q, src1, src2, dst, depends,
dotc_dispatch_vector);
};

m.def("_dotc", dotc_pyapi,
Expand All @@ -99,14 +101,15 @@ PYBIND11_MODULE(_blas_impl, m)
}

{
dot_ext::init_dot_dispatch_vector<dot_impl_fn_ptr_t,
blas_ext::DotuContigFactory>(
dot_ns::init_dot_dispatch_vector<dot_impl_fn_ptr_t,
blas_ns::DotuContigFactory>(
dotu_dispatch_vector);

auto dotu_pyapi = [&](sycl::queue exec_q, arrayT src1, arrayT src2,
arrayT dst, const event_vecT &depends = {}) {
return dot_ext::dot_func(exec_q, src1, src2, dst, depends,
dotu_dispatch_vector);
auto dotu_pyapi = [&](sycl::queue &exec_q, const arrayT &src1,
const arrayT &src2, const arrayT &dst,
const event_vecT &depends = {}) {
return dot_ns::dot_func(exec_q, src1, src2, dst, depends,
dotu_dispatch_vector);
};

m.def("_dotu", dotu_pyapi,
Expand All @@ -117,23 +120,23 @@ PYBIND11_MODULE(_blas_impl, m)
}

{
m.def("_gemm", &blas_ext::gemm,
m.def("_gemm", &blas_ns::gemm,
"Call `gemm` from OneMKL BLAS library to compute "
"the matrix-matrix product with 2-D matrices.",
py::arg("sycl_queue"), py::arg("matrixA"), py::arg("matrixB"),
py::arg("resultC"), py::arg("depends") = py::list());
}

{
m.def("_gemm_batch", &blas_ext::gemm_batch,
m.def("_gemm_batch", &blas_ns::gemm_batch,
"Call `gemm_batch` from OneMKL BLAS library to compute "
"the matrix-matrix product for a batch of 2-D matrices.",
py::arg("sycl_queue"), py::arg("matrixA"), py::arg("matrixB"),
py::arg("resultC"), py::arg("depends") = py::list());
}

{
m.def("_gemv", &blas_ext::gemv,
m.def("_gemv", &blas_ns::gemv,
"Call `gemv` from OneMKL BLAS library to compute "
"the matrix-vector product with a general matrix.",
py::arg("sycl_queue"), py::arg("matrixA"), py::arg("vectorX"),
Expand Down
21 changes: 6 additions & 15 deletions dpnp/backend/extensions/blas/dot.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,31 +27,25 @@

#include "dot_common.hpp"

namespace dpnp
{
namespace backend
{
namespace ext
{
namespace blas
namespace dpnp::extensions::blas
{
namespace mkl_blas = oneapi::mkl::blas;
namespace type_utils = dpctl::tensor::type_utils;

template <typename T>
static sycl::event dot_impl(sycl::queue &exec_q,
const std::int64_t n,
char *vectorX,
const char *vectorX,
const std::int64_t incx,
char *vectorY,
const char *vectorY,
const std::int64_t incy,
char *result,
const std::vector<sycl::event> &depends)
{
type_utils::validate_type_for_device<T>(exec_q);

T *x = reinterpret_cast<T *>(vectorX);
T *y = reinterpret_cast<T *>(vectorY);
const T *x = reinterpret_cast<const T *>(vectorX);
const T *y = reinterpret_cast<const T *>(vectorY);
T *res = reinterpret_cast<T *>(result);

std::stringstream error_msg;
Expand Down Expand Up @@ -99,7 +93,4 @@ struct DotContigFactory
}
}
};
} // namespace blas
} // namespace ext
} // namespace backend
} // namespace dpnp
} // namespace dpnp::extensions::blas
44 changes: 16 additions & 28 deletions dpnp/backend/extensions/blas/dot_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,21 +36,13 @@

#include "types_matrix.hpp"

namespace dpnp
{
namespace backend
{
namespace ext
{
namespace blas
{
namespace dot
namespace dpnp::extensions::blas::dot
{
typedef sycl::event (*dot_impl_fn_ptr_t)(sycl::queue &,
const std::int64_t,
char *,
const char *,
const std::int64_t,
char *,
const char *,
const std::int64_t,
char *,
const std::vector<sycl::event> &);
Expand All @@ -61,9 +53,9 @@ namespace py = pybind11;
template <typename dispatchT>
std::pair<sycl::event, sycl::event>
dot_func(sycl::queue &exec_q,
dpctl::tensor::usm_ndarray vectorX,
dpctl::tensor::usm_ndarray vectorY,
dpctl::tensor::usm_ndarray result,
const dpctl::tensor::usm_ndarray &vectorX,
const dpctl::tensor::usm_ndarray &vectorY,
const dpctl::tensor::usm_ndarray &result,
const std::vector<sycl::event> &depends,
const dispatchT &dot_dispatch_vector)
{
Expand Down Expand Up @@ -109,30 +101,30 @@ std::pair<sycl::event, sycl::event>
"USM allocations are not compatible with the execution queue.");
}

size_t src_nelems = 1;
const int src_nelems = 1;
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(result);
dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(result,
src_nelems);

py::ssize_t x_size = vectorX.get_size();
py::ssize_t y_size = vectorY.get_size();
const py::ssize_t x_size = vectorX.get_size();
const py::ssize_t y_size = vectorY.get_size();
const std::int64_t n = x_size;
if (x_size != y_size) {
throw py::value_error("The size of the first input array must be "
"equal to the size of the second input array.");
}

int vectorX_typenum = vectorX.get_typenum();
int vectorY_typenum = vectorY.get_typenum();
int result_typenum = result.get_typenum();
const int vectorX_typenum = vectorX.get_typenum();
const int vectorY_typenum = vectorY.get_typenum();
const int result_typenum = result.get_typenum();

if (result_typenum != vectorX_typenum || result_typenum != vectorY_typenum)
{
throw py::value_error("Given arrays must be of the same type.");
}

auto array_types = dpctl_td_ns::usm_ndarray_types();
int type_id = array_types.typenum_to_lookup_id(vectorX_typenum);
const int type_id = array_types.typenum_to_lookup_id(vectorX_typenum);

dot_impl_fn_ptr_t dot_fn = dot_dispatch_vector[type_id];
if (dot_fn == nullptr) {
Expand All @@ -144,8 +136,8 @@ std::pair<sycl::event, sycl::event>
char *y_typeless_ptr = vectorY.get_data();
char *r_typeless_ptr = result.get_data();

std::vector<py::ssize_t> x_stride = vectorX.get_strides_vector();
std::vector<py::ssize_t> y_stride = vectorY.get_strides_vector();
const std::vector<py::ssize_t> x_stride = vectorX.get_strides_vector();
const std::vector<py::ssize_t> y_stride = vectorY.get_strides_vector();
const int x_elemsize = vectorX.get_elemsize();
const int y_elemsize = vectorY.get_elemsize();

Expand Down Expand Up @@ -184,8 +176,4 @@ void init_dot_dispatch_vector(dispatchT dot_dispatch_vector[])
contig;
contig.populate_dispatch_vector(dot_dispatch_vector);
}
} // namespace dot
} // namespace blas
} // namespace ext
} // namespace backend
} // namespace dpnp
} // namespace dpnp::extensions::blas::dot
21 changes: 6 additions & 15 deletions dpnp/backend/extensions/blas/dotc.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,31 +27,25 @@

#include "dot_common.hpp"

namespace dpnp
{
namespace backend
{
namespace ext
{
namespace blas
namespace dpnp::extensions::blas
{
namespace mkl_blas = oneapi::mkl::blas;
namespace type_utils = dpctl::tensor::type_utils;

template <typename T>
static sycl::event dotc_impl(sycl::queue &exec_q,
const std::int64_t n,
char *vectorX,
const char *vectorX,
const std::int64_t incx,
char *vectorY,
const char *vectorY,
const std::int64_t incy,
char *result,
const std::vector<sycl::event> &depends)
{
type_utils::validate_type_for_device<T>(exec_q);

T *x = reinterpret_cast<T *>(vectorX);
T *y = reinterpret_cast<T *>(vectorY);
const T *x = reinterpret_cast<const T *>(vectorX);
const T *y = reinterpret_cast<const T *>(vectorY);
T *res = reinterpret_cast<T *>(result);

std::stringstream error_msg;
Expand Down Expand Up @@ -100,7 +94,4 @@ struct DotcContigFactory
}
};

} // namespace blas
} // namespace ext
} // namespace backend
} // namespace dpnp
} // namespace dpnp::extensions::blas
21 changes: 6 additions & 15 deletions dpnp/backend/extensions/blas/dotu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,31 +27,25 @@

#include "dot_common.hpp"

namespace dpnp
{
namespace backend
{
namespace ext
{
namespace blas
namespace dpnp::extensions::blas
{
namespace mkl_blas = oneapi::mkl::blas;
namespace type_utils = dpctl::tensor::type_utils;

template <typename T>
static sycl::event dotu_impl(sycl::queue &exec_q,
const std::int64_t n,
char *vectorX,
const char *vectorX,
const std::int64_t incx,
char *vectorY,
const char *vectorY,
const std::int64_t incy,
char *result,
const std::vector<sycl::event> &depends)
{
type_utils::validate_type_for_device<T>(exec_q);

T *x = reinterpret_cast<T *>(vectorX);
T *y = reinterpret_cast<T *>(vectorY);
const T *x = reinterpret_cast<const T *>(vectorX);
const T *y = reinterpret_cast<const T *>(vectorY);
T *res = reinterpret_cast<T *>(result);

std::stringstream error_msg;
Expand Down Expand Up @@ -99,7 +93,4 @@ struct DotuContigFactory
}
}
};
} // namespace blas
} // namespace ext
} // namespace backend
} // namespace dpnp
} // namespace dpnp::extensions::blas
Loading
Loading