Skip to content

Commit ad53472

Browse files
committed
Fix incorrect comments throughtout gemm kernels
Comments incorrectly stated that the third argument to `scale_gemm_k_parameters` is modified by reference
1 parent 60f1d21 commit ad53472

File tree

1 file changed

+60
-80
lines changed
  • dpctl/tensor/libtensor/include/kernels/linalg_functions

1 file changed

+60
-80
lines changed

dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp

Lines changed: 60 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -1071,10 +1071,9 @@ sycl::event gemm_impl(sycl::queue &exec_q,
10711071
size_t delta_n(16);
10721072

10731073
gemm_detail::scale_gemm_k_parameters<resTy, m_groups>(
1074-
local_mem_size, reserved_slm_size,
1075-
delta_k, // modified by reference
1076-
n_wi, // modified by reference
1077-
delta_n // modified by reference
1074+
local_mem_size, reserved_slm_size, delta_k,
1075+
n_wi, // modified by reference
1076+
delta_n // modified by reference
10781077
);
10791078

10801079
size_t n_blocks = (n + delta_n - 1) / delta_n;
@@ -1108,10 +1107,9 @@ sycl::event gemm_impl(sycl::queue &exec_q,
11081107
size_t delta_n(16);
11091108

11101109
gemm_detail::scale_gemm_k_parameters<resTy, m_groups>(
1111-
local_mem_size, reserved_slm_size,
1112-
delta_k, // modified by reference
1113-
n_wi, // modified by reference
1114-
delta_n // modified by reference
1110+
local_mem_size, reserved_slm_size, delta_k,
1111+
n_wi, // modified by reference
1112+
delta_n // modified by reference
11151113
);
11161114

11171115
size_t n_blocks = (n + delta_n - 1) / delta_n;
@@ -1239,10 +1237,9 @@ sycl::event gemm_contig_impl(sycl::queue &exec_q,
12391237
size_t delta_n(16);
12401238

12411239
gemm_detail::scale_gemm_k_parameters<resTy, m_groups>(
1242-
local_mem_size, reserved_slm_size,
1243-
delta_k, // modified by reference
1244-
n_wi, // modified by reference
1245-
delta_n // modified by reference
1240+
local_mem_size, reserved_slm_size, delta_k,
1241+
n_wi, // modified by reference
1242+
delta_n // modified by reference
12461243
);
12471244

12481245
size_t n_blocks = (n + delta_n - 1) / delta_n;
@@ -1276,10 +1273,9 @@ sycl::event gemm_contig_impl(sycl::queue &exec_q,
12761273
size_t delta_n(16);
12771274

12781275
gemm_detail::scale_gemm_k_parameters<resTy, m_groups>(
1279-
local_mem_size, reserved_slm_size,
1280-
delta_k, // modified by reference
1281-
n_wi, // modified by reference
1282-
delta_n // modified by reference
1276+
local_mem_size, reserved_slm_size, delta_k,
1277+
n_wi, // modified by reference
1278+
delta_n // modified by reference
12831279
);
12841280

12851281
size_t n_blocks = (n + delta_n - 1) / delta_n;
@@ -1976,10 +1972,9 @@ sycl::event gemm_tree_impl(sycl::queue &exec_q,
19761972
constexpr int m_groups = 1;
19771973

19781974
gemm_detail::scale_gemm_k_parameters<resTy, m_groups>(
1979-
local_mem_size, reserved_slm_size,
1980-
delta_k, // modified by reference
1981-
n_wi, // modified by reference
1982-
delta_n // modified by reference
1975+
local_mem_size, reserved_slm_size, delta_k,
1976+
n_wi, // modified by reference
1977+
delta_n // modified by reference
19831978
);
19841979

19851980
sycl::event gemm_ev;
@@ -2250,10 +2245,9 @@ sycl::event gemm_tree_impl(sycl::queue &exec_q,
22502245
else {
22512246
constexpr int m_groups = 2;
22522247
gemm_detail::scale_gemm_k_parameters<resTy, m_groups>(
2253-
local_mem_size, reserved_slm_size,
2254-
delta_k, // modified by reference
2255-
n_wi, // modified by reference
2256-
delta_n // modified by reference
2248+
local_mem_size, reserved_slm_size, delta_k,
2249+
n_wi, // modified by reference
2250+
delta_n // modified by reference
22572251
);
22582252

22592253
sycl::event gemm_ev;
@@ -2529,10 +2523,9 @@ sycl::event gemm_tree_impl(sycl::queue &exec_q,
25292523
constexpr int m_groups = 1;
25302524

25312525
gemm_detail::scale_gemm_k_parameters<resTy, m_groups>(
2532-
local_mem_size, reserved_slm_size,
2533-
delta_k, // modified by reference
2534-
n_wi, // modified by reference
2535-
delta_n // modified by reference
2526+
local_mem_size, reserved_slm_size, delta_k,
2527+
n_wi, // modified by reference
2528+
delta_n // modified by reference
25362529
);
25372530

25382531
sycl::event gemm_ev;
@@ -3410,10 +3403,9 @@ sycl::event gemm_contig_tree_impl(sycl::queue &exec_q,
34103403
constexpr int m_groups = 1;
34113404

34123405
gemm_detail::scale_gemm_k_parameters<resTy, m_groups>(
3413-
local_mem_size, reserved_slm_size,
3414-
delta_k, // modified by reference
3415-
n_wi, // modified by reference
3416-
delta_n // modified by reference
3406+
local_mem_size, reserved_slm_size, delta_k,
3407+
n_wi, // modified by reference
3408+
delta_n // modified by reference
34173409
);
34183410

34193411
sycl::event gemm_ev;
@@ -3663,10 +3655,9 @@ sycl::event gemm_contig_tree_impl(sycl::queue &exec_q,
36633655
else {
36643656
constexpr int m_groups = 2;
36653657
gemm_detail::scale_gemm_k_parameters<resTy, m_groups>(
3666-
local_mem_size, reserved_slm_size,
3667-
delta_k, // modified by reference
3668-
n_wi, // modified by reference
3669-
delta_n // modified by reference
3658+
local_mem_size, reserved_slm_size, delta_k,
3659+
n_wi, // modified by reference
3660+
delta_n // modified by reference
36703661
);
36713662

36723663
sycl::event gemm_ev;
@@ -3920,10 +3911,9 @@ sycl::event gemm_contig_tree_impl(sycl::queue &exec_q,
39203911
constexpr int m_groups = 1;
39213912

39223913
gemm_detail::scale_gemm_k_parameters<resTy, m_groups>(
3923-
local_mem_size, reserved_slm_size,
3924-
delta_k, // modified by reference
3925-
n_wi, // modified by reference
3926-
delta_n // modified by reference
3914+
local_mem_size, reserved_slm_size, delta_k,
3915+
n_wi, // modified by reference
3916+
delta_n // modified by reference
39273917
);
39283918

39293919
sycl::event gemm_ev;
@@ -5476,10 +5466,9 @@ sycl::event gemm_batch_impl(sycl::queue &exec_q,
54765466
size_t delta_n(16);
54775467

54785468
gemm_detail::scale_gemm_k_parameters<resTy, m_groups>(
5479-
local_mem_size, reserved_slm_size,
5480-
delta_k, // modified by reference
5481-
n_wi, // modified by reference
5482-
delta_n // modified by reference
5469+
local_mem_size, reserved_slm_size, delta_k,
5470+
n_wi, // modified by reference
5471+
delta_n // modified by reference
54835472
);
54845473

54855474
size_t n_blocks = (n + delta_n - 1) / delta_n;
@@ -5518,10 +5507,9 @@ sycl::event gemm_batch_impl(sycl::queue &exec_q,
55185507
size_t delta_n(16);
55195508

55205509
gemm_detail::scale_gemm_k_parameters<resTy, m_groups>(
5521-
local_mem_size, reserved_slm_size,
5522-
delta_k, // modified by reference
5523-
n_wi, // modified by reference
5524-
delta_n // modified by reference
5510+
local_mem_size, reserved_slm_size, delta_k,
5511+
n_wi, // modified by reference
5512+
delta_n // modified by reference
55255513
);
55265514

55275515
size_t n_blocks = (n + delta_n - 1) / delta_n;
@@ -5680,10 +5668,9 @@ sycl::event gemm_batch_contig_impl(sycl::queue &exec_q,
56805668
size_t delta_n(16);
56815669

56825670
gemm_detail::scale_gemm_k_parameters<resTy, m_groups>(
5683-
local_mem_size, reserved_slm_size,
5684-
delta_k, // modified by reference
5685-
n_wi, // modified by reference
5686-
delta_n // modified by reference
5671+
local_mem_size, reserved_slm_size, delta_k,
5672+
n_wi, // modified by reference
5673+
delta_n // modified by reference
56875674
);
56885675

56895676
size_t n_blocks = (n + delta_n - 1) / delta_n;
@@ -5722,10 +5709,9 @@ sycl::event gemm_batch_contig_impl(sycl::queue &exec_q,
57225709
size_t delta_n(16);
57235710

57245711
gemm_detail::scale_gemm_k_parameters<resTy, m_groups>(
5725-
local_mem_size, reserved_slm_size,
5726-
delta_k, // modified by reference
5727-
n_wi, // modified by reference
5728-
delta_n // modified by reference
5712+
local_mem_size, reserved_slm_size, delta_k,
5713+
n_wi, // modified by reference
5714+
delta_n // modified by reference
57295715
);
57305716

57315717
size_t n_blocks = (n + delta_n - 1) / delta_n;
@@ -6506,10 +6492,9 @@ gemm_batch_tree_impl(sycl::queue &exec_q,
65066492
if (m == 1) {
65076493
constexpr int m_groups = 1;
65086494
gemm_detail::scale_gemm_k_parameters<resTy, m_groups>(
6509-
local_mem_size, reserved_slm_size,
6510-
delta_k, // modified by reference
6511-
n_wi, // modified by reference
6512-
delta_n // modified by reference
6495+
local_mem_size, reserved_slm_size, delta_k,
6496+
n_wi, // modified by reference
6497+
delta_n // modified by reference
65136498
);
65146499

65156500
if (k <= (delta_k * n_wi)) {
@@ -6836,10 +6821,9 @@ gemm_batch_tree_impl(sycl::queue &exec_q,
68366821
constexpr int m_groups = 2;
68376822

68386823
gemm_detail::scale_gemm_k_parameters<resTy, m_groups>(
6839-
local_mem_size, reserved_slm_size,
6840-
delta_k, // modified by reference
6841-
n_wi, // modified by reference
6842-
delta_n // modified by reference
6824+
local_mem_size, reserved_slm_size, delta_k,
6825+
n_wi, // modified by reference
6826+
delta_n // modified by reference
68436827
);
68446828

68456829
if (k <= (delta_k * n_wi)) {
@@ -7174,10 +7158,9 @@ gemm_batch_tree_impl(sycl::queue &exec_q,
71747158
constexpr int m_groups = 1;
71757159

71767160
gemm_detail::scale_gemm_k_parameters<resTy, m_groups>(
7177-
local_mem_size, reserved_slm_size,
7178-
delta_k, // modified by reference
7179-
n_wi, // modified by reference
7180-
delta_n // modified by reference
7161+
local_mem_size, reserved_slm_size, delta_k,
7162+
n_wi, // modified by reference
7163+
delta_n // modified by reference
71817164
);
71827165

71837166
// each group processes delta_k * n_wi
@@ -8212,10 +8195,9 @@ gemm_batch_contig_tree_impl(sycl::queue &exec_q,
82128195
if (m == 1) {
82138196
constexpr int m_groups = 1;
82148197
gemm_detail::scale_gemm_k_parameters<resTy, m_groups>(
8215-
local_mem_size, reserved_slm_size,
8216-
delta_k, // modified by reference
8217-
n_wi, // modified by reference
8218-
delta_n // modified by reference
8198+
local_mem_size, reserved_slm_size, delta_k,
8199+
n_wi, // modified by reference
8200+
delta_n // modified by reference
82198201
);
82208202

82218203
if (k <= (delta_k * n_wi)) {
@@ -8533,10 +8515,9 @@ gemm_batch_contig_tree_impl(sycl::queue &exec_q,
85338515
constexpr int m_groups = 2;
85348516

85358517
gemm_detail::scale_gemm_k_parameters<resTy, m_groups>(
8536-
local_mem_size, reserved_slm_size,
8537-
delta_k, // modified by reference
8538-
n_wi, // modified by reference
8539-
delta_n // modified by reference
8518+
local_mem_size, reserved_slm_size, delta_k,
8519+
n_wi, // modified by reference
8520+
delta_n // modified by reference
85408521
);
85418522

85428523
if (k <= (delta_k * n_wi)) {
@@ -8857,10 +8838,9 @@ gemm_batch_contig_tree_impl(sycl::queue &exec_q,
88578838
constexpr int m_groups = 1;
88588839

88598840
gemm_detail::scale_gemm_k_parameters<resTy, m_groups>(
8860-
local_mem_size, reserved_slm_size,
8861-
delta_k, // modified by reference
8862-
n_wi, // modified by reference
8863-
delta_n // modified by reference
8841+
local_mem_size, reserved_slm_size, delta_k,
8842+
n_wi, // modified by reference
8843+
delta_n // modified by reference
88648844
);
88658845

88668846
// each group processes delta_k * n_wi

0 commit comments

Comments
 (0)