Skip to content

Commit a813fae

Browse files
vtavanaantonwolfy
andauthored
update BLAS extension routines (#1884)
Co-authored-by: Anton <[email protected]>
1 parent af601c6 commit a813fae

File tree

11 files changed

+192
-275
lines changed

11 files changed

+192
-275
lines changed

dpnp/backend/extensions/blas/blas_py.cpp

Lines changed: 30 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -37,17 +37,17 @@
3737
#include "gemm.hpp"
3838
#include "gemv.hpp"
3939

40-
namespace blas_ext = dpnp::backend::ext::blas;
40+
namespace blas_ns = dpnp::extensions::blas;
4141
namespace py = pybind11;
42-
namespace dot_ext = blas_ext::dot;
43-
using dot_ext::dot_impl_fn_ptr_t;
42+
namespace dot_ns = blas_ns::dot;
43+
using dot_ns::dot_impl_fn_ptr_t;
4444

4545
// populate dispatch vectors and tables
4646
void init_dispatch_vectors_tables(void)
4747
{
48-
blas_ext::init_gemm_batch_dispatch_table();
49-
blas_ext::init_gemm_dispatch_table();
50-
blas_ext::init_gemv_dispatch_vector();
48+
blas_ns::init_gemm_batch_dispatch_table();
49+
blas_ns::init_gemm_dispatch_table();
50+
blas_ns::init_gemv_dispatch_vector();
5151
}
5252

5353
static dot_impl_fn_ptr_t dot_dispatch_vector[dpctl_td_ns::num_types];
@@ -62,14 +62,15 @@ PYBIND11_MODULE(_blas_impl, m)
6262
using event_vecT = std::vector<sycl::event>;
6363

6464
{
65-
dot_ext::init_dot_dispatch_vector<dot_impl_fn_ptr_t,
66-
blas_ext::DotContigFactory>(
65+
dot_ns::init_dot_dispatch_vector<dot_impl_fn_ptr_t,
66+
blas_ns::DotContigFactory>(
6767
dot_dispatch_vector);
6868

69-
auto dot_pyapi = [&](sycl::queue exec_q, arrayT src1, arrayT src2,
70-
arrayT dst, const event_vecT &depends = {}) {
71-
return dot_ext::dot_func(exec_q, src1, src2, dst, depends,
72-
dot_dispatch_vector);
69+
auto dot_pyapi = [&](sycl::queue &exec_q, const arrayT &src1,
70+
const arrayT &src2, const arrayT &dst,
71+
const event_vecT &depends = {}) {
72+
return dot_ns::dot_func(exec_q, src1, src2, dst, depends,
73+
dot_dispatch_vector);
7374
};
7475

7576
m.def("_dot", dot_pyapi,
@@ -80,14 +81,15 @@ PYBIND11_MODULE(_blas_impl, m)
8081
}
8182

8283
{
83-
dot_ext::init_dot_dispatch_vector<dot_impl_fn_ptr_t,
84-
blas_ext::DotcContigFactory>(
84+
dot_ns::init_dot_dispatch_vector<dot_impl_fn_ptr_t,
85+
blas_ns::DotcContigFactory>(
8586
dotc_dispatch_vector);
8687

87-
auto dotc_pyapi = [&](sycl::queue exec_q, arrayT src1, arrayT src2,
88-
arrayT dst, const event_vecT &depends = {}) {
89-
return dot_ext::dot_func(exec_q, src1, src2, dst, depends,
90-
dotc_dispatch_vector);
88+
auto dotc_pyapi = [&](sycl::queue &exec_q, const arrayT &src1,
89+
const arrayT &src2, const arrayT &dst,
90+
const event_vecT &depends = {}) {
91+
return dot_ns::dot_func(exec_q, src1, src2, dst, depends,
92+
dotc_dispatch_vector);
9193
};
9294

9395
m.def("_dotc", dotc_pyapi,
@@ -99,14 +101,15 @@ PYBIND11_MODULE(_blas_impl, m)
99101
}
100102

101103
{
102-
dot_ext::init_dot_dispatch_vector<dot_impl_fn_ptr_t,
103-
blas_ext::DotuContigFactory>(
104+
dot_ns::init_dot_dispatch_vector<dot_impl_fn_ptr_t,
105+
blas_ns::DotuContigFactory>(
104106
dotu_dispatch_vector);
105107

106-
auto dotu_pyapi = [&](sycl::queue exec_q, arrayT src1, arrayT src2,
107-
arrayT dst, const event_vecT &depends = {}) {
108-
return dot_ext::dot_func(exec_q, src1, src2, dst, depends,
109-
dotu_dispatch_vector);
108+
auto dotu_pyapi = [&](sycl::queue &exec_q, const arrayT &src1,
109+
const arrayT &src2, const arrayT &dst,
110+
const event_vecT &depends = {}) {
111+
return dot_ns::dot_func(exec_q, src1, src2, dst, depends,
112+
dotu_dispatch_vector);
110113
};
111114

112115
m.def("_dotu", dotu_pyapi,
@@ -117,23 +120,23 @@ PYBIND11_MODULE(_blas_impl, m)
117120
}
118121

119122
{
120-
m.def("_gemm", &blas_ext::gemm,
123+
m.def("_gemm", &blas_ns::gemm,
121124
"Call `gemm` from OneMKL BLAS library to compute "
122125
"the matrix-matrix product with 2-D matrices.",
123126
py::arg("sycl_queue"), py::arg("matrixA"), py::arg("matrixB"),
124127
py::arg("resultC"), py::arg("depends") = py::list());
125128
}
126129

127130
{
128-
m.def("_gemm_batch", &blas_ext::gemm_batch,
131+
m.def("_gemm_batch", &blas_ns::gemm_batch,
129132
"Call `gemm_batch` from OneMKL BLAS library to compute "
130133
"the matrix-matrix product for a batch of 2-D matrices.",
131134
py::arg("sycl_queue"), py::arg("matrixA"), py::arg("matrixB"),
132135
py::arg("resultC"), py::arg("depends") = py::list());
133136
}
134137

135138
{
136-
m.def("_gemv", &blas_ext::gemv,
139+
m.def("_gemv", &blas_ns::gemv,
137140
"Call `gemv` from OneMKL BLAS library to compute "
138141
"the matrix-vector product with a general matrix.",
139142
py::arg("sycl_queue"), py::arg("matrixA"), py::arg("vectorX"),

dpnp/backend/extensions/blas/dot.hpp

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,31 +27,25 @@
2727

2828
#include "dot_common.hpp"
2929

30-
namespace dpnp
31-
{
32-
namespace backend
33-
{
34-
namespace ext
35-
{
36-
namespace blas
30+
namespace dpnp::extensions::blas
3731
{
3832
namespace mkl_blas = oneapi::mkl::blas;
3933
namespace type_utils = dpctl::tensor::type_utils;
4034

4135
template <typename T>
4236
static sycl::event dot_impl(sycl::queue &exec_q,
4337
const std::int64_t n,
44-
char *vectorX,
38+
const char *vectorX,
4539
const std::int64_t incx,
46-
char *vectorY,
40+
const char *vectorY,
4741
const std::int64_t incy,
4842
char *result,
4943
const std::vector<sycl::event> &depends)
5044
{
5145
type_utils::validate_type_for_device<T>(exec_q);
5246

53-
T *x = reinterpret_cast<T *>(vectorX);
54-
T *y = reinterpret_cast<T *>(vectorY);
47+
const T *x = reinterpret_cast<const T *>(vectorX);
48+
const T *y = reinterpret_cast<const T *>(vectorY);
5549
T *res = reinterpret_cast<T *>(result);
5650

5751
std::stringstream error_msg;
@@ -99,7 +93,4 @@ struct DotContigFactory
9993
}
10094
}
10195
};
102-
} // namespace blas
103-
} // namespace ext
104-
} // namespace backend
105-
} // namespace dpnp
96+
} // namespace dpnp::extensions::blas

dpnp/backend/extensions/blas/dot_common.hpp

Lines changed: 16 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -36,21 +36,13 @@
3636

3737
#include "types_matrix.hpp"
3838

39-
namespace dpnp
40-
{
41-
namespace backend
42-
{
43-
namespace ext
44-
{
45-
namespace blas
46-
{
47-
namespace dot
39+
namespace dpnp::extensions::blas::dot
4840
{
4941
typedef sycl::event (*dot_impl_fn_ptr_t)(sycl::queue &,
5042
const std::int64_t,
51-
char *,
43+
const char *,
5244
const std::int64_t,
53-
char *,
45+
const char *,
5446
const std::int64_t,
5547
char *,
5648
const std::vector<sycl::event> &);
@@ -61,9 +53,9 @@ namespace py = pybind11;
6153
template <typename dispatchT>
6254
std::pair<sycl::event, sycl::event>
6355
dot_func(sycl::queue &exec_q,
64-
dpctl::tensor::usm_ndarray vectorX,
65-
dpctl::tensor::usm_ndarray vectorY,
66-
dpctl::tensor::usm_ndarray result,
56+
const dpctl::tensor::usm_ndarray &vectorX,
57+
const dpctl::tensor::usm_ndarray &vectorY,
58+
const dpctl::tensor::usm_ndarray &result,
6759
const std::vector<sycl::event> &depends,
6860
const dispatchT &dot_dispatch_vector)
6961
{
@@ -109,30 +101,30 @@ std::pair<sycl::event, sycl::event>
109101
"USM allocations are not compatible with the execution queue.");
110102
}
111103

112-
size_t src_nelems = 1;
104+
const int src_nelems = 1;
113105
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(result);
114106
dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(result,
115107
src_nelems);
116108

117-
py::ssize_t x_size = vectorX.get_size();
118-
py::ssize_t y_size = vectorY.get_size();
109+
const py::ssize_t x_size = vectorX.get_size();
110+
const py::ssize_t y_size = vectorY.get_size();
119111
const std::int64_t n = x_size;
120112
if (x_size != y_size) {
121113
throw py::value_error("The size of the first input array must be "
122114
"equal to the size of the second input array.");
123115
}
124116

125-
int vectorX_typenum = vectorX.get_typenum();
126-
int vectorY_typenum = vectorY.get_typenum();
127-
int result_typenum = result.get_typenum();
117+
const int vectorX_typenum = vectorX.get_typenum();
118+
const int vectorY_typenum = vectorY.get_typenum();
119+
const int result_typenum = result.get_typenum();
128120

129121
if (result_typenum != vectorX_typenum || result_typenum != vectorY_typenum)
130122
{
131123
throw py::value_error("Given arrays must be of the same type.");
132124
}
133125

134126
auto array_types = dpctl_td_ns::usm_ndarray_types();
135-
int type_id = array_types.typenum_to_lookup_id(vectorX_typenum);
127+
const int type_id = array_types.typenum_to_lookup_id(vectorX_typenum);
136128

137129
dot_impl_fn_ptr_t dot_fn = dot_dispatch_vector[type_id];
138130
if (dot_fn == nullptr) {
@@ -144,8 +136,8 @@ std::pair<sycl::event, sycl::event>
144136
char *y_typeless_ptr = vectorY.get_data();
145137
char *r_typeless_ptr = result.get_data();
146138

147-
std::vector<py::ssize_t> x_stride = vectorX.get_strides_vector();
148-
std::vector<py::ssize_t> y_stride = vectorY.get_strides_vector();
139+
const std::vector<py::ssize_t> x_stride = vectorX.get_strides_vector();
140+
const std::vector<py::ssize_t> y_stride = vectorY.get_strides_vector();
149141
const int x_elemsize = vectorX.get_elemsize();
150142
const int y_elemsize = vectorY.get_elemsize();
151143

@@ -184,8 +176,4 @@ void init_dot_dispatch_vector(dispatchT dot_dispatch_vector[])
184176
contig;
185177
contig.populate_dispatch_vector(dot_dispatch_vector);
186178
}
187-
} // namespace dot
188-
} // namespace blas
189-
} // namespace ext
190-
} // namespace backend
191-
} // namespace dpnp
179+
} // namespace dpnp::extensions::blas::dot

dpnp/backend/extensions/blas/dotc.hpp

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,31 +27,25 @@
2727

2828
#include "dot_common.hpp"
2929

30-
namespace dpnp
31-
{
32-
namespace backend
33-
{
34-
namespace ext
35-
{
36-
namespace blas
30+
namespace dpnp::extensions::blas
3731
{
3832
namespace mkl_blas = oneapi::mkl::blas;
3933
namespace type_utils = dpctl::tensor::type_utils;
4034

4135
template <typename T>
4236
static sycl::event dotc_impl(sycl::queue &exec_q,
4337
const std::int64_t n,
44-
char *vectorX,
38+
const char *vectorX,
4539
const std::int64_t incx,
46-
char *vectorY,
40+
const char *vectorY,
4741
const std::int64_t incy,
4842
char *result,
4943
const std::vector<sycl::event> &depends)
5044
{
5145
type_utils::validate_type_for_device<T>(exec_q);
5246

53-
T *x = reinterpret_cast<T *>(vectorX);
54-
T *y = reinterpret_cast<T *>(vectorY);
47+
const T *x = reinterpret_cast<const T *>(vectorX);
48+
const T *y = reinterpret_cast<const T *>(vectorY);
5549
T *res = reinterpret_cast<T *>(result);
5650

5751
std::stringstream error_msg;
@@ -100,7 +94,4 @@ struct DotcContigFactory
10094
}
10195
};
10296

103-
} // namespace blas
104-
} // namespace ext
105-
} // namespace backend
106-
} // namespace dpnp
97+
} // namespace dpnp::extensions::blas

dpnp/backend/extensions/blas/dotu.hpp

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,31 +27,25 @@
2727

2828
#include "dot_common.hpp"
2929

30-
namespace dpnp
31-
{
32-
namespace backend
33-
{
34-
namespace ext
35-
{
36-
namespace blas
30+
namespace dpnp::extensions::blas
3731
{
3832
namespace mkl_blas = oneapi::mkl::blas;
3933
namespace type_utils = dpctl::tensor::type_utils;
4034

4135
template <typename T>
4236
static sycl::event dotu_impl(sycl::queue &exec_q,
4337
const std::int64_t n,
44-
char *vectorX,
38+
const char *vectorX,
4539
const std::int64_t incx,
46-
char *vectorY,
40+
const char *vectorY,
4741
const std::int64_t incy,
4842
char *result,
4943
const std::vector<sycl::event> &depends)
5044
{
5145
type_utils::validate_type_for_device<T>(exec_q);
5246

53-
T *x = reinterpret_cast<T *>(vectorX);
54-
T *y = reinterpret_cast<T *>(vectorY);
47+
const T *x = reinterpret_cast<const T *>(vectorX);
48+
const T *y = reinterpret_cast<const T *>(vectorY);
5549
T *res = reinterpret_cast<T *>(result);
5650

5751
std::stringstream error_msg;
@@ -99,7 +93,4 @@ struct DotuContigFactory
9993
}
10094
}
10195
};
102-
} // namespace blas
103-
} // namespace ext
104-
} // namespace backend
105-
} // namespace dpnp
96+
} // namespace dpnp::extensions::blas

0 commit comments

Comments
 (0)