@@ -58,7 +58,7 @@ uint16_t make_bf16(float x) {
58
58
return (uint16_t )*res;
59
59
}
60
60
61
- template <typename T1, typename T2, size_t Big_N, size_t Big_K >
61
+ template <size_t Big_N, size_t Big_K, typename T1, typename T2 >
62
62
T2 matrix_ref_mn (const int &m, const int &n, T1 *A, T1 *B, T2 *C) {
63
63
T2 res = C[m * Big_N + n];
64
64
@@ -80,7 +80,8 @@ T2 matrix_ref_mn(const int &m, const int &n, T1 *A, T1 *B, T2 *C) {
80
80
}
81
81
82
82
template <typename T1, typename T2, size_t Sub_Tiles_M, size_t Sub_Tiles_K,
83
- size_t Sub_Tiles_N, size_t M, size_t K, size_t N, typename T3 = T1>
83
+ size_t Sub_Tiles_N, size_t M, size_t K, size_t N,
84
+ typename T3 = std::remove_const_t <T1>>
84
85
void test (queue &q) {
85
86
86
87
constexpr auto Big_M =
@@ -93,25 +94,26 @@ void test(queue &q) {
93
94
Sub_Tiles_K *
94
95
K; // total number of K dimension matrix elements for the "Big matrix".
95
96
96
- T1 A[Big_M * Big_K];
97
- T1 B[Big_K * Big_N];
98
- T2 C[Big_M * Big_N];
99
- T2 D[Big_M * Big_N];
97
+ std:: remove_const_t <T1> A[Big_M * Big_K];
98
+ std:: remove_const_t <T1> B[Big_K * Big_N];
99
+ std:: remove_const_t <T2> C[Big_M * Big_N];
100
+ std:: remove_const_t <T2> D[Big_M * Big_N];
100
101
101
102
for (int i = 0 ; i < Big_M * Big_N; i++) {
102
103
C[i] = 1 ;
103
104
D[i] = 0 ;
104
105
}
105
106
106
- if constexpr (std::is_same<T1 , uint16_t >::value) {
107
+ if constexpr (std::is_same<std:: remove_const_t <T1> , uint16_t >::value) {
107
108
for (int i = 0 ; i < Big_M * Big_K; i++) {
108
109
A[i] = make_bf16 (0 .1f * (i % 10 ));
109
110
}
110
111
111
112
for (int i = 0 ; i < Big_K * Big_N; i++) {
112
113
B[i] = make_bf16 (0 .1f * (i % 10 ));
113
114
}
114
- } else if constexpr (!std::is_same<T1, bfloat16>::value) {
115
+ } else if constexpr (!std::is_same<std::remove_const_t <T1>,
116
+ bfloat16>::value) {
115
117
for (int i = 0 ; i < Big_M * Big_K; i++) {
116
118
A[i] = i % 100 ;
117
119
}
@@ -121,41 +123,43 @@ void test(queue &q) {
121
123
}
122
124
}
123
125
{
124
- buffer<T1, 1 > bufA (A, range<1 >(Big_M * Big_K));
125
- buffer<T1, 1 > bufB (B, range<1 >(Big_K * Big_N));
126
- buffer<T2, 1 > bufC (C, range<1 >(Big_M * Big_N));
127
- buffer<T2, 1 > bufD (D, range<1 >(Big_M * Big_N));
126
+ if constexpr (std::is_same<std::remove_const_t <T1>, bfloat16>::value) {
128
127
129
- // currently bfloat16 has to be initialized on device
130
- if constexpr (std::is_same<T1, bfloat16>::value) {
128
+ buffer< bfloat16, 1 > bufA (A, range< 1 >(Big_M * Big_K));
129
+ buffer<bfloat16, 1 > bufB (B, range< 1 >(Big_K * Big_N));
131
130
q.submit ([&](handler &cgh) {
132
- accessor<T1 , 1 , access::mode::read_write , target::device> accA (bufA,
133
- cgh);
131
+ accessor<bfloat16 , 1 , access::mode::write , target::device> accA (bufA,
132
+ cgh);
134
133
135
- cgh.parallel_for <KernelName<bfloat16 , class copyA , M, K, N>>(
134
+ cgh.parallel_for <KernelName<T1 , class copyA , M, K, N>>(
136
135
range<1 >(Big_M * Big_K), [=](item<1 > item) {
137
136
auto i = item.get_linear_id ();
138
137
accA[i] = 0 .1f * (i % 10 );
139
138
});
140
139
});
141
-
142
140
q.submit ([&](handler &cgh) {
143
- accessor<T1 , 1 , access::mode::read_write , target::device> accB (bufB,
144
- cgh);
141
+ accessor<bfloat16 , 1 , access::mode::write , target::device> accB (bufB,
142
+ cgh);
145
143
146
- cgh.parallel_for <KernelName<bfloat16 , class copyB , M, K, N>>(
144
+ cgh.parallel_for <KernelName<T1 , class copyB , M, K, N>>(
147
145
range<1 >(Big_K * Big_N), [=](item<1 > item) {
148
146
auto i = item.get_linear_id ();
149
147
accB[i] = 0 .1f * (i % 10 );
150
148
});
151
149
});
152
150
}
153
151
152
+ buffer<T1, 1 > bufA (A, range<1 >(Big_M * Big_K));
153
+ buffer<T1, 1 > bufB (B, range<1 >(Big_K * Big_N));
154
+ buffer<T2, 1 > bufC (C, range<1 >(Big_M * Big_N));
155
+ buffer<std::remove_const_t <T2>, 1 > bufD (D, range<1 >(Big_M * Big_N));
156
+
154
157
q.submit ([&](handler &cgh) {
155
- accessor<T1, 1 , access::mode::read_write, target::device> accA (bufA, cgh);
156
- accessor<T1, 1 , access::mode::read_write, target::device> accB (bufB, cgh);
157
- accessor<T2, 1 , access::mode::read_write, target::device> accC (bufC, cgh);
158
- accessor<T2, 1 , access::mode::read_write, target::device> accD (bufD, cgh);
158
+ accessor<T1, 1 , access::mode::read, target::device> accA (bufA, cgh);
159
+ accessor<T1, 1 , access::mode::read, target::device> accB (bufB, cgh);
160
+ accessor<T2, 1 , access::mode::read, target::device> accC (bufC, cgh);
161
+ accessor<std::remove_const_t <T2>, 1 , access::mode::write, target::device>
162
+ accD (bufD, cgh);
159
163
160
164
range<2 > LocalRange = {1 , N_THREADS_PER_MATRIX_OP};
161
165
range<2 > GlobalRange = {Sub_Tiles_M,
@@ -177,7 +181,7 @@ void test(queue &q) {
177
181
joint_matrix<T3, matrix_use::b, K, N, matrix_layout::row_major>
178
182
sub_b;
179
183
180
- joint_matrix<T2 , matrix_use::accumulator, M, N,
184
+ joint_matrix<std:: remove_const_t <T2> , matrix_use::accumulator, M, N,
181
185
matrix_layout::row_major>
182
186
sub_c;
183
187
@@ -216,14 +220,14 @@ void test(queue &q) {
216
220
217
221
for (int m = 0 ; m < Big_M; m++) {
218
222
for (int n = 0 ; n < Big_N; n++) {
219
- if constexpr (std::is_same<T1 , bfloat16>::value) {
220
- auto res_device = matrix_ref_mn<T1, T2, Big_N, Big_K>(m, n, A, B, C);
223
+ if constexpr (std::is_same<std:: remove_const_t <T1> , bfloat16>::value) {
224
+ auto res_device = matrix_ref_mn<Big_N, Big_K>(m, n, A, B, C);
221
225
assert (fabs (2 * (D[m * Big_N + n] - res_device)) /
222
226
(D[m * Big_N + n] + res_device) <
223
227
bf16_eps * 2 );
224
228
} else {
225
- assert ((D[m * Big_N + n] ==
226
- matrix_ref_mn<T1, T2, Big_N, Big_K>(m, n, A, B, C)));
229
+ assert (
230
+ (D[m * Big_N + n] == matrix_ref_mn< Big_N, Big_K>(m, n, A, B, C)));
227
231
}
228
232
}
229
233
}
@@ -241,36 +245,82 @@ int main() {
241
245
test<half, float , SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 8 , 16 , 32 >(Q);
242
246
test<half, float , SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 32 , 16 , 8 >(Q);
243
247
248
+ test<const half, const float , SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 16 , 16 ,
249
+ 16 >(Q);
250
+ test<const half, const float , SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 8 , 16 ,
251
+ 32 >(Q);
252
+ test<const half, const float , SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 32 , 16 ,
253
+ 8 >(Q);
254
+
244
255
// A/B/Accumulator half
245
256
test<half, half, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 16 , 16 , 16 >(Q);
246
257
test<half, half, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 8 , 16 , 32 >(Q);
247
258
test<half, half, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 32 , 16 , 8 >(Q);
259
+
260
+ test<const half, const half, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 16 , 16 ,
261
+ 16 >(Q);
262
+ test<const half, const half, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 8 , 16 ,
263
+ 32 >(Q);
264
+ test<const half, const half, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 32 , 16 ,
265
+ 8 >(Q);
248
266
}
249
267
if (computeCapability >= 7.2 ) {
250
268
test<int8_t , int32_t , SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 16 , 16 , 16 >(Q);
251
269
test<int8_t , int32_t , SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 8 , 16 , 32 >(Q);
252
270
test<int8_t , int32_t , SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 32 , 16 , 8 >(Q);
253
271
272
+ test<const int8_t , const int32_t , SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 16 ,
273
+ 16 , 16 >(Q);
274
+ test<const int8_t , const int32_t , SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 8 ,
275
+ 16 , 32 >(Q);
276
+ test<const int8_t , const int32_t , SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 32 ,
277
+ 16 , 8 >(Q);
278
+
254
279
test<uint8_t , int32_t , SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 16 , 16 , 16 >(
255
280
Q);
256
281
test<uint8_t , int32_t , SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 8 , 16 , 32 >(Q);
257
282
test<uint8_t , int32_t , SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 32 , 16 , 8 >(Q);
283
+
284
+ test<const uint8_t , const int32_t , SUB_TILES_M, SUB_TILES_K, SUB_TILES_N,
285
+ 16 , 16 , 16 >(Q);
286
+ test<const uint8_t , const int32_t , SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 8 ,
287
+ 16 , 32 >(Q);
288
+ test<const uint8_t , const int32_t , SUB_TILES_M, SUB_TILES_K, SUB_TILES_N,
289
+ 32 , 16 , 8 >(Q);
258
290
}
259
291
if (computeCapability >= 8.0 ) {
260
292
test<double , double , SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 8 , 4 , 8 >(Q);
293
+ test<const double , const double , SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 8 ,
294
+ 4 , 8 >(Q);
261
295
262
296
// A/B bfloat16 using storage type
263
297
test<uint16_t , float , SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 16 , 16 , 16 >(Q);
264
298
test<uint16_t , float , SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 8 , 16 , 32 >(Q);
265
299
test<uint16_t , float , SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 32 , 16 , 8 >(Q);
266
300
301
+ test<const uint16_t , const float , SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 16 ,
302
+ 16 , 16 >(Q);
303
+ test<const uint16_t , const float , SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 8 ,
304
+ 16 , 32 >(Q);
305
+ test<const uint16_t , const float , SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 32 ,
306
+ 16 , 8 >(Q);
307
+
267
308
test<bfloat16, float , SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 16 , 16 , 16 >(Q);
268
309
test<bfloat16, float , SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 8 , 16 , 32 >(Q);
269
310
test<bfloat16, float , SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 32 , 16 , 8 >(Q);
270
311
312
+ test<const bfloat16, const float , SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 16 ,
313
+ 16 , 16 >(Q);
314
+ test<const bfloat16, const float , SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 8 ,
315
+ 16 , 32 >(Q);
316
+ test<const bfloat16, const float , SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 32 ,
317
+ 16 , 8 >(Q);
318
+
271
319
// A/B tf32
272
320
test<float , float , SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 16 , 8 , 16 ,
273
321
precision::tf32>(Q);
322
+ test<const float , const float , SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 16 , 8 ,
323
+ 16 , precision::tf32>(Q);
274
324
}
275
325
return 0 ;
276
326
};
0 commit comments