Skip to content

Commit 3982001

Browse files
committed
Used neg ptx7.0 builtin for unary minus
1 parent 4d99f3f commit 3982001

File tree

4 files changed

+38
-2
lines changed

4 files changed

+38
-2
lines changed

clang/include/clang/Basic/BuiltinsNVPTX.def

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,11 @@ BUILTIN(__nvvm_fabs_ftz_f, "ff", "")
182182
BUILTIN(__nvvm_fabs_f, "ff", "")
183183
BUILTIN(__nvvm_fabs_d, "dd", "")
184184

185+
// Neg
186+
187+
TARGET_BUILTIN(__nvvm_neg_bf16, "ZUsZUs", "", AND(SM_80,PTX70))
188+
TARGET_BUILTIN(__nvvm_neg_bf16x2, "ZUiZUi", "", AND(SM_80,PTX70))
189+
185190
// Round
186191

187192
BUILTIN(__nvvm_round_ftz_f, "ff", "")

llvm/include/llvm/IR/IntrinsicsNVVM.td

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -740,6 +740,19 @@ let TargetPrefix = "nvvm" in {
740740
def int_nvvm_fabs_d : GCCBuiltin<"__nvvm_fabs_d">,
741741
DefaultAttrsIntrinsic<[llvm_double_ty], [llvm_double_ty], [IntrNoMem, IntrSpeculatable]>;
742742

743+
//
744+
// Neg bf16, bf16x2
745+
//
746+
747+
foreach unary = ["neg"] in {
748+
def int_nvvm_ # unary # _bf16 :
749+
GCCBuiltin<!strconcat("__nvvm_", unary, "_bf16")>,
750+
DefaultAttrsIntrinsic<[llvm_i16_ty], [llvm_i16_ty], [IntrNoMem]>;
751+
def int_nvvm_ # unary # _bf16x2 :
752+
GCCBuiltin<!strconcat("__nvvm_", unary, "_bf16x2")>,
753+
DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty], [IntrNoMem]>;
754+
}
755+
743756
//
744757
// Round
745758
//

llvm/lib/Target/NVPTX/NVPTXIntrinsics.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -719,6 +719,15 @@ def INT_NVVM_FABS_F : F_MATH_1<"abs.f32 \t$dst, $src0;", Float32Regs,
719719
def INT_NVVM_FABS_D : F_MATH_1<"abs.f64 \t$dst, $src0;", Float64Regs,
720720
Float64Regs, int_nvvm_fabs_d>;
721721

722+
//
723+
// Neg bf16, bf16x2
724+
//
725+
726+
def INT_NVVM_NEG_BF16 : F_MATH_1<"neg.bf16 \t$dst, $src0;", Int16Regs,
727+
Int16Regs, int_nvvm_neg_bf16>;
728+
def INT_NVVM_NEG_BF16X2 : F_MATH_1<"neg.bf16x2 \t$dst, $src0;", Int32Regs,
729+
Int32Regs, int_nvvm_neg_bf16x2>;
730+
722731
//
723732
// Round
724733
//

sycl/include/sycl/ext/oneapi/experimental/bfloat16.hpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ class [[sycl_detail::uses_aspects(ext_intel_bf16_conversion)]] bfloat16 {
4242
static float to_float(const storage_t &a) {
4343
#if defined(__SYCL_DEVICE_ONLY__)
4444
#if defined(__NVPTX__)
45-
unsigned int y = a;
45+
uint32_t y = a;
4646
y = y << 16;
4747
float *res = reinterpret_cast<float *>(&y);
4848
return *res;
@@ -81,7 +81,16 @@ class [[sycl_detail::uses_aspects(ext_intel_bf16_conversion)]] bfloat16 {
8181

8282
// Unary minus operator overloading
8383
friend bfloat16 operator-(bfloat16 &lhs) {
84-
return bfloat16{-to_float(lhs.value)};
84+
#if defined(__SYCL_DEVICE_ONLY__)
85+
#if defined(__NVPTX__)
86+
return from_bits(__nvvm_neg_bf16(lhs.value));
87+
#else
88+
return bfloat16{-__spirv_ConvertBF16ToFINTEL(lhs.value)};
89+
#endif
90+
#else
91+
throw exception{errc::feature_not_supported,
92+
"Bfloat16 unary minus is not supported on host device"};
93+
#endif
8594
}
8695

8796
// Increment and decrement operators overloading

0 commit comments

Comments
 (0)