Skip to content
This repository was archived by the owner on Mar 28, 2023. It is now read-only.

Commit 8b8ecdf

Browse files
authored
[SYCL][CUDA] Tensor Cores - Add test cases for const T (#1280)
This tests new cases made possible in intel/llvm#6532, such that A,B, Accumulator accessors that are loaded to registers in joint_matrix can be of type const T.
1 parent 1f72c41 commit 8b8ecdf

File tree

1 file changed

+80
-30
lines changed

1 file changed

+80
-30
lines changed

SYCL/Matrix/joint_matrix_tensorcore.cpp

Lines changed: 80 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ uint16_t make_bf16(float x) {
5858
return (uint16_t)*res;
5959
}
6060

61-
template <typename T1, typename T2, size_t Big_N, size_t Big_K>
61+
template <size_t Big_N, size_t Big_K, typename T1, typename T2>
6262
T2 matrix_ref_mn(const int &m, const int &n, T1 *A, T1 *B, T2 *C) {
6363
T2 res = C[m * Big_N + n];
6464

@@ -80,7 +80,8 @@ T2 matrix_ref_mn(const int &m, const int &n, T1 *A, T1 *B, T2 *C) {
8080
}
8181

8282
template <typename T1, typename T2, size_t Sub_Tiles_M, size_t Sub_Tiles_K,
83-
size_t Sub_Tiles_N, size_t M, size_t K, size_t N, typename T3 = T1>
83+
size_t Sub_Tiles_N, size_t M, size_t K, size_t N,
84+
typename T3 = std::remove_const_t<T1>>
8485
void test(queue &q) {
8586

8687
constexpr auto Big_M =
@@ -93,25 +94,26 @@ void test(queue &q) {
9394
Sub_Tiles_K *
9495
K; // total number of K dimension matrix elements for the "Big matrix".
9596

96-
T1 A[Big_M * Big_K];
97-
T1 B[Big_K * Big_N];
98-
T2 C[Big_M * Big_N];
99-
T2 D[Big_M * Big_N];
97+
std::remove_const_t<T1> A[Big_M * Big_K];
98+
std::remove_const_t<T1> B[Big_K * Big_N];
99+
std::remove_const_t<T2> C[Big_M * Big_N];
100+
std::remove_const_t<T2> D[Big_M * Big_N];
100101

101102
for (int i = 0; i < Big_M * Big_N; i++) {
102103
C[i] = 1;
103104
D[i] = 0;
104105
}
105106

106-
if constexpr (std::is_same<T1, uint16_t>::value) {
107+
if constexpr (std::is_same<std::remove_const_t<T1>, uint16_t>::value) {
107108
for (int i = 0; i < Big_M * Big_K; i++) {
108109
A[i] = make_bf16(0.1f * (i % 10));
109110
}
110111

111112
for (int i = 0; i < Big_K * Big_N; i++) {
112113
B[i] = make_bf16(0.1f * (i % 10));
113114
}
114-
} else if constexpr (!std::is_same<T1, bfloat16>::value) {
115+
} else if constexpr (!std::is_same<std::remove_const_t<T1>,
116+
bfloat16>::value) {
115117
for (int i = 0; i < Big_M * Big_K; i++) {
116118
A[i] = i % 100;
117119
}
@@ -121,41 +123,43 @@ void test(queue &q) {
121123
}
122124
}
123125
{
124-
buffer<T1, 1> bufA(A, range<1>(Big_M * Big_K));
125-
buffer<T1, 1> bufB(B, range<1>(Big_K * Big_N));
126-
buffer<T2, 1> bufC(C, range<1>(Big_M * Big_N));
127-
buffer<T2, 1> bufD(D, range<1>(Big_M * Big_N));
126+
if constexpr (std::is_same<std::remove_const_t<T1>, bfloat16>::value) {
128127

129-
// currently bfloat16 has to be initialized on device
130-
if constexpr (std::is_same<T1, bfloat16>::value) {
128+
buffer<bfloat16, 1> bufA(A, range<1>(Big_M * Big_K));
129+
buffer<bfloat16, 1> bufB(B, range<1>(Big_K * Big_N));
131130
q.submit([&](handler &cgh) {
132-
accessor<T1, 1, access::mode::read_write, target::device> accA(bufA,
133-
cgh);
131+
accessor<bfloat16, 1, access::mode::write, target::device> accA(bufA,
132+
cgh);
134133

135-
cgh.parallel_for<KernelName<bfloat16, class copyA, M, K, N>>(
134+
cgh.parallel_for<KernelName<T1, class copyA, M, K, N>>(
136135
range<1>(Big_M * Big_K), [=](item<1> item) {
137136
auto i = item.get_linear_id();
138137
accA[i] = 0.1f * (i % 10);
139138
});
140139
});
141-
142140
q.submit([&](handler &cgh) {
143-
accessor<T1, 1, access::mode::read_write, target::device> accB(bufB,
144-
cgh);
141+
accessor<bfloat16, 1, access::mode::write, target::device> accB(bufB,
142+
cgh);
145143

146-
cgh.parallel_for<KernelName<bfloat16, class copyB, M, K, N>>(
144+
cgh.parallel_for<KernelName<T1, class copyB, M, K, N>>(
147145
range<1>(Big_K * Big_N), [=](item<1> item) {
148146
auto i = item.get_linear_id();
149147
accB[i] = 0.1f * (i % 10);
150148
});
151149
});
152150
}
153151

152+
buffer<T1, 1> bufA(A, range<1>(Big_M * Big_K));
153+
buffer<T1, 1> bufB(B, range<1>(Big_K * Big_N));
154+
buffer<T2, 1> bufC(C, range<1>(Big_M * Big_N));
155+
buffer<std::remove_const_t<T2>, 1> bufD(D, range<1>(Big_M * Big_N));
156+
154157
q.submit([&](handler &cgh) {
155-
accessor<T1, 1, access::mode::read_write, target::device> accA(bufA, cgh);
156-
accessor<T1, 1, access::mode::read_write, target::device> accB(bufB, cgh);
157-
accessor<T2, 1, access::mode::read_write, target::device> accC(bufC, cgh);
158-
accessor<T2, 1, access::mode::read_write, target::device> accD(bufD, cgh);
158+
accessor<T1, 1, access::mode::read, target::device> accA(bufA, cgh);
159+
accessor<T1, 1, access::mode::read, target::device> accB(bufB, cgh);
160+
accessor<T2, 1, access::mode::read, target::device> accC(bufC, cgh);
161+
accessor<std::remove_const_t<T2>, 1, access::mode::write, target::device>
162+
accD(bufD, cgh);
159163

160164
range<2> LocalRange = {1, N_THREADS_PER_MATRIX_OP};
161165
range<2> GlobalRange = {Sub_Tiles_M,
@@ -177,7 +181,7 @@ void test(queue &q) {
177181
joint_matrix<T3, matrix_use::b, K, N, matrix_layout::row_major>
178182
sub_b;
179183

180-
joint_matrix<T2, matrix_use::accumulator, M, N,
184+
joint_matrix<std::remove_const_t<T2>, matrix_use::accumulator, M, N,
181185
matrix_layout::row_major>
182186
sub_c;
183187

@@ -216,14 +220,14 @@ void test(queue &q) {
216220

217221
for (int m = 0; m < Big_M; m++) {
218222
for (int n = 0; n < Big_N; n++) {
219-
if constexpr (std::is_same<T1, bfloat16>::value) {
220-
auto res_device = matrix_ref_mn<T1, T2, Big_N, Big_K>(m, n, A, B, C);
223+
if constexpr (std::is_same<std::remove_const_t<T1>, bfloat16>::value) {
224+
auto res_device = matrix_ref_mn<Big_N, Big_K>(m, n, A, B, C);
221225
assert(fabs(2 * (D[m * Big_N + n] - res_device)) /
222226
(D[m * Big_N + n] + res_device) <
223227
bf16_eps * 2);
224228
} else {
225-
assert((D[m * Big_N + n] ==
226-
matrix_ref_mn<T1, T2, Big_N, Big_K>(m, n, A, B, C)));
229+
assert(
230+
(D[m * Big_N + n] == matrix_ref_mn<Big_N, Big_K>(m, n, A, B, C)));
227231
}
228232
}
229233
}
@@ -241,36 +245,82 @@ int main() {
241245
test<half, float, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 8, 16, 32>(Q);
242246
test<half, float, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 32, 16, 8>(Q);
243247

248+
test<const half, const float, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 16, 16,
249+
16>(Q);
250+
test<const half, const float, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 8, 16,
251+
32>(Q);
252+
test<const half, const float, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 32, 16,
253+
8>(Q);
254+
244255
// A/B/Accumulator half
245256
test<half, half, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 16, 16, 16>(Q);
246257
test<half, half, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 8, 16, 32>(Q);
247258
test<half, half, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 32, 16, 8>(Q);
259+
260+
test<const half, const half, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 16, 16,
261+
16>(Q);
262+
test<const half, const half, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 8, 16,
263+
32>(Q);
264+
test<const half, const half, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 32, 16,
265+
8>(Q);
248266
}
249267
if (computeCapability >= 7.2) {
250268
test<int8_t, int32_t, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 16, 16, 16>(Q);
251269
test<int8_t, int32_t, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 8, 16, 32>(Q);
252270
test<int8_t, int32_t, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 32, 16, 8>(Q);
253271

272+
test<const int8_t, const int32_t, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 16,
273+
16, 16>(Q);
274+
test<const int8_t, const int32_t, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 8,
275+
16, 32>(Q);
276+
test<const int8_t, const int32_t, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 32,
277+
16, 8>(Q);
278+
254279
test<uint8_t, int32_t, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 16, 16, 16>(
255280
Q);
256281
test<uint8_t, int32_t, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 8, 16, 32>(Q);
257282
test<uint8_t, int32_t, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 32, 16, 8>(Q);
283+
284+
test<const uint8_t, const int32_t, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N,
285+
16, 16, 16>(Q);
286+
test<const uint8_t, const int32_t, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 8,
287+
16, 32>(Q);
288+
test<const uint8_t, const int32_t, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N,
289+
32, 16, 8>(Q);
258290
}
259291
if (computeCapability >= 8.0) {
260292
test<double, double, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 8, 4, 8>(Q);
293+
test<const double, const double, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 8,
294+
4, 8>(Q);
261295

262296
// A/B bfloat16 using storage type
263297
test<uint16_t, float, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 16, 16, 16>(Q);
264298
test<uint16_t, float, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 8, 16, 32>(Q);
265299
test<uint16_t, float, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 32, 16, 8>(Q);
266300

301+
test<const uint16_t, const float, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 16,
302+
16, 16>(Q);
303+
test<const uint16_t, const float, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 8,
304+
16, 32>(Q);
305+
test<const uint16_t, const float, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 32,
306+
16, 8>(Q);
307+
267308
test<bfloat16, float, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 16, 16, 16>(Q);
268309
test<bfloat16, float, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 8, 16, 32>(Q);
269310
test<bfloat16, float, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 32, 16, 8>(Q);
270311

312+
test<const bfloat16, const float, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 16,
313+
16, 16>(Q);
314+
test<const bfloat16, const float, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 8,
315+
16, 32>(Q);
316+
test<const bfloat16, const float, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 32,
317+
16, 8>(Q);
318+
271319
// A/B tf32
272320
test<float, float, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 16, 8, 16,
273321
precision::tf32>(Q);
322+
test<const float, const float, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 16, 8,
323+
16, precision::tf32>(Q);
274324
}
275325
return 0;
276326
};

0 commit comments

Comments
 (0)