Skip to content

Commit 00cbc35

Browse files
committed
Test removal of k-threading gemm kernel which writes to multiple outputs atomically
1 parent 144ac0f commit 00cbc35

File tree

1 file changed

+80
-74
lines changed
  • dpctl/tensor/libtensor/include/kernels/linalg_functions

1 file changed

+80
-74
lines changed

dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp

Lines changed: 80 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1064,7 +1064,7 @@ sycl::event gemm_impl(sycl::queue &exec_q,
10641064
rhs_shape_strides);
10651065
OuterInnerIndexerT res_indexer(res_outer_nd, 0, res_shape_strides);
10661066

1067-
if (m == 1) {
1067+
if (k > n && k > m || m == 1) {
10681068
constexpr size_t m_groups = 1;
10691069
size_t delta_k(4);
10701070
size_t n_wi(4);
@@ -1099,42 +1099,46 @@ sycl::event gemm_impl(sycl::queue &exec_q,
10991099
lhs_tp, rhs_tp, res_tp, workspace, local_B_block,
11001100
n, n_blocks, delta_n, k, k_blocks, delta_k, n_wi,
11011101
m, lhs_indexer, rhs_indexer, res_indexer));
1102-
}
1103-
else if (k > n && k > m) {
1104-
constexpr size_t m_groups = 2;
1105-
size_t delta_k(4);
1106-
size_t n_wi(4);
1107-
size_t delta_n(4);
1108-
1109-
gemm_detail::scale_gemm_k_parameters<resTy, m_groups>(
1110-
local_mem_size, reserved_slm_size, delta_k,
1111-
n_wi, // modified by reference
1112-
delta_n // modified by reference
1113-
);
1114-
1115-
size_t n_blocks = (n + delta_n - 1) / delta_n;
1116-
size_t m_blocks = (m + m_groups - 1) / m_groups;
1117-
size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k);
1118-
1119-
size_t lws = delta_n * delta_k;
1120-
1121-
auto gRange = sycl::range<1>(n_blocks * m_blocks * k_blocks * lws);
1122-
auto lRange = sycl::range<1>(lws);
1123-
1124-
auto ndRange = sycl::nd_range<1>(gRange, lRange);
1125-
1126-
using LocAccT = sycl::local_accessor<sycl::vec<resTy, m_groups>, 1>;
1127-
LocAccT local_B_block(n_wi * delta_k, cgh);
1128-
LocAccT workspace(delta_n * delta_k, cgh);
1129-
1130-
using KernelName = class gemm_k_krn<lhsTy, rhsTy, resTy,
1131-
OuterInnerIndexerT, m_groups>;
1132-
cgh.parallel_for<KernelName>(
1133-
ndRange, GemmFunctorThreadK<lhsTy, rhsTy, resTy, LocAccT,
1134-
OuterInnerIndexerT, m_groups>(
1135-
lhs_tp, rhs_tp, res_tp, workspace, local_B_block,
1136-
n, n_blocks, delta_n, k, k_blocks, delta_k, n_wi,
1137-
m, lhs_indexer, rhs_indexer, res_indexer));
1102+
// }
1103+
// else if (k > n && k > m) {
1104+
// constexpr size_t m_groups = 2;
1105+
// size_t delta_k(4);
1106+
// size_t n_wi(4);
1107+
// size_t delta_n(4);
1108+
1109+
// gemm_detail::scale_gemm_k_parameters<resTy, m_groups>(
1110+
// local_mem_size, reserved_slm_size, delta_k,
1111+
// n_wi, // modified by reference
1112+
// delta_n // modified by reference
1113+
// );
1114+
1115+
// size_t n_blocks = (n + delta_n - 1) / delta_n;
1116+
// size_t m_blocks = (m + m_groups - 1) / m_groups;
1117+
// size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi *
1118+
// delta_k);
1119+
1120+
// size_t lws = delta_n * delta_k;
1121+
1122+
// auto gRange = sycl::range<1>(n_blocks * m_blocks * k_blocks *
1123+
// lws); auto lRange = sycl::range<1>(lws);
1124+
1125+
// auto ndRange = sycl::nd_range<1>(gRange, lRange);
1126+
1127+
// using LocAccT = sycl::local_accessor<sycl::vec<resTy,
1128+
// m_groups>, 1>; LocAccT local_B_block(n_wi * delta_k, cgh);
1129+
// LocAccT workspace(delta_n * delta_k, cgh);
1130+
1131+
// using KernelName = class gemm_k_krn<lhsTy, rhsTy, resTy,
1132+
// OuterInnerIndexerT,
1133+
// m_groups>;
1134+
// cgh.parallel_for<KernelName>(
1135+
// ndRange, GemmFunctorThreadK<lhsTy, rhsTy, resTy, LocAccT,
1136+
// OuterInnerIndexerT,
1137+
// m_groups>(
1138+
// lhs_tp, rhs_tp, res_tp, workspace,
1139+
// local_B_block, n, n_blocks, delta_n, k,
1140+
// k_blocks, delta_k, n_wi, m, lhs_indexer,
1141+
// rhs_indexer, res_indexer));
11381142
}
11391143
else {
11401144
constexpr int wi_delta_n = 2;
@@ -1230,7 +1234,7 @@ sycl::event gemm_contig_impl(sycl::queue &exec_q,
12301234
OuterInnerIndexerT rhs_indexer{};
12311235
OuterInnerIndexerT res_indexer{};
12321236

1233-
if (m == 1) {
1237+
if (k > n && k > m || m == 1) {
12341238
constexpr size_t m_groups = 1;
12351239
size_t delta_k(4);
12361240
size_t n_wi(4);
@@ -1266,42 +1270,44 @@ sycl::event gemm_contig_impl(sycl::queue &exec_q,
12661270
n, n_blocks, delta_n, k, k_blocks, delta_k, n_wi,
12671271
m, lhs_indexer, rhs_indexer, res_indexer));
12681272
}
1269-
else if (k > n && k > m) {
1270-
constexpr size_t m_groups = 2;
1271-
size_t delta_k(4);
1272-
size_t n_wi(4);
1273-
size_t delta_n(4);
1274-
1275-
gemm_detail::scale_gemm_k_parameters<resTy, m_groups>(
1276-
local_mem_size, reserved_slm_size, delta_k,
1277-
n_wi, // modified by reference
1278-
delta_n // modified by reference
1279-
);
1280-
1281-
size_t n_blocks = (n + delta_n - 1) / delta_n;
1282-
size_t m_blocks = (m + m_groups - 1) / m_groups;
1283-
size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k);
1284-
1285-
size_t lws = delta_n * delta_k;
1286-
1287-
auto gRange = sycl::range<1>(n_blocks * m_blocks * k_blocks * lws);
1288-
auto lRange = sycl::range<1>(lws);
1289-
1290-
auto ndRange = sycl::nd_range<1>(gRange, lRange);
1291-
1292-
using LocAccT = sycl::local_accessor<sycl::vec<resTy, m_groups>, 1>;
1293-
LocAccT local_B_block(n_wi * delta_k, cgh);
1294-
LocAccT workspace(delta_n * delta_k, cgh);
1295-
1296-
using KernelName = class gemm_k_krn<lhsTy, rhsTy, resTy,
1297-
OuterInnerIndexerT, m_groups>;
1298-
cgh.parallel_for<KernelName>(
1299-
ndRange, GemmFunctorThreadK<lhsTy, rhsTy, resTy, LocAccT,
1300-
OuterInnerIndexerT, m_groups>(
1301-
lhs_tp, rhs_tp, res_tp, workspace, local_B_block,
1302-
n, n_blocks, delta_n, k, k_blocks, delta_k, n_wi,
1303-
m, lhs_indexer, rhs_indexer, res_indexer));
1304-
}
1273+
// else if (k > n && k > m) {
1274+
// constexpr size_t m_groups = 2;
1275+
// size_t delta_k(4);
1276+
// size_t n_wi(4);
1277+
// size_t delta_n(4);
1278+
1279+
// gemm_detail::scale_gemm_k_parameters<resTy, m_groups>(
1280+
// local_mem_size, reserved_slm_size, delta_k,
1281+
// n_wi, // modified by reference
1282+
// delta_n // modified by reference
1283+
// );
1284+
1285+
// size_t n_blocks = (n + delta_n - 1) / delta_n;
1286+
// size_t m_blocks = (m + m_groups - 1) / m_groups;
1287+
// size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k);
1288+
1289+
// size_t lws = delta_n * delta_k;
1290+
1291+
// auto gRange = sycl::range<1>(n_blocks * m_blocks * k_blocks *
1292+
// lws); auto lRange = sycl::range<1>(lws);
1293+
1294+
// auto ndRange = sycl::nd_range<1>(gRange, lRange);
1295+
1296+
// using LocAccT = sycl::local_accessor<sycl::vec<resTy, m_groups>,
1297+
// 1>; LocAccT local_B_block(n_wi * delta_k, cgh); LocAccT
1298+
// workspace(delta_n * delta_k, cgh);
1299+
1300+
// using KernelName = class gemm_k_krn<lhsTy, rhsTy, resTy,
1301+
// OuterInnerIndexerT,
1302+
// m_groups>;
1303+
// cgh.parallel_for<KernelName>(
1304+
// ndRange, GemmFunctorThreadK<lhsTy, rhsTy, resTy, LocAccT,
1305+
// OuterInnerIndexerT, m_groups>(
1306+
// lhs_tp, rhs_tp, res_tp, workspace,
1307+
// local_B_block, n, n_blocks, delta_n, k,
1308+
// k_blocks, delta_k, n_wi, m, lhs_indexer,
1309+
// rhs_indexer, res_indexer));
1310+
// }
13051311
else {
13061312
constexpr int wi_delta_n = 2;
13071313
constexpr int wi_delta_m = 4;

0 commit comments

Comments
 (0)