Skip to content

Commit f9e4f10

Browse files
authored
[SYCL][CUDA] Improved joint_matrix layout test coverage. (#12483)
Improved joint_matrix layout test coverage. The test framework that the cuda backend tests use has been updated to support all possible `joint_matrix` gemm API combinations, including all matrix layouts. the gemm header is backend agnostic; hence all backends could use this test framework in the future. This test framework can also act as an example to show how to deal with different layout combinations when computing a general GEMM. Signed-off-by: JackAKirk <[email protected]>
1 parent e402523 commit f9e4f10

File tree

4 files changed

+118
-43
lines changed

4 files changed

+118
-43
lines changed

sycl/test-e2e/Matrix/joint_matrix_gemm_cuda.hpp

Lines changed: 68 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ using namespace sycl::ext::oneapi;
1111
using namespace sycl::ext::oneapi::experimental::matrix;
1212
constexpr float bf16_eps = 0.00390625;
1313

14-
// Example usage of Nvidia matrix multiply.
14+
// Example usage of joint_matrix matrix multiply.
1515
// Optimizations such as memory paddings for avoiding bank conflicts are not
1616
// included in this test which aids clarity for what is going on. This example
1717
// forms a "Big matrix" corresponding to a single "TILE" using cuda example
@@ -30,37 +30,47 @@ constexpr float bf16_eps = 0.00390625;
3030
constexpr int N_THREADS_PER_MATRIX_OP = 32;
3131

3232
// number of submatrices per row of accumulator ("C", "D") matrices.
33-
constexpr int SUB_TILES_M = 3;
33+
constexpr int SUB_TILES_M = 2;
3434
// number of submatrices per col of accumulator matrices.
3535
constexpr int SUB_TILES_N = 2;
3636
// number of submatrices per col of "A"/per row of "B", matrices.
37-
constexpr int SUB_TILES_K = 1;
37+
constexpr int SUB_TILES_K = 2;
3838

39-
template <typename Tm, typename Tc, typename Td, size_t M, size_t K, size_t N>
39+
template <typename Tm, typename Tc, typename Td, size_t M, size_t K, size_t N,
40+
layout layout_A, layout layout_B, layout layout_C>
4041
class TypeHelper;
4142

42-
template <typename Tm, typename Tc, typename Td, size_t M, size_t K, size_t N>
43-
using KernelName = class TypeHelper<Tm, Tc, Td, M, K, N>;
43+
template <typename Tm, typename Tc, typename Td, size_t M, size_t K, size_t N,
44+
layout layout_A, layout layout_B, layout layout_C>
45+
using KernelName =
46+
class TypeHelper<Tm, Tc, Td, M, K, N, layout_A, layout_B, layout_C>;
4447

45-
template <size_t Big_N, size_t Big_K, typename Tm, typename Tc>
48+
template <size_t Big_N, size_t Big_K, size_t Big_M, layout layout_A,
49+
layout layout_B, typename Tm, typename Tc>
4650
Tc matrix_ref_mn(const int &m, const int &n, Tm *A, Tm *B, Tc *C) {
4751
Tc res = C[m * Big_N + n];
4852

49-
if constexpr (std::is_same<Tm, bfloat16>::value) {
50-
for (int k = 0; k < Big_K; k++)
51-
res += A[m * Big_K + k] * B[k * Big_N + n];
52-
} else {
53-
for (int k = 0; k < Big_K; k++)
54-
res +=
55-
static_cast<Tc>(A[m * Big_K + k]) * static_cast<Tc>(B[k * Big_N + n]);
53+
for (int k = 0; k < Big_K; k++) {
54+
auto index_a =
55+
layout_A == layout::row_major ? m * Big_K + k : m + k * Big_M;
56+
auto index_b =
57+
layout_B == layout::row_major ? k * Big_N + n : k + n * Big_K;
58+
59+
if constexpr (std::is_same<Tm, bfloat16>::value) {
60+
res += A[index_a] * B[index_b];
61+
} else {
62+
res += static_cast<Tc>(A[index_a]) * static_cast<Tc>(B[index_b]);
63+
}
5664
}
5765

5866
return res;
5967
}
6068

61-
template <typename Tm, typename Tc, typename Td, size_t Sub_Tiles_M,
62-
size_t Sub_Tiles_K, size_t Sub_Tiles_N, size_t M, size_t K, size_t N,
63-
typename T3 = std::remove_const_t<Tm>>
69+
template <
70+
typename Tm, typename Tc, typename Td, size_t Sub_Tiles_M,
71+
size_t Sub_Tiles_K, size_t Sub_Tiles_N, size_t M, size_t K, size_t N,
72+
layout layout_A = layout::row_major, layout layout_B = layout::row_major,
73+
layout layout_C = layout::row_major, typename T3 = std::remove_const_t<Tm>>
6474
void test(queue &q) {
6575
// total number of M dimension matrix elements for the "Big matrix".
6676
constexpr auto Big_M = Sub_Tiles_M * M;
@@ -97,7 +107,8 @@ void test(queue &q) {
97107
accessor<bfloat16, 1, access::mode::write, target::device> accA(bufA,
98108
cgh);
99109

100-
cgh.parallel_for<KernelName<Tm, Tc, class copyA, M, K, N>>(
110+
cgh.parallel_for<KernelName<Tm, Tc, class copyA, M, K, N, layout_A,
111+
layout_B, layout_C>>(
101112
range<1>(Big_M * Big_K), [=](item<1> item) {
102113
auto i = item.get_linear_id();
103114
accA[i] = 0.1f * (i % 10);
@@ -107,7 +118,8 @@ void test(queue &q) {
107118
accessor<bfloat16, 1, access::mode::write, target::device> accB(bufB,
108119
cgh);
109120

110-
cgh.parallel_for<KernelName<Tm, Tc, class copyB, M, K, N>>(
121+
cgh.parallel_for<KernelName<Tm, Tc, class copyB, M, K, N, layout_A,
122+
layout_B, layout_C>>(
111123
range<1>(Big_K * Big_N), [=](item<1> item) {
112124
auto i = item.get_linear_id();
113125
accB[i] = 0.1f * (i % 10);
@@ -130,41 +142,55 @@ void test(queue &q) {
130142
range<2> GlobalRange = {Sub_Tiles_M,
131143
Sub_Tiles_N * N_THREADS_PER_MATRIX_OP};
132144

133-
cgh.parallel_for<KernelName<Tm, Tc, Td, M, K, N>>(
145+
cgh.parallel_for<
146+
KernelName<Tm, Tc, Td, M, K, N, layout_A, layout_B, layout_C>>(
134147
nd_range<2>(GlobalRange, LocalRange), [=](nd_item<2> item) {
135148
sycl::sub_group sg = item.get_sub_group();
136149
// row id of current submatrix of BIG C matrix
137150
const auto m = item.get_group().get_group_id()[0];
138151
// column id of current submatrix of BIG C matrix
139152
const auto n = item.get_group().get_group_id()[1];
140153

141-
joint_matrix<sycl::sub_group, T3, use::a, M, K, layout::row_major>
142-
sub_a;
143-
joint_matrix<sycl::sub_group, T3, use::b, K, N, layout::row_major>
144-
sub_b;
154+
joint_matrix<sycl::sub_group, T3, use::a, M, K, layout_A> sub_a;
155+
joint_matrix<sycl::sub_group, T3, use::b, K, N, layout_B> sub_b;
145156
joint_matrix<sycl::sub_group, std::remove_const_t<Tc>,
146157
use::accumulator, M, N>
147158
sub_c;
148159
joint_matrix<sycl::sub_group, Td, use::accumulator, M, N> sub_d;
160+
auto stride_C = layout_C == layout::row_major ? Big_N : Big_M;
161+
auto load_stride_C = layout_C == layout::row_major
162+
? (m * M) * Big_N + n * N
163+
: (m * M) + n * N * Big_M;
149164

150165
joint_matrix_load(
151166
sg, sub_c,
152167
accC.template get_multi_ptr<access::decorated::no>() +
153-
(m * M) * Big_N + n * N,
154-
Big_N, layout::row_major);
168+
load_stride_C,
169+
stride_C, layout_C);
170+
171+
auto stride_A = layout_A == layout::row_major ? Big_K : Big_M;
172+
auto stride_B = layout_B == layout::row_major ? Big_N : Big_K;
173+
155174
// k = row/col id of current submatrix of BIG A/B matrices
156175
for (int k = 0; k < Sub_Tiles_K; k++) {
176+
auto load_stride_A = layout_A == layout::row_major
177+
? (k * K) + (m * M * Big_K)
178+
: (k * K * Big_M) + (m * M);
179+
auto load_stride_B = layout_B == layout::row_major
180+
? (k * K * Big_N) + (n * N)
181+
: (k * K) + (n * N * Big_K);
182+
157183
joint_matrix_load(
158184
sg, sub_a,
159185
accA.template get_multi_ptr<access::decorated::no>() +
160-
(k * K) + (m * M * Big_K),
161-
Big_K);
186+
load_stride_A,
187+
stride_A);
162188

163189
joint_matrix_load(
164190
sg, sub_b,
165191
accB.template get_multi_ptr<access::decorated::no>() +
166-
(k * K * Big_N) + (n * N),
167-
Big_N);
192+
load_stride_B,
193+
stride_B);
168194

169195
// round values to correct precision if using tf32
170196
if constexpr (std::is_same<T3, precision::tf32>::value) {
@@ -174,27 +200,32 @@ void test(queue &q) {
174200
}
175201

176202
joint_matrix_mad(sg, sub_d, sub_a, sub_b, sub_c);
203+
joint_matrix_copy(sg, sub_d, sub_c);
177204
}
178205
joint_matrix_store(
179206
sg, sub_d,
180207
accD.template get_multi_ptr<access::decorated::no>() +
181-
(m * M) * Big_N + n * N,
182-
Big_N, layout::row_major);
208+
load_stride_C,
209+
stride_C, layout_C);
183210
});
184211
});
185212
q.wait();
186213
}
187214

188215
for (int m = 0; m < Big_M; m++) {
189216
for (int n = 0; n < Big_N; n++) {
217+
auto index_D =
218+
layout_C == layout::row_major ? m * Big_N + n : m + n * Big_M;
190219
if constexpr (std::is_same<std::remove_const_t<Tm>, bfloat16>::value) {
191-
auto res_device = matrix_ref_mn<Big_N, Big_K>(m, n, A, B, C);
192-
assert(fabs(2 * (D[m * Big_N + n] - res_device)) /
193-
(D[m * Big_N + n] + res_device) <
220+
auto res_device =
221+
matrix_ref_mn<Big_N, Big_K, Big_M, layout_A, layout_B>(m, n, A, B,
222+
C);
223+
assert(fabs(2 * (D[index_D] - res_device)) / (D[index_D] + res_device) <
194224
bf16_eps * 2);
195225
} else {
196-
assert(
197-
(D[m * Big_N + n] == matrix_ref_mn<Big_N, Big_K>(m, n, A, B, C)));
226+
assert((D[index_D] ==
227+
matrix_ref_mn<Big_N, Big_K, Big_M, layout_A, layout_B>(m, n, A,
228+
B, C)));
198229
}
199230
}
200231
}

sycl/test-e2e/Matrix/joint_matrix_tensorcores_sm70.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,23 @@ int main() {
8080
test<const half, const half, float, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N,
8181
32, 16, 8>(Q);
8282

83+
// test different layout combinations for one case
84+
85+
test<const half, const half, float, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N,
86+
32, 16, 8, layout::row_major, layout::row_major, layout::col_major>(Q);
87+
test<const half, const half, float, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N,
88+
32, 16, 8, layout::row_major, layout::col_major, layout::row_major>(Q);
89+
test<const half, const half, float, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N,
90+
32, 16, 8, layout::col_major, layout::row_major, layout::row_major>(Q);
91+
test<const half, const half, float, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N,
92+
32, 16, 8, layout::col_major, layout::col_major, layout::row_major>(Q);
93+
94+
// joint_matrix_apply tests
95+
8396
auto apply_add = [](auto &x) { x = x + 2; };
8497
float D[MATRIX_M][MATRIX_N];
8598
big_matrix<float, MATRIX_M, MATRIX_N> MD_f((float *)&D);
8699

87-
// joint_matrix_apply tests
88-
89100
matrix_verify_lambda<half, float, M, 16, N>(Q, MD_f, 0.0, apply_add);
90101
}
91102

sycl/test-e2e/Matrix/joint_matrix_tensorcores_sm72.cpp

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,28 @@ int main() {
5050
test<const uint8_t, const int32_t, int32_t, SUB_TILES_M, SUB_TILES_K,
5151
SUB_TILES_N, 32, 16, 8>(Q);
5252

53+
// test different layout combinations for one case
54+
55+
test<const uint8_t, const int32_t, int32_t, SUB_TILES_M, SUB_TILES_K,
56+
SUB_TILES_N, 32, 16, 8, layout::row_major, layout::row_major,
57+
layout::col_major>(Q);
58+
test<const uint8_t, const int32_t, int32_t, SUB_TILES_M, SUB_TILES_K,
59+
SUB_TILES_N, 32, 16, 8, layout::col_major, layout::row_major,
60+
layout::row_major>(Q);
61+
test<const uint8_t, const int32_t, int32_t, SUB_TILES_M, SUB_TILES_K,
62+
SUB_TILES_N, 32, 16, 8, layout::row_major, layout::col_major,
63+
layout::row_major>(Q);
64+
test<const uint8_t, const int32_t, int32_t, SUB_TILES_M, SUB_TILES_K,
65+
SUB_TILES_N, 32, 16, 8, layout::col_major, layout::col_major,
66+
layout::row_major>(Q);
67+
68+
// joint_matrix_apply tests
69+
5370
auto apply_add = [](auto &x) { x = x + 2; };
5471

5572
int32_t D_i[MATRIX_M][MATRIX_N];
5673
big_matrix<int32_t, MATRIX_M, MATRIX_N> MD_i((int32_t *)&D_i);
5774

58-
// joint_matrix_apply tests
59-
6075
matrix_verify_lambda<uint8_t, int32_t, M, 16, N>(Q, MD_i, 0, apply_add);
6176
matrix_verify_lambda<int8_t, int32_t, M, 16, N>(Q, MD_i, 0, apply_add);
6277
}

sycl/test-e2e/Matrix/joint_matrix_tensorcores_sm80.cpp

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,28 @@ int main() {
4343

4444
// A/B tf32
4545
test<float, float, float, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 16, 8, 16,
46+
layout::row_major, layout::row_major, layout::row_major,
4647
precision::tf32>(Q);
4748
test<const float, const float, float, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N,
48-
16, 8, 16, precision::tf32>(Q);
49+
16, 8, 16, layout::row_major, layout::row_major, layout::row_major,
50+
precision::tf32>(Q);
51+
52+
// test different layout combinations for one case
53+
54+
test<const bfloat16, const float, float, SUB_TILES_M, SUB_TILES_K,
55+
SUB_TILES_N, 8, 16, 32, layout::row_major, layout::col_major,
56+
layout::row_major>(Q);
57+
test<const bfloat16, const float, float, SUB_TILES_M, SUB_TILES_K,
58+
SUB_TILES_N, 8, 16, 32, layout::col_major, layout::row_major,
59+
layout::row_major>(Q);
60+
test<const bfloat16, const float, float, SUB_TILES_M, SUB_TILES_K,
61+
SUB_TILES_N, 8, 16, 32, layout::col_major, layout::col_major,
62+
layout::row_major>(Q);
63+
test<const bfloat16, const float, float, SUB_TILES_M, SUB_TILES_K,
64+
SUB_TILES_N, 8, 16, 32, layout::col_major, layout::col_major,
65+
layout::col_major>(Q);
66+
67+
// joint_matrix_apply tests
4968

5069
float D[MATRIX_M][MATRIX_N];
5170
big_matrix<float, MATRIX_M, MATRIX_N> MD_f((float *)&D);
@@ -54,7 +73,6 @@ int main() {
5473
big_matrix<double, 8 * nWGperDim, 8 * nWGperDim> MD_d((double *)&D_d);
5574
auto apply_add = [](auto &x) { x = x + 2; };
5675

57-
// joint_matrix_apply tests
5876
matrix_verify_lambda<bfloat16, float, 16, 16, 16>(Q, MD_f, 0.0, apply_add);
5977

6078
matrix_verify_lambda<double, double, 8, 4, 8>(Q, MD_d, -60.0, apply_add);

0 commit comments

Comments
 (0)