@@ -56,25 +56,25 @@ enum class matrix_type {
56
56
enum class scope_t { sub_group, work_group };
57
57
58
58
template <tpu u, typename Ta = void , typename Tb = void , typename Tc = void ,
59
- int M = 0 , int N = 0 , int K = 0 , typename Enabled = void >
59
+ int sM = 0 , int sN = 0 , int sK = 0 , typename Enabled = void >
60
60
struct tpu_params ;
61
61
62
62
#if __cplusplus >= 201703L
63
63
template <typename Ta, typename Tb, typename Tc>
64
- constexpr bool is_combination_valid_amx (int M , int N , int K ) {
64
+ constexpr bool is_combination_valid_amx (int sM , int sN , int sK ) {
65
65
// is_same_v is a C++17 feature
66
66
if ((std::is_same_v<Ta, int8_t > && std::is_same_v<Tb, int8_t > &&
67
- std::is_same_v<Tc, int > && M <= 16 && N <= 16 && K <= 64 ) ||
67
+ std::is_same_v<Tc, int > && sM <= 16 && sN <= 16 && sK <= 64 ) ||
68
68
(std::is_same_v<Ta, uint8_t > && std::is_same_v<Tb, uint8_t > &&
69
- std::is_same_v<Tc, int > && M <= 16 && N <= 16 && K <= 64 ) ||
69
+ std::is_same_v<Tc, int > && sM <= 16 && sN <= 16 && sK <= 64 ) ||
70
70
(std::is_same_v<Ta, int8_t > && std::is_same_v<Tb, uint8_t > &&
71
- std::is_same_v<Tc, int > && M <= 16 && N <= 16 && K <= 64 ) ||
71
+ std::is_same_v<Tc, int > && sM <= 16 && sN <= 16 && sK <= 64 ) ||
72
72
(std::is_same_v<Ta, uint8_t > && std::is_same_v<Tb, int8_t > &&
73
- std::is_same_v<Tc, int > && M <= 16 && N <= 16 && K <= 64 ) ||
73
+ std::is_same_v<Tc, int > && sM <= 16 && sN <= 16 && sK <= 64 ) ||
74
74
// bf16
75
75
(std::is_same_v<Ta, unsigned short > &&
76
76
std::is_same_v<Tb, unsigned short > && std::is_same_v<Tc, float > &&
77
- M <= 16 && N <= 16 && K <= 32 ))
77
+ sM <= 16 && sN <= 16 && sK <= 32 ))
78
78
return true ;
79
79
else
80
80
return false ;
@@ -100,11 +100,11 @@ constexpr bool are_types_valid_amx() {
100
100
101
101
// General query:
102
102
// types are not given, no default sizes and no implicit matrix construction
103
- template <int M , int N , int K >
104
- struct tpu_params <tpu::amx, void , void , void , M, N, K > {
105
- static constexpr std::size_t defaultM = -1 ; // depends on the type
106
- static constexpr std::size_t defaultN = -1 ;
107
- static constexpr std::size_t defaultK = -1 ;
103
+ template <int sM , int sN , int sK >
104
+ struct tpu_params <tpu::amx, void , void , void , sM , sN , sK > {
105
+ static constexpr std::size_t M = -1 ; // depends on the type
106
+ static constexpr std::size_t N = -1 ;
107
+ static constexpr std::size_t K = -1 ;
108
108
109
109
bool dynamic_p = false ; // should be true in future implementations because
110
110
// AMX hardware supports dynamic sizes
@@ -116,7 +116,7 @@ struct tpu_params<tpu::amx, void, void, void, M, N, K> {
116
116
uint32_t max_ksize;
117
117
matrix_type atype;
118
118
matrix_type btype;
119
- matrix_type ctype ;
119
+ matrix_type accumulatortype ;
120
120
uint32_t msize;
121
121
uint32_t nsize;
122
122
uint32_t ksize;
@@ -146,19 +146,17 @@ struct tpu_params<tpu::amx, Ta, Tb, Tc, 0, 0, 0,
146
146
" DPC++ code to implement bf16) " );
147
147
148
148
// construct the matrices using the default sizes
149
- static constexpr std::size_t defaultM = 16 ;
150
- static constexpr std::size_t defaultN = 16 ;
151
- static constexpr std::size_t defaultK = ((sizeof (Ta) == 1 ) ? 64 : 32 );
149
+ static constexpr std::size_t M = 16 ;
150
+ static constexpr std::size_t N = 16 ;
151
+ static constexpr std::size_t K = ((sizeof (Ta) == 1 ) ? 64 : 32 );
152
152
153
153
template <typename Group>
154
- using joint_matrix_a =
155
- joint_matrix<Ta, defaultM, defaultK, use::a, layout::row_major, Group>;
154
+ using joint_matrix_a = joint_matrix<Ta, M, K, use::a, layout::unused, Group>;
156
155
template <typename Group>
157
- using joint_matrix_b =
158
- joint_matrix<Tb, defaultK, defaultN, use::b, layout::packed_b, Group>;
156
+ using joint_matrix_b = joint_matrix<Tb, K, N, use::b, layout::unused, Group>;
159
157
template <typename Group>
160
- using joint_matrix_c = joint_matrix<Tc, defaultM, defaultN, use::accumulator,
161
- layout::row_major , Group>;
158
+ using joint_matrix_accumulator =
159
+ joint_matrix<Tc, M, N, use::accumulator, layout::unused , Group>;
162
160
163
161
bool dynamic_p = false ; // should be true in future implementations because
164
162
// AMX hardware supports dynamic sizes
@@ -170,7 +168,7 @@ struct tpu_params<tpu::amx, Ta, Tb, Tc, 0, 0, 0,
170
168
uint32_t max_ksize;
171
169
matrix_type atype;
172
170
matrix_type btype;
173
- matrix_type ctype ;
171
+ matrix_type accumulatortype ;
174
172
uint32_t msize;
175
173
uint32_t nsize;
176
174
uint32_t ksize;
@@ -183,36 +181,34 @@ struct tpu_params<tpu::amx, Ta, Tb, Tc, 0, 0, 0,
183
181
184
182
// Valid or not:
185
183
// Specialization when both types and sizes are given
186
- template <typename Ta, typename Tb, typename Tc, int M , int N , int K >
184
+ template <typename Ta, typename Tb, typename Tc, int sM , int sN , int sK >
187
185
struct tpu_params <
188
- tpu::amx, Ta, Tb, Tc, M, N, K ,
186
+ tpu::amx, Ta, Tb, Tc, sM , sN , sK ,
189
187
typename std::enable_if<(
190
188
!std::is_same_v<Ta, void > && !std::is_same_v<Tb, void > &&
191
- !std::is_same_v<Tc, void > && M != 0 && N != 0 && K != 0 )>::type> {
189
+ !std::is_same_v<Tc, void > && sM != 0 && sN != 0 && sK != 0 )>::type> {
192
190
// Validate that parameters are supported
193
191
static_assert (
194
- (M == 0 && N == 0 && K == 0 ) ||
195
- (is_combination_valid_amx<Ta, Tb, Tc>(M, N, K )),
192
+ (sM == 0 && sN == 0 && sK == 0 ) ||
193
+ (is_combination_valid_amx<Ta, Tb, Tc>(sM , sN , sK )),
196
194
" Invalid parameters for AMX, query valid types and maximum sizes "
197
195
" using: tpu_params<tpu::amx> myparams; and then check out "
198
196
" myparams.combinations array" );
199
197
200
198
// if combination is valid, construct the matrices
201
199
202
- static constexpr std::size_t defaultM = (M != 0 ) ? M : 16 ;
203
- static constexpr std::size_t defaultN = (N != 0 ) ? N : 16 ;
204
- static constexpr std::size_t defaultK =
205
- (K != 0 ) ? K : ((sizeof (Ta) == 1 ) ? 64 : 32 );
200
+ static constexpr std::size_t M = (sM != 0 ) ? sM : 16 ;
201
+ static constexpr std::size_t N = (sN != 0 ) ? sN : 16 ;
202
+ static constexpr std::size_t K =
203
+ (sK != 0 ) ? sK : ((sizeof (Ta) == 1 ) ? 64 : 32 );
206
204
207
205
template <typename Group>
208
- using joint_matrix_a =
209
- joint_matrix<Ta, defaultM, defaultK, use::a, layout::row_major, Group>;
206
+ using joint_matrix_a = joint_matrix<Ta, M, K, use::a, layout::unused, Group>;
210
207
template <typename Group>
211
- using joint_matrix_b =
212
- joint_matrix<Tb, defaultK, defaultN, use::b, layout::packed_b, Group>;
208
+ using joint_matrix_b = joint_matrix<Tb, K, N, use::b, layout::unused, Group>;
213
209
template <typename Group>
214
- using joint_matrix_c = joint_matrix<Tc, defaultM, defaultN, use::accumulator,
215
- layout::row_major , Group>;
210
+ using joint_matrix_accumulator =
211
+ joint_matrix<Tc, M, N, use::accumulator, layout::unused , Group>;
216
212
217
213
bool dynamic_p = false ; // should be true in future implementations
218
214
// because AMX hardware supports dynamic sizes
@@ -226,25 +222,25 @@ struct tpu_params<
226
222
// capabilities of the DPAS hardware.
227
223
228
224
template <typename Ta, typename Tb, typename Tc>
229
- constexpr bool is_combination_valid_dpas (int M , int N , int K ) {
225
+ constexpr bool is_combination_valid_dpas (int sM , int sN , int sK ) {
230
226
if ((std::is_same_v<Ta, int8_t > && std::is_same_v<Tb, int8_t > &&
231
- std::is_same_v<Tc, int > && (M == 1 || M == 2 || M == 4 || M == 8 ) &&
232
- N == 8 && K == 32 ) ||
227
+ std::is_same_v<Tc, int > && (sM == 1 || sM == 2 || sM == 4 || sM == 8 ) &&
228
+ sN == 8 && sK == 32 ) ||
233
229
(std::is_same_v<Ta, int8_t > && std::is_same_v<Tb, uint8_t > &&
234
- std::is_same_v<Tc, int > && (M == 1 || M == 2 || M == 4 || M == 8 ) &&
235
- N == 8 && K == 32 ) ||
230
+ std::is_same_v<Tc, int > && (sM == 1 || sM == 2 || sM == 4 || sM == 8 ) &&
231
+ sN == 8 && sK == 32 ) ||
236
232
(std::is_same_v<Ta, uint8_t > && std::is_same_v<Tb, int8_t > &&
237
- std::is_same_v<Tc, int > && (M == 1 || M == 2 || M == 4 || M == 8 ) &&
238
- N == 8 && K == 32 ) ||
233
+ std::is_same_v<Tc, int > && (sM == 1 || sM == 2 || sM == 4 || sM == 8 ) &&
234
+ sN == 8 && sK == 32 ) ||
239
235
(std::is_same_v<Ta, uint8_t > && std::is_same_v<Tb, uint8_t > &&
240
- std::is_same_v<Tc, int > && (M == 1 || M == 2 || M == 4 || M == 8 ) &&
241
- N == 8 && K == 32 ) ||
236
+ std::is_same_v<Tc, int > && (sM == 1 || sM == 2 || sM == 4 || sM == 8 ) &&
237
+ sN == 8 && sK == 32 ) ||
242
238
(std::is_same_v<Ta, half> && std::is_same_v<Tb, half> &&
243
- std::is_same_v<Tc, float > && (M == 1 || M == 2 || M == 4 || M == 8 ) &&
244
- N == 8 && K == 16 ) ||
239
+ std::is_same_v<Tc, float > &&
240
+ ( sM == 1 || sM == 2 || sM == 4 || sM == 8 ) && sN == 8 && sK == 16 ) ||
245
241
(std::is_same_v<Ta, unsigned short > &&
246
242
std::is_same_v<Tb, unsigned short > && std::is_same_v<Tc, float > &&
247
- (M == 1 || M == 2 || M == 4 || M == 8 ) && N == 8 && K == 16 ))
243
+ (sM == 1 || sM == 2 || sM == 4 || sM == 8 ) && sN == 8 && sK == 16 ))
248
244
return true ;
249
245
else
250
246
return false ;
@@ -272,11 +268,11 @@ constexpr bool are_types_valid_dpas() {
272
268
273
269
// General Query
274
270
// specialization for when types are not given --> no default values
275
- template <int M , int N , int K >
276
- struct tpu_params <tpu::dpas, void , void , void , M, N, K > {
277
- static constexpr std::size_t defaultM = -1 ; // depends on the type
278
- static constexpr std::size_t defaultN = -1 ;
279
- static constexpr std::size_t defaultK = -1 ;
271
+ template <int sM , int sN , int sK >
272
+ struct tpu_params <tpu::dpas, void , void , void , sM , sN , sK > {
273
+ static constexpr std::size_t M = -1 ; // depends on the type
274
+ static constexpr std::size_t N = -1 ;
275
+ static constexpr std::size_t K = -1 ;
280
276
281
277
bool dynamic_p = false ; // no dynamic allocation on the GPU
282
278
uint32_t numtiles = -1 ; // does not apply for DPAS
@@ -288,7 +284,7 @@ struct tpu_params<tpu::dpas, void, void, void, M, N, K> {
288
284
uint32_t max_ksize;
289
285
matrix_type atype;
290
286
matrix_type btype;
291
- matrix_type ctype ;
287
+ matrix_type accumulatortype ;
292
288
uint32_t msize;
293
289
uint32_t nsize;
294
290
uint32_t ksize;
@@ -340,19 +336,17 @@ struct tpu_params<tpu::dpas, Ta, Tb, Tc, 0, 0, 0,
340
336
341
337
// construct the matrices using the default sizes
342
338
343
- static constexpr std::size_t defaultM = 8 ;
344
- static constexpr std::size_t defaultN = 8 ;
345
- static constexpr std::size_t defaultK = ((sizeof (Ta) == 1 ) ? 32 : 16 );
339
+ static constexpr std::size_t M = 8 ;
340
+ static constexpr std::size_t N = 8 ;
341
+ static constexpr std::size_t K = ((sizeof (Ta) == 1 ) ? 32 : 16 );
346
342
347
343
template <typename Group>
348
- using joint_matrix_a =
349
- joint_matrix<Ta, defaultM, defaultK, use::a, layout::row_major, Group>;
344
+ using joint_matrix_a = joint_matrix<Ta, M, K, use::a, layout::unused, Group>;
350
345
template <typename Group>
351
- using joint_matrix_b =
352
- joint_matrix<Tb, defaultK, defaultN, use::b, layout::packed_b, Group>;
346
+ using joint_matrix_b = joint_matrix<Tb, K, N, use::b, layout::unused, Group>;
353
347
template <typename Group>
354
- using joint_matrix_c = joint_matrix<Tc, defaultM, defaultN, use::accumulator,
355
- layout::row_major , Group>;
348
+ using joint_matrix_accumulator =
349
+ joint_matrix<Tc, M, N, use::accumulator, layout::unused , Group>;
356
350
357
351
bool dynamic_p = false ; // no dynamic allocation on the GPU
358
352
uint32_t numtiles = -1 ; // does not apply for DPAS
@@ -363,15 +357,16 @@ struct tpu_params<tpu::dpas, Ta, Tb, Tc, 0, 0, 0,
363
357
uint32_t max_ksize;
364
358
matrix_type atype;
365
359
matrix_type btype;
366
- matrix_type ctype ;
360
+ matrix_type accumulatortype ;
367
361
uint32_t msize;
368
362
uint32_t nsize;
369
363
uint32_t ksize;
370
364
};
371
365
using mt = matrix_type;
372
366
static constexpr combination combinations[] = {
373
367
// The types used in the initialization below are fake and not used. In
374
- // this case, users already chose the types, they are only looking for the
368
+ // this case, users already chose the types, they are only looking for
369
+ // the
375
370
// sizes
376
371
{0 , 0 , 0 , mt::bf8, mt::bf8, mt::bf8, 1 , 8 , (sizeof (Ta) == 1 ) ? 32 : 16 },
377
372
{0 , 0 , 0 , mt::bf8, mt::bf8, mt::bf8, 2 , 8 , (sizeof (Ta) == 1 ) ? 32 : 16 },
@@ -384,32 +379,30 @@ struct tpu_params<tpu::dpas, Ta, Tb, Tc, 0, 0, 0,
384
379
385
380
// Valid or not:
386
381
// Specialization when both types and sizes are given
387
- template <typename Ta, typename Tb, typename Tc, int M , int N , int K >
382
+ template <typename Ta, typename Tb, typename Tc, int sM , int sN , int sK >
388
383
struct tpu_params <
389
- tpu::dpas, Ta, Tb, Tc, M, N, K ,
390
- typename std::enable_if<((!std::is_same_v<Ta, void > && M != 0 ))>::type> {
384
+ tpu::dpas, Ta, Tb, Tc, sM , sN , sK ,
385
+ typename std::enable_if<((!std::is_same_v<Ta, void > && sM != 0 ))>::type> {
391
386
// Validate that parameters are supported
392
- static_assert ((M == 0 && N == 0 && K == 0 ) ||
393
- (is_combination_valid_dpas<Ta, Tb, Tc>(M, N, K )),
387
+ static_assert ((sM == 0 && sN == 0 && sK == 0 ) ||
388
+ (is_combination_valid_dpas<Ta, Tb, Tc>(sM , sN , sK )),
394
389
" Invalid parameters for DPAS, query valid combinations "
395
390
" using: tpu_params<tpu::dpas> myparams; and then check out "
396
391
" myparams.combinations array" );
397
392
398
393
// if combination is valid, construct the matrices
399
- static constexpr std::size_t defaultM = (M != 0 ) ? M : 8 ;
400
- static constexpr std::size_t defaultN = (N != 0 ) ? N : 8 ;
401
- static constexpr std::size_t defaultK =
402
- (K != 0 ) ? K : ((sizeof (Ta) == 1 ) ? 32 : 16 );
394
+ static constexpr std::size_t M = (sM != 0 ) ? sM : 8 ;
395
+ static constexpr std::size_t N = (sN != 0 ) ? sN : 8 ;
396
+ static constexpr std::size_t K =
397
+ (sK != 0 ) ? sK : ((sizeof (Ta) == 1 ) ? 32 : 16 );
403
398
404
399
template <typename Group>
405
- using joint_matrix_a =
406
- joint_matrix<Ta, defaultM, defaultK, use::a, layout::row_major, Group>;
400
+ using joint_matrix_a = joint_matrix<Ta, M, K, use::a, layout::unused, Group>;
407
401
template <typename Group>
408
- using joint_matrix_b =
409
- joint_matrix<Tb, defaultK, defaultN, use::b, layout::packed_b, Group>;
402
+ using joint_matrix_b = joint_matrix<Tb, K, N, use::b, layout::unused, Group>;
410
403
template <typename Group>
411
- using joint_matrix_c = joint_matrix<Tc, defaultM, defaultN, use::accumulator,
412
- layout::row_major , Group>;
404
+ using joint_matrix_accumulator =
405
+ joint_matrix<Tc, M, N, use::accumulator, layout::unused , Group>;
413
406
414
407
bool dynamic_p = false ; // no dynamic allocation on the GPU
415
408
uint32_t numtiles = -1 ; // does not apply for DPAS
0 commit comments