Skip to content
This repository was archived by the owner on Mar 28, 2023. It is now read-only.

Commit 4c2dc33

Browse files
[SYCL] Fix CUDA tests using bfloat16 (#1421)
* [SYCL] Fix CUDA tests using bfloat16 * Add missing using in element_wise_wi_marray_legacy Signed-off-by: Larsen, Steffen <[email protected]>
1 parent 7d6694e commit 4c2dc33

File tree

4 files changed

+3
-5
lines changed

4 files changed

+3
-5
lines changed

SYCL/BFloat16/bfloat16_builtins.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
using namespace sycl;
1515
using namespace sycl::ext::oneapi;
16+
using namespace sycl::ext::oneapi::experimental;
1617

1718
constexpr int N = 60; // divisible by all tested array sizes
1819
constexpr float bf16_eps = 0.00390625;

SYCL/Matrix/element_wise_all_ops_bf16.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
using namespace sycl;
1919
using namespace sycl::ext::intel;
20+
using namespace sycl::ext::oneapi;
2021
using namespace sycl::ext::oneapi::experimental::matrix;
2122

2223
#define SG_SZ 16

SYCL/Matrix/element_wise_wi_marray_legacy.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include <sycl/sycl.hpp>
1414

1515
using namespace sycl;
16+
using namespace sycl::ext::oneapi;
1617
using namespace sycl::ext::oneapi::experimental;
1718
using namespace sycl::ext::oneapi::experimental::matrix;
1819

SYCL/Matrix/joint_matrix_tensorcores_legacy.cpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,13 +66,8 @@ T2 matrix_ref_mn(const int &m, const int &n, T1 *A, T1 *B, T2 *C) {
6666
if constexpr (std::is_same<T1, uint16_t>::value) {
6767
for (int k = 0; k < Big_K; k++)
6868
res += make_fp32(A[m * Big_K + k]) * make_fp32(B[k * Big_N + n]);
69-
} else if constexpr (std::is_same<T1, bfloat16>::value) {
70-
for (int k = 0; k < Big_K; k++)
71-
res +=
72-
make_fp32(A[m * Big_K + k].raw()) * make_fp32(B[k * Big_N + n].raw());
7369
} else {
7470
for (int k = 0; k < Big_K; k++)
75-
7671
res +=
7772
static_cast<T2>(A[m * Big_K + k]) * static_cast<T2>(B[k * Big_N + n]);
7873
}

0 commit comments

Comments
 (0)