Skip to content

Commit 7908298

Browse files
authored
[SYCL][ESIMD] Fix the issue with converting half scalar into a simd vector (#12121)
1 parent 86734c0 commit 7908298

File tree

3 files changed

+119
-0
lines changed

3 files changed

+119
-0
lines changed

sycl/include/sycl/ext/intel/esimd/detail/half_type_traits.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,9 @@ template <>
121121
struct is_esimd_arithmetic_type<half_raw_type, void> : std::true_type {};
122122
#endif // __SYCL_DEVICE_ONLY__
123123

124+
template <>
125+
struct is_esimd_arithmetic_type<sycl::half, void> : std::true_type {};
126+
124127
// Misc
125128
inline std::ostream &operator<<(std::ostream &O, sycl::half const &rhs) {
126129
O << static_cast<float>(rhs);
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
// RUN: %{build} -o %t.out
2+
// RUN: %{run} %t.out
3+
//= bfloat16_half_vector_plus_eq_scalar.cpp - Test for bfloat16 operators =//
4+
//
5+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
6+
// See https://llvm.org/LICENSE.txt for license information.
7+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8+
//
9+
//===----------------------------------------------------------------------===//
10+
#include "../esimd_test_utils.hpp"
11+
#include <iostream>
12+
#include <sycl/ext/intel/esimd.hpp>
13+
#include <sycl/sycl.hpp>
14+
15+
using namespace sycl;
16+
using namespace sycl::ext::intel::esimd;
17+
using namespace sycl::ext::intel::experimental::esimd;
18+
19+
template <typename T> ESIMD_NOINLINE bool test(queue Q) {
20+
std::cout << "Testing T=" << esimd_test::type_name<T>() << "...\n";
21+
22+
constexpr int N = 8;
23+
24+
constexpr int NumOps = 4;
25+
constexpr int CSize = NumOps * N;
26+
27+
T *Mem = malloc_shared<T>(CSize, Q);
28+
T TOne = static_cast<T>(1);
29+
T TTen = static_cast<T>(10);
30+
31+
Q.single_task([=]() SYCL_ESIMD_KERNEL {
32+
{
33+
simd<T, N> Vec(TOne);
34+
Vec += TTen;
35+
Vec.copy_to(Mem);
36+
}
37+
{
38+
simd<T, N> Vec(TOne);
39+
Vec -= TTen;
40+
Vec.copy_to(Mem + N);
41+
}
42+
{
43+
simd<T, N> Vec(TOne);
44+
Vec *= TTen;
45+
Vec.copy_to(Mem + 2 * N);
46+
}
47+
{
48+
simd<T, N> Vec(TOne);
49+
Vec /= TTen;
50+
Vec.copy_to(Mem + 3 * N);
51+
}
52+
}).wait();
53+
54+
bool ReturnValue = true;
55+
for (int i = 0; i < N; ++i) {
56+
if (Mem[i] != TOne + TTen) {
57+
ReturnValue = false;
58+
break;
59+
}
60+
if (Mem[i + N] != TOne - TTen) {
61+
ReturnValue = false;
62+
break;
63+
}
64+
if (Mem[i + 2 * N] != TOne * TTen) {
65+
ReturnValue = false;
66+
break;
67+
}
68+
if (!((Mem[i + 3 * N] == (TOne / TTen)) ||
69+
(std::abs((double)(Mem[i + 3 * N] - (TOne / TTen)) /
70+
(double)(TOne / TTen)) <= 0.001))) {
71+
ReturnValue = false;
72+
break;
73+
}
74+
}
75+
76+
free(Mem, Q);
77+
return ReturnValue;
78+
}
79+
80+
int main() {
81+
queue Q;
82+
esimd_test::printTestLabel(Q);
83+
84+
bool SupportsHalf = Q.get_device().has(aspect::fp16);
85+
86+
bool Passed = true;
87+
Passed &= test<int>(Q);
88+
Passed &= test<float>(Q);
89+
if (SupportsHalf) {
90+
Passed &= test<sycl::half>(Q);
91+
}
92+
93+
#ifdef USE_BF16
94+
// TODO: Reenable once the issue with bfloat16 is resolved
95+
// Passed &= test<sycl::ext::oneapi::bfloat16>(Q);
96+
#endif
97+
#ifdef USE_TF32
98+
Passed &= test<sycl::ext::intel::experimental::esimd::tfloat32>(Q);
99+
#endif
100+
std::cout << (Passed ? "Passed\n" : "FAILED\n");
101+
return Passed ? 0 : 1;
102+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
//= bfloat16_half_vector_plus_eq_scalar_pvc.cpp - Test for bfloat16 operators=//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// REQUIRES: gpu-intel-pvc
9+
// RUN: %{build} -o %t.out
10+
// RUN: %{run} %t.out
11+
12+
#define USE_BF16
13+
#define USE_TF32
14+
#include "bfloat16_half_vector_plus_eq_scalar.cpp"

0 commit comments

Comments
 (0)