@@ -18,112 +18,90 @@ enum class matrix_use { a, b, accumulator };
18
18
19
19
enum class matrix_layout { row_major, col_major, packed_a, packed_b };
20
20
21
- enum class precision { standard, tf32 /* TODO add more precisions*/ };
21
+ namespace precision {
22
+ class tf32 {};
23
+ } // namespace precision
22
24
23
25
template <typename T, matrix_use Use, size_t Rows = sycl::dynamic_extent,
24
26
size_t Cols = sycl::dynamic_extent,
25
27
matrix_layout Layout = matrix_layout::row_major,
26
- typename Group = sycl::sub_group,
27
- precision Prec = precision::standard, typename Cond = void >
28
+ typename Group = sycl::sub_group, typename Cond = void >
28
29
struct joint_matrix ;
29
30
30
- #define __SYCL_JOINT_MATRIX_OVERLOAD (type, use, M, N, frag_type, frag_size, \
31
- Prec) \
31
+ #define __SYCL_JOINT_MATRIX_OVERLOAD (type, use, M, N, frag_type, frag_size ) \
32
32
template <matrix_layout Layout> \
33
33
struct joint_matrix < \
34
- type, matrix_use::use, M, N, Layout, sycl::sub_group, Prec, \
34
+ type, matrix_use::use, M, N, Layout, sycl::sub_group, \
35
35
typename std::enable_if_t <Layout == matrix_layout::row_major || \
36
36
Layout == matrix_layout::col_major>> { \
37
37
frag_type data[frag_size]; \
38
38
};
39
39
40
40
// m8n8k4 double only
41
- __SYCL_JOINT_MATRIX_OVERLOAD (double , a, 8 , 4 , double , 1 , precision::standard)
42
- __SYCL_JOINT_MATRIX_OVERLOAD (double , b, 4 , 8 , double , 1 , precision::standard)
43
- __SYCL_JOINT_MATRIX_OVERLOAD (double , accumulator, 8 , 8 , double , 2 ,
44
- precision::standard)
41
+ __SYCL_JOINT_MATRIX_OVERLOAD (double , a, 8 , 4 , double , 1 )
42
+ __SYCL_JOINT_MATRIX_OVERLOAD (double , b, 4 , 8 , double , 1 )
43
+ __SYCL_JOINT_MATRIX_OVERLOAD (double , accumulator, 8 , 8 , double , 2 )
45
44
46
45
// m8n32k16
47
46
// bf16 data format uses uint16_t data type
48
- __SYCL_JOINT_MATRIX_OVERLOAD (uint16_t , a, 8 , 16 , int32_t , 2 ,
49
- precision::standard)
50
- __SYCL_JOINT_MATRIX_OVERLOAD (uint16_t , b, 16 , 32 , int32_t , 8 ,
51
- precision::standard)
52
- __SYCL_JOINT_MATRIX_OVERLOAD (half, a, 8 , 16 , int32_t , 8 , precision::standard)
53
- __SYCL_JOINT_MATRIX_OVERLOAD (half, b, 16 , 32 , int32_t , 8 , precision::standard)
54
- __SYCL_JOINT_MATRIX_OVERLOAD (float , accumulator, 8 , 32 , float , 8 ,
55
- precision::standard)
56
- __SYCL_JOINT_MATRIX_OVERLOAD (half, accumulator, 8 , 32 , int32_t , 4 ,
57
- precision::standard)
58
-
59
- __SYCL_JOINT_MATRIX_OVERLOAD (int8_t , a, 8 , 16 , int32_t , 1 , precision::standard)
60
- __SYCL_JOINT_MATRIX_OVERLOAD (int8_t , b, 16 , 32 , int32_t , 4 , precision::standard)
61
- __SYCL_JOINT_MATRIX_OVERLOAD (uint8_t , a, 8 , 16 , int32_t , 1 , precision::standard)
62
- __SYCL_JOINT_MATRIX_OVERLOAD (uint8_t , b, 16 , 32 , int32_t , 4 ,
63
- precision::standard)
64
- __SYCL_JOINT_MATRIX_OVERLOAD (int32_t , accumulator, 8 , 32 , int32_t , 8 ,
65
- precision::standard)
47
+ __SYCL_JOINT_MATRIX_OVERLOAD (uint16_t , a, 8 , 16 , int32_t , 2 )
48
+ __SYCL_JOINT_MATRIX_OVERLOAD (uint16_t , b, 16 , 32 , int32_t , 8 )
49
+ __SYCL_JOINT_MATRIX_OVERLOAD (half, a, 8 , 16 , int32_t , 8 )
50
+ __SYCL_JOINT_MATRIX_OVERLOAD (half, b, 16 , 32 , int32_t , 8 )
51
+ __SYCL_JOINT_MATRIX_OVERLOAD (float , accumulator, 8 , 32 , float , 8 )
52
+ __SYCL_JOINT_MATRIX_OVERLOAD (half, accumulator, 8 , 32 , int32_t , 4 )
53
+
54
+ __SYCL_JOINT_MATRIX_OVERLOAD (int8_t , a, 8 , 16 , int32_t , 1 )
55
+ __SYCL_JOINT_MATRIX_OVERLOAD (int8_t , b, 16 , 32 , int32_t , 4 )
56
+ __SYCL_JOINT_MATRIX_OVERLOAD (uint8_t , a, 8 , 16 , int32_t , 1 )
57
+ __SYCL_JOINT_MATRIX_OVERLOAD (uint8_t , b, 16 , 32 , int32_t , 4 )
58
+ __SYCL_JOINT_MATRIX_OVERLOAD (int32_t , accumulator, 8 , 32 , int32_t , 8 )
66
59
67
60
// m32n8k16
68
- __SYCL_JOINT_MATRIX_OVERLOAD (uint16_t , a, 32 , 16 , int32_t , 8 ,
69
- precision::standard)
70
- __SYCL_JOINT_MATRIX_OVERLOAD (uint16_t , b, 16 , 8 , int32_t , 2 ,
71
- precision::standard)
72
- __SYCL_JOINT_MATRIX_OVERLOAD (half, a, 32 , 16 , int32_t , 8 , precision::standard)
73
- __SYCL_JOINT_MATRIX_OVERLOAD (half, b, 16 , 8 , int32_t , 8 , precision::standard)
74
- __SYCL_JOINT_MATRIX_OVERLOAD (float , accumulator, 32 , 8 , float , 8 ,
75
- precision::standard)
76
- __SYCL_JOINT_MATRIX_OVERLOAD (half, accumulator, 32 , 8 , int32_t , 4 ,
77
- precision::standard)
78
-
79
- __SYCL_JOINT_MATRIX_OVERLOAD (int8_t , a, 32 , 16 , int32_t , 4 , precision::standard)
80
- __SYCL_JOINT_MATRIX_OVERLOAD (int8_t , b, 16 , 8 , int32_t , 1 , precision::standard)
81
- __SYCL_JOINT_MATRIX_OVERLOAD (uint8_t , a, 32 , 16 , int32_t , 4 ,
82
- precision::standard)
83
- __SYCL_JOINT_MATRIX_OVERLOAD (uint8_t , b, 16 , 8 , int32_t , 1 , precision::standard)
84
- __SYCL_JOINT_MATRIX_OVERLOAD (int32_t , accumulator, 32 , 8 , int32_t , 8 ,
85
- precision::standard)
61
+ __SYCL_JOINT_MATRIX_OVERLOAD (uint16_t , a, 32 , 16 , int32_t , 8 )
62
+ __SYCL_JOINT_MATRIX_OVERLOAD (uint16_t , b, 16 , 8 , int32_t , 2 )
63
+ __SYCL_JOINT_MATRIX_OVERLOAD (half, a, 32 , 16 , int32_t , 8 )
64
+ __SYCL_JOINT_MATRIX_OVERLOAD (half, b, 16 , 8 , int32_t , 8 )
65
+ __SYCL_JOINT_MATRIX_OVERLOAD (float , accumulator, 32 , 8 , float , 8 )
66
+ __SYCL_JOINT_MATRIX_OVERLOAD (half, accumulator, 32 , 8 , int32_t , 4 )
67
+
68
+ __SYCL_JOINT_MATRIX_OVERLOAD (int8_t , a, 32 , 16 , int32_t , 4 )
69
+ __SYCL_JOINT_MATRIX_OVERLOAD (int8_t , b, 16 , 8 , int32_t , 1 )
70
+ __SYCL_JOINT_MATRIX_OVERLOAD (uint8_t , a, 32 , 16 , int32_t , 4 )
71
+ __SYCL_JOINT_MATRIX_OVERLOAD (uint8_t , b, 16 , 8 , int32_t , 1 )
72
+ __SYCL_JOINT_MATRIX_OVERLOAD (int32_t , accumulator, 32 , 8 , int32_t , 8 )
86
73
87
74
// m16n16k16
88
- __SYCL_JOINT_MATRIX_OVERLOAD (uint16_t , a, 16 , 16 , int32_t , 4 ,
89
- precision::standard)
90
- __SYCL_JOINT_MATRIX_OVERLOAD (uint16_t , b, 16 , 16 , int32_t , 4 ,
91
- precision::standard)
92
- __SYCL_JOINT_MATRIX_OVERLOAD (half, a, 16 , 16 , int32_t , 8 , precision::standard)
93
- __SYCL_JOINT_MATRIX_OVERLOAD (half, b, 16 , 16 , int32_t , 8 , precision::standard)
94
- __SYCL_JOINT_MATRIX_OVERLOAD (float , accumulator, 16 , 16 , float , 8 ,
95
- precision::standard)
96
- __SYCL_JOINT_MATRIX_OVERLOAD (half, accumulator, 16 , 16 , int32_t , 4 ,
97
- precision::standard)
98
-
99
- __SYCL_JOINT_MATRIX_OVERLOAD (int8_t , a, 16 , 16 , int32_t , 2 , precision::standard)
100
- __SYCL_JOINT_MATRIX_OVERLOAD (int8_t , b, 16 , 16 , int32_t , 2 , precision::standard)
101
- __SYCL_JOINT_MATRIX_OVERLOAD (uint8_t , a, 16 , 16 , int32_t , 2 ,
102
- precision::standard)
103
- __SYCL_JOINT_MATRIX_OVERLOAD (uint8_t , b, 16 , 16 , int32_t , 2 ,
104
- precision::standard)
105
- __SYCL_JOINT_MATRIX_OVERLOAD (int32_t , accumulator, 16 , 16 , int32_t , 8 ,
106
- precision::standard)
75
+ __SYCL_JOINT_MATRIX_OVERLOAD (uint16_t , a, 16 , 16 , int32_t , 4 )
76
+ __SYCL_JOINT_MATRIX_OVERLOAD (uint16_t , b, 16 , 16 , int32_t , 4 )
77
+ __SYCL_JOINT_MATRIX_OVERLOAD (half, a, 16 , 16 , int32_t , 8 )
78
+ __SYCL_JOINT_MATRIX_OVERLOAD (half, b, 16 , 16 , int32_t , 8 )
79
+ __SYCL_JOINT_MATRIX_OVERLOAD (float , accumulator, 16 , 16 , float , 8 )
80
+ __SYCL_JOINT_MATRIX_OVERLOAD (half, accumulator, 16 , 16 , int32_t , 4 )
81
+
82
+ __SYCL_JOINT_MATRIX_OVERLOAD (int8_t , a, 16 , 16 , int32_t , 2 )
83
+ __SYCL_JOINT_MATRIX_OVERLOAD (int8_t , b, 16 , 16 , int32_t , 2 )
84
+ __SYCL_JOINT_MATRIX_OVERLOAD (uint8_t , a, 16 , 16 , int32_t , 2 )
85
+ __SYCL_JOINT_MATRIX_OVERLOAD (uint8_t , b, 16 , 16 , int32_t , 2 )
86
+ __SYCL_JOINT_MATRIX_OVERLOAD (int32_t , accumulator, 16 , 16 , int32_t , 8 )
107
87
108
88
// m16n16k8 tf32
109
- __SYCL_JOINT_MATRIX_OVERLOAD (float , a, 16 , 8 , int32_t , 4 , precision::tf32 )
110
- __SYCL_JOINT_MATRIX_OVERLOAD (float , b, 8 , 16 , int32_t , 4 , precision::tf32 )
89
+ __SYCL_JOINT_MATRIX_OVERLOAD (precision::tf32 , a, 16 , 8 , float , 4 )
90
+ __SYCL_JOINT_MATRIX_OVERLOAD (precision::tf32 , b, 8 , 16 , float , 4 )
111
91
112
92
#undef __SYCL_JOINT_MATRIX_OVERLOAD
113
93
} // namespace experimental::matrix
114
94
115
95
namespace detail {
116
96
117
- template <typename T, sycl::ext::oneapi::experimental::matrix::matrix_use Use,
97
+ template <typename S, typename T,
98
+ sycl::ext::oneapi::experimental::matrix::matrix_use Use,
118
99
size_t NumRows, size_t NumCols,
119
100
sycl::ext::oneapi::experimental::matrix::matrix_layout Layout,
120
- access::address_space Space,
121
- sycl::ext::oneapi::experimental::matrix::precision Prec =
122
- sycl::ext::oneapi::experimental::matrix::precision::standard,
123
- typename Cond = void >
101
+ access::address_space Space, typename Cond = void >
124
102
struct joint_matrix_load_impl {
125
103
void load (sycl::ext::oneapi::experimental::matrix::joint_matrix<
126
- T , Use, NumRows, NumCols, Layout, sycl::sub_group, Prec > &res,
104
+ S , Use, NumRows, NumCols, Layout, sycl::sub_group> &res,
127
105
multi_ptr<T, Space> src, size_t stride);
128
106
};
129
107
@@ -142,19 +120,19 @@ constexpr int get_layout_id<
142
120
return 1 ;
143
121
}
144
122
145
- template <typename T, sycl::ext::oneapi::experimental::matrix::matrix_use Use,
123
+ template <typename S, typename T,
124
+ sycl::ext::oneapi::experimental::matrix::matrix_use Use,
146
125
size_t NumRows, size_t NumCols,
147
126
sycl::ext::oneapi::experimental::matrix::matrix_layout Layout,
148
- access::address_space Space,
149
- sycl::ext::oneapi::experimental::matrix::precision Prec>
127
+ access::address_space Space>
150
128
struct joint_matrix_load_impl <
151
- T, Use, NumRows, NumCols, Layout, Space, Prec ,
129
+ S, T, Use, NumRows, NumCols, Layout, Space,
152
130
typename std::enable_if_t <Layout == sycl::ext::oneapi::experimental::
153
131
matrix::matrix_layout::row_major ||
154
132
Layout == sycl::ext::oneapi::experimental::
155
133
matrix::matrix_layout::col_major>> {
156
134
void load (sycl::ext::oneapi::experimental::matrix::joint_matrix<
157
- T , Use, NumRows, NumCols, Layout, sycl::sub_group, Prec > &res,
135
+ S , Use, NumRows, NumCols, Layout, sycl::sub_group> &res,
158
136
multi_ptr<T, Space> src, size_t stride) {
159
137
if constexpr (std::is_same<T, uint16_t >::value) {
160
138
int32_t *tileptr = reinterpret_cast <int32_t *>(src.get ());
@@ -279,21 +257,7 @@ struct joint_matrix_load_impl<
279
257
get_layout_id<Layout>());
280
258
}
281
259
} else if constexpr (std::is_same<T, float >::value) {
282
- if constexpr (Prec ==
283
- sycl::ext::oneapi::experimental::matrix::precision::tf32) {
284
- int32_t *tileptr = reinterpret_cast <int32_t *>(src.get ());
285
- if constexpr (NumRows == 16 && NumCols == 8 ) {
286
- __mma_tf32_m16n16k8_ld_a (res.data , tileptr, stride,
287
- get_layout_id<Layout>());
288
- } else if constexpr (NumRows == 8 && NumCols == 16 ) {
289
- __mma_tf32_m16n16k8_ld_b (res.data , tileptr, stride,
290
- get_layout_id<Layout>());
291
- }
292
- for (int i = 0 ; i < 4 ; ++i) {
293
- auto tmpfloat = __nvvm_bitcast_i2f (res.data [i]);
294
- res.data [i] = __nvvm_f2tf32_rna (tmpfloat);
295
- }
296
- } else {
260
+ if (std::is_same<S, float >::value) {
297
261
if constexpr (NumRows == 16 && NumCols == 16 ) {
298
262
__hmma_m16n16k16_ld_c_f32 (res.data , src.get (), stride,
299
263
get_layout_id<Layout>());
@@ -304,6 +268,16 @@ struct joint_matrix_load_impl<
304
268
__hmma_m32n8k16_ld_c_f32 (res.data , src.get (), stride,
305
269
get_layout_id<Layout>());
306
270
}
271
+ } else if (std::is_same<S, sycl::ext::oneapi::experimental::matrix::
272
+ precision::tf32>::value) {
273
+ int32_t *tileptr = reinterpret_cast <int32_t *>(src.get ());
274
+ if constexpr (NumRows == 16 && NumCols == 8 ) {
275
+ __mma_tf32_m16n16k8_ld_a (reinterpret_cast <int32_t *>(res.data ),
276
+ tileptr, stride, get_layout_id<Layout>());
277
+ } else if constexpr (NumRows == 8 && NumCols == 16 ) {
278
+ __mma_tf32_m16n16k8_ld_b (reinterpret_cast <int32_t *>(res.data ),
279
+ tileptr, stride, get_layout_id<Layout>());
280
+ }
307
281
}
308
282
} else if constexpr (std::is_same<T, double >::value) {
309
283
if constexpr (Use ==
@@ -395,19 +369,18 @@ template <typename T1, typename T2, std::size_t M, std::size_t K, std::size_t N,
395
369
sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutA,
396
370
sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutB,
397
371
sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutC,
398
- sycl::ext::oneapi::experimental::matrix::precision Prec,
399
372
typename Cond = void >
400
373
struct joint_matrix_mad_impl {
401
374
sycl::ext::oneapi::experimental::matrix::joint_matrix<
402
375
T2, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, M,
403
376
N, LayoutC, sycl::sub_group>
404
377
mad (sycl::ext::oneapi::experimental::matrix::joint_matrix<
405
378
T1, sycl::ext::oneapi::experimental::matrix::matrix_use::a, M, K,
406
- LayoutA, sycl::sub_group, Prec >
379
+ LayoutA, sycl::sub_group>
407
380
A,
408
381
sycl::ext::oneapi::experimental::matrix::joint_matrix<
409
382
T1, sycl::ext::oneapi::experimental::matrix::matrix_use::b, K, N,
410
- LayoutB, sycl::sub_group, Prec >
383
+ LayoutB, sycl::sub_group>
411
384
B,
412
385
sycl::ext::oneapi::experimental::matrix::joint_matrix<
413
386
T2, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator,
@@ -450,10 +423,9 @@ constexpr int get_layout_pair_id<
450
423
template <typename T1, typename T2, std::size_t M, std::size_t K, std::size_t N,
451
424
sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutA,
452
425
sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutB,
453
- sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutC,
454
- sycl::ext::oneapi::experimental::matrix::precision Prec>
426
+ sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutC>
455
427
struct joint_matrix_mad_impl <
456
- T1, T2, M, K, N, LayoutA, LayoutB, LayoutC, Prec,
428
+ T1, T2, M, K, N, LayoutA, LayoutB, LayoutC,
457
429
typename std::enable_if_t <
458
430
(LayoutA == sycl::ext::oneapi::experimental::matrix::matrix_layout::
459
431
row_major ||
@@ -472,11 +444,11 @@ struct joint_matrix_mad_impl<
472
444
N, LayoutC, sycl::sub_group>
473
445
mad (sycl::ext::oneapi::experimental::matrix::joint_matrix<
474
446
T1, sycl::ext::oneapi::experimental::matrix::matrix_use::a, M, K,
475
- LayoutA, sycl::sub_group, Prec >
447
+ LayoutA, sycl::sub_group>
476
448
A,
477
449
sycl::ext::oneapi::experimental::matrix::joint_matrix<
478
450
T1, sycl::ext::oneapi::experimental::matrix::matrix_use::b, K, N,
479
- LayoutB, sycl::sub_group, Prec >
451
+ LayoutB, sycl::sub_group>
480
452
B,
481
453
sycl::ext::oneapi::experimental::matrix::joint_matrix<
482
454
T2, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator,
@@ -545,10 +517,9 @@ struct joint_matrix_mad_impl<
545
517
get_layout_pair_id<LayoutA, LayoutB>(), 0 );
546
518
}
547
519
}
548
- } else if constexpr (M == 16 && N == 16 && K == 8 &&
549
- Prec == sycl::ext::oneapi::experimental::matrix::
550
- precision::tf32) {
551
- __mma_tf32_m16n16k8_mma_f32 (D.data , A.data , B.data , C.data ,
520
+ } else if constexpr (M == 16 && N == 16 && K == 8 ) {
521
+ __mma_tf32_m16n16k8_mma_f32 (D.data , reinterpret_cast <int32_t *>(A.data ),
522
+ reinterpret_cast <int32_t *>(B.data ), C.data ,
552
523
get_layout_pair_id<LayoutA, LayoutB>(), 0 );
553
524
} else if constexpr (std::is_same<T1, double >::value) {
554
525
__dmma_m8n8k4_mma_f64 (D.data , A.data , B.data , C.data ,
@@ -562,15 +533,19 @@ struct joint_matrix_mad_impl<
562
533
563
534
namespace experimental ::matrix {
564
535
565
- template <typename Group, typename T, matrix_use Use, size_t NumRows,
566
- size_t NumCols, matrix_layout Layout, access::address_space Space,
567
- precision Prec = precision::standard>
536
+ template <typename Group, typename S, typename T, matrix_use Use,
537
+ size_t NumRows, size_t NumCols, matrix_layout Layout,
538
+ access::address_space Space,
539
+ std::enable_if_t <std::is_same<S, T>::value ||
540
+ (std::is_same<S, precision::tf32>::value &&
541
+ std::is_same<T, float >::value),
542
+ bool > = true >
568
543
void joint_matrix_load (
569
- Group sg, joint_matrix<T , Use, NumRows, NumCols, Layout, Group, Prec > &res,
544
+ Group sg, joint_matrix<S , Use, NumRows, NumCols, Layout, Group> &res,
570
545
multi_ptr<T, Space> src, size_t stride) {
571
546
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
572
- sycl::ext::oneapi::detail::joint_matrix_load_impl<T, Use, NumRows, NumCols,
573
- Layout, Space, Prec >{}
547
+ sycl::ext::oneapi::detail::joint_matrix_load_impl<S, T, Use, NumRows, NumCols,
548
+ Layout, Space>{}
574
549
.load (res, src, stride);
575
550
#else
576
551
(void )sg;
@@ -608,15 +583,15 @@ void joint_matrix_store(Group sg,
608
583
609
584
template <typename Group, typename T1, typename T2, std::size_t M,
610
585
std::size_t K, std::size_t N, matrix_layout LayoutA,
611
- matrix_layout LayoutB, matrix_layout LayoutC, precision Prec >
586
+ matrix_layout LayoutB, matrix_layout LayoutC>
612
587
joint_matrix<T2, matrix_use::accumulator, M, N, LayoutC, Group>
613
588
joint_matrix_mad (
614
- Group sg, joint_matrix<T1, matrix_use::a, M, K, LayoutA, Group, Prec > A,
615
- joint_matrix<T1, matrix_use::b, K, N, LayoutB, Group, Prec > B,
589
+ Group sg, joint_matrix<T1, matrix_use::a, M, K, LayoutA, Group> A,
590
+ joint_matrix<T1, matrix_use::b, K, N, LayoutB, Group> B,
616
591
joint_matrix<T2, matrix_use::accumulator, M, N, LayoutC, Group> C) {
617
592
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
618
593
return sycl::ext::oneapi::detail::joint_matrix_mad_impl<
619
- T1, T2, M, K, N, LayoutA, LayoutB, LayoutC, Prec >{}
594
+ T1, T2, M, K, N, LayoutA, LayoutB, LayoutC>{}
620
595
.mad (A, B, C);
621
596
#else
622
597
(void )sg;
@@ -629,6 +604,24 @@ joint_matrix_mad(
629
604
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
630
605
}
631
606
607
+ float float_to_tf32 (float a) {
608
+ #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
609
+ int32_t tmp_int = __nvvm_f2tf32_rna (a);
610
+ return __nvvm_bitcast_i2f (tmp_int);
611
+ #else
612
+ throw runtime_error (" When using SYCL_EXT_ONEAPI_MATRIX=3 float_to_tf32 is "
613
+ " only supported by CUDA devices" ,
614
+ PI_INVALID_DEVICE);
615
+ #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
616
+ }
617
+
618
+ // This function just zeros out the bottom 13 bits of the tf32 type
619
+ float tf32_to_float (float a) {
620
+ uint32_t tmp_uint = reinterpret_cast <uint32_t &>(a);
621
+ tmp_uint &= 0xFFFFE000u ;
622
+ return reinterpret_cast <float &>(tmp_uint);
623
+ }
624
+
632
625
} // namespace experimental::matrix
633
626
} // namespace oneapi
634
627
} // namespace ext
0 commit comments