6
6
//
7
7
// ===--------------------------------------------------------------------=== //
8
8
// This file implements the static query interface for the joint_matrix
9
- // experimental extension. AMX, DPAS and different other TPUs support different
10
- // logical sizes and types. The query interface is used to validate user code
11
- // and inform them about supported types, sizes, scope, and layouts by the
12
- // current implementation. Note that this query interface is a compile-time
13
- // query, so there will be no runtime errors. The query interface provides
14
- // three functionalities:
15
- // 1- At compile time, inform the user whether a specific
16
- // combination is valid or not.
17
- // 2- Construct the matrices using a default shape
18
- // if user does not provide a combination
19
- // 3- General query interface for sizes, types,
20
- // static/dynamic, scope. This is needed to void padding by the user,
21
- // for tuning, and efficient code generation if used by a library.
9
+ // experimental extension. Intel AMX, Intel XMX, and Nvidia Tensor Cores support
10
+ // different logical sizes and types. The query interface is used to validate
11
+ // user code and inform them about supported types, sizes, scopes, and layouts
12
+ // by the current implementation. Note that this query interface is a
13
+ // compile-time query, so there will be no runtime errors. The query interface
14
+ // provides three functionalities: 1- At compile time, inform the user whether a
15
+ // specific combination is valid or not. 2- Construct the matrices using a
16
+ // default shape if user does not provide a combination 3- General query
17
+ // interface for sizes, types, scopes. This is needed to void padding by the
18
+ // user, for tuning, and efficient code generation if used by a library.
22
19
23
20
#pragma once
24
21
@@ -29,14 +26,15 @@ namespace oneapi {
29
26
namespace experimental ::matrix {
30
27
31
28
enum class tpu {
32
- dpas,
29
+ xmx8,
30
+ xmx16,
33
31
amx,
34
32
};
35
33
enum class matrix_type {
36
34
bf8,
37
35
bf16 ,
38
36
fp16,
39
- fp19, // tfloat32
37
+ tf32,
40
38
fp32,
41
39
fp64,
42
40
sint2,
@@ -104,10 +102,9 @@ struct tpu_params<tpu::amx, void, void, void, sM, sN, sK> {
104
102
static constexpr std::size_t N = -1 ;
105
103
static constexpr std::size_t K = -1 ;
106
104
107
- bool dynamic_p = false ; // should be true in future implementations because
108
- // AMX hardware supports dynamic sizes
109
105
uint32_t numtiles = 8 ;
110
- scope_t scope = scope_t ::sub_group;
106
+ static constexpr scope_t scopes[] = {scope_t ::sub_group};
107
+ static constexpr int num_scopes = sizeof (scopes) / sizeof (scope_t );
111
108
struct combination {
112
109
uint32_t max_msize;
113
110
uint32_t max_nsize;
@@ -155,10 +152,9 @@ struct tpu_params<tpu::amx, Ta, Tb, Tc, 0, 0, 0,
155
152
using joint_matrix_accumulator =
156
153
joint_matrix<Group, Tc, use::accumulator, M, N>;
157
154
158
- bool dynamic_p = false ; // should be true in future implementations because
159
- // AMX hardware supports dynamic sizes
160
155
uint32_t numtiles = 8 ;
161
- scope_t scope = scope_t ::sub_group;
156
+ static constexpr scope_t scopes[] = {scope_t ::sub_group};
157
+ static constexpr int num_scopes = sizeof (scopes) / sizeof (scope_t );
162
158
struct combination {
163
159
uint32_t max_msize;
164
160
uint32_t max_nsize;
@@ -207,19 +203,18 @@ struct tpu_params<
207
203
using joint_matrix_accumulator =
208
204
joint_matrix<Group, Tc, use::accumulator, M, N>;
209
205
210
- bool dynamic_p = false ; // should be true in future implementations
211
- // because AMX hardware supports dynamic sizes
212
206
uint32_t numtiles = 8 ;
213
- scope_t scope = scope_t ::sub_group;
207
+ static constexpr scope_t scopes[] = {scope_t ::sub_group};
208
+ static constexpr int num_scopes = sizeof (scopes) / sizeof (scope_t );
214
209
};
215
210
216
- // DPAS case
217
- // The DPAS implementation supports the logical capability support of the HW
218
- // So in this case, M, N, K sizes returned by the query represent the logical
219
- // capabilities of the DPAS hardware.
211
+ // Intel XMX with SIMD8 capability
212
+ // The Intel XMX implementation supports the logical capability support of the
213
+ // HW So in this case, M, N, K sizes returned by the query represent the logical
214
+ // capabilities of the Intel XMX hardware.
220
215
221
216
template <typename Ta, typename Tb, typename Tc>
222
- constexpr bool is_combination_valid_dpas (int sM , int sN , int sK ) {
217
+ constexpr bool is_combination_valid_xmx8 (int sM , int sN , int sK ) {
223
218
if ((std::is_same_v<Ta, int8_t > && std::is_same_v<Tb, int8_t > &&
224
219
std::is_same_v<Tc, int > && (sM == 1 || sM == 2 || sM == 4 || sM == 8 ) &&
225
220
sN == 8 && sK == 32 ) ||
@@ -244,7 +239,7 @@ constexpr bool is_combination_valid_dpas(int sM, int sN, int sK) {
244
239
}
245
240
246
241
template <typename Ta, typename Tb, typename Tc>
247
- constexpr bool are_types_valid_dpas () {
242
+ constexpr bool are_types_valid_xmx8 () {
248
243
if ((std::is_same_v<Ta, int8_t > && std::is_same_v<Tb, int8_t > &&
249
244
std::is_same_v<Tc, int >) ||
250
245
(std::is_same_v<Ta, uint8_t > && std::is_same_v<Tb, int8_t > &&
@@ -265,14 +260,14 @@ constexpr bool are_types_valid_dpas() {
265
260
// General Query
266
261
// specialization for when types are not given --> no default values
267
262
template <int sM , int sN , int sK >
268
- struct tpu_params <tpu::dpas , void , void , void , sM , sN , sK > {
263
+ struct tpu_params <tpu::xmx8 , void , void , void , sM , sN , sK > {
269
264
static constexpr std::size_t M = -1 ; // depends on the type
270
265
static constexpr std::size_t N = -1 ;
271
266
static constexpr std::size_t K = -1 ;
272
267
273
- bool dynamic_p = false ; // no dynamic allocation on the GPU
274
- uint32_t numtiles = - 1 ; // does not apply for DPAS
275
- scope_t scope = scope_t ::sub_group ;
268
+ uint32_t numtiles = - 1 ; // does not apply for XMX8
269
+ static constexpr scope_t scopes[] = { scope_t ::sub_group};
270
+ static constexpr int num_scopes = sizeof (scopes) / sizeof ( scope_t ) ;
276
271
277
272
struct combination {
278
273
uint32_t max_msize;
@@ -320,12 +315,12 @@ struct tpu_params<tpu::dpas, void, void, void, sM, sN, sK> {
320
315
// Specialization for when only types are given, need to query only sizes
321
316
322
317
template <typename Ta, typename Tb, typename Tc>
323
- struct tpu_params <tpu::dpas , Ta, Tb, Tc, 0 , 0 , 0 ,
318
+ struct tpu_params <tpu::xmx8 , Ta, Tb, Tc, 0 , 0 , 0 ,
324
319
typename std::enable_if<(!std::is_same_v<Ta, void > &&
325
320
!std::is_same_v<Tb, void > &&
326
321
!std::is_same_v<Tc, void >)>::type> {
327
- static_assert ((are_types_valid_dpas <Ta, Tb, Tc>()),
328
- " Invalid types for DPAS , supported types are int8_t, uint8_t, "
322
+ static_assert ((are_types_valid_xmx8 <Ta, Tb, Tc>()),
323
+ " Invalid types for XMX8 , supported types are int8_t, uint8_t, "
329
324
" half, and bf16 (Note that unsigned short should be used in the"
330
325
" DPC++ code to implement bf16)" );
331
326
@@ -343,9 +338,9 @@ struct tpu_params<tpu::dpas, Ta, Tb, Tc, 0, 0, 0,
343
338
using joint_matrix_accumulator =
344
339
joint_matrix<Group, Tc, use::accumulator, M, N>;
345
340
346
- bool dynamic_p = false ; // no dynamic allocation on the GPU
347
- uint32_t numtiles = - 1 ; // does not apply for DPAS
348
- scope_t scope = scope_t ::sub_group ;
341
+ uint32_t numtiles = - 1 ; // does not apply for XMX8
342
+ static constexpr scope_t scopes[] = { scope_t ::sub_group};
343
+ static constexpr int num_scopes = sizeof (scopes) / sizeof ( scope_t ) ;
349
344
struct combination {
350
345
uint32_t max_msize;
351
346
uint32_t max_nsize;
@@ -376,13 +371,13 @@ struct tpu_params<tpu::dpas, Ta, Tb, Tc, 0, 0, 0,
376
371
// Specialization when both types and sizes are given
377
372
template <typename Ta, typename Tb, typename Tc, int sM , int sN , int sK >
378
373
struct tpu_params <
379
- tpu::dpas , Ta, Tb, Tc, sM , sN , sK ,
374
+ tpu::xmx8 , Ta, Tb, Tc, sM , sN , sK ,
380
375
typename std::enable_if<((!std::is_same_v<Ta, void > && sM != 0 ))>::type> {
381
376
// Validate that parameters are supported
382
377
static_assert ((sM == 0 && sN == 0 && sK == 0 ) ||
383
- (is_combination_valid_dpas <Ta, Tb, Tc>(sM , sN , sK )),
384
- " Invalid parameters for DPAS , query valid combinations "
385
- " using: tpu_params<tpu::dpas > myparams; and then check out "
378
+ (is_combination_valid_xmx8 <Ta, Tb, Tc>(sM , sN , sK )),
379
+ " Invalid parameters for XMX8 , query valid combinations "
380
+ " using: tpu_params<tpu::xmx8 > myparams; and then check out "
386
381
" myparams.combinations array" );
387
382
388
383
// if combination is valid, construct the matrices
@@ -399,9 +394,200 @@ struct tpu_params<
399
394
using joint_matrix_accumulator =
400
395
joint_matrix<Group, Tc, use::accumulator, M, N>;
401
396
402
- bool dynamic_p = false ; // no dynamic allocation on the GPU
403
- uint32_t numtiles = -1 ; // does not apply for DPAS
404
- scope_t scope = scope_t ::sub_group;
397
+ uint32_t numtiles = -1 ; // does not apply for XMX8
398
+ static constexpr scope_t scopes[] = {scope_t ::sub_group};
399
+ static constexpr int num_scopes = sizeof (scopes) / sizeof (scope_t );
400
+ };
401
+
402
+ // Intel XMX with SIMD16 capability
403
+ // The Intel XMX implementation supports the logical capability support of the
404
+ // HW So in this case, M, N, K sizes returned by the query represent the logical
405
+ // capabilities of the Intel XMX hardware.
406
+
407
+ template <typename Ta, typename Tb, typename Tc>
408
+ constexpr bool is_combination_valid_xmx16 (int sM , int sN , int sK ) {
409
+ if ((std::is_same_v<Ta, int8_t > && std::is_same_v<Tb, int8_t > &&
410
+ std::is_same_v<Tc, int > && (sM == 1 || sM == 2 || sM == 4 || sM == 8 ) &&
411
+ sN == 16 && sK == 32 ) ||
412
+ (std::is_same_v<Ta, int8_t > && std::is_same_v<Tb, uint8_t > &&
413
+ std::is_same_v<Tc, int > && (sM == 1 || sM == 2 || sM == 4 || sM == 8 ) &&
414
+ sN == 16 && sK == 32 ) ||
415
+ (std::is_same_v<Ta, uint8_t > && std::is_same_v<Tb, int8_t > &&
416
+ std::is_same_v<Tc, int > && (sM == 1 || sM == 2 || sM == 4 || sM == 8 ) &&
417
+ sN == 16 && sK == 32 ) ||
418
+ (std::is_same_v<Ta, uint8_t > && std::is_same_v<Tb, uint8_t > &&
419
+ std::is_same_v<Tc, int > && (sM == 1 || sM == 2 || sM == 4 || sM == 8 ) &&
420
+ sN == 16 && sK == 32 ) ||
421
+ (std::is_same_v<Ta, half> && std::is_same_v<Tb, half> &&
422
+ std::is_same_v<Tc, float > &&
423
+ (sM == 1 || sM == 2 || sM == 4 || sM == 8 ) && sN == 16 && sK == 16 ) ||
424
+ (std::is_same_v<Ta, unsigned short > &&
425
+ std::is_same_v<Tb, unsigned short > && std::is_same_v<Tc, float > &&
426
+ (sM == 1 || sM == 2 || sM == 4 || sM == 8 ) && sN == 16 && sK == 16 ))
427
+ return true ;
428
+ else
429
+ return false ;
430
+ }
431
+
432
+ template <typename Ta, typename Tb, typename Tc>
433
+ constexpr bool are_types_valid_xmx16 () {
434
+ if ((std::is_same_v<Ta, int8_t > && std::is_same_v<Tb, int8_t > &&
435
+ std::is_same_v<Tc, int >) ||
436
+ (std::is_same_v<Ta, uint8_t > && std::is_same_v<Tb, int8_t > &&
437
+ std::is_same_v<Tc, int >) ||
438
+ (std::is_same_v<Ta, int8_t > && std::is_same_v<Tb, uint8_t > &&
439
+ std::is_same_v<Tc, int >) ||
440
+ (std::is_same_v<Ta, uint8_t > && std::is_same_v<Tb, uint8_t > &&
441
+ std::is_same_v<Tc, int >) ||
442
+ (std::is_same_v<Ta, half> && std::is_same_v<Tb, half> &&
443
+ std::is_same_v<Tc, float >) ||
444
+ (std::is_same_v<Ta, unsigned short > &&
445
+ std::is_same_v<Tb, unsigned short > && std::is_same_v<Tc, float >))
446
+ return true ;
447
+ else
448
+ return false ;
449
+ }
450
+
451
+ // General Query
452
+ // specialization for when types are not given --> no default values
453
+ template <int sM , int sN , int sK >
454
+ struct tpu_params <tpu::xmx16, void , void , void , sM , sN , sK > {
455
+ static constexpr std::size_t M = -1 ; // depends on the type
456
+ static constexpr std::size_t N = -1 ;
457
+ static constexpr std::size_t K = -1 ;
458
+
459
+ uint32_t numtiles = -1 ; // does not apply for XMX
460
+ static constexpr scope_t scopes[] = {scope_t ::sub_group};
461
+ static constexpr int num_scopes = sizeof (scopes) / sizeof (scope_t );
462
+
463
+ struct combination {
464
+ uint32_t max_msize;
465
+ uint32_t max_nsize;
466
+ uint32_t max_ksize;
467
+ matrix_type atype;
468
+ matrix_type btype;
469
+ matrix_type accumulatortype;
470
+ uint32_t msize;
471
+ uint32_t nsize;
472
+ uint32_t ksize;
473
+ };
474
+ using mt = matrix_type;
475
+ static constexpr combination combinations[] = {
476
+ {0 , 0 , 0 , mt::sint8, mt::sint8, mt::sint32, 1 , 16 , 32 },
477
+ {0 , 0 , 0 , mt::sint8, mt::sint8, mt::sint32, 2 , 16 , 32 },
478
+ {0 , 0 , 0 , mt::sint8, mt::sint8, mt::sint32, 4 , 16 , 32 },
479
+ {0 , 0 , 0 , mt::sint8, mt::sint8, mt::sint32, 8 , 16 , 32 },
480
+ {0 , 0 , 0 , mt::sint8, mt::uint8, mt::sint32, 1 , 16 , 32 },
481
+ {0 , 0 , 0 , mt::sint8, mt::uint8, mt::sint32, 2 , 16 , 32 },
482
+ {0 , 0 , 0 , mt::sint8, mt::uint8, mt::sint32, 4 , 16 , 32 },
483
+ {0 , 0 , 0 , mt::sint8, mt::uint8, mt::sint32, 8 , 16 , 32 },
484
+ {0 , 0 , 0 , mt::uint8, mt::sint8, mt::sint32, 1 , 16 , 32 },
485
+ {0 , 0 , 0 , mt::uint8, mt::sint8, mt::sint32, 2 , 16 , 32 },
486
+ {0 , 0 , 0 , mt::uint8, mt::sint8, mt::sint32, 4 , 16 , 32 },
487
+ {0 , 0 , 0 , mt::uint8, mt::sint8, mt::sint32, 8 , 16 , 32 },
488
+ {0 , 0 , 0 , mt::uint8, mt::uint8, mt::sint32, 1 , 16 , 32 },
489
+ {0 , 0 , 0 , mt::uint8, mt::uint8, mt::sint32, 2 , 16 , 32 },
490
+ {0 , 0 , 0 , mt::uint8, mt::uint8, mt::sint32, 4 , 16 , 32 },
491
+ {0 , 0 , 0 , mt::uint8, mt::uint8, mt::sint32, 8 , 16 , 32 },
492
+ {0 , 0 , 0 , mt::fp16, mt::fp16, mt::fp32, 1 , 16 , 16 },
493
+ {0 , 0 , 0 , mt::fp16, mt::fp16, mt::fp32, 2 , 16 , 16 },
494
+ {0 , 0 , 0 , mt::fp16, mt::fp16, mt::fp32, 4 , 16 , 16 },
495
+ {0 , 0 , 0 , mt::fp16, mt::fp16, mt::fp32, 8 , 16 , 16 },
496
+ {0 , 0 , 0 , mt::bf16 , mt::bf16 , mt::fp32, 1 , 16 , 16 },
497
+ {0 , 0 , 0 , mt::bf16 , mt::bf16 , mt::fp32, 2 , 16 , 16 },
498
+ {0 , 0 , 0 , mt::bf16 , mt::bf16 , mt::fp32, 4 , 16 , 16 },
499
+ {0 , 0 , 0 , mt::bf16 , mt::bf16 , mt::fp32, 8 , 16 , 16 },
500
+ };
501
+ static constexpr int num_combinations =
502
+ sizeof (combinations) / sizeof (combination);
503
+ };
504
+
505
+ // Sizes-only query:
506
+ // Specialization for when only types are given, need to query only sizes
507
+
508
+ template <typename Ta, typename Tb, typename Tc>
509
+ struct tpu_params <tpu::xmx16, Ta, Tb, Tc, 0 , 0 , 0 ,
510
+ typename std::enable_if<(!std::is_same_v<Ta, void > &&
511
+ !std::is_same_v<Tb, void > &&
512
+ !std::is_same_v<Tc, void >)>::type> {
513
+ static_assert ((are_types_valid_xmx16<Ta, Tb, Tc>()),
514
+ " Invalid types for XMX16, supported types are int8_t, uint8_t, "
515
+ " half, and bf16 (Note that unsigned short should be used in the"
516
+ " DPC++ code to implement bf16)" );
517
+
518
+ // construct the matrices using the default sizes
519
+
520
+ static constexpr std::size_t M = 8 ;
521
+ static constexpr std::size_t N = 16 ;
522
+ static constexpr std::size_t K = ((sizeof (Ta) == 1 ) ? 32 : 16 );
523
+
524
+ template <typename Group, layout Layout>
525
+ using joint_matrix_a = joint_matrix<Group, Ta, use::a, M, K, Layout>;
526
+ template <typename Group, layout Layout>
527
+ using joint_matrix_b = joint_matrix<Group, Tb, use::b, K, N, Layout>;
528
+ template <typename Group>
529
+ using joint_matrix_accumulator =
530
+ joint_matrix<Group, Tc, use::accumulator, M, N>;
531
+
532
+ uint32_t numtiles = -1 ; // does not apply for XMX
533
+ static constexpr scope_t scopes[] = {scope_t ::sub_group};
534
+ static constexpr int num_scopes = sizeof (scopes) / sizeof (scope_t );
535
+ struct combination {
536
+ uint32_t max_msize;
537
+ uint32_t max_nsize;
538
+ uint32_t max_ksize;
539
+ matrix_type atype;
540
+ matrix_type btype;
541
+ matrix_type accumulatortype;
542
+ uint32_t msize;
543
+ uint32_t nsize;
544
+ uint32_t ksize;
545
+ };
546
+ using mt = matrix_type;
547
+ static constexpr combination combinations[] = {
548
+ // The types used in the initialization below are fake and not used. In
549
+ // this case, users already chose the types, they are only looking for
550
+ // the
551
+ // sizes
552
+ {0 , 0 , 0 , mt::bf8, mt::bf8, mt::bf8, 1 , 16 , (sizeof (Ta) == 1 ) ? 32 : 16 },
553
+ {0 , 0 , 0 , mt::bf8, mt::bf8, mt::bf8, 2 , 16 , (sizeof (Ta) == 1 ) ? 32 : 16 },
554
+ {0 , 0 , 0 , mt::bf8, mt::bf8, mt::bf8, 4 , 16 , (sizeof (Ta) == 1 ) ? 32 : 16 },
555
+ {0 , 0 , 0 , mt::bf8, mt::bf8, mt::bf8, 8 , 16 , (sizeof (Ta) == 1 ) ? 32 : 16 },
556
+ };
557
+ static constexpr int num_combinations =
558
+ sizeof (combinations) / sizeof (combination);
559
+ };
560
+
561
+ // Valid or not:
562
+ // Specialization when both types and sizes are given
563
+ template <typename Ta, typename Tb, typename Tc, int sM , int sN , int sK >
564
+ struct tpu_params <
565
+ tpu::xmx16, Ta, Tb, Tc, sM , sN , sK ,
566
+ typename std::enable_if<((!std::is_same_v<Ta, void > && sM != 0 ))>::type> {
567
+ // Validate that parameters are supported
568
+ static_assert ((sM == 0 && sN == 0 && sK == 0 ) ||
569
+ (is_combination_valid_xmx16<Ta, Tb, Tc>(sM , sN , sK )),
570
+ " Invalid parameters for XMX16, query valid combinations "
571
+ " using: tpu_params<tpu::xmx16> myparams; and then check out "
572
+ " myparams.combinations array" );
573
+
574
+ // if combination is valid, construct the matrices
575
+ static constexpr std::size_t M = (sM != 0 ) ? sM : 8 ;
576
+ static constexpr std::size_t N = (sN != 0 ) ? sN : 8 ;
577
+ static constexpr std::size_t K =
578
+ (sK != 0 ) ? sK : ((sizeof (Ta) == 1 ) ? 32 : 16 );
579
+
580
+ template <typename Group, layout Layout>
581
+ using joint_matrix_a = joint_matrix<Group, Ta, use::a, M, K, Layout>;
582
+ template <typename Group, layout Layout>
583
+ using joint_matrix_b = joint_matrix<Group, Tb, use::b, K, N, Layout>;
584
+ template <typename Group>
585
+ using joint_matrix_accumulator =
586
+ joint_matrix<Group, Tc, use::accumulator, M, N>;
587
+
588
+ uint32_t numtiles = -1 ; // does not apply for XMX16
589
+ static constexpr scope_t scopes[] = {scope_t ::sub_group};
590
+ static constexpr int num_scopes = sizeof (scopes) / sizeof (scope_t );
405
591
};
406
592
} // namespace experimental::matrix
407
593
} // namespace oneapi
0 commit comments