@@ -18,80 +18,96 @@ enum class matrix_use { a, b, accumulator };
18
18
19
19
enum class matrix_layout { row_major, col_major, packed_a, packed_b };
20
20
21
- enum class use_tf32 { yes, no };
21
+ enum class precision { standard, tf32 /* TODO add more precisions */ };
22
22
23
23
template <typename T, matrix_use Use, size_t Rows = sycl::dynamic_extent,
24
24
size_t Cols = sycl::dynamic_extent,
25
25
matrix_layout Layout = matrix_layout::row_major,
26
- typename Group = sycl::sub_group, use_tf32 Tf32 = use_tf32::no,
27
- typename Cond = void >
26
+ typename Group = sycl::sub_group,
27
+ precision Prec = precision::standard, typename Cond = void >
28
28
struct joint_matrix ;
29
29
30
30
#define __SYCL_JOINT_MATRIX_OVERLOAD (type, use, M, N, frag_type, frag_size, \
31
- Tf32 ) \
31
+ Prec ) \
32
32
template <matrix_layout Layout> \
33
33
struct joint_matrix < \
34
- type, matrix_use::use, M, N, Layout, sycl::sub_group, Tf32 , \
34
+ type, matrix_use::use, M, N, Layout, sycl::sub_group, Prec , \
35
35
typename std::enable_if_t <Layout == matrix_layout::row_major || \
36
36
Layout == matrix_layout::col_major>> { \
37
37
frag_type data[frag_size]; \
38
38
};
39
39
40
40
// m8n8k4 double only
41
- __SYCL_JOINT_MATRIX_OVERLOAD (double , a, 8 , 4 , double , 1 , use_tf32::no)
42
- __SYCL_JOINT_MATRIX_OVERLOAD (double , b, 4 , 8 , double , 1 , use_tf32::no)
43
- __SYCL_JOINT_MATRIX_OVERLOAD (double , accumulator, 8 , 8 , double , 2 , use_tf32::no)
41
+ __SYCL_JOINT_MATRIX_OVERLOAD (double , a, 8 , 4 , double , 1 , precision::standard)
42
+ __SYCL_JOINT_MATRIX_OVERLOAD (double , b, 4 , 8 , double , 1 , precision::standard)
43
+ __SYCL_JOINT_MATRIX_OVERLOAD (double , accumulator, 8 , 8 , double , 2 ,
44
+ precision::standard)
44
45
45
46
// m8n32k16
46
47
// bf16 data format uses uint16_t data type
47
- __SYCL_JOINT_MATRIX_OVERLOAD (uint16_t , a, 8 , 16 , int32_t , 2 , use_tf32::no)
48
- __SYCL_JOINT_MATRIX_OVERLOAD (uint16_t , b, 16 , 32 , int32_t , 8 , use_tf32::no)
49
- __SYCL_JOINT_MATRIX_OVERLOAD (half, a, 8 , 16 , int32_t , 8 , use_tf32::no)
50
- __SYCL_JOINT_MATRIX_OVERLOAD (half, b, 16 , 32 , int32_t , 8 , use_tf32::no)
51
- __SYCL_JOINT_MATRIX_OVERLOAD (float , accumulator, 8 , 32 , float , 8 , use_tf32::no)
52
- __SYCL_JOINT_MATRIX_OVERLOAD (half, accumulator, 8 , 32 , int32_t , 4 , use_tf32::no)
53
-
54
- __SYCL_JOINT_MATRIX_OVERLOAD (int8_t , a, 8 , 16 , int32_t , 1 , use_tf32::no)
55
- __SYCL_JOINT_MATRIX_OVERLOAD (int8_t , b, 16 , 32 , int32_t , 4 , use_tf32::no)
56
- __SYCL_JOINT_MATRIX_OVERLOAD (uint8_t , a, 8 , 16 , int32_t , 1 , use_tf32::no)
57
- __SYCL_JOINT_MATRIX_OVERLOAD (uint8_t , b, 16 , 32 , int32_t , 4 , use_tf32::no)
48
+ __SYCL_JOINT_MATRIX_OVERLOAD (uint16_t , a, 8 , 16 , int32_t , 2 ,
49
+ precision::standard)
50
+ __SYCL_JOINT_MATRIX_OVERLOAD (uint16_t , b, 16 , 32 , int32_t , 8 ,
51
+ precision::standard)
52
+ __SYCL_JOINT_MATRIX_OVERLOAD (half, a, 8 , 16 , int32_t , 8 , precision::standard)
53
+ __SYCL_JOINT_MATRIX_OVERLOAD (half, b, 16 , 32 , int32_t , 8 , precision::standard)
54
+ __SYCL_JOINT_MATRIX_OVERLOAD (float , accumulator, 8 , 32 , float , 8 ,
55
+ precision::standard)
56
+ __SYCL_JOINT_MATRIX_OVERLOAD (half, accumulator, 8 , 32 , int32_t , 4 ,
57
+ precision::standard)
58
+
59
+ __SYCL_JOINT_MATRIX_OVERLOAD (int8_t , a, 8 , 16 , int32_t , 1 , precision::standard)
60
+ __SYCL_JOINT_MATRIX_OVERLOAD (int8_t , b, 16 , 32 , int32_t , 4 , precision::standard)
61
+ __SYCL_JOINT_MATRIX_OVERLOAD (uint8_t , a, 8 , 16 , int32_t , 1 , precision::standard)
62
+ __SYCL_JOINT_MATRIX_OVERLOAD (uint8_t , b, 16 , 32 , int32_t , 4 ,
63
+ precision::standard)
58
64
__SYCL_JOINT_MATRIX_OVERLOAD (int32_t , accumulator, 8 , 32 , int32_t , 8 ,
59
- use_tf32::no )
65
+ precision::standard )
60
66
61
67
// m32n8k16
62
- __SYCL_JOINT_MATRIX_OVERLOAD (uint16_t , a, 32 , 16 , int32_t , 8 , use_tf32::no)
63
- __SYCL_JOINT_MATRIX_OVERLOAD (uint16_t , b, 16 , 8 , int32_t , 2 , use_tf32::no)
64
- __SYCL_JOINT_MATRIX_OVERLOAD (half, a, 32 , 16 , int32_t , 8 , use_tf32::no)
65
- __SYCL_JOINT_MATRIX_OVERLOAD (half, b, 16 , 8 , int32_t , 8 , use_tf32::no)
66
- __SYCL_JOINT_MATRIX_OVERLOAD (float , accumulator, 32 , 8 , float , 8 , use_tf32::no)
67
- __SYCL_JOINT_MATRIX_OVERLOAD (half, accumulator, 32 , 8 , int32_t , 4 , use_tf32::no)
68
-
69
- __SYCL_JOINT_MATRIX_OVERLOAD (int8_t , a, 32 , 16 , int32_t , 4 , use_tf32::no)
70
- __SYCL_JOINT_MATRIX_OVERLOAD (int8_t , b, 16 , 8 , int32_t , 1 , use_tf32::no)
71
- __SYCL_JOINT_MATRIX_OVERLOAD (uint8_t , a, 32 , 16 , int32_t , 4 , use_tf32::no)
72
- __SYCL_JOINT_MATRIX_OVERLOAD (uint8_t , b, 16 , 8 , int32_t , 1 , use_tf32::no)
68
+ __SYCL_JOINT_MATRIX_OVERLOAD (uint16_t , a, 32 , 16 , int32_t , 8 ,
69
+ precision::standard)
70
+ __SYCL_JOINT_MATRIX_OVERLOAD (uint16_t , b, 16 , 8 , int32_t , 2 ,
71
+ precision::standard)
72
+ __SYCL_JOINT_MATRIX_OVERLOAD (half, a, 32 , 16 , int32_t , 8 , precision::standard)
73
+ __SYCL_JOINT_MATRIX_OVERLOAD (half, b, 16 , 8 , int32_t , 8 , precision::standard)
74
+ __SYCL_JOINT_MATRIX_OVERLOAD (float , accumulator, 32 , 8 , float , 8 ,
75
+ precision::standard)
76
+ __SYCL_JOINT_MATRIX_OVERLOAD (half, accumulator, 32 , 8 , int32_t , 4 ,
77
+ precision::standard)
78
+
79
+ __SYCL_JOINT_MATRIX_OVERLOAD (int8_t , a, 32 , 16 , int32_t , 4 , precision::standard)
80
+ __SYCL_JOINT_MATRIX_OVERLOAD (int8_t , b, 16 , 8 , int32_t , 1 , precision::standard)
81
+ __SYCL_JOINT_MATRIX_OVERLOAD (uint8_t , a, 32 , 16 , int32_t , 4 ,
82
+ precision::standard)
83
+ __SYCL_JOINT_MATRIX_OVERLOAD (uint8_t , b, 16 , 8 , int32_t , 1 , precision::standard)
73
84
__SYCL_JOINT_MATRIX_OVERLOAD (int32_t , accumulator, 32 , 8 , int32_t , 8 ,
74
- use_tf32::no )
85
+ precision::standard )
75
86
76
87
// m16n16k16
77
- __SYCL_JOINT_MATRIX_OVERLOAD (uint16_t , a, 16 , 16 , int32_t , 4 , use_tf32::no)
78
- __SYCL_JOINT_MATRIX_OVERLOAD (uint16_t , b, 16 , 16 , int32_t , 4 , use_tf32::no)
79
- __SYCL_JOINT_MATRIX_OVERLOAD (half, a, 16 , 16 , int32_t , 8 , use_tf32::no)
80
- __SYCL_JOINT_MATRIX_OVERLOAD (half, b, 16 , 16 , int32_t , 8 , use_tf32::no)
81
- __SYCL_JOINT_MATRIX_OVERLOAD (float , accumulator, 16 , 16 , float , 8 , use_tf32::no)
88
+ __SYCL_JOINT_MATRIX_OVERLOAD (uint16_t , a, 16 , 16 , int32_t , 4 ,
89
+ precision::standard)
90
+ __SYCL_JOINT_MATRIX_OVERLOAD (uint16_t , b, 16 , 16 , int32_t , 4 ,
91
+ precision::standard)
92
+ __SYCL_JOINT_MATRIX_OVERLOAD (half, a, 16 , 16 , int32_t , 8 , precision::standard)
93
+ __SYCL_JOINT_MATRIX_OVERLOAD (half, b, 16 , 16 , int32_t , 8 , precision::standard)
94
+ __SYCL_JOINT_MATRIX_OVERLOAD (float , accumulator, 16 , 16 , float , 8 ,
95
+ precision::standard)
82
96
__SYCL_JOINT_MATRIX_OVERLOAD (half, accumulator, 16 , 16 , int32_t , 4 ,
83
- use_tf32::no)
84
-
85
- __SYCL_JOINT_MATRIX_OVERLOAD (int8_t , a, 16 , 16 , int32_t , 2 , use_tf32::no)
86
- __SYCL_JOINT_MATRIX_OVERLOAD (int8_t , b, 16 , 16 , int32_t , 2 , use_tf32::no)
87
- __SYCL_JOINT_MATRIX_OVERLOAD (uint8_t , a, 16 , 16 , int32_t , 2 , use_tf32::no)
88
- __SYCL_JOINT_MATRIX_OVERLOAD (uint8_t , b, 16 , 16 , int32_t , 2 , use_tf32::no)
97
+ precision::standard)
98
+
99
+ __SYCL_JOINT_MATRIX_OVERLOAD (int8_t , a, 16 , 16 , int32_t , 2 , precision::standard)
100
+ __SYCL_JOINT_MATRIX_OVERLOAD (int8_t , b, 16 , 16 , int32_t , 2 , precision::standard)
101
+ __SYCL_JOINT_MATRIX_OVERLOAD (uint8_t , a, 16 , 16 , int32_t , 2 ,
102
+ precision::standard)
103
+ __SYCL_JOINT_MATRIX_OVERLOAD (uint8_t , b, 16 , 16 , int32_t , 2 ,
104
+ precision::standard)
89
105
__SYCL_JOINT_MATRIX_OVERLOAD (int32_t , accumulator, 16 , 16 , int32_t , 8 ,
90
- use_tf32::no )
106
+ precision::standard )
91
107
92
108
// m16n16k8 tf32
93
- __SYCL_JOINT_MATRIX_OVERLOAD (float , a, 16 , 8 , int32_t , 4 , use_tf32::yes )
94
- __SYCL_JOINT_MATRIX_OVERLOAD (float , b, 8 , 16 , int32_t , 4 , use_tf32::yes )
109
+ __SYCL_JOINT_MATRIX_OVERLOAD (float , a, 16 , 8 , int32_t , 4 , precision::tf32 )
110
+ __SYCL_JOINT_MATRIX_OVERLOAD (float , b, 8 , 16 , int32_t , 4 , precision::tf32 )
95
111
96
112
#undef __SYCL_JOINT_MATRIX_OVERLOAD
97
113
} // namespace experimental::matrix
@@ -102,12 +118,12 @@ template <typename T, sycl::ext::oneapi::experimental::matrix::matrix_use Use,
102
118
size_t NumRows, size_t NumCols,
103
119
sycl::ext::oneapi::experimental::matrix::matrix_layout Layout,
104
120
access::address_space Space,
105
- sycl::ext::oneapi::experimental::matrix::use_tf32 Tf32 =
106
- sycl::ext::oneapi::experimental::matrix::use_tf32::no ,
121
+ sycl::ext::oneapi::experimental::matrix::precision Prec =
122
+ sycl::ext::oneapi::experimental::matrix::precision::standard ,
107
123
typename Cond = void >
108
124
struct joint_matrix_load_impl {
109
125
void load (sycl::ext::oneapi::experimental::matrix::joint_matrix<
110
- T, Use, NumRows, NumCols, Layout, sycl::sub_group, Tf32 > &res,
126
+ T, Use, NumRows, NumCols, Layout, sycl::sub_group, Prec > &res,
111
127
multi_ptr<T, Space> src, size_t stride);
112
128
};
113
129
@@ -130,15 +146,15 @@ template <typename T, sycl::ext::oneapi::experimental::matrix::matrix_use Use,
130
146
size_t NumRows, size_t NumCols,
131
147
sycl::ext::oneapi::experimental::matrix::matrix_layout Layout,
132
148
access::address_space Space,
133
- sycl::ext::oneapi::experimental::matrix::use_tf32 Tf32 >
149
+ sycl::ext::oneapi::experimental::matrix::precision Prec >
134
150
struct joint_matrix_load_impl <
135
- T, Use, NumRows, NumCols, Layout, Space, Tf32 ,
151
+ T, Use, NumRows, NumCols, Layout, Space, Prec ,
136
152
typename std::enable_if_t <Layout == sycl::ext::oneapi::experimental::
137
153
matrix::matrix_layout::row_major ||
138
154
Layout == sycl::ext::oneapi::experimental::
139
155
matrix::matrix_layout::col_major>> {
140
156
void load (sycl::ext::oneapi::experimental::matrix::joint_matrix<
141
- T, Use, NumRows, NumCols, Layout, sycl::sub_group, Tf32 > &res,
157
+ T, Use, NumRows, NumCols, Layout, sycl::sub_group, Prec > &res,
142
158
multi_ptr<T, Space> src, size_t stride) {
143
159
if constexpr (std::is_same<T, uint16_t >::value) {
144
160
int32_t *tileptr = reinterpret_cast <int32_t *>(src.get ());
@@ -263,8 +279,8 @@ struct joint_matrix_load_impl<
263
279
get_layout_id<Layout>());
264
280
}
265
281
} else if constexpr (std::is_same<T, float >::value) {
266
- if constexpr (Tf32 ==
267
- sycl::ext::oneapi::experimental::matrix::use_tf32::yes ) {
282
+ if constexpr (Prec ==
283
+ sycl::ext::oneapi::experimental::matrix::precision::tf32 ) {
268
284
int32_t *tileptr = reinterpret_cast <int32_t *>(src.get ());
269
285
if constexpr (NumRows == 16 && NumCols == 8 ) {
270
286
__mma_tf32_m16n16k8_ld_a (res.data , tileptr, stride,
@@ -379,19 +395,19 @@ template <typename T1, typename T2, std::size_t M, std::size_t K, std::size_t N,
379
395
sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutA,
380
396
sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutB,
381
397
sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutC,
382
- sycl::ext::oneapi::experimental::matrix::use_tf32 Tf32 ,
398
+ sycl::ext::oneapi::experimental::matrix::precision Prec ,
383
399
typename Cond = void >
384
400
struct joint_matrix_mad_impl {
385
401
sycl::ext::oneapi::experimental::matrix::joint_matrix<
386
402
T2, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, M,
387
403
N, LayoutC, sycl::sub_group>
388
404
mad (sycl::ext::oneapi::experimental::matrix::joint_matrix<
389
405
T1, sycl::ext::oneapi::experimental::matrix::matrix_use::a, M, K,
390
- LayoutA, sycl::sub_group, Tf32 >
406
+ LayoutA, sycl::sub_group, Prec >
391
407
A,
392
408
sycl::ext::oneapi::experimental::matrix::joint_matrix<
393
409
T1, sycl::ext::oneapi::experimental::matrix::matrix_use::b, K, N,
394
- LayoutB, sycl::sub_group, Tf32 >
410
+ LayoutB, sycl::sub_group, Prec >
395
411
B,
396
412
sycl::ext::oneapi::experimental::matrix::joint_matrix<
397
413
T2, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator,
@@ -435,9 +451,9 @@ template <typename T1, typename T2, std::size_t M, std::size_t K, std::size_t N,
435
451
sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutA,
436
452
sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutB,
437
453
sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutC,
438
- sycl::ext::oneapi::experimental::matrix::use_tf32 Tf32 >
454
+ sycl::ext::oneapi::experimental::matrix::precision Prec >
439
455
struct joint_matrix_mad_impl <
440
- T1, T2, M, K, N, LayoutA, LayoutB, LayoutC, Tf32 ,
456
+ T1, T2, M, K, N, LayoutA, LayoutB, LayoutC, Prec ,
441
457
typename std::enable_if_t <
442
458
(LayoutA == sycl::ext::oneapi::experimental::matrix::matrix_layout::
443
459
row_major ||
@@ -456,11 +472,11 @@ struct joint_matrix_mad_impl<
456
472
N, LayoutC, sycl::sub_group>
457
473
mad (sycl::ext::oneapi::experimental::matrix::joint_matrix<
458
474
T1, sycl::ext::oneapi::experimental::matrix::matrix_use::a, M, K,
459
- LayoutA, sycl::sub_group, Tf32 >
475
+ LayoutA, sycl::sub_group, Prec >
460
476
A,
461
477
sycl::ext::oneapi::experimental::matrix::joint_matrix<
462
478
T1, sycl::ext::oneapi::experimental::matrix::matrix_use::b, K, N,
463
- LayoutB, sycl::sub_group, Tf32 >
479
+ LayoutB, sycl::sub_group, Prec >
464
480
B,
465
481
sycl::ext::oneapi::experimental::matrix::joint_matrix<
466
482
T2, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator,
@@ -530,8 +546,8 @@ struct joint_matrix_mad_impl<
530
546
}
531
547
}
532
548
} else if constexpr (M == 16 && N == 16 && K == 8 &&
533
- Tf32 == sycl::ext::oneapi::experimental::matrix::
534
- use_tf32::yes ) {
549
+ Prec == sycl::ext::oneapi::experimental::matrix::
550
+ precision::tf32 ) {
535
551
__mma_tf32_m16n16k8_mma_f32 (D.data , A.data , B.data , C.data ,
536
552
get_layout_pair_id<LayoutA, LayoutB>(), 0 );
537
553
} else if constexpr (std::is_same<T1, double >::value) {
@@ -548,13 +564,13 @@ namespace experimental::matrix {
548
564
549
565
template <typename Group, typename T, matrix_use Use, size_t NumRows,
550
566
size_t NumCols, matrix_layout Layout, access::address_space Space,
551
- use_tf32 Tf32 = use_tf32::no >
567
+ precision Prec = precision::standard >
552
568
void joint_matrix_load (
553
- Group sg, joint_matrix<T, Use, NumRows, NumCols, Layout, Group, Tf32 > &res,
569
+ Group sg, joint_matrix<T, Use, NumRows, NumCols, Layout, Group, Prec > &res,
554
570
multi_ptr<T, Space> src, size_t stride) {
555
571
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
556
572
sycl::ext::oneapi::detail::joint_matrix_load_impl<T, Use, NumRows, NumCols,
557
- Layout, Space, Tf32 >{}
573
+ Layout, Space, Prec >{}
558
574
.load (res, src, stride);
559
575
#else
560
576
(void )sg;
@@ -592,15 +608,15 @@ void joint_matrix_store(Group sg,
592
608
593
609
template <typename Group, typename T1, typename T2, std::size_t M,
594
610
std::size_t K, std::size_t N, matrix_layout LayoutA,
595
- matrix_layout LayoutB, matrix_layout LayoutC, use_tf32 Tf32 >
611
+ matrix_layout LayoutB, matrix_layout LayoutC, precision Prec >
596
612
joint_matrix<T2, matrix_use::accumulator, M, N, LayoutC, Group>
597
613
joint_matrix_mad (
598
- Group sg, joint_matrix<T1, matrix_use::a, M, K, LayoutA, Group, Tf32 > A,
599
- joint_matrix<T1, matrix_use::b, K, N, LayoutB, Group, Tf32 > B,
614
+ Group sg, joint_matrix<T1, matrix_use::a, M, K, LayoutA, Group, Prec > A,
615
+ joint_matrix<T1, matrix_use::b, K, N, LayoutB, Group, Prec > B,
600
616
joint_matrix<T2, matrix_use::accumulator, M, N, LayoutC, Group> C) {
601
617
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
602
618
return sycl::ext::oneapi::detail::joint_matrix_mad_impl<
603
- T1, T2, M, K, N, LayoutA, LayoutB, LayoutC, Tf32 >{}
619
+ T1, T2, M, K, N, LayoutA, LayoutB, LayoutC, Prec >{}
604
620
.mad (A, B, C);
605
621
#else
606
622
(void )sg;
0 commit comments