Skip to content

Commit 7daa501

Browse files
authored
[NVPTX] Cleanup and document nvvm.fabs intrinsics, adding f16 support (#135644)
This change unifies the NVVM intrinsics for floating point absolute value into two new overloaded intrinsics "llvm.nvvm.fabs.*" and "llvm.nvvm.fabs.ftz.*". Documentation has been added specifying the semantics of these intrinsics to clarify how they differ from "llvm.fabs.*". In addition, support for these new intrinsics is extended to cover the f16 variants.
1 parent 427a779 commit 7daa501

File tree

13 files changed

+368
-118
lines changed

13 files changed

+368
-118
lines changed

clang/include/clang/Basic/BuiltinsNVPTX.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,11 @@ def __nvvm_fabs_ftz_f : NVPTXBuiltin<"float(float)">;
321321
def __nvvm_fabs_f : NVPTXBuiltin<"float(float)">;
322322
def __nvvm_fabs_d : NVPTXBuiltin<"double(double)">;
323323

324+
def __nvvm_fabs_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16)", SM_53, PTX65>;
325+
def __nvvm_fabs_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>)", SM_53, PTX65>;
326+
def __nvvm_fabs_ftz_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16)", SM_53, PTX65>;
327+
def __nvvm_fabs_ftz_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>)", SM_53, PTX65>;
328+
324329
// Round
325330

326331
def __nvvm_round_ftz_f : NVPTXBuiltin<"float(float)">;

clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1034,6 +1034,21 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID,
10341034
case NVPTX::BI__nvvm_fmin_xorsign_abs_f16x2:
10351035
return MakeHalfType(Intrinsic::nvvm_fmin_xorsign_abs_f16x2, BuiltinID, E,
10361036
*this);
1037+
case NVPTX::BI__nvvm_fabs_f:
1038+
case NVPTX::BI__nvvm_abs_bf16:
1039+
case NVPTX::BI__nvvm_abs_bf16x2:
1040+
case NVPTX::BI__nvvm_fabs_f16:
1041+
case NVPTX::BI__nvvm_fabs_f16x2:
1042+
return Builder.CreateUnaryIntrinsic(Intrinsic::nvvm_fabs,
1043+
EmitScalarExpr(E->getArg(0)));
1044+
case NVPTX::BI__nvvm_fabs_ftz_f:
1045+
case NVPTX::BI__nvvm_fabs_ftz_f16:
1046+
case NVPTX::BI__nvvm_fabs_ftz_f16x2:
1047+
return Builder.CreateUnaryIntrinsic(Intrinsic::nvvm_fabs_ftz,
1048+
EmitScalarExpr(E->getArg(0)));
1049+
case NVPTX::BI__nvvm_fabs_d:
1050+
return Builder.CreateUnaryIntrinsic(Intrinsic::fabs,
1051+
EmitScalarExpr(E->getArg(0)));
10371052
case NVPTX::BI__nvvm_ldg_h:
10381053
case NVPTX::BI__nvvm_ldg_h2:
10391054
return MakeHalfType(Intrinsic::not_intrinsic, BuiltinID, E, *this);

clang/test/CodeGen/builtins-nvptx-native-half-type.c

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,14 @@
2626
// RUN: | FileCheck -check-prefix=CHECK -check-prefix=CHECK_PTX72_SM86 %s
2727

2828
// RUN: %clang_cc1 -ffp-contract=off -triple nvptx-unknown-unknown -target-cpu \
29-
// RUN: sm_53 -target-feature +ptx42 -fcuda-is-device -fnative-half-type \
29+
// RUN: sm_53 -target-feature +ptx65 -fcuda-is-device -fnative-half-type \
3030
// RUN: -emit-llvm -o - -x cuda %s \
31-
// RUN: | FileCheck -check-prefix=CHECK -check-prefix=CHECK_PTX42_SM53 %s
31+
// RUN: | FileCheck -check-prefix=CHECK -check-prefix=CHECK_PTX65_SM53 %s
3232

3333
// RUN: %clang_cc1 -ffp-contract=off -triple nvptx64-unknown-unknown \
34-
// RUN: -target-cpu sm_53 -target-feature +ptx42 -fcuda-is-device \
34+
// RUN: -target-cpu sm_53 -target-feature +ptx65 -fcuda-is-device \
3535
// RUN: -fnative-half-type -emit-llvm -o - -x cuda %s \
36-
// RUN: | FileCheck -check-prefix=CHECK -check-prefix=CHECK_PTX42_SM53 %s
36+
// RUN: | FileCheck -check-prefix=CHECK -check-prefix=CHECK_PTX65_SM53 %s
3737

3838
#define __device__ __attribute__((device))
3939

@@ -108,25 +108,25 @@ __device__ void nvvm_fma_f16_f16x2_sm80() {
108108
// CHECK-LABEL: nvvm_fma_f16_f16x2_sm53
109109
__device__ void nvvm_fma_f16_f16x2_sm53() {
110110
#if __CUDA_ARCH__ >= 530
111-
// CHECK_PTX42_SM53: call half @llvm.nvvm.fma.rn.f16
111+
// CHECK_PTX65_SM53: call half @llvm.nvvm.fma.rn.f16
112112
__nvvm_fma_rn_f16(0.1f16, 0.1f16, 0.1f16);
113-
// CHECK_PTX42_SM53: call half @llvm.nvvm.fma.rn.ftz.f16
113+
// CHECK_PTX65_SM53: call half @llvm.nvvm.fma.rn.ftz.f16
114114
__nvvm_fma_rn_ftz_f16(0.1f16, 0.1f16, 0.1f16);
115-
// CHECK_PTX42_SM53: call half @llvm.nvvm.fma.rn.sat.f16
115+
// CHECK_PTX65_SM53: call half @llvm.nvvm.fma.rn.sat.f16
116116
__nvvm_fma_rn_sat_f16(0.1f16, 0.1f16, 0.1f16);
117-
// CHECK_PTX42_SM53: call half @llvm.nvvm.fma.rn.ftz.sat.f16
117+
// CHECK_PTX65_SM53: call half @llvm.nvvm.fma.rn.ftz.sat.f16
118118
__nvvm_fma_rn_ftz_sat_f16(0.1f16, 0.1f16, 0.1f16);
119119

120-
// CHECK_PTX42_SM53: call <2 x half> @llvm.nvvm.fma.rn.f16x2
120+
// CHECK_PTX65_SM53: call <2 x half> @llvm.nvvm.fma.rn.f16x2
121121
__nvvm_fma_rn_f16x2({0.1f16, 0.7f16}, {0.1f16, 0.7f16},
122122
{0.1f16, 0.7f16});
123-
// CHECK_PTX42_SM53: call <2 x half> @llvm.nvvm.fma.rn.ftz.f16x2
123+
// CHECK_PTX65_SM53: call <2 x half> @llvm.nvvm.fma.rn.ftz.f16x2
124124
__nvvm_fma_rn_ftz_f16x2({0.1f16, 0.7f16}, {0.1f16, 0.7f16},
125125
{0.1f16, 0.7f16});
126-
// CHECK_PTX42_SM53: call <2 x half> @llvm.nvvm.fma.rn.sat.f16x2
126+
// CHECK_PTX65_SM53: call <2 x half> @llvm.nvvm.fma.rn.sat.f16x2
127127
__nvvm_fma_rn_sat_f16x2({0.1f16, 0.7f16}, {0.1f16, 0.7f16},
128128
{0.1f16, 0.7f16});
129-
// CHECK_PTX42_SM53: call <2 x half> @llvm.nvvm.fma.rn.ftz.sat.f16x2
129+
// CHECK_PTX65_SM53: call <2 x half> @llvm.nvvm.fma.rn.ftz.sat.f16x2
130130
__nvvm_fma_rn_ftz_sat_f16x2({0.1f16, 0.7f16}, {0.1f16, 0.7f16},
131131
{0.1f16, 0.7f16});
132132
#endif
@@ -173,6 +173,23 @@ __device__ void nvvm_min_max_sm86() {
173173
// CHECK: ret void
174174
}
175175

176+
// CHECK-LABEL: nvvm_fabs_f16
177+
__device__ void nvvm_fabs_f16() {
178+
#if __CUDA_ARCH__ >= 530
179+
// CHECK: call half @llvm.nvvm.fabs.f16
180+
__nvvm_fabs_f16(0.1f16);
181+
// CHECK: call half @llvm.nvvm.fabs.ftz.f16
182+
__nvvm_fabs_ftz_f16(0.1f16);
183+
// CHECK: call <2 x half> @llvm.nvvm.fabs.v2f16
184+
__nvvm_fabs_f16x2({0.1f16, 0.7f16});
185+
// CHECK: call <2 x half> @llvm.nvvm.fabs.ftz.v2f16
186+
__nvvm_fabs_ftz_f16x2({0.1f16, 0.7f16});
187+
#endif
188+
// CHECK: ret void
189+
}
190+
191+
192+
176193
typedef __fp16 __fp16v2 __attribute__((ext_vector_type(2)));
177194

178195
// CHECK-LABEL: nvvm_ldg_native_half_types

clang/test/CodeGen/builtins-nvptx.c

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,14 @@ __device__ void nvvm_math(float f1, float f2, double d1, double d2) {
245245
// CHECK: call double @llvm.nvvm.rcp.rn.d
246246
double td4 = __nvvm_rcp_rn_d(d2);
247247

248+
// CHECK: call float @llvm.nvvm.fabs.f32
249+
float t6 = __nvvm_fabs_f(f1);
250+
// CHECK: call float @llvm.nvvm.fabs.ftz.f32
251+
float t7 = __nvvm_fabs_ftz_f(f2);
252+
253+
// CHECK: call double @llvm.fabs.f64
254+
double td5 = __nvvm_fabs_d(d1);
255+
248256
// CHECK: call void @llvm.nvvm.membar.cta()
249257
__nvvm_membar_cta();
250258
// CHECK: call void @llvm.nvvm.membar.gl()
@@ -1181,9 +1189,9 @@ __device__ void nvvm_cvt_sm100a_sm101a_sm120a() {
11811189
__device__ void nvvm_abs_neg_bf16_bf16x2_sm80() {
11821190
#if __CUDA_ARCH__ >= 800
11831191

1184-
// CHECK_PTX70_SM80: call bfloat @llvm.nvvm.abs.bf16(bfloat 0xR3DCD)
1192+
// CHECK_PTX70_SM80: call bfloat @llvm.nvvm.fabs.bf16(bfloat 0xR3DCD)
11851193
__nvvm_abs_bf16(BF16);
1186-
// CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.abs.bf16x2(<2 x bfloat> splat (bfloat 0xR3DCD))
1194+
// CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.fabs.v2bf16(<2 x bfloat> splat (bfloat 0xR3DCD))
11871195
__nvvm_abs_bf16x2(BF16X2);
11881196

11891197
// CHECK_PTX70_SM80: call bfloat @llvm.nvvm.neg.bf16(bfloat 0xR3DCD)

llvm/docs/NVPTXUsage.rst

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,59 @@ space casted to this space), 1 is returned, otherwise 0 is returned.
309309
Arithmetic Intrinsics
310310
---------------------
311311

312+
'``llvm.nvvm.fabs.*``' Intrinsic
313+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
314+
315+
Syntax:
316+
"""""""
317+
318+
.. code-block:: llvm
319+
320+
declare float @llvm.nvvm.fabs.f32(float %a)
321+
declare double @llvm.nvvm.fabs.f64(double %a)
322+
declare half @llvm.nvvm.fabs.f16(half %a)
323+
declare <2 x half> @llvm.nvvm.fabs.v2f16(<2 x half> %a)
324+
declare bfloat @llvm.nvvm.fabs.bf16(bfloat %a)
325+
declare <2 x bfloat> @llvm.nvvm.fabs.v2bf16(<2 x bfloat> %a)
326+
327+
Overview:
328+
"""""""""
329+
330+
The '``llvm.nvvm.fabs.*``' intrinsics return the absolute value of the operand.
331+
332+
Semantics:
333+
""""""""""
334+
335+
Unlike, '``llvm.fabs.*``', these intrinsics do not perfectly preserve NaN
336+
values. Instead, a NaN input yeilds an unspecified NaN output.
337+
338+
339+
'``llvm.nvvm.fabs.ftz.*``' Intrinsic
340+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
341+
342+
Syntax:
343+
"""""""
344+
345+
.. code-block:: llvm
346+
347+
declare float @llvm.nvvm.fabs.ftz.f32(float %a)
348+
declare half @llvm.nvvm.fabs.ftz.f16(half %a)
349+
declare <2 x half> @llvm.nvvm.fabs.ftz.v2f16(<2 x half> %a)
350+
351+
Overview:
352+
"""""""""
353+
354+
The '``llvm.nvvm.fabs.ftz.*``' intrinsics return the absolute value of the
355+
operand, flushing subnormals to sign preserving zero.
356+
357+
Semantics:
358+
""""""""""
359+
360+
Before the absolute value is taken, the input is flushed to sign preserving
361+
zero if it is a subnormal. In addtion, unlike '``llvm.fabs.*``', a NaN input
362+
yields an unspecified NaN output.
363+
364+
312365
'``llvm.nvvm.idp2a.[us].[us]``' Intrinsics
313366
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
314367

llvm/include/llvm/IR/IntrinsicsNVVM.td

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1039,18 +1039,18 @@ let TargetPrefix = "nvvm" in {
10391039
// Abs
10401040
//
10411041

1042-
def int_nvvm_fabs_ftz_f : ClangBuiltin<"__nvvm_fabs_ftz_f">,
1043-
DefaultAttrsIntrinsic<[llvm_float_ty], [llvm_float_ty], [IntrNoMem, IntrSpeculatable]>;
1044-
def int_nvvm_fabs_f : ClangBuiltin<"__nvvm_fabs_f">,
1045-
DefaultAttrsIntrinsic<[llvm_float_ty], [llvm_float_ty], [IntrNoMem, IntrSpeculatable]>;
1046-
def int_nvvm_fabs_d : ClangBuiltin<"__nvvm_fabs_d">,
1047-
DefaultAttrsIntrinsic<[llvm_double_ty], [llvm_double_ty], [IntrNoMem, IntrSpeculatable]>;
1042+
def int_nvvm_fabs_ftz :
1043+
DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>],
1044+
[IntrNoMem, IntrSpeculatable]>;
10481045

1046+
def int_nvvm_fabs :
1047+
DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>],
1048+
[IntrNoMem, IntrSpeculatable]>;
10491049
//
10501050
// Abs, Neg bf16, bf16x2
10511051
//
10521052

1053-
foreach unary = ["abs", "neg"] in {
1053+
foreach unary = ["neg"] in {
10541054
def int_nvvm_ # unary # _bf16 :
10551055
ClangBuiltin<!strconcat("__nvvm_", unary, "_bf16")>,
10561056
DefaultAttrsIntrinsic<[llvm_bfloat_ty], [llvm_bfloat_ty], [IntrNoMem]>;

llvm/lib/IR/AutoUpgrade.cpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -939,12 +939,6 @@ static bool upgradeArmOrAarch64IntrinsicFunction(bool IsArm, Function *F,
939939
}
940940

941941
static Intrinsic::ID shouldUpgradeNVPTXBF16Intrinsic(StringRef Name) {
942-
if (Name.consume_front("abs."))
943-
return StringSwitch<Intrinsic::ID>(Name)
944-
.Case("bf16", Intrinsic::nvvm_abs_bf16)
945-
.Case("bf16x2", Intrinsic::nvvm_abs_bf16x2)
946-
.Default(Intrinsic::not_intrinsic);
947-
948942
if (Name.consume_front("fma.rn."))
949943
return StringSwitch<Intrinsic::ID>(Name)
950944
.Case("bf16", Intrinsic::nvvm_fma_rn_bf16)
@@ -1291,7 +1285,8 @@ static bool upgradeIntrinsicFunction1(Function *F, Function *&NewFn,
12911285
bool Expand = false;
12921286
if (Name.consume_front("abs."))
12931287
// nvvm.abs.{i,ii}
1294-
Expand = Name == "i" || Name == "ll";
1288+
Expand =
1289+
Name == "i" || Name == "ll" || Name == "bf16" || Name == "bf16x2";
12951290
else if (Name == "clz.ll" || Name == "popc.ll" || Name == "h2f" ||
12961291
Name == "swap.lo.hi.b64")
12971292
Expand = true;
@@ -2316,6 +2311,13 @@ static Value *upgradeNVVMIntrinsicCall(StringRef Name, CallBase *CI,
23162311
Value *Cmp = Builder.CreateICmpSGE(
23172312
Arg, llvm::Constant::getNullValue(Arg->getType()), "abs.cond");
23182313
Rep = Builder.CreateSelect(Cmp, Arg, Neg, "abs");
2314+
} else if (Name == "abs.bf16" || Name == "abs.bf16x2") {
2315+
Type *Ty = (Name == "abs.bf16")
2316+
? Builder.getBFloatTy()
2317+
: FixedVectorType::get(Builder.getBFloatTy(), 2);
2318+
Value *Arg = Builder.CreateBitCast(CI->getArgOperand(0), Ty);
2319+
Value *Abs = Builder.CreateUnaryIntrinsic(Intrinsic::nvvm_fabs, Arg);
2320+
Rep = Builder.CreateBitCast(Abs, CI->getType());
23192321
} else if (Name.starts_with("atomic.load.add.f32.p") ||
23202322
Name.starts_with("atomic.load.add.f64.p")) {
23212323
Value *Ptr = CI->getArgOperand(0);

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -226,14 +226,17 @@ class RegTyInfo<ValueType ty, NVPTXRegClass rc, Operand imm, SDNode imm_node,
226226
int Size = ty.Size;
227227
}
228228

229-
def I16RT : RegTyInfo<i16, Int16Regs, i16imm, imm>;
230-
def I32RT : RegTyInfo<i32, Int32Regs, i32imm, imm>;
231-
def I64RT : RegTyInfo<i64, Int64Regs, i64imm, imm>;
232-
233-
def F32RT : RegTyInfo<f32, Float32Regs, f32imm, fpimm>;
234-
def F64RT : RegTyInfo<f64, Float64Regs, f64imm, fpimm>;
235-
def F16RT : RegTyInfo<f16, Int16Regs, f16imm, fpimm, supports_imm = 0>;
236-
def BF16RT : RegTyInfo<bf16, Int16Regs, bf16imm, fpimm, supports_imm = 0>;
229+
def I16RT : RegTyInfo<i16, Int16Regs, i16imm, imm>;
230+
def I32RT : RegTyInfo<i32, Int32Regs, i32imm, imm>;
231+
def I64RT : RegTyInfo<i64, Int64Regs, i64imm, imm>;
232+
233+
def F32RT : RegTyInfo<f32, Float32Regs, f32imm, fpimm>;
234+
def F64RT : RegTyInfo<f64, Float64Regs, f64imm, fpimm>;
235+
def F16RT : RegTyInfo<f16, Int16Regs, f16imm, fpimm, supports_imm = 0>;
236+
def BF16RT : RegTyInfo<bf16, Int16Regs, bf16imm, fpimm, supports_imm = 0>;
237+
238+
def F16X2RT : RegTyInfo<v2f16, Int32Regs, ?, ?, supports_imm = 0>;
239+
def BF16X2RT : RegTyInfo<v2bf16, Int32Regs, ?, ?, supports_imm = 0>;
237240

238241

239242
// Template for instructions which take three int64, int32, or int16 args.

0 commit comments

Comments
 (0)