7
7
// ===--------------------------------------------------------------------=== //
8
8
9
9
#pragma once
10
- #include < sycl/ext/oneapi/experimental/ bfloat16.hpp>
10
+ #include < sycl/ext/oneapi/bfloat16.hpp>
11
11
12
12
__SYCL_INLINE_NAMESPACE (cl) {
13
13
namespace sycl {
@@ -219,8 +219,7 @@ struct joint_matrix_load_impl<
219
219
S, Use, NumRows, NumCols, Layout, sycl::sub_group> &res,
220
220
multi_ptr<T, Space> src, size_t stride) {
221
221
if constexpr (std::is_same<T, uint16_t >::value ||
222
- std::is_same<
223
- T, sycl::ext::oneapi::experimental::bfloat16>::value) {
222
+ std::is_same<T, sycl::ext::oneapi::bfloat16>::value) {
224
223
auto tileptr = reinterpret_cast <int32_t const *>(src.get ());
225
224
auto destptr = reinterpret_cast <int32_t *>(&res.wi_marray );
226
225
if constexpr (NumRows == 16 && NumCols == 16 ) {
@@ -585,8 +584,8 @@ struct joint_matrix_mad_impl<
585
584
get_layout_pair_id<LayoutA, LayoutB>(), 0 );
586
585
}
587
586
} else if constexpr (std::is_same<T1, uint16_t >::value ||
588
- std::is_same<T1, sycl::ext::oneapi::experimental::
589
- bfloat16>::value) {
587
+ std::is_same<T1,
588
+ sycl::ext::oneapi:: bfloat16>::value) {
590
589
__mma_bf16_m16n16k16_mma_f32 (
591
590
reinterpret_cast <float *>(&D.wi_marray ),
592
591
reinterpret_cast <int32_t const *>(&A.wi_marray ),
@@ -622,8 +621,8 @@ struct joint_matrix_mad_impl<
622
621
get_layout_pair_id<LayoutA, LayoutB>(), 0 );
623
622
}
624
623
} else if constexpr (std::is_same<T1, uint16_t >::value ||
625
- std::is_same<T1, sycl::ext::oneapi::experimental::
626
- bfloat16>::value) {
624
+ std::is_same<T1,
625
+ sycl::ext::oneapi:: bfloat16>::value) {
627
626
__mma_bf16_m8n32k16_mma_f32 (
628
627
reinterpret_cast <float *>(&D.wi_marray ),
629
628
reinterpret_cast <int32_t const *>(&A.wi_marray ),
@@ -645,8 +644,8 @@ struct joint_matrix_mad_impl<
645
644
get_layout_pair_id<LayoutA, LayoutB>(), 0 );
646
645
}
647
646
} else if constexpr (std::is_same<T1, uint16_t >::value ||
648
- std::is_same<T1, sycl::ext::oneapi::experimental::
649
- bfloat16>::value) {
647
+ std::is_same<T1,
648
+ sycl::ext::oneapi:: bfloat16>::value) {
650
649
__mma_bf16_m32n8k16_mma_f32 (
651
650
reinterpret_cast <float *>(&D.wi_marray ),
652
651
reinterpret_cast <int32_t const *>(&A.wi_marray ),
0 commit comments