Skip to content

Commit 23cb7da

Browse files
author
Hugh Delaney
committed
Changing to precision enum
1 parent 61b3d8f commit 23cb7da

File tree

2 files changed

+90
-74
lines changed

2 files changed

+90
-74
lines changed

sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp

Lines changed: 86 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -18,80 +18,96 @@ enum class matrix_use { a, b, accumulator };
1818

1919
enum class matrix_layout { row_major, col_major, packed_a, packed_b };
2020

21-
enum class use_tf32 { yes, no };
21+
enum class precision { standard, tf32 /* TODO add more precisions*/ };
2222

2323
template <typename T, matrix_use Use, size_t Rows = sycl::dynamic_extent,
2424
size_t Cols = sycl::dynamic_extent,
2525
matrix_layout Layout = matrix_layout::row_major,
26-
typename Group = sycl::sub_group, use_tf32 Tf32 = use_tf32::no,
27-
typename Cond = void>
26+
typename Group = sycl::sub_group,
27+
precision Prec = precision::standard, typename Cond = void>
2828
struct joint_matrix;
2929

3030
#define __SYCL_JOINT_MATRIX_OVERLOAD(type, use, M, N, frag_type, frag_size, \
31-
Tf32) \
31+
Prec) \
3232
template <matrix_layout Layout> \
3333
struct joint_matrix< \
34-
type, matrix_use::use, M, N, Layout, sycl::sub_group, Tf32, \
34+
type, matrix_use::use, M, N, Layout, sycl::sub_group, Prec, \
3535
typename std::enable_if_t<Layout == matrix_layout::row_major || \
3636
Layout == matrix_layout::col_major>> { \
3737
frag_type data[frag_size]; \
3838
};
3939

4040
// m8n8k4 double only
41-
__SYCL_JOINT_MATRIX_OVERLOAD(double, a, 8, 4, double, 1, use_tf32::no)
42-
__SYCL_JOINT_MATRIX_OVERLOAD(double, b, 4, 8, double, 1, use_tf32::no)
43-
__SYCL_JOINT_MATRIX_OVERLOAD(double, accumulator, 8, 8, double, 2, use_tf32::no)
41+
__SYCL_JOINT_MATRIX_OVERLOAD(double, a, 8, 4, double, 1, precision::standard)
42+
__SYCL_JOINT_MATRIX_OVERLOAD(double, b, 4, 8, double, 1, precision::standard)
43+
__SYCL_JOINT_MATRIX_OVERLOAD(double, accumulator, 8, 8, double, 2,
44+
precision::standard)
4445

4546
// m8n32k16
4647
// bf16 data format uses uint16_t data type
47-
__SYCL_JOINT_MATRIX_OVERLOAD(uint16_t, a, 8, 16, int32_t, 2, use_tf32::no)
48-
__SYCL_JOINT_MATRIX_OVERLOAD(uint16_t, b, 16, 32, int32_t, 8, use_tf32::no)
49-
__SYCL_JOINT_MATRIX_OVERLOAD(half, a, 8, 16, int32_t, 8, use_tf32::no)
50-
__SYCL_JOINT_MATRIX_OVERLOAD(half, b, 16, 32, int32_t, 8, use_tf32::no)
51-
__SYCL_JOINT_MATRIX_OVERLOAD(float, accumulator, 8, 32, float, 8, use_tf32::no)
52-
__SYCL_JOINT_MATRIX_OVERLOAD(half, accumulator, 8, 32, int32_t, 4, use_tf32::no)
53-
54-
__SYCL_JOINT_MATRIX_OVERLOAD(int8_t, a, 8, 16, int32_t, 1, use_tf32::no)
55-
__SYCL_JOINT_MATRIX_OVERLOAD(int8_t, b, 16, 32, int32_t, 4, use_tf32::no)
56-
__SYCL_JOINT_MATRIX_OVERLOAD(uint8_t, a, 8, 16, int32_t, 1, use_tf32::no)
57-
__SYCL_JOINT_MATRIX_OVERLOAD(uint8_t, b, 16, 32, int32_t, 4, use_tf32::no)
48+
__SYCL_JOINT_MATRIX_OVERLOAD(uint16_t, a, 8, 16, int32_t, 2,
49+
precision::standard)
50+
__SYCL_JOINT_MATRIX_OVERLOAD(uint16_t, b, 16, 32, int32_t, 8,
51+
precision::standard)
52+
__SYCL_JOINT_MATRIX_OVERLOAD(half, a, 8, 16, int32_t, 8, precision::standard)
53+
__SYCL_JOINT_MATRIX_OVERLOAD(half, b, 16, 32, int32_t, 8, precision::standard)
54+
__SYCL_JOINT_MATRIX_OVERLOAD(float, accumulator, 8, 32, float, 8,
55+
precision::standard)
56+
__SYCL_JOINT_MATRIX_OVERLOAD(half, accumulator, 8, 32, int32_t, 4,
57+
precision::standard)
58+
59+
__SYCL_JOINT_MATRIX_OVERLOAD(int8_t, a, 8, 16, int32_t, 1, precision::standard)
60+
__SYCL_JOINT_MATRIX_OVERLOAD(int8_t, b, 16, 32, int32_t, 4, precision::standard)
61+
__SYCL_JOINT_MATRIX_OVERLOAD(uint8_t, a, 8, 16, int32_t, 1, precision::standard)
62+
__SYCL_JOINT_MATRIX_OVERLOAD(uint8_t, b, 16, 32, int32_t, 4,
63+
precision::standard)
5864
__SYCL_JOINT_MATRIX_OVERLOAD(int32_t, accumulator, 8, 32, int32_t, 8,
59-
use_tf32::no)
65+
precision::standard)
6066

6167
// m32n8k16
62-
__SYCL_JOINT_MATRIX_OVERLOAD(uint16_t, a, 32, 16, int32_t, 8, use_tf32::no)
63-
__SYCL_JOINT_MATRIX_OVERLOAD(uint16_t, b, 16, 8, int32_t, 2, use_tf32::no)
64-
__SYCL_JOINT_MATRIX_OVERLOAD(half, a, 32, 16, int32_t, 8, use_tf32::no)
65-
__SYCL_JOINT_MATRIX_OVERLOAD(half, b, 16, 8, int32_t, 8, use_tf32::no)
66-
__SYCL_JOINT_MATRIX_OVERLOAD(float, accumulator, 32, 8, float, 8, use_tf32::no)
67-
__SYCL_JOINT_MATRIX_OVERLOAD(half, accumulator, 32, 8, int32_t, 4, use_tf32::no)
68-
69-
__SYCL_JOINT_MATRIX_OVERLOAD(int8_t, a, 32, 16, int32_t, 4, use_tf32::no)
70-
__SYCL_JOINT_MATRIX_OVERLOAD(int8_t, b, 16, 8, int32_t, 1, use_tf32::no)
71-
__SYCL_JOINT_MATRIX_OVERLOAD(uint8_t, a, 32, 16, int32_t, 4, use_tf32::no)
72-
__SYCL_JOINT_MATRIX_OVERLOAD(uint8_t, b, 16, 8, int32_t, 1, use_tf32::no)
68+
__SYCL_JOINT_MATRIX_OVERLOAD(uint16_t, a, 32, 16, int32_t, 8,
69+
precision::standard)
70+
__SYCL_JOINT_MATRIX_OVERLOAD(uint16_t, b, 16, 8, int32_t, 2,
71+
precision::standard)
72+
__SYCL_JOINT_MATRIX_OVERLOAD(half, a, 32, 16, int32_t, 8, precision::standard)
73+
__SYCL_JOINT_MATRIX_OVERLOAD(half, b, 16, 8, int32_t, 8, precision::standard)
74+
__SYCL_JOINT_MATRIX_OVERLOAD(float, accumulator, 32, 8, float, 8,
75+
precision::standard)
76+
__SYCL_JOINT_MATRIX_OVERLOAD(half, accumulator, 32, 8, int32_t, 4,
77+
precision::standard)
78+
79+
__SYCL_JOINT_MATRIX_OVERLOAD(int8_t, a, 32, 16, int32_t, 4, precision::standard)
80+
__SYCL_JOINT_MATRIX_OVERLOAD(int8_t, b, 16, 8, int32_t, 1, precision::standard)
81+
__SYCL_JOINT_MATRIX_OVERLOAD(uint8_t, a, 32, 16, int32_t, 4,
82+
precision::standard)
83+
__SYCL_JOINT_MATRIX_OVERLOAD(uint8_t, b, 16, 8, int32_t, 1, precision::standard)
7384
__SYCL_JOINT_MATRIX_OVERLOAD(int32_t, accumulator, 32, 8, int32_t, 8,
74-
use_tf32::no)
85+
precision::standard)
7586

7687
// m16n16k16
77-
__SYCL_JOINT_MATRIX_OVERLOAD(uint16_t, a, 16, 16, int32_t, 4, use_tf32::no)
78-
__SYCL_JOINT_MATRIX_OVERLOAD(uint16_t, b, 16, 16, int32_t, 4, use_tf32::no)
79-
__SYCL_JOINT_MATRIX_OVERLOAD(half, a, 16, 16, int32_t, 8, use_tf32::no)
80-
__SYCL_JOINT_MATRIX_OVERLOAD(half, b, 16, 16, int32_t, 8, use_tf32::no)
81-
__SYCL_JOINT_MATRIX_OVERLOAD(float, accumulator, 16, 16, float, 8, use_tf32::no)
88+
__SYCL_JOINT_MATRIX_OVERLOAD(uint16_t, a, 16, 16, int32_t, 4,
89+
precision::standard)
90+
__SYCL_JOINT_MATRIX_OVERLOAD(uint16_t, b, 16, 16, int32_t, 4,
91+
precision::standard)
92+
__SYCL_JOINT_MATRIX_OVERLOAD(half, a, 16, 16, int32_t, 8, precision::standard)
93+
__SYCL_JOINT_MATRIX_OVERLOAD(half, b, 16, 16, int32_t, 8, precision::standard)
94+
__SYCL_JOINT_MATRIX_OVERLOAD(float, accumulator, 16, 16, float, 8,
95+
precision::standard)
8296
__SYCL_JOINT_MATRIX_OVERLOAD(half, accumulator, 16, 16, int32_t, 4,
83-
use_tf32::no)
84-
85-
__SYCL_JOINT_MATRIX_OVERLOAD(int8_t, a, 16, 16, int32_t, 2, use_tf32::no)
86-
__SYCL_JOINT_MATRIX_OVERLOAD(int8_t, b, 16, 16, int32_t, 2, use_tf32::no)
87-
__SYCL_JOINT_MATRIX_OVERLOAD(uint8_t, a, 16, 16, int32_t, 2, use_tf32::no)
88-
__SYCL_JOINT_MATRIX_OVERLOAD(uint8_t, b, 16, 16, int32_t, 2, use_tf32::no)
97+
precision::standard)
98+
99+
__SYCL_JOINT_MATRIX_OVERLOAD(int8_t, a, 16, 16, int32_t, 2, precision::standard)
100+
__SYCL_JOINT_MATRIX_OVERLOAD(int8_t, b, 16, 16, int32_t, 2, precision::standard)
101+
__SYCL_JOINT_MATRIX_OVERLOAD(uint8_t, a, 16, 16, int32_t, 2,
102+
precision::standard)
103+
__SYCL_JOINT_MATRIX_OVERLOAD(uint8_t, b, 16, 16, int32_t, 2,
104+
precision::standard)
89105
__SYCL_JOINT_MATRIX_OVERLOAD(int32_t, accumulator, 16, 16, int32_t, 8,
90-
use_tf32::no)
106+
precision::standard)
91107

92108
// m16n16k8 tf32
93-
__SYCL_JOINT_MATRIX_OVERLOAD(float, a, 16, 8, int32_t, 4, use_tf32::yes)
94-
__SYCL_JOINT_MATRIX_OVERLOAD(float, b, 8, 16, int32_t, 4, use_tf32::yes)
109+
__SYCL_JOINT_MATRIX_OVERLOAD(float, a, 16, 8, int32_t, 4, precision::tf32)
110+
__SYCL_JOINT_MATRIX_OVERLOAD(float, b, 8, 16, int32_t, 4, precision::tf32)
95111

96112
#undef __SYCL_JOINT_MATRIX_OVERLOAD
97113
} // namespace experimental::matrix
@@ -102,12 +118,12 @@ template <typename T, sycl::ext::oneapi::experimental::matrix::matrix_use Use,
102118
size_t NumRows, size_t NumCols,
103119
sycl::ext::oneapi::experimental::matrix::matrix_layout Layout,
104120
access::address_space Space,
105-
sycl::ext::oneapi::experimental::matrix::use_tf32 Tf32 =
106-
sycl::ext::oneapi::experimental::matrix::use_tf32::no,
121+
sycl::ext::oneapi::experimental::matrix::precision Prec =
122+
sycl::ext::oneapi::experimental::matrix::precision::standard,
107123
typename Cond = void>
108124
struct joint_matrix_load_impl {
109125
void load(sycl::ext::oneapi::experimental::matrix::joint_matrix<
110-
T, Use, NumRows, NumCols, Layout, sycl::sub_group, Tf32> &res,
126+
T, Use, NumRows, NumCols, Layout, sycl::sub_group, Prec> &res,
111127
multi_ptr<T, Space> src, size_t stride);
112128
};
113129

@@ -130,15 +146,15 @@ template <typename T, sycl::ext::oneapi::experimental::matrix::matrix_use Use,
130146
size_t NumRows, size_t NumCols,
131147
sycl::ext::oneapi::experimental::matrix::matrix_layout Layout,
132148
access::address_space Space,
133-
sycl::ext::oneapi::experimental::matrix::use_tf32 Tf32>
149+
sycl::ext::oneapi::experimental::matrix::precision Prec>
134150
struct joint_matrix_load_impl<
135-
T, Use, NumRows, NumCols, Layout, Space, Tf32,
151+
T, Use, NumRows, NumCols, Layout, Space, Prec,
136152
typename std::enable_if_t<Layout == sycl::ext::oneapi::experimental::
137153
matrix::matrix_layout::row_major ||
138154
Layout == sycl::ext::oneapi::experimental::
139155
matrix::matrix_layout::col_major>> {
140156
void load(sycl::ext::oneapi::experimental::matrix::joint_matrix<
141-
T, Use, NumRows, NumCols, Layout, sycl::sub_group, Tf32> &res,
157+
T, Use, NumRows, NumCols, Layout, sycl::sub_group, Prec> &res,
142158
multi_ptr<T, Space> src, size_t stride) {
143159
if constexpr (std::is_same<T, uint16_t>::value) {
144160
int32_t *tileptr = reinterpret_cast<int32_t *>(src.get());
@@ -263,8 +279,8 @@ struct joint_matrix_load_impl<
263279
get_layout_id<Layout>());
264280
}
265281
} else if constexpr (std::is_same<T, float>::value) {
266-
if constexpr (Tf32 ==
267-
sycl::ext::oneapi::experimental::matrix::use_tf32::yes) {
282+
if constexpr (Prec ==
283+
sycl::ext::oneapi::experimental::matrix::precision::tf32) {
268284
int32_t *tileptr = reinterpret_cast<int32_t *>(src.get());
269285
if constexpr (NumRows == 16 && NumCols == 8) {
270286
__mma_tf32_m16n16k8_ld_a(res.data, tileptr, stride,
@@ -379,19 +395,19 @@ template <typename T1, typename T2, std::size_t M, std::size_t K, std::size_t N,
379395
sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutA,
380396
sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutB,
381397
sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutC,
382-
sycl::ext::oneapi::experimental::matrix::use_tf32 Tf32,
398+
sycl::ext::oneapi::experimental::matrix::precision Prec,
383399
typename Cond = void>
384400
struct joint_matrix_mad_impl {
385401
sycl::ext::oneapi::experimental::matrix::joint_matrix<
386402
T2, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, M,
387403
N, LayoutC, sycl::sub_group>
388404
mad(sycl::ext::oneapi::experimental::matrix::joint_matrix<
389405
T1, sycl::ext::oneapi::experimental::matrix::matrix_use::a, M, K,
390-
LayoutA, sycl::sub_group, Tf32>
406+
LayoutA, sycl::sub_group, Prec>
391407
A,
392408
sycl::ext::oneapi::experimental::matrix::joint_matrix<
393409
T1, sycl::ext::oneapi::experimental::matrix::matrix_use::b, K, N,
394-
LayoutB, sycl::sub_group, Tf32>
410+
LayoutB, sycl::sub_group, Prec>
395411
B,
396412
sycl::ext::oneapi::experimental::matrix::joint_matrix<
397413
T2, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator,
@@ -435,9 +451,9 @@ template <typename T1, typename T2, std::size_t M, std::size_t K, std::size_t N,
435451
sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutA,
436452
sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutB,
437453
sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutC,
438-
sycl::ext::oneapi::experimental::matrix::use_tf32 Tf32>
454+
sycl::ext::oneapi::experimental::matrix::precision Prec>
439455
struct joint_matrix_mad_impl<
440-
T1, T2, M, K, N, LayoutA, LayoutB, LayoutC, Tf32,
456+
T1, T2, M, K, N, LayoutA, LayoutB, LayoutC, Prec,
441457
typename std::enable_if_t<
442458
(LayoutA == sycl::ext::oneapi::experimental::matrix::matrix_layout::
443459
row_major ||
@@ -456,11 +472,11 @@ struct joint_matrix_mad_impl<
456472
N, LayoutC, sycl::sub_group>
457473
mad(sycl::ext::oneapi::experimental::matrix::joint_matrix<
458474
T1, sycl::ext::oneapi::experimental::matrix::matrix_use::a, M, K,
459-
LayoutA, sycl::sub_group, Tf32>
475+
LayoutA, sycl::sub_group, Prec>
460476
A,
461477
sycl::ext::oneapi::experimental::matrix::joint_matrix<
462478
T1, sycl::ext::oneapi::experimental::matrix::matrix_use::b, K, N,
463-
LayoutB, sycl::sub_group, Tf32>
479+
LayoutB, sycl::sub_group, Prec>
464480
B,
465481
sycl::ext::oneapi::experimental::matrix::joint_matrix<
466482
T2, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator,
@@ -530,8 +546,8 @@ struct joint_matrix_mad_impl<
530546
}
531547
}
532548
} else if constexpr (M == 16 && N == 16 && K == 8 &&
533-
Tf32 == sycl::ext::oneapi::experimental::matrix::
534-
use_tf32::yes) {
549+
Prec == sycl::ext::oneapi::experimental::matrix::
550+
precision::tf32) {
535551
__mma_tf32_m16n16k8_mma_f32(D.data, A.data, B.data, C.data,
536552
get_layout_pair_id<LayoutA, LayoutB>(), 0);
537553
} else if constexpr (std::is_same<T1, double>::value) {
@@ -548,13 +564,13 @@ namespace experimental::matrix {
548564

549565
template <typename Group, typename T, matrix_use Use, size_t NumRows,
550566
size_t NumCols, matrix_layout Layout, access::address_space Space,
551-
use_tf32 Tf32 = use_tf32::no>
567+
precision Prec = precision::standard>
552568
void joint_matrix_load(
553-
Group sg, joint_matrix<T, Use, NumRows, NumCols, Layout, Group, Tf32> &res,
569+
Group sg, joint_matrix<T, Use, NumRows, NumCols, Layout, Group, Prec> &res,
554570
multi_ptr<T, Space> src, size_t stride) {
555571
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
556572
sycl::ext::oneapi::detail::joint_matrix_load_impl<T, Use, NumRows, NumCols,
557-
Layout, Space, Tf32>{}
573+
Layout, Space, Prec>{}
558574
.load(res, src, stride);
559575
#else
560576
(void)sg;
@@ -592,15 +608,15 @@ void joint_matrix_store(Group sg,
592608

593609
template <typename Group, typename T1, typename T2, std::size_t M,
594610
std::size_t K, std::size_t N, matrix_layout LayoutA,
595-
matrix_layout LayoutB, matrix_layout LayoutC, use_tf32 Tf32>
611+
matrix_layout LayoutB, matrix_layout LayoutC, precision Prec>
596612
joint_matrix<T2, matrix_use::accumulator, M, N, LayoutC, Group>
597613
joint_matrix_mad(
598-
Group sg, joint_matrix<T1, matrix_use::a, M, K, LayoutA, Group, Tf32> A,
599-
joint_matrix<T1, matrix_use::b, K, N, LayoutB, Group, Tf32> B,
614+
Group sg, joint_matrix<T1, matrix_use::a, M, K, LayoutA, Group, Prec> A,
615+
joint_matrix<T1, matrix_use::b, K, N, LayoutB, Group, Prec> B,
600616
joint_matrix<T2, matrix_use::accumulator, M, N, LayoutC, Group> C) {
601617
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
602618
return sycl::ext::oneapi::detail::joint_matrix_mad_impl<
603-
T1, T2, M, K, N, LayoutA, LayoutB, LayoutC, Tf32>{}
619+
T1, T2, M, K, N, LayoutA, LayoutB, LayoutC, Prec>{}
604620
.mad(A, B, C);
605621
#else
606622
(void)sg;

sycl/test/check_device_code/matrix/matrix-nvptx-tf32-test.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,11 @@ int main() {
5959
sycl::sub_group sg = item.get_sub_group();
6060

6161
joint_matrix<float, matrix_use::a, M, K, matrix_layout::row_major,
62-
sycl::sub_group, use_tf32::yes>
62+
sycl::sub_group, precision::tf32>
6363
sub_a;
6464

6565
joint_matrix<float, matrix_use::b, K, N, matrix_layout::row_major,
66-
sycl::sub_group, use_tf32::yes>
66+
sycl::sub_group, precision::tf32>
6767
sub_b;
6868

6969
joint_matrix<float, matrix_use::accumulator, M, N,
@@ -95,11 +95,11 @@ int main() {
9595
sycl::sub_group sg = item.get_sub_group();
9696

9797
joint_matrix<float, matrix_use::a, M, K, matrix_layout::col_major,
98-
sycl::sub_group, use_tf32::yes>
98+
sycl::sub_group, precision::tf32>
9999
sub_a;
100100

101101
joint_matrix<float, matrix_use::b, K, N, matrix_layout::col_major,
102-
sycl::sub_group, use_tf32::yes>
102+
sycl::sub_group, precision::tf32>
103103
sub_b;
104104

105105
joint_matrix<float, matrix_use::accumulator, M, N,

0 commit comments

Comments
 (0)