Skip to content

Commit bdd88e5

Browse files
committed
Corrections to tests.
1 parent 6014cef commit bdd88e5

File tree

3 files changed

+10
-11
lines changed

3 files changed

+10
-11
lines changed

sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
// ===--------------------------------------------------------------------=== //
88

99
#pragma once
10-
#include <sycl/ext/oneapi/experimental/bfloat16.hpp>
10+
#include <sycl/ext/oneapi/bfloat16.hpp>
1111

1212
__SYCL_INLINE_NAMESPACE(cl) {
1313
namespace sycl {
@@ -219,8 +219,7 @@ struct joint_matrix_load_impl<
219219
S, Use, NumRows, NumCols, Layout, sycl::sub_group> &res,
220220
multi_ptr<T, Space> src, size_t stride) {
221221
if constexpr (std::is_same<T, uint16_t>::value ||
222-
std::is_same<
223-
T, sycl::ext::oneapi::experimental::bfloat16>::value) {
222+
std::is_same<T, sycl::ext::oneapi::bfloat16>::value) {
224223
auto tileptr = reinterpret_cast<int32_t const *>(src.get());
225224
auto destptr = reinterpret_cast<int32_t *>(&res.wi_marray);
226225
if constexpr (NumRows == 16 && NumCols == 16) {
@@ -585,8 +584,8 @@ struct joint_matrix_mad_impl<
585584
get_layout_pair_id<LayoutA, LayoutB>(), 0);
586585
}
587586
} else if constexpr (std::is_same<T1, uint16_t>::value ||
588-
std::is_same<T1, sycl::ext::oneapi::experimental::
589-
bfloat16>::value) {
587+
std::is_same<T1,
588+
sycl::ext::oneapi::bfloat16>::value) {
590589
__mma_bf16_m16n16k16_mma_f32(
591590
reinterpret_cast<float *>(&D.wi_marray),
592591
reinterpret_cast<int32_t const *>(&A.wi_marray),
@@ -622,8 +621,8 @@ struct joint_matrix_mad_impl<
622621
get_layout_pair_id<LayoutA, LayoutB>(), 0);
623622
}
624623
} else if constexpr (std::is_same<T1, uint16_t>::value ||
625-
std::is_same<T1, sycl::ext::oneapi::experimental::
626-
bfloat16>::value) {
624+
std::is_same<T1,
625+
sycl::ext::oneapi::bfloat16>::value) {
627626
__mma_bf16_m8n32k16_mma_f32(
628627
reinterpret_cast<float *>(&D.wi_marray),
629628
reinterpret_cast<int32_t const *>(&A.wi_marray),
@@ -645,8 +644,8 @@ struct joint_matrix_mad_impl<
645644
get_layout_pair_id<LayoutA, LayoutB>(), 0);
646645
}
647646
} else if constexpr (std::is_same<T1, uint16_t>::value ||
648-
std::is_same<T1, sycl::ext::oneapi::experimental::
649-
bfloat16>::value) {
647+
std::is_same<T1,
648+
sycl::ext::oneapi::bfloat16>::value) {
650649
__mma_bf16_m32n8k16_mma_f32(
651650
reinterpret_cast<float *>(&D.wi_marray),
652651
reinterpret_cast<int32_t const *>(&A.wi_marray),

sycl/test/check_device_code/matrix/matrix-nvptx-bfloat16-test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
using namespace sycl;
88
using namespace sycl::ext::oneapi::experimental::matrix;
9-
using sycl::ext::oneapi::experimental::bfloat16;
9+
using sycl::ext::oneapi::bfloat16;
1010

1111
constexpr int stride = 16;
1212

sycl/test/matrix/matrix-bfloat16-test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
#include <iostream>
55

66
using namespace sycl::ext::oneapi::experimental::matrix;
7-
using bfloat16 = sycl::ext::oneapi::experimental::bfloat16;
7+
using bfloat16 = sycl::ext::oneapi::bfloat16;
88

99
static constexpr auto TILE_SZ = 16;
1010
static constexpr auto TM = TILE_SZ - 1;

0 commit comments

Comments
 (0)