@@ -1064,7 +1064,7 @@ sycl::event gemm_impl(sycl::queue &exec_q,
1064
1064
rhs_shape_strides);
1065
1065
OuterInnerIndexerT res_indexer (res_outer_nd, 0 , res_shape_strides);
1066
1066
1067
- if (m == 1 ) {
1067
+ if (k > n && k > m || m == 1 ) {
1068
1068
constexpr size_t m_groups = 1 ;
1069
1069
size_t delta_k (4 );
1070
1070
size_t n_wi (4 );
@@ -1099,42 +1099,46 @@ sycl::event gemm_impl(sycl::queue &exec_q,
1099
1099
lhs_tp, rhs_tp, res_tp, workspace, local_B_block,
1100
1100
n, n_blocks, delta_n, k, k_blocks, delta_k, n_wi,
1101
1101
m, lhs_indexer, rhs_indexer, res_indexer));
1102
- }
1103
- else if (k > n && k > m) {
1104
- constexpr size_t m_groups = 2 ;
1105
- size_t delta_k (4 );
1106
- size_t n_wi (4 );
1107
- size_t delta_n (4 );
1108
-
1109
- gemm_detail::scale_gemm_k_parameters<resTy, m_groups>(
1110
- local_mem_size, reserved_slm_size, delta_k,
1111
- n_wi, // modified by reference
1112
- delta_n // modified by reference
1113
- );
1114
-
1115
- size_t n_blocks = (n + delta_n - 1 ) / delta_n;
1116
- size_t m_blocks = (m + m_groups - 1 ) / m_groups;
1117
- size_t k_blocks = (k + n_wi * delta_k - 1 ) / (n_wi * delta_k);
1118
-
1119
- size_t lws = delta_n * delta_k;
1120
-
1121
- auto gRange = sycl::range<1 >(n_blocks * m_blocks * k_blocks * lws);
1122
- auto lRange = sycl::range<1 >(lws);
1123
-
1124
- auto ndRange = sycl::nd_range<1 >(gRange , lRange);
1125
-
1126
- using LocAccT = sycl::local_accessor<sycl::vec<resTy, m_groups>, 1 >;
1127
- LocAccT local_B_block (n_wi * delta_k, cgh);
1128
- LocAccT workspace (delta_n * delta_k, cgh);
1129
-
1130
- using KernelName = class gemm_k_krn <lhsTy, rhsTy, resTy,
1131
- OuterInnerIndexerT, m_groups>;
1132
- cgh.parallel_for <KernelName>(
1133
- ndRange, GemmFunctorThreadK<lhsTy, rhsTy, resTy, LocAccT,
1134
- OuterInnerIndexerT, m_groups>(
1135
- lhs_tp, rhs_tp, res_tp, workspace, local_B_block,
1136
- n, n_blocks, delta_n, k, k_blocks, delta_k, n_wi,
1137
- m, lhs_indexer, rhs_indexer, res_indexer));
1102
+ // }
1103
+ // else if (k > n && k > m) {
1104
+ // constexpr size_t m_groups = 2;
1105
+ // size_t delta_k(4);
1106
+ // size_t n_wi(4);
1107
+ // size_t delta_n(4);
1108
+
1109
+ // gemm_detail::scale_gemm_k_parameters<resTy, m_groups>(
1110
+ // local_mem_size, reserved_slm_size, delta_k,
1111
+ // n_wi, // modified by reference
1112
+ // delta_n // modified by reference
1113
+ // );
1114
+
1115
+ // size_t n_blocks = (n + delta_n - 1) / delta_n;
1116
+ // size_t m_blocks = (m + m_groups - 1) / m_groups;
1117
+ // size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi *
1118
+ // delta_k);
1119
+
1120
+ // size_t lws = delta_n * delta_k;
1121
+
1122
+ // auto gRange = sycl::range<1>(n_blocks * m_blocks * k_blocks *
1123
+ // lws); auto lRange = sycl::range<1>(lws);
1124
+
1125
+ // auto ndRange = sycl::nd_range<1>(gRange, lRange);
1126
+
1127
+ // using LocAccT = sycl::local_accessor<sycl::vec<resTy,
1128
+ // m_groups>, 1>; LocAccT local_B_block(n_wi * delta_k, cgh);
1129
+ // LocAccT workspace(delta_n * delta_k, cgh);
1130
+
1131
+ // using KernelName = class gemm_k_krn<lhsTy, rhsTy, resTy,
1132
+ // OuterInnerIndexerT,
1133
+ // m_groups>;
1134
+ // cgh.parallel_for<KernelName>(
1135
+ // ndRange, GemmFunctorThreadK<lhsTy, rhsTy, resTy, LocAccT,
1136
+ // OuterInnerIndexerT,
1137
+ // m_groups>(
1138
+ // lhs_tp, rhs_tp, res_tp, workspace,
1139
+ // local_B_block, n, n_blocks, delta_n, k,
1140
+ // k_blocks, delta_k, n_wi, m, lhs_indexer,
1141
+ // rhs_indexer, res_indexer));
1138
1142
}
1139
1143
else {
1140
1144
constexpr int wi_delta_n = 2 ;
@@ -1230,7 +1234,7 @@ sycl::event gemm_contig_impl(sycl::queue &exec_q,
1230
1234
OuterInnerIndexerT rhs_indexer{};
1231
1235
OuterInnerIndexerT res_indexer{};
1232
1236
1233
- if (m == 1 ) {
1237
+ if (k > n && k > m || m == 1 ) {
1234
1238
constexpr size_t m_groups = 1 ;
1235
1239
size_t delta_k (4 );
1236
1240
size_t n_wi (4 );
@@ -1266,42 +1270,44 @@ sycl::event gemm_contig_impl(sycl::queue &exec_q,
1266
1270
n, n_blocks, delta_n, k, k_blocks, delta_k, n_wi,
1267
1271
m, lhs_indexer, rhs_indexer, res_indexer));
1268
1272
}
1269
- else if (k > n && k > m) {
1270
- constexpr size_t m_groups = 2 ;
1271
- size_t delta_k (4 );
1272
- size_t n_wi (4 );
1273
- size_t delta_n (4 );
1274
-
1275
- gemm_detail::scale_gemm_k_parameters<resTy, m_groups>(
1276
- local_mem_size, reserved_slm_size, delta_k,
1277
- n_wi, // modified by reference
1278
- delta_n // modified by reference
1279
- );
1280
-
1281
- size_t n_blocks = (n + delta_n - 1 ) / delta_n;
1282
- size_t m_blocks = (m + m_groups - 1 ) / m_groups;
1283
- size_t k_blocks = (k + n_wi * delta_k - 1 ) / (n_wi * delta_k);
1284
-
1285
- size_t lws = delta_n * delta_k;
1286
-
1287
- auto gRange = sycl::range<1 >(n_blocks * m_blocks * k_blocks * lws);
1288
- auto lRange = sycl::range<1 >(lws);
1289
-
1290
- auto ndRange = sycl::nd_range<1 >(gRange , lRange);
1291
-
1292
- using LocAccT = sycl::local_accessor<sycl::vec<resTy, m_groups>, 1 >;
1293
- LocAccT local_B_block (n_wi * delta_k, cgh);
1294
- LocAccT workspace (delta_n * delta_k, cgh);
1295
-
1296
- using KernelName = class gemm_k_krn <lhsTy, rhsTy, resTy,
1297
- OuterInnerIndexerT, m_groups>;
1298
- cgh.parallel_for <KernelName>(
1299
- ndRange, GemmFunctorThreadK<lhsTy, rhsTy, resTy, LocAccT,
1300
- OuterInnerIndexerT, m_groups>(
1301
- lhs_tp, rhs_tp, res_tp, workspace, local_B_block,
1302
- n, n_blocks, delta_n, k, k_blocks, delta_k, n_wi,
1303
- m, lhs_indexer, rhs_indexer, res_indexer));
1304
- }
1273
+ // else if (k > n && k > m) {
1274
+ // constexpr size_t m_groups = 2;
1275
+ // size_t delta_k(4);
1276
+ // size_t n_wi(4);
1277
+ // size_t delta_n(4);
1278
+
1279
+ // gemm_detail::scale_gemm_k_parameters<resTy, m_groups>(
1280
+ // local_mem_size, reserved_slm_size, delta_k,
1281
+ // n_wi, // modified by reference
1282
+ // delta_n // modified by reference
1283
+ // );
1284
+
1285
+ // size_t n_blocks = (n + delta_n - 1) / delta_n;
1286
+ // size_t m_blocks = (m + m_groups - 1) / m_groups;
1287
+ // size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k);
1288
+
1289
+ // size_t lws = delta_n * delta_k;
1290
+
1291
+ // auto gRange = sycl::range<1>(n_blocks * m_blocks * k_blocks *
1292
+ // lws); auto lRange = sycl::range<1>(lws);
1293
+
1294
+ // auto ndRange = sycl::nd_range<1>(gRange, lRange);
1295
+
1296
+ // using LocAccT = sycl::local_accessor<sycl::vec<resTy, m_groups>,
1297
+ // 1>; LocAccT local_B_block(n_wi * delta_k, cgh); LocAccT
1298
+ // workspace(delta_n * delta_k, cgh);
1299
+
1300
+ // using KernelName = class gemm_k_krn<lhsTy, rhsTy, resTy,
1301
+ // OuterInnerIndexerT,
1302
+ // m_groups>;
1303
+ // cgh.parallel_for<KernelName>(
1304
+ // ndRange, GemmFunctorThreadK<lhsTy, rhsTy, resTy, LocAccT,
1305
+ // OuterInnerIndexerT, m_groups>(
1306
+ // lhs_tp, rhs_tp, res_tp, workspace,
1307
+ // local_B_block, n, n_blocks, delta_n, k,
1308
+ // k_blocks, delta_k, n_wi, m, lhs_indexer,
1309
+ // rhs_indexer, res_indexer));
1310
+ // }
1305
1311
else {
1306
1312
constexpr int wi_delta_n = 2 ;
1307
1313
constexpr int wi_delta_m = 4 ;
0 commit comments