Skip to content

Commit be60cdd

Browse files
author
Hugh Delaney
committed
Responding to comments. Using precision::tf32 as empty class and float as the fragment type. Also adding free functions tf32_to_float and float_to_tf32
1 parent 23cb7da commit be60cdd

File tree

2 files changed

+136
-127
lines changed

2 files changed

+136
-127
lines changed

sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp

Lines changed: 108 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -18,112 +18,90 @@ enum class matrix_use { a, b, accumulator };
1818

1919
enum class matrix_layout { row_major, col_major, packed_a, packed_b };
2020

21-
enum class precision { standard, tf32 /* TODO add more precisions*/ };
21+
namespace precision {
22+
class tf32 {};
23+
} // namespace precision
2224

2325
template <typename T, matrix_use Use, size_t Rows = sycl::dynamic_extent,
2426
size_t Cols = sycl::dynamic_extent,
2527
matrix_layout Layout = matrix_layout::row_major,
26-
typename Group = sycl::sub_group,
27-
precision Prec = precision::standard, typename Cond = void>
28+
typename Group = sycl::sub_group, typename Cond = void>
2829
struct joint_matrix;
2930

30-
#define __SYCL_JOINT_MATRIX_OVERLOAD(type, use, M, N, frag_type, frag_size, \
31-
Prec) \
31+
#define __SYCL_JOINT_MATRIX_OVERLOAD(type, use, M, N, frag_type, frag_size) \
3232
template <matrix_layout Layout> \
3333
struct joint_matrix< \
34-
type, matrix_use::use, M, N, Layout, sycl::sub_group, Prec, \
34+
type, matrix_use::use, M, N, Layout, sycl::sub_group, \
3535
typename std::enable_if_t<Layout == matrix_layout::row_major || \
3636
Layout == matrix_layout::col_major>> { \
3737
frag_type data[frag_size]; \
3838
};
3939

4040
// m8n8k4 double only
41-
__SYCL_JOINT_MATRIX_OVERLOAD(double, a, 8, 4, double, 1, precision::standard)
42-
__SYCL_JOINT_MATRIX_OVERLOAD(double, b, 4, 8, double, 1, precision::standard)
43-
__SYCL_JOINT_MATRIX_OVERLOAD(double, accumulator, 8, 8, double, 2,
44-
precision::standard)
41+
__SYCL_JOINT_MATRIX_OVERLOAD(double, a, 8, 4, double, 1)
42+
__SYCL_JOINT_MATRIX_OVERLOAD(double, b, 4, 8, double, 1)
43+
__SYCL_JOINT_MATRIX_OVERLOAD(double, accumulator, 8, 8, double, 2)
4544

4645
// m8n32k16
4746
// bf16 data format uses uint16_t data type
48-
__SYCL_JOINT_MATRIX_OVERLOAD(uint16_t, a, 8, 16, int32_t, 2,
49-
precision::standard)
50-
__SYCL_JOINT_MATRIX_OVERLOAD(uint16_t, b, 16, 32, int32_t, 8,
51-
precision::standard)
52-
__SYCL_JOINT_MATRIX_OVERLOAD(half, a, 8, 16, int32_t, 8, precision::standard)
53-
__SYCL_JOINT_MATRIX_OVERLOAD(half, b, 16, 32, int32_t, 8, precision::standard)
54-
__SYCL_JOINT_MATRIX_OVERLOAD(float, accumulator, 8, 32, float, 8,
55-
precision::standard)
56-
__SYCL_JOINT_MATRIX_OVERLOAD(half, accumulator, 8, 32, int32_t, 4,
57-
precision::standard)
58-
59-
__SYCL_JOINT_MATRIX_OVERLOAD(int8_t, a, 8, 16, int32_t, 1, precision::standard)
60-
__SYCL_JOINT_MATRIX_OVERLOAD(int8_t, b, 16, 32, int32_t, 4, precision::standard)
61-
__SYCL_JOINT_MATRIX_OVERLOAD(uint8_t, a, 8, 16, int32_t, 1, precision::standard)
62-
__SYCL_JOINT_MATRIX_OVERLOAD(uint8_t, b, 16, 32, int32_t, 4,
63-
precision::standard)
64-
__SYCL_JOINT_MATRIX_OVERLOAD(int32_t, accumulator, 8, 32, int32_t, 8,
65-
precision::standard)
47+
__SYCL_JOINT_MATRIX_OVERLOAD(uint16_t, a, 8, 16, int32_t, 2)
48+
__SYCL_JOINT_MATRIX_OVERLOAD(uint16_t, b, 16, 32, int32_t, 8)
49+
__SYCL_JOINT_MATRIX_OVERLOAD(half, a, 8, 16, int32_t, 8)
50+
__SYCL_JOINT_MATRIX_OVERLOAD(half, b, 16, 32, int32_t, 8)
51+
__SYCL_JOINT_MATRIX_OVERLOAD(float, accumulator, 8, 32, float, 8)
52+
__SYCL_JOINT_MATRIX_OVERLOAD(half, accumulator, 8, 32, int32_t, 4)
53+
54+
__SYCL_JOINT_MATRIX_OVERLOAD(int8_t, a, 8, 16, int32_t, 1)
55+
__SYCL_JOINT_MATRIX_OVERLOAD(int8_t, b, 16, 32, int32_t, 4)
56+
__SYCL_JOINT_MATRIX_OVERLOAD(uint8_t, a, 8, 16, int32_t, 1)
57+
__SYCL_JOINT_MATRIX_OVERLOAD(uint8_t, b, 16, 32, int32_t, 4)
58+
__SYCL_JOINT_MATRIX_OVERLOAD(int32_t, accumulator, 8, 32, int32_t, 8)
6659

6760
// m32n8k16
68-
__SYCL_JOINT_MATRIX_OVERLOAD(uint16_t, a, 32, 16, int32_t, 8,
69-
precision::standard)
70-
__SYCL_JOINT_MATRIX_OVERLOAD(uint16_t, b, 16, 8, int32_t, 2,
71-
precision::standard)
72-
__SYCL_JOINT_MATRIX_OVERLOAD(half, a, 32, 16, int32_t, 8, precision::standard)
73-
__SYCL_JOINT_MATRIX_OVERLOAD(half, b, 16, 8, int32_t, 8, precision::standard)
74-
__SYCL_JOINT_MATRIX_OVERLOAD(float, accumulator, 32, 8, float, 8,
75-
precision::standard)
76-
__SYCL_JOINT_MATRIX_OVERLOAD(half, accumulator, 32, 8, int32_t, 4,
77-
precision::standard)
78-
79-
__SYCL_JOINT_MATRIX_OVERLOAD(int8_t, a, 32, 16, int32_t, 4, precision::standard)
80-
__SYCL_JOINT_MATRIX_OVERLOAD(int8_t, b, 16, 8, int32_t, 1, precision::standard)
81-
__SYCL_JOINT_MATRIX_OVERLOAD(uint8_t, a, 32, 16, int32_t, 4,
82-
precision::standard)
83-
__SYCL_JOINT_MATRIX_OVERLOAD(uint8_t, b, 16, 8, int32_t, 1, precision::standard)
84-
__SYCL_JOINT_MATRIX_OVERLOAD(int32_t, accumulator, 32, 8, int32_t, 8,
85-
precision::standard)
61+
__SYCL_JOINT_MATRIX_OVERLOAD(uint16_t, a, 32, 16, int32_t, 8)
62+
__SYCL_JOINT_MATRIX_OVERLOAD(uint16_t, b, 16, 8, int32_t, 2)
63+
__SYCL_JOINT_MATRIX_OVERLOAD(half, a, 32, 16, int32_t, 8)
64+
__SYCL_JOINT_MATRIX_OVERLOAD(half, b, 16, 8, int32_t, 8)
65+
__SYCL_JOINT_MATRIX_OVERLOAD(float, accumulator, 32, 8, float, 8)
66+
__SYCL_JOINT_MATRIX_OVERLOAD(half, accumulator, 32, 8, int32_t, 4)
67+
68+
__SYCL_JOINT_MATRIX_OVERLOAD(int8_t, a, 32, 16, int32_t, 4)
69+
__SYCL_JOINT_MATRIX_OVERLOAD(int8_t, b, 16, 8, int32_t, 1)
70+
__SYCL_JOINT_MATRIX_OVERLOAD(uint8_t, a, 32, 16, int32_t, 4)
71+
__SYCL_JOINT_MATRIX_OVERLOAD(uint8_t, b, 16, 8, int32_t, 1)
72+
__SYCL_JOINT_MATRIX_OVERLOAD(int32_t, accumulator, 32, 8, int32_t, 8)
8673

8774
// m16n16k16
88-
__SYCL_JOINT_MATRIX_OVERLOAD(uint16_t, a, 16, 16, int32_t, 4,
89-
precision::standard)
90-
__SYCL_JOINT_MATRIX_OVERLOAD(uint16_t, b, 16, 16, int32_t, 4,
91-
precision::standard)
92-
__SYCL_JOINT_MATRIX_OVERLOAD(half, a, 16, 16, int32_t, 8, precision::standard)
93-
__SYCL_JOINT_MATRIX_OVERLOAD(half, b, 16, 16, int32_t, 8, precision::standard)
94-
__SYCL_JOINT_MATRIX_OVERLOAD(float, accumulator, 16, 16, float, 8,
95-
precision::standard)
96-
__SYCL_JOINT_MATRIX_OVERLOAD(half, accumulator, 16, 16, int32_t, 4,
97-
precision::standard)
98-
99-
__SYCL_JOINT_MATRIX_OVERLOAD(int8_t, a, 16, 16, int32_t, 2, precision::standard)
100-
__SYCL_JOINT_MATRIX_OVERLOAD(int8_t, b, 16, 16, int32_t, 2, precision::standard)
101-
__SYCL_JOINT_MATRIX_OVERLOAD(uint8_t, a, 16, 16, int32_t, 2,
102-
precision::standard)
103-
__SYCL_JOINT_MATRIX_OVERLOAD(uint8_t, b, 16, 16, int32_t, 2,
104-
precision::standard)
105-
__SYCL_JOINT_MATRIX_OVERLOAD(int32_t, accumulator, 16, 16, int32_t, 8,
106-
precision::standard)
75+
__SYCL_JOINT_MATRIX_OVERLOAD(uint16_t, a, 16, 16, int32_t, 4)
76+
__SYCL_JOINT_MATRIX_OVERLOAD(uint16_t, b, 16, 16, int32_t, 4)
77+
__SYCL_JOINT_MATRIX_OVERLOAD(half, a, 16, 16, int32_t, 8)
78+
__SYCL_JOINT_MATRIX_OVERLOAD(half, b, 16, 16, int32_t, 8)
79+
__SYCL_JOINT_MATRIX_OVERLOAD(float, accumulator, 16, 16, float, 8)
80+
__SYCL_JOINT_MATRIX_OVERLOAD(half, accumulator, 16, 16, int32_t, 4)
81+
82+
__SYCL_JOINT_MATRIX_OVERLOAD(int8_t, a, 16, 16, int32_t, 2)
83+
__SYCL_JOINT_MATRIX_OVERLOAD(int8_t, b, 16, 16, int32_t, 2)
84+
__SYCL_JOINT_MATRIX_OVERLOAD(uint8_t, a, 16, 16, int32_t, 2)
85+
__SYCL_JOINT_MATRIX_OVERLOAD(uint8_t, b, 16, 16, int32_t, 2)
86+
__SYCL_JOINT_MATRIX_OVERLOAD(int32_t, accumulator, 16, 16, int32_t, 8)
10787

10888
// m16n16k8 tf32
109-
__SYCL_JOINT_MATRIX_OVERLOAD(float, a, 16, 8, int32_t, 4, precision::tf32)
110-
__SYCL_JOINT_MATRIX_OVERLOAD(float, b, 8, 16, int32_t, 4, precision::tf32)
89+
__SYCL_JOINT_MATRIX_OVERLOAD(precision::tf32, a, 16, 8, float, 4)
90+
__SYCL_JOINT_MATRIX_OVERLOAD(precision::tf32, b, 8, 16, float, 4)
11191

11292
#undef __SYCL_JOINT_MATRIX_OVERLOAD
11393
} // namespace experimental::matrix
11494

11595
namespace detail {
11696

117-
template <typename T, sycl::ext::oneapi::experimental::matrix::matrix_use Use,
97+
template <typename S, typename T,
98+
sycl::ext::oneapi::experimental::matrix::matrix_use Use,
11899
size_t NumRows, size_t NumCols,
119100
sycl::ext::oneapi::experimental::matrix::matrix_layout Layout,
120-
access::address_space Space,
121-
sycl::ext::oneapi::experimental::matrix::precision Prec =
122-
sycl::ext::oneapi::experimental::matrix::precision::standard,
123-
typename Cond = void>
101+
access::address_space Space, typename Cond = void>
124102
struct joint_matrix_load_impl {
125103
void load(sycl::ext::oneapi::experimental::matrix::joint_matrix<
126-
T, Use, NumRows, NumCols, Layout, sycl::sub_group, Prec> &res,
104+
S, Use, NumRows, NumCols, Layout, sycl::sub_group> &res,
127105
multi_ptr<T, Space> src, size_t stride);
128106
};
129107

@@ -142,19 +120,19 @@ constexpr int get_layout_id<
142120
return 1;
143121
}
144122

145-
template <typename T, sycl::ext::oneapi::experimental::matrix::matrix_use Use,
123+
template <typename S, typename T,
124+
sycl::ext::oneapi::experimental::matrix::matrix_use Use,
146125
size_t NumRows, size_t NumCols,
147126
sycl::ext::oneapi::experimental::matrix::matrix_layout Layout,
148-
access::address_space Space,
149-
sycl::ext::oneapi::experimental::matrix::precision Prec>
127+
access::address_space Space>
150128
struct joint_matrix_load_impl<
151-
T, Use, NumRows, NumCols, Layout, Space, Prec,
129+
S, T, Use, NumRows, NumCols, Layout, Space,
152130
typename std::enable_if_t<Layout == sycl::ext::oneapi::experimental::
153131
matrix::matrix_layout::row_major ||
154132
Layout == sycl::ext::oneapi::experimental::
155133
matrix::matrix_layout::col_major>> {
156134
void load(sycl::ext::oneapi::experimental::matrix::joint_matrix<
157-
T, Use, NumRows, NumCols, Layout, sycl::sub_group, Prec> &res,
135+
S, Use, NumRows, NumCols, Layout, sycl::sub_group> &res,
158136
multi_ptr<T, Space> src, size_t stride) {
159137
if constexpr (std::is_same<T, uint16_t>::value) {
160138
int32_t *tileptr = reinterpret_cast<int32_t *>(src.get());
@@ -279,21 +257,7 @@ struct joint_matrix_load_impl<
279257
get_layout_id<Layout>());
280258
}
281259
} else if constexpr (std::is_same<T, float>::value) {
282-
if constexpr (Prec ==
283-
sycl::ext::oneapi::experimental::matrix::precision::tf32) {
284-
int32_t *tileptr = reinterpret_cast<int32_t *>(src.get());
285-
if constexpr (NumRows == 16 && NumCols == 8) {
286-
__mma_tf32_m16n16k8_ld_a(res.data, tileptr, stride,
287-
get_layout_id<Layout>());
288-
} else if constexpr (NumRows == 8 && NumCols == 16) {
289-
__mma_tf32_m16n16k8_ld_b(res.data, tileptr, stride,
290-
get_layout_id<Layout>());
291-
}
292-
for (int i = 0; i < 4; ++i) {
293-
auto tmpfloat = __nvvm_bitcast_i2f(res.data[i]);
294-
res.data[i] = __nvvm_f2tf32_rna(tmpfloat);
295-
}
296-
} else {
260+
if (std::is_same<S, float>::value) {
297261
if constexpr (NumRows == 16 && NumCols == 16) {
298262
__hmma_m16n16k16_ld_c_f32(res.data, src.get(), stride,
299263
get_layout_id<Layout>());
@@ -304,6 +268,16 @@ struct joint_matrix_load_impl<
304268
__hmma_m32n8k16_ld_c_f32(res.data, src.get(), stride,
305269
get_layout_id<Layout>());
306270
}
271+
} else if (std::is_same<S, sycl::ext::oneapi::experimental::matrix::
272+
precision::tf32>::value) {
273+
int32_t *tileptr = reinterpret_cast<int32_t *>(src.get());
274+
if constexpr (NumRows == 16 && NumCols == 8) {
275+
__mma_tf32_m16n16k8_ld_a(reinterpret_cast<int32_t *>(res.data),
276+
tileptr, stride, get_layout_id<Layout>());
277+
} else if constexpr (NumRows == 8 && NumCols == 16) {
278+
__mma_tf32_m16n16k8_ld_b(reinterpret_cast<int32_t *>(res.data),
279+
tileptr, stride, get_layout_id<Layout>());
280+
}
307281
}
308282
} else if constexpr (std::is_same<T, double>::value) {
309283
if constexpr (Use ==
@@ -395,19 +369,18 @@ template <typename T1, typename T2, std::size_t M, std::size_t K, std::size_t N,
395369
sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutA,
396370
sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutB,
397371
sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutC,
398-
sycl::ext::oneapi::experimental::matrix::precision Prec,
399372
typename Cond = void>
400373
struct joint_matrix_mad_impl {
401374
sycl::ext::oneapi::experimental::matrix::joint_matrix<
402375
T2, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, M,
403376
N, LayoutC, sycl::sub_group>
404377
mad(sycl::ext::oneapi::experimental::matrix::joint_matrix<
405378
T1, sycl::ext::oneapi::experimental::matrix::matrix_use::a, M, K,
406-
LayoutA, sycl::sub_group, Prec>
379+
LayoutA, sycl::sub_group>
407380
A,
408381
sycl::ext::oneapi::experimental::matrix::joint_matrix<
409382
T1, sycl::ext::oneapi::experimental::matrix::matrix_use::b, K, N,
410-
LayoutB, sycl::sub_group, Prec>
383+
LayoutB, sycl::sub_group>
411384
B,
412385
sycl::ext::oneapi::experimental::matrix::joint_matrix<
413386
T2, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator,
@@ -450,10 +423,9 @@ constexpr int get_layout_pair_id<
450423
template <typename T1, typename T2, std::size_t M, std::size_t K, std::size_t N,
451424
sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutA,
452425
sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutB,
453-
sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutC,
454-
sycl::ext::oneapi::experimental::matrix::precision Prec>
426+
sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutC>
455427
struct joint_matrix_mad_impl<
456-
T1, T2, M, K, N, LayoutA, LayoutB, LayoutC, Prec,
428+
T1, T2, M, K, N, LayoutA, LayoutB, LayoutC,
457429
typename std::enable_if_t<
458430
(LayoutA == sycl::ext::oneapi::experimental::matrix::matrix_layout::
459431
row_major ||
@@ -472,11 +444,11 @@ struct joint_matrix_mad_impl<
472444
N, LayoutC, sycl::sub_group>
473445
mad(sycl::ext::oneapi::experimental::matrix::joint_matrix<
474446
T1, sycl::ext::oneapi::experimental::matrix::matrix_use::a, M, K,
475-
LayoutA, sycl::sub_group, Prec>
447+
LayoutA, sycl::sub_group>
476448
A,
477449
sycl::ext::oneapi::experimental::matrix::joint_matrix<
478450
T1, sycl::ext::oneapi::experimental::matrix::matrix_use::b, K, N,
479-
LayoutB, sycl::sub_group, Prec>
451+
LayoutB, sycl::sub_group>
480452
B,
481453
sycl::ext::oneapi::experimental::matrix::joint_matrix<
482454
T2, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator,
@@ -545,10 +517,9 @@ struct joint_matrix_mad_impl<
545517
get_layout_pair_id<LayoutA, LayoutB>(), 0);
546518
}
547519
}
548-
} else if constexpr (M == 16 && N == 16 && K == 8 &&
549-
Prec == sycl::ext::oneapi::experimental::matrix::
550-
precision::tf32) {
551-
__mma_tf32_m16n16k8_mma_f32(D.data, A.data, B.data, C.data,
520+
} else if constexpr (M == 16 && N == 16 && K == 8) {
521+
__mma_tf32_m16n16k8_mma_f32(D.data, reinterpret_cast<int32_t *>(A.data),
522+
reinterpret_cast<int32_t *>(B.data), C.data,
552523
get_layout_pair_id<LayoutA, LayoutB>(), 0);
553524
} else if constexpr (std::is_same<T1, double>::value) {
554525
__dmma_m8n8k4_mma_f64(D.data, A.data, B.data, C.data,
@@ -562,15 +533,19 @@ struct joint_matrix_mad_impl<
562533

563534
namespace experimental::matrix {
564535

565-
template <typename Group, typename T, matrix_use Use, size_t NumRows,
566-
size_t NumCols, matrix_layout Layout, access::address_space Space,
567-
precision Prec = precision::standard>
536+
template <typename Group, typename S, typename T, matrix_use Use,
537+
size_t NumRows, size_t NumCols, matrix_layout Layout,
538+
access::address_space Space,
539+
std::enable_if_t<std::is_same<S, T>::value ||
540+
(std::is_same<S, precision::tf32>::value &&
541+
std::is_same<T, float>::value),
542+
bool> = true>
568543
void joint_matrix_load(
569-
Group sg, joint_matrix<T, Use, NumRows, NumCols, Layout, Group, Prec> &res,
544+
Group sg, joint_matrix<S, Use, NumRows, NumCols, Layout, Group> &res,
570545
multi_ptr<T, Space> src, size_t stride) {
571546
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
572-
sycl::ext::oneapi::detail::joint_matrix_load_impl<T, Use, NumRows, NumCols,
573-
Layout, Space, Prec>{}
547+
sycl::ext::oneapi::detail::joint_matrix_load_impl<S, T, Use, NumRows, NumCols,
548+
Layout, Space>{}
574549
.load(res, src, stride);
575550
#else
576551
(void)sg;
@@ -608,15 +583,15 @@ void joint_matrix_store(Group sg,
608583

609584
template <typename Group, typename T1, typename T2, std::size_t M,
610585
std::size_t K, std::size_t N, matrix_layout LayoutA,
611-
matrix_layout LayoutB, matrix_layout LayoutC, precision Prec>
586+
matrix_layout LayoutB, matrix_layout LayoutC>
612587
joint_matrix<T2, matrix_use::accumulator, M, N, LayoutC, Group>
613588
joint_matrix_mad(
614-
Group sg, joint_matrix<T1, matrix_use::a, M, K, LayoutA, Group, Prec> A,
615-
joint_matrix<T1, matrix_use::b, K, N, LayoutB, Group, Prec> B,
589+
Group sg, joint_matrix<T1, matrix_use::a, M, K, LayoutA, Group> A,
590+
joint_matrix<T1, matrix_use::b, K, N, LayoutB, Group> B,
616591
joint_matrix<T2, matrix_use::accumulator, M, N, LayoutC, Group> C) {
617592
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
618593
return sycl::ext::oneapi::detail::joint_matrix_mad_impl<
619-
T1, T2, M, K, N, LayoutA, LayoutB, LayoutC, Prec>{}
594+
T1, T2, M, K, N, LayoutA, LayoutB, LayoutC>{}
620595
.mad(A, B, C);
621596
#else
622597
(void)sg;
@@ -629,6 +604,24 @@ joint_matrix_mad(
629604
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
630605
}
631606

607+
float float_to_tf32(float a) {
608+
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
609+
int32_t tmp_int = __nvvm_f2tf32_rna(a);
610+
return __nvvm_bitcast_i2f(tmp_int);
611+
#else
612+
throw runtime_error("When using SYCL_EXT_ONEAPI_MATRIX=3 float_to_tf32 is "
613+
"only supported by CUDA devices",
614+
PI_INVALID_DEVICE);
615+
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
616+
}
617+
618+
// This function just zeros out the bottom 13 bits of the tf32 type
619+
float tf32_to_float(float a) {
620+
uint32_t tmp_uint = reinterpret_cast<uint32_t &>(a);
621+
tmp_uint &= 0xFFFFE000u;
622+
return reinterpret_cast<float &>(tmp_uint);
623+
}
624+
632625
} // namespace experimental::matrix
633626
} // namespace oneapi
634627
} // namespace ext

0 commit comments

Comments
 (0)