@@ -58,6 +58,7 @@ typedef sycl::event (*gemm_impl_fn_ptr_t)(sycl::queue,
58
58
const std::int64_t ,
59
59
char *,
60
60
const std::int64_t ,
61
+ const bool ,
61
62
const std::vector<sycl::event> &);
62
63
63
64
static gemm_impl_fn_ptr_t gemm_dispatch_table[dpctl_td_ns::num_types]
@@ -76,6 +77,7 @@ static sycl::event gemm_impl(sycl::queue exec_q,
76
77
const std::int64_t ld_array_2,
77
78
char *resultC,
78
79
const std::int64_t ld_result,
80
+ const bool isRowMajor,
79
81
const std::vector<sycl::event> &depends)
80
82
{
81
83
type_utils::validate_type_for_device<Tab>(exec_q);
@@ -92,24 +94,54 @@ static sycl::event gemm_impl(sycl::queue exec_q,
92
94
sycl::event gemm_event;
93
95
try {
94
96
// Need to add logic to call column_major::gemm
95
- gemm_event = mkl_blas::row_major::gemm (
96
- exec_q,
97
- transA, // Parameter indicating whether matrix A is not transposed
98
- // ('N'), transposed ('T'), or conjugate transposed ('C').
99
- transB, // Same as transA but for matrix B.
100
- m, // Number of rows in matrices A and C.
101
- n, // Number of columns in matrices B and C.
102
- k, // Number of columns in matrix A and rows in matrix B.
103
- Tab (1 ), // Scaling factor for the product of matrices A and B.
104
- a, // Pointer to matrix A.
105
- ld_array_1, // Leading dimension of matrix A, which is the stride
106
- // between successive rows (for row major layout).
107
- b, // Pointer to matrix B.
108
- ld_array_2, // Leading dimension of matrix B, similar to ld_array_1.
109
- Tab (0 ), // Scaling factor for matrix C.
110
- res, // Pointer to matrix C, where the result is stored.
111
- ld_result, // Leading dimension of matrix C.
112
- depends);
97
+ if (isRowMajor) {
98
+ gemm_event = mkl_blas::row_major::gemm (
99
+ exec_q,
100
+ transA, // Parameter indicating whether matrix A is not
101
+ // transposed
102
+ // ('N'), transposed ('T'), or conjugate transposed
103
+ // ('C').
104
+ transB, // Same as transA but for matrix B.
105
+ m, // Number of rows in matrices A and C.
106
+ n, // Number of columns in matrices B and C.
107
+ k, // Number of columns in matrix A and rows in matrix B.
108
+ Tab (1 ), // Scaling factor for the product of matrices A and B.
109
+ a, // Pointer to matrix A.
110
+ ld_array_1, // Leading dimension of matrix A, which is the
111
+ // stride between successive rows (for row major
112
+ // layout).
113
+ b, // Pointer to matrix B.
114
+ ld_array_2, // Leading dimension of matrix B, similar to
115
+ // ld_array_1.
116
+ Tab (0 ), // Scaling factor for matrix C.
117
+ res, // Pointer to matrix C, where the result is stored.
118
+ ld_result, // Leading dimension of matrix C.
119
+ depends);
120
+ }
121
+ else {
122
+ gemm_event = mkl_blas::column_major::gemm (
123
+ exec_q,
124
+ transA, // Parameter indicating whether matrix A is not
125
+ // transposed
126
+ // ('N'), transposed ('T'), or conjugate transposed
127
+ // ('C').
128
+ transB, // Same as transA but for matrix B.
129
+ m, // Number of rows in matrices A and C.
130
+ n, // Number of columns in matrices B and C.
131
+ k, // Number of columns in matrix A and rows in matrix B.
132
+ Tab (1 ), // Scaling factor for the product of matrices A and B.
133
+ a, // Pointer to matrix A.
134
+ ld_array_1, // Leading dimension of matrix A, which is the
135
+ // stride between successive rows (for row major
136
+ // layout).
137
+ b, // Pointer to matrix B.
138
+ ld_array_2, // Leading dimension of matrix B, similar to
139
+ // ld_array_1.
140
+ Tab (0 ), // Scaling factor for matrix C.
141
+ res, // Pointer to matrix C, where the result is stored.
142
+ ld_result, // Leading dimension of matrix C.
143
+ depends);
144
+ }
113
145
} catch (oneapi::mkl::exception const &e) {
114
146
error_msg
115
147
<< " Unexpected MKL exception caught during gemm() call:\n reason: "
@@ -134,6 +166,7 @@ std::pair<sycl::event, sycl::event>
134
166
dpctl::tensor::usm_ndarray matrixA,
135
167
dpctl::tensor::usm_ndarray matrixB,
136
168
dpctl::tensor::usm_ndarray resultC,
169
+ const bool isRowMajor,
137
170
const std::vector<sycl::event> &depends)
138
171
{
139
172
const int matrixA_nd = matrixA.get_ndim ();
@@ -234,7 +267,8 @@ std::pair<sycl::event, sycl::event>
234
267
std::vector<sycl::event> host_task_events;
235
268
sycl::event gemm_ev =
236
269
gemm_fn (exec_q, transA, transB, m, n, k, a_typeless_ptr, ld_array_1,
237
- b_typeless_ptr, ld_array_2, r_typeless_ptr, ld_result, depends);
270
+ b_typeless_ptr, ld_array_2, r_typeless_ptr, ld_result,
271
+ isRowMajor, depends);
238
272
239
273
sycl::event args_ev = dpctl::utils::keep_args_alive (
240
274
exec_q, {matrixA, matrixB, resultC}, host_task_events);
0 commit comments