Skip to content

Commit d0ebfc3

Browse files
authored
[SYCL][CUDA] Swapped bf16 builtins with inline asm. (#10695)
A recent pulldown had to revert this commit llvm/llvm-project@250f2bb The reason for this is that SYCL runtime uses `uint16_t` as the storage type for `bfloat16` and this commit uses `__bf16` instead. Eventually we plan to fully support the `__bf16` type in SYCL runtime: see this discussion: #10457. But until `__bf16` support is added it is easy for us to instead use inline asm operating on `uint16_t` in the nvptx backend. This PR makes this change. When this PR is merged llvm/llvm-project@250f2bb can be safely pulled down. --------- Signed-off-by: JackAKirk <[email protected]>
1 parent c662a34 commit d0ebfc3

File tree

6 files changed

+40
-12
lines changed

6 files changed

+40
-12
lines changed

libclc/ptx-nvidiacl/libspirv/math/fabs.cl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,16 @@
1818

1919
// Requires at least sm_80
2020
_CLC_DEF _CLC_OVERLOAD ushort __clc_fabs(ushort x) {
21-
return __nvvm_abs_bf16(x);
21+
ushort res;
22+
__asm__("abs.bf16 %0, %1;" : "=h"(res) : "h"(x));
23+
return res;
2224
}
2325
_CLC_UNARY_VECTORIZE(_CLC_OVERLOAD _CLC_DEF, ushort, __clc_fabs, ushort)
2426

2527
// Requires at least sm_80
2628
_CLC_DEF _CLC_OVERLOAD uint __clc_fabs(uint x) {
27-
return __nvvm_abs_bf16x2(x);
29+
uint res;
30+
__asm__("abs.bf16x2 %0, %1;" : "=r"(res) : "r"(x));
31+
return res;
2832
}
2933
_CLC_UNARY_VECTORIZE(_CLC_OVERLOAD _CLC_DEF, uint, __clc_fabs, uint)

libclc/ptx-nvidiacl/libspirv/math/fma.cl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,18 @@ _CLC_TERNARY_VECTORIZE_HAVE2(_CLC_OVERLOAD _CLC_DEF, half, __spirv_ocl_fma,
5050

5151
// Requires at least sm_80
5252
_CLC_DEF _CLC_OVERLOAD ushort __clc_fma(ushort x, ushort y, ushort z) {
53-
return __nvvm_fma_rn_bf16(x, y, z);
53+
ushort res;
54+
__asm__("fma.rn.bf16 %0, %1, %2, %3;" : "=h"(res) : "h"(x), "h"(y), "h"(z));
55+
return res;
5456
}
5557
_CLC_TERNARY_VECTORIZE(_CLC_OVERLOAD _CLC_DEF, ushort, __clc_fma, ushort,
5658
ushort, ushort)
5759

5860
// Requires at least sm_80
5961
_CLC_DEF _CLC_OVERLOAD uint __clc_fma(uint x, uint y, uint z) {
60-
return __nvvm_fma_rn_bf16x2(x, y, z);
62+
uint res;
63+
__asm__("fma.rn.bf16x2 %0, %1, %2, %3;" : "=r"(res) : "r"(x), "r"(y), "r"(z));
64+
return res;
6165
}
6266
_CLC_TERNARY_VECTORIZE(_CLC_OVERLOAD _CLC_DEF, uint, __clc_fma, uint,
6367
uint, uint)

libclc/ptx-nvidiacl/libspirv/math/fma_relu.cl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,11 @@ _CLC_TERNARY_VECTORIZE_HAVE2(_CLC_OVERLOAD _CLC_DEF, half, __clc_fma_relu,
4040
_CLC_DEF _CLC_OVERLOAD ushort __clc_fma_relu(ushort x, ushort y,
4141
ushort z) {
4242
if (__clc_nvvm_reflect_arch() >= 800) {
43-
return __nvvm_fma_rn_relu_bf16(x, y, z);
43+
ushort res;
44+
__asm__("fma.rn.relu.bf16 %0, %1, %2, %3;"
45+
: "=h"(res)
46+
: "h"(x), "h"(y), "h"(z));
47+
return res;
4448
}
4549
__builtin_trap();
4650
__builtin_unreachable();
@@ -50,7 +54,11 @@ _CLC_TERNARY_VECTORIZE(_CLC_OVERLOAD _CLC_DEF, ushort, __clc_fma_relu,
5054

5155
_CLC_DEF _CLC_OVERLOAD uint __clc_fma_relu(uint x, uint y, uint z) {
5256
if (__clc_nvvm_reflect_arch() >= 800) {
53-
return __nvvm_fma_rn_relu_bf16x2(x, y, z);
57+
uint res;
58+
__asm__("fma.rn.relu.bf16x2 %0, %1, %2, %3;"
59+
: "=r"(res)
60+
: "r"(x), "r"(y), "r"(z));
61+
return res;
5462
}
5563
__builtin_trap();
5664
__builtin_unreachable();

libclc/ptx-nvidiacl/libspirv/math/fmax.cl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,14 +53,18 @@ _CLC_BINARY_VECTORIZE_HAVE2(_CLC_OVERLOAD _CLC_DEF, half, __spirv_ocl_fmax,
5353

5454
// Requires at least sm_80
5555
_CLC_DEF _CLC_OVERLOAD ushort __clc_fmax(ushort x, ushort y) {
56-
return __nvvm_fmax_bf16(x, y);
56+
ushort res;
57+
__asm__("max.bf16 %0, %1, %2;" : "=h"(res) : "h"(x), "h"(y));
58+
return res;
5759
}
5860
_CLC_BINARY_VECTORIZE(_CLC_OVERLOAD _CLC_DEF, ushort, __clc_fmax, ushort,
5961
ushort)
6062

6163
// Requires at least sm_80
6264
_CLC_DEF _CLC_OVERLOAD uint __clc_fmax(uint x, uint y) {
63-
return __nvvm_fmax_bf16x2(x, y);
65+
uint res;
66+
__asm__("max.bf16x2 %0, %1, %2;" : "=r"(res) : "r"(x), "r"(y));
67+
return res;
6468
}
6569
_CLC_BINARY_VECTORIZE(_CLC_OVERLOAD _CLC_DEF, uint, __clc_fmax, uint,
6670
uint)

libclc/ptx-nvidiacl/libspirv/math/fmin.cl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,14 +53,18 @@ _CLC_BINARY_VECTORIZE_HAVE2(_CLC_OVERLOAD _CLC_DEF, half, __spirv_ocl_fmin, half
5353

5454
// Requires at least sm_80
5555
_CLC_DEF _CLC_OVERLOAD ushort __clc_fmin(ushort x, ushort y) {
56-
return __nvvm_fmin_bf16(x, y);
56+
ushort res;
57+
__asm__("min.bf16 %0, %1, %2;" : "=h"(res) : "h"(x), "h"(y));
58+
return res;
5759
}
5860
_CLC_BINARY_VECTORIZE(_CLC_OVERLOAD _CLC_DEF, ushort, __clc_fmin, ushort,
5961
ushort)
6062

6163
// Requires at least sm_80
6264
_CLC_DEF _CLC_OVERLOAD uint __clc_fmin(uint x, uint y) {
63-
return __nvvm_fmin_bf16x2(x, y);
65+
uint res;
66+
__asm__("min.bf16x2 %0, %1, %2;" : "=r"(res) : "r"(x), "r"(y));
67+
return res;
6468
}
6569
_CLC_BINARY_VECTORIZE(_CLC_OVERLOAD _CLC_DEF, uint, __clc_fmin, uint,
6670
uint)

sycl/include/sycl/ext/oneapi/bfloat16.hpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,9 @@ class bfloat16 {
6464
#if defined(__SYCL_DEVICE_ONLY__)
6565
#if defined(__NVPTX__)
6666
#if (__SYCL_CUDA_ARCH__ >= 800)
67-
return __nvvm_f2bf16_rn(a);
67+
detail::Bfloat16StorageT res;
68+
asm("cvt.rn.bf16.f32 %0, %1;" : "=h"(res) : "f"(a));
69+
return res;
6870
#else
6971
return from_float_fallback(a);
7072
#endif
@@ -120,7 +122,9 @@ class bfloat16 {
120122
friend bfloat16 operator-(bfloat16 &lhs) {
121123
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) && \
122124
(__SYCL_CUDA_ARCH__ >= 800)
123-
return detail::bitsToBfloat16(__nvvm_neg_bf16(lhs.value));
125+
detail::Bfloat16StorageT res;
126+
asm("neg.bf16 %0, %1;" : "=h"(res) : "h"(lhs.value));
127+
return detail::bitsToBfloat16(res);
124128
#elif defined(__SYCL_DEVICE_ONLY__) && defined(__SPIR__)
125129
return bfloat16{-__devicelib_ConvertBF16ToFINTEL(lhs.value)};
126130
#else

0 commit comments

Comments
 (0)