@@ -150,10 +150,6 @@ std::pair<sycl::event, sycl::event>
150
150
dpctl::tensor::usm_ndarray matrixA,
151
151
dpctl::tensor::usm_ndarray matrixB,
152
152
dpctl::tensor::usm_ndarray resultC,
153
- const std::int64_t batch_size,
154
- size_t stridea,
155
- size_t strideb,
156
- size_t stridec,
157
153
const std::vector<sycl::event> &depends = {})
158
154
{
159
155
const int matrixA_nd = matrixA.get_ndim ();
@@ -185,49 +181,60 @@ std::pair<sycl::event, sycl::event>
185
181
const py::ssize_t *a_shape = matrixA.get_shape_raw ();
186
182
const py::ssize_t *b_shape = matrixB.get_shape_raw ();
187
183
const py::ssize_t *c_shape = resultC.get_shape_raw ();
188
- const std::int64_t m = a_shape[matrixA_nd - 2 ];
189
- const std::int64_t n = b_shape[matrixB_nd - 1 ];
190
- const std::int64_t k = a_shape[matrixA_nd - 1 ];
191
- if (a_shape[matrixA_nd - 1 ] != b_shape[matrixB_nd - 2 ]) {
184
+ const std::int64_t m = a_shape[1 ];
185
+ const std::int64_t n = b_shape[2 ];
186
+ const std::int64_t k = a_shape[2 ];
187
+ const std::int64_t batch_size = c_shape[0 ];
188
+ if (a_shape[2 ] != b_shape[1 ]) {
192
189
throw py::value_error (" The number of columns in A must be equal to "
193
190
" the number of rows in B." );
194
191
}
195
- if (a_shape[matrixA_nd - 2 ] != c_shape[resultC_nd - 2 ]) {
192
+ if (a_shape[1 ] != c_shape[1 ]) {
196
193
throw py::value_error (" The number of rows in A must be equal to "
197
194
" the number of rows in result array." );
198
195
}
199
- if (b_shape[matrixB_nd - 1 ] != c_shape[resultC_nd - 1 ]) {
196
+ if (b_shape[2 ] != c_shape[2 ]) {
200
197
throw py::value_error (" The number of columns in B must be equal to "
201
198
" the number of columns in result array." );
202
199
}
203
200
204
- bool shapes_equal = true ;
205
- size_t src_nelems = 1 ;
206
- py::ssize_t lead_dim;
207
- for (int i = 0 ; i < matrixA_nd - 2 ; ++i) {
208
- if (a_shape[i] == b_shape[i]) {
209
- lead_dim = a_shape[i];
210
- }
211
- else if (a_shape[i] == 1 || b_shape[i] == 1 ) {
212
- lead_dim = std::max (a_shape[i], b_shape[i]);
213
- }
214
- else {
215
- throw py::value_error (" Array shapes do not match." );
216
- }
217
- src_nelems *= static_cast <size_t >(lead_dim);
218
- shapes_equal = shapes_equal && (lead_dim == c_shape[i]);
201
+ std::int64_t first_dim;
202
+ if (a_shape[0 ] == b_shape[0 ]) {
203
+ first_dim = a_shape[0 ];
204
+ }
205
+ else if (a_shape[0 ] == 1 || b_shape[0 ] == 1 ) {
206
+ first_dim = std::max (a_shape[0 ], b_shape[0 ]);
219
207
}
220
- src_nelems *= (m * n);
221
- if (!shapes_equal) {
208
+ else {
222
209
throw py::value_error (" Array shapes do not match." );
223
210
}
211
+ if (first_dim != c_shape[0 ]) {
212
+ throw py::value_error (" Array shapes do not match." );
213
+ }
214
+ std::int64_t src_nelems = first_dim * m * n;
224
215
dpctl::tensor::validation::CheckWritable::throw_if_not_writable (resultC);
225
216
dpctl::tensor::validation::AmpleMemory::throw_if_not_ample (resultC,
226
217
src_nelems);
227
218
228
- // transA and transB are always False
229
- oneapi::mkl::transpose transA = oneapi::mkl::transpose::N;
230
- oneapi::mkl::transpose transB = oneapi::mkl::transpose::N;
219
+ std::vector<py::ssize_t > a_stride = matrixA.get_strides_vector ();
220
+ std::vector<py::ssize_t > b_stride = matrixB.get_strides_vector ();
221
+ std::vector<py::ssize_t > c_stride = resultC.get_strides_vector ();
222
+ const std::int64_t stridea = a_stride[0 ];
223
+ const std::int64_t strideb = b_stride[0 ];
224
+ const std::int64_t stridec = c_stride[0 ];
225
+ bool A_base_is_f_contig = a_stride[1 ] == 1 && a_stride[2 ] == a_shape[1 ];
226
+ bool B_base_is_f_contig = b_stride[1 ] == 1 && b_stride[2 ] == b_shape[1 ];
227
+
228
+ oneapi::mkl::transpose transA = A_base_is_f_contig
229
+ ? oneapi::mkl::transpose::T
230
+ : oneapi::mkl::transpose::N;
231
+ oneapi::mkl::transpose transB = B_base_is_f_contig
232
+ ? oneapi::mkl::transpose::T
233
+ : oneapi::mkl::transpose::N;
234
+
235
+ const std::int64_t lda = (transA == oneapi::mkl::transpose::N) ? k : m;
236
+ const std::int64_t ldb = (transB == oneapi::mkl::transpose::N) ? n : k;
237
+ const std::int64_t ldc = n; // always n for row_major
231
238
232
239
int matrixA_typenum = matrixA.get_typenum ();
233
240
int matrixB_typenum = matrixB.get_typenum ();
@@ -252,10 +259,10 @@ std::pair<sycl::event, sycl::event>
252
259
char *b_typeless_ptr = matrixB.get_data ();
253
260
char *r_typeless_ptr = resultC.get_data ();
254
261
255
- // Note that lda = k, ldb = n, and ld_result = n
256
- sycl::event gemm_batch_ev = gemm_batch_fn (
257
- exec_q, m, n, k, batch_size, k, n, n, stridea, strideb, stridec, transA,
258
- transB, a_typeless_ptr, b_typeless_ptr, r_typeless_ptr, depends);
262
+ sycl::event gemm_batch_ev =
263
+ gemm_batch_fn (exec_q, m, n, k, batch_size, lda, ldb, ldc, stridea,
264
+ strideb, stridec, transA, transB, a_typeless_ptr ,
265
+ b_typeless_ptr, r_typeless_ptr, depends);
259
266
260
267
sycl::event args_batch_ev = dpctl::utils::keep_args_alive (
261
268
exec_q, {matrixA, matrixB, resultC}, {gemm_batch_ev});
0 commit comments