Skip to content

Commit b8f7f00

Browse files
committed
Add support for N-D array
add N-dimension
1 parent 3444816 commit b8f7f00

File tree

9 files changed

+617
-140
lines changed

9 files changed

+617
-140
lines changed

dpnp/backend/extensions/blas/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ set(python_module_name _blas_impl)
2828
set(_module_src
2929
${CMAKE_CURRENT_SOURCE_DIR}/blas_py.cpp
3030
${CMAKE_CURRENT_SOURCE_DIR}/gemm.cpp
31+
${CMAKE_CURRENT_SOURCE_DIR}/gemm_batch.cpp
3132
)
3233

3334
pybind11_add_module(${python_module_name} MODULE ${_module_src})

dpnp/backend/extensions/blas/blas_py.cpp

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,25 @@ namespace py = pybind11;
3939
void init_dispatch_tables(void)
4040
{
4141
blas_ext::init_gemm_dispatch_table();
42+
blas_ext::init_gemm_batch_dispatch_table();
4243
}
4344

4445
PYBIND11_MODULE(_blas_impl, m)
4546
{
4647
init_dispatch_tables();
4748

48-
m.def("_gemm", &blas_ext::gemm,
49-
"Call `gemm` from OneMKL LAPACK library to return "
50-
"the matrix-matrix product with general matrices.",
51-
py::arg("sycl_queue"), py::arg("matrixA"), py::arg("matrixB"),
52-
py::arg("matrixC"), py::arg("depends") = py::list());
49+
{
50+
m.def("_gemm", &blas_ext::gemm,
51+
"Call `gemm` from OneMKL LAPACK library to return "
52+
"the matrix-matrix product with 2-D matrices.",
53+
py::arg("sycl_queue"), py::arg("matrixA"), py::arg("matrixB"),
54+
py::arg("matrixC"), py::arg("isRowMajor"),
55+
py::arg("depends") = py::list());
56+
}
57+
58+
{
59+
m.def("_gemm_batch", &blas_ext::gemm_batch,
60+
"Call `gemm_batch` from OneMKL LAPACK library to return "
61+
"the matrix-matrix product with general matrices.");
62+
}
5363
}

dpnp/backend/extensions/blas/gemm.cpp

Lines changed: 53 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ typedef sycl::event (*gemm_impl_fn_ptr_t)(sycl::queue,
5858
const std::int64_t,
5959
char *,
6060
const std::int64_t,
61+
const bool,
6162
const std::vector<sycl::event> &);
6263

6364
static gemm_impl_fn_ptr_t gemm_dispatch_table[dpctl_td_ns::num_types]
@@ -76,6 +77,7 @@ static sycl::event gemm_impl(sycl::queue exec_q,
7677
const std::int64_t ld_array_2,
7778
char *resultC,
7879
const std::int64_t ld_result,
80+
const bool isRowMajor,
7981
const std::vector<sycl::event> &depends)
8082
{
8183
type_utils::validate_type_for_device<Tab>(exec_q);
@@ -92,24 +94,54 @@ static sycl::event gemm_impl(sycl::queue exec_q,
9294
sycl::event gemm_event;
9395
try {
9496
// Need to add logic to call column_major::gemm
95-
gemm_event = mkl_blas::row_major::gemm(
96-
exec_q,
97-
transA, // Parameter indicating whether matrix A is not transposed
98-
// ('N'), transposed ('T'), or conjugate transposed ('C').
99-
transB, // Same as transA but for matrix B.
100-
m, // Number of rows in matrices A and C.
101-
n, // Number of columns in matrices B and C.
102-
k, // Number of columns in matrix A and rows in matrix B.
103-
Tab(1), // Scaling factor for the product of matrices A and B.
104-
a, // Pointer to matrix A.
105-
ld_array_1, // Leading dimension of matrix A, which is the stride
106-
// between successive rows (for row major layout).
107-
b, // Pointer to matrix B.
108-
ld_array_2, // Leading dimension of matrix B, similar to ld_array_1.
109-
Tab(0), // Scaling factor for matrix C.
110-
res, // Pointer to matrix C, where the result is stored.
111-
ld_result, // Leading dimension of matrix C.
112-
depends);
97+
if (isRowMajor) {
98+
gemm_event = mkl_blas::row_major::gemm(
99+
exec_q,
100+
transA, // Parameter indicating whether matrix A is not
101+
// transposed
102+
// ('N'), transposed ('T'), or conjugate transposed
103+
// ('C').
104+
transB, // Same as transA but for matrix B.
105+
m, // Number of rows in matrices A and C.
106+
n, // Number of columns in matrices B and C.
107+
k, // Number of columns in matrix A and rows in matrix B.
108+
Tab(1), // Scaling factor for the product of matrices A and B.
109+
a, // Pointer to matrix A.
110+
ld_array_1, // Leading dimension of matrix A, which is the
111+
// stride between successive rows (for row major
112+
// layout).
113+
b, // Pointer to matrix B.
114+
ld_array_2, // Leading dimension of matrix B, similar to
115+
// ld_array_1.
116+
Tab(0), // Scaling factor for matrix C.
117+
res, // Pointer to matrix C, where the result is stored.
118+
ld_result, // Leading dimension of matrix C.
119+
depends);
120+
}
121+
else {
122+
gemm_event = mkl_blas::column_major::gemm(
123+
exec_q,
124+
transA, // Parameter indicating whether matrix A is not
125+
// transposed
126+
// ('N'), transposed ('T'), or conjugate transposed
127+
// ('C').
128+
transB, // Same as transA but for matrix B.
129+
m, // Number of rows in matrices A and C.
130+
n, // Number of columns in matrices B and C.
131+
k, // Number of columns in matrix A and rows in matrix B.
132+
Tab(1), // Scaling factor for the product of matrices A and B.
133+
a, // Pointer to matrix A.
134+
ld_array_1, // Leading dimension of matrix A, which is the
135+
// stride between successive rows (for row major
136+
// layout).
137+
b, // Pointer to matrix B.
138+
ld_array_2, // Leading dimension of matrix B, similar to
139+
// ld_array_1.
140+
Tab(0), // Scaling factor for matrix C.
141+
res, // Pointer to matrix C, where the result is stored.
142+
ld_result, // Leading dimension of matrix C.
143+
depends);
144+
}
113145
} catch (oneapi::mkl::exception const &e) {
114146
error_msg
115147
<< "Unexpected MKL exception caught during gemm() call:\nreason: "
@@ -134,6 +166,7 @@ std::pair<sycl::event, sycl::event>
134166
dpctl::tensor::usm_ndarray matrixA,
135167
dpctl::tensor::usm_ndarray matrixB,
136168
dpctl::tensor::usm_ndarray resultC,
169+
const bool isRowMajor,
137170
const std::vector<sycl::event> &depends)
138171
{
139172
const int matrixA_nd = matrixA.get_ndim();
@@ -234,7 +267,8 @@ std::pair<sycl::event, sycl::event>
234267
std::vector<sycl::event> host_task_events;
235268
sycl::event gemm_ev =
236269
gemm_fn(exec_q, transA, transB, m, n, k, a_typeless_ptr, ld_array_1,
237-
b_typeless_ptr, ld_array_2, r_typeless_ptr, ld_result, depends);
270+
b_typeless_ptr, ld_array_2, r_typeless_ptr, ld_result,
271+
isRowMajor, depends);
238272

239273
sycl::event args_ev = dpctl::utils::keep_args_alive(
240274
exec_q, {matrixA, matrixB, resultC}, host_task_events);

dpnp/backend/extensions/blas/gemm.hpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,31 @@ extern std::pair<sycl::event, sycl::event>
4343
dpctl::tensor::usm_ndarray matrixA,
4444
dpctl::tensor::usm_ndarray matrixB,
4545
dpctl::tensor::usm_ndarray resultC,
46+
const bool isRowMajor,
4647
const std::vector<sycl::event> &depends);
4748

49+
// extern sycl::event
50+
extern std::pair<sycl::event, sycl::event>
51+
gemm_batch(sycl::queue q,
52+
dpctl::tensor::usm_ndarray matrixA,
53+
dpctl::tensor::usm_ndarray matrixB,
54+
dpctl::tensor::usm_ndarray resultC,
55+
const std::int64_t m,
56+
const std::int64_t n,
57+
const std::int64_t k,
58+
const std::int64_t batch_size,
59+
const std::int64_t ld_array_1,
60+
const std::int64_t ld_array_2,
61+
const std::int64_t ld_result,
62+
size_t stridea,
63+
size_t strideb,
64+
size_t stridec,
65+
const std::int64_t transA_int,
66+
const std::int64_t transB_int,
67+
const std::vector<sycl::event> &depends);
68+
4869
extern void init_gemm_dispatch_table(void);
70+
extern void init_gemm_batch_dispatch_table(void);
4971
} // namespace blas
5072
} // namespace ext
5173
} // namespace backend

0 commit comments

Comments
 (0)