Skip to content

Commit d8acc9d

Browse files
authored
Merge 00a7755 into a8e6fce
2 parents a8e6fce + 00a7755 commit d8acc9d

File tree

5 files changed

+206
-159
lines changed

5 files changed

+206
-159
lines changed

dpnp/backend/extensions/blas/blas_py.cpp

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,13 @@ PYBIND11_MODULE(_blas_impl, m)
6464
blas_ext::DotContigFactory>(
6565
dot_dispatch_vector);
6666

67-
auto dot_pypi = [&](sycl::queue exec_q, arrayT src1, arrayT src2,
68-
arrayT dst, const event_vecT &depends = {}) {
67+
auto dot_pyapi = [&](sycl::queue exec_q, arrayT src1, arrayT src2,
68+
arrayT dst, const event_vecT &depends = {}) {
6969
return dot_ext::dot_func(exec_q, src1, src2, dst, depends,
7070
dot_dispatch_vector);
7171
};
7272

73-
m.def("_dot", dot_pypi,
73+
m.def("_dot", dot_pyapi,
7474
"Call `dot` from OneMKL BLAS library to return "
7575
"the dot product of two real-valued vectors.",
7676
py::arg("sycl_queue"), py::arg("vectorA"), py::arg("vectorB"),
@@ -82,13 +82,13 @@ PYBIND11_MODULE(_blas_impl, m)
8282
blas_ext::DotcContigFactory>(
8383
dotc_dispatch_vector);
8484

85-
auto dotc_pypi = [&](sycl::queue exec_q, arrayT src1, arrayT src2,
86-
arrayT dst, const event_vecT &depends = {}) {
85+
auto dotc_pyapi = [&](sycl::queue exec_q, arrayT src1, arrayT src2,
86+
arrayT dst, const event_vecT &depends = {}) {
8787
return dot_ext::dot_func(exec_q, src1, src2, dst, depends,
8888
dotc_dispatch_vector);
8989
};
9090

91-
m.def("_dotc", dotc_pypi,
91+
m.def("_dotc", dotc_pyapi,
9292
"Call `dotc` from OneMKL BLAS library to return "
9393
"the dot product of two complex vectors, "
9494
"conjugating the first vector.",
@@ -101,13 +101,13 @@ PYBIND11_MODULE(_blas_impl, m)
101101
blas_ext::DotuContigFactory>(
102102
dotu_dispatch_vector);
103103

104-
auto dotu_pypi = [&](sycl::queue exec_q, arrayT src1, arrayT src2,
105-
arrayT dst, const event_vecT &depends = {}) {
104+
auto dotu_pyapi = [&](sycl::queue exec_q, arrayT src1, arrayT src2,
105+
arrayT dst, const event_vecT &depends = {}) {
106106
return dot_ext::dot_func(exec_q, src1, src2, dst, depends,
107107
dotu_dispatch_vector);
108108
};
109109

110-
m.def("_dotu", dotu_pypi,
110+
m.def("_dotu", dotu_pyapi,
111111
"Call `dotu` from OneMKL BLAS library to return "
112112
"the dot product of two complex vectors.",
113113
py::arg("sycl_queue"), py::arg("vectorA"), py::arg("vectorB"),
@@ -119,16 +119,14 @@ PYBIND11_MODULE(_blas_impl, m)
119119
"Call `gemm` from OneMKL BLAS library to return "
120120
"the matrix-matrix product with 2-D matrices.",
121121
py::arg("sycl_queue"), py::arg("matrixA"), py::arg("matrixB"),
122-
py::arg("result"), py::arg("depends") = py::list());
122+
py::arg("resultC"), py::arg("depends") = py::list());
123123
}
124124

125125
{
126126
m.def("_gemm_batch", &blas_ext::gemm_batch,
127127
"Call `gemm_batch` from OneMKL BLAS library to return "
128128
"the matrix-matrix product for a batch of 2-D matrices.",
129129
py::arg("sycl_queue"), py::arg("matrixA"), py::arg("matrixB"),
130-
py::arg("result"), py::arg("batch_size"), py::arg("stridea"),
131-
py::arg("strideb"), py::arg("stridec"),
132-
py::arg("depends") = py::list());
130+
py::arg("resultC"), py::arg("depends") = py::list());
133131
}
134132
}

dpnp/backend/extensions/blas/gemm.hpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,6 @@ extern std::pair<sycl::event, sycl::event>
5050
dpctl::tensor::usm_ndarray matrixA,
5151
dpctl::tensor::usm_ndarray matrixB,
5252
dpctl::tensor::usm_ndarray resultC,
53-
const std::int64_t batch_size,
54-
size_t stridea,
55-
size_t strideb,
56-
size_t stridec,
5753
const std::vector<sycl::event> &depends);
5854

5955
extern void init_gemm_dispatch_table(void);

dpnp/backend/extensions/blas/gemm_batch.cpp

Lines changed: 41 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -150,10 +150,6 @@ std::pair<sycl::event, sycl::event>
150150
dpctl::tensor::usm_ndarray matrixA,
151151
dpctl::tensor::usm_ndarray matrixB,
152152
dpctl::tensor::usm_ndarray resultC,
153-
const std::int64_t batch_size,
154-
size_t stridea,
155-
size_t strideb,
156-
size_t stridec,
157153
const std::vector<sycl::event> &depends = {})
158154
{
159155
const int matrixA_nd = matrixA.get_ndim();
@@ -185,49 +181,60 @@ std::pair<sycl::event, sycl::event>
185181
const py::ssize_t *a_shape = matrixA.get_shape_raw();
186182
const py::ssize_t *b_shape = matrixB.get_shape_raw();
187183
const py::ssize_t *c_shape = resultC.get_shape_raw();
188-
const std::int64_t m = a_shape[matrixA_nd - 2];
189-
const std::int64_t n = b_shape[matrixB_nd - 1];
190-
const std::int64_t k = a_shape[matrixA_nd - 1];
191-
if (a_shape[matrixA_nd - 1] != b_shape[matrixB_nd - 2]) {
184+
const std::int64_t m = a_shape[1];
185+
const std::int64_t n = b_shape[2];
186+
const std::int64_t k = a_shape[2];
187+
const std::int64_t batch_size = c_shape[0];
188+
if (a_shape[2] != b_shape[1]) {
192189
throw py::value_error("The number of columns in A must be equal to "
193190
"the number of rows in B.");
194191
}
195-
if (a_shape[matrixA_nd - 2] != c_shape[resultC_nd - 2]) {
192+
if (a_shape[1] != c_shape[1]) {
196193
throw py::value_error("The number of rows in A must be equal to "
197194
"the number of rows in result array.");
198195
}
199-
if (b_shape[matrixB_nd - 1] != c_shape[resultC_nd - 1]) {
196+
if (b_shape[2] != c_shape[2]) {
200197
throw py::value_error("The number of columns in B must be equal to "
201198
"the number of columns in result array.");
202199
}
203200

204-
bool shapes_equal = true;
205-
size_t src_nelems = 1;
206-
py::ssize_t lead_dim;
207-
for (int i = 0; i < matrixA_nd - 2; ++i) {
208-
if (a_shape[i] == b_shape[i]) {
209-
lead_dim = a_shape[i];
210-
}
211-
else if (a_shape[i] == 1 || b_shape[i] == 1) {
212-
lead_dim = std::max(a_shape[i], b_shape[i]);
213-
}
214-
else {
215-
throw py::value_error("Array shapes do not match.");
216-
}
217-
src_nelems *= static_cast<size_t>(lead_dim);
218-
shapes_equal = shapes_equal && (lead_dim == c_shape[i]);
201+
std::int64_t first_dim;
202+
if (a_shape[0] == b_shape[0]) {
203+
first_dim = a_shape[0];
204+
}
205+
else if (a_shape[0] == 1 || b_shape[0] == 1) {
206+
first_dim = std::max(a_shape[0], b_shape[0]);
219207
}
220-
src_nelems *= (m * n);
221-
if (!shapes_equal) {
208+
else {
222209
throw py::value_error("Array shapes do not match.");
223210
}
211+
if (first_dim != c_shape[0]) {
212+
throw py::value_error("Array shapes do not match.");
213+
}
214+
std::int64_t src_nelems = first_dim * m * n;
224215
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(resultC);
225216
dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(resultC,
226217
src_nelems);
227218

228-
// transA and transB are always False
229-
oneapi::mkl::transpose transA = oneapi::mkl::transpose::N;
230-
oneapi::mkl::transpose transB = oneapi::mkl::transpose::N;
219+
std::vector<py::ssize_t> a_stride = matrixA.get_strides_vector();
220+
std::vector<py::ssize_t> b_stride = matrixB.get_strides_vector();
221+
std::vector<py::ssize_t> c_stride = resultC.get_strides_vector();
222+
const std::int64_t stridea = a_stride[0];
223+
const std::int64_t strideb = b_stride[0];
224+
const std::int64_t stridec = c_stride[0];
225+
bool A_base_is_f_contig = a_stride[1] == 1 && a_stride[2] == a_shape[1];
226+
bool B_base_is_f_contig = b_stride[1] == 1 && b_stride[2] == b_shape[1];
227+
228+
oneapi::mkl::transpose transA = A_base_is_f_contig
229+
? oneapi::mkl::transpose::T
230+
: oneapi::mkl::transpose::N;
231+
oneapi::mkl::transpose transB = B_base_is_f_contig
232+
? oneapi::mkl::transpose::T
233+
: oneapi::mkl::transpose::N;
234+
235+
const std::int64_t lda = (transA == oneapi::mkl::transpose::N) ? k : m;
236+
const std::int64_t ldb = (transB == oneapi::mkl::transpose::N) ? n : k;
237+
const std::int64_t ldc = n; // always n for row_major
231238

232239
int matrixA_typenum = matrixA.get_typenum();
233240
int matrixB_typenum = matrixB.get_typenum();
@@ -252,10 +259,10 @@ std::pair<sycl::event, sycl::event>
252259
char *b_typeless_ptr = matrixB.get_data();
253260
char *r_typeless_ptr = resultC.get_data();
254261

255-
// Note that lda = k, ldb = n, and ld_result = n
256-
sycl::event gemm_batch_ev = gemm_batch_fn(
257-
exec_q, m, n, k, batch_size, k, n, n, stridea, strideb, stridec, transA,
258-
transB, a_typeless_ptr, b_typeless_ptr, r_typeless_ptr, depends);
262+
sycl::event gemm_batch_ev =
263+
gemm_batch_fn(exec_q, m, n, k, batch_size, lda, ldb, ldc, stridea,
264+
strideb, stridec, transA, transB, a_typeless_ptr,
265+
b_typeless_ptr, r_typeless_ptr, depends);
259266

260267
sycl::event args_batch_ev = dpctl::utils::keep_args_alive(
261268
exec_q, {matrixA, matrixB, resultC}, {gemm_batch_ev});

0 commit comments

Comments
 (0)