Skip to content

Commit 785e7d0

Browse files
committed
[NVPTX] Cleaup and document nvvm.fabs intrinsics, adding f16 support
1 parent f75dce4 commit 785e7d0

File tree

8 files changed

+312
-103
lines changed

8 files changed

+312
-103
lines changed

llvm/docs/NVPTXUsage.rst

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,58 @@ space casted to this space), 1 is returned, otherwise 0 is returned.
309309
Arithmetic Intrinsics
310310
---------------------
311311

312+
'``llvm.nvvm.fabs.*``' Intrinsic
313+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
314+
315+
Syntax:
316+
"""""""
317+
318+
.. code-block:: llvm
319+
320+
declare float @llvm.nvvm.fabs.f32(float %a)
321+
declare double @llvm.nvvm.fabs.f64(double %a)
322+
declare half @llvm.nvvm.fabs.f16(half %a)
323+
declare <2 x bfloat> @llvm.nvvm.fabs.v2bf16(<2 x bfloat> %a)
324+
325+
Overview:
326+
"""""""""
327+
328+
The '``llvm.nvvm.fabs.*``' intrinsics return the absolute value of the operand.
329+
330+
Semantics:
331+
""""""""""
332+
333+
Unlike, '``llvm.fabs.*``', these intrinsics do not perfectly preserve NaN
334+
values. Instead, a NaN input yeilds an unspecified NaN output. The exception to
335+
this rule is the double precision variant, for which NaN is preserved.
336+
337+
338+
'``llvm.nvvm.fabs.ftz.*``' Intrinsic
339+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
340+
341+
Syntax:
342+
"""""""
343+
344+
.. code-block:: llvm
345+
346+
declare float @llvm.nvvm.fabs.ftz.f32(float %a)
347+
declare half @llvm.nvvm.fabs.ftz.f16(half %a)
348+
declare <2 x half> @llvm.nvvm.fabs.ftz.v2f16(<2 x half> %a)
349+
350+
Overview:
351+
"""""""""
352+
353+
The '``llvm.nvvm.fabs.ftz.*``' intrinsics return the absolute value of the
354+
operand, flushing subnormals to sign preserving zero.
355+
356+
Semantics:
357+
""""""""""
358+
359+
Before the absolute value is taken, the input is flushed to sign preserving
360+
zero if it is a subnormal. In addtion, unlike '``llvm.fabs.*``', a NaN input
361+
yields an unspecified NaN output.
362+
363+
312364
'``llvm.nvvm.idp2a.[us].[us]``' Intrinsics
313365
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
314366

llvm/include/llvm/IR/IntrinsicsNVVM.td

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1039,18 +1039,18 @@ let TargetPrefix = "nvvm" in {
10391039
// Abs
10401040
//
10411041

1042-
def int_nvvm_fabs_ftz_f : ClangBuiltin<"__nvvm_fabs_ftz_f">,
1043-
DefaultAttrsIntrinsic<[llvm_float_ty], [llvm_float_ty], [IntrNoMem, IntrSpeculatable]>;
1044-
def int_nvvm_fabs_f : ClangBuiltin<"__nvvm_fabs_f">,
1045-
DefaultAttrsIntrinsic<[llvm_float_ty], [llvm_float_ty], [IntrNoMem, IntrSpeculatable]>;
1046-
def int_nvvm_fabs_d : ClangBuiltin<"__nvvm_fabs_d">,
1047-
DefaultAttrsIntrinsic<[llvm_double_ty], [llvm_double_ty], [IntrNoMem, IntrSpeculatable]>;
1042+
def int_nvvm_fabs_ftz :
1043+
DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>],
1044+
[IntrNoMem, IntrSpeculatable]>;
10481045

1046+
def int_nvvm_fabs :
1047+
DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>],
1048+
[IntrNoMem, IntrSpeculatable]>;
10491049
//
10501050
// Abs, Neg bf16, bf16x2
10511051
//
10521052

1053-
foreach unary = ["abs", "neg"] in {
1053+
foreach unary = ["neg"] in {
10541054
def int_nvvm_ # unary # _bf16 :
10551055
ClangBuiltin<!strconcat("__nvvm_", unary, "_bf16")>,
10561056
DefaultAttrsIntrinsic<[llvm_bfloat_ty], [llvm_bfloat_ty], [IntrNoMem]>;

llvm/lib/IR/AutoUpgrade.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -939,12 +939,6 @@ static bool upgradeArmOrAarch64IntrinsicFunction(bool IsArm, Function *F,
939939
}
940940

941941
static Intrinsic::ID shouldUpgradeNVPTXBF16Intrinsic(StringRef Name) {
942-
if (Name.consume_front("abs."))
943-
return StringSwitch<Intrinsic::ID>(Name)
944-
.Case("bf16", Intrinsic::nvvm_abs_bf16)
945-
.Case("bf16x2", Intrinsic::nvvm_abs_bf16x2)
946-
.Default(Intrinsic::not_intrinsic);
947-
948942
if (Name.consume_front("fma.rn."))
949943
return StringSwitch<Intrinsic::ID>(Name)
950944
.Case("bf16", Intrinsic::nvvm_fma_rn_bf16)
@@ -1291,7 +1285,7 @@ static bool upgradeIntrinsicFunction1(Function *F, Function *&NewFn,
12911285
bool Expand = false;
12921286
if (Name.consume_front("abs."))
12931287
// nvvm.abs.{i,ii}
1294-
Expand = Name == "i" || Name == "ll";
1288+
Expand = Name == "i" || Name == "ll" || Name == "bf16" || Name == "bf16x2";
12951289
else if (Name == "clz.ll" || Name == "popc.ll" || Name == "h2f" ||
12961290
Name == "swap.lo.hi.b64")
12971291
Expand = true;
@@ -2311,6 +2305,11 @@ static Value *upgradeNVVMIntrinsicCall(StringRef Name, CallBase *CI,
23112305
Value *Cmp = Builder.CreateICmpSGE(
23122306
Arg, llvm::Constant::getNullValue(Arg->getType()), "abs.cond");
23132307
Rep = Builder.CreateSelect(Cmp, Arg, Neg, "abs");
2308+
} else if (Name == "abs.bf16" || Name == "abs.bf16x2") {
2309+
Type *Ty = (Name == "abs.bf16") ? Builder.getBFloatTy() : FixedVectorType::get(Builder.getBFloatTy(), 2);
2310+
Value *Arg = Builder.CreateBitCast(CI->getArgOperand(0), Ty);
2311+
Value *Abs = Builder.CreateUnaryIntrinsic(Intrinsic::nvvm_fabs, Arg);
2312+
Rep = Builder.CreateBitCast(Abs, CI->getType());
23142313
} else if (Name.starts_with("atomic.load.add.f32.p") ||
23152314
Name.starts_with("atomic.load.add.f64.p")) {
23162315
Value *Ptr = CI->getArgOperand(0);

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -226,14 +226,17 @@ class RegTyInfo<ValueType ty, NVPTXRegClass rc, Operand imm, SDNode imm_node,
226226
int Size = ty.Size;
227227
}
228228

229-
def I16RT : RegTyInfo<i16, Int16Regs, i16imm, imm>;
230-
def I32RT : RegTyInfo<i32, Int32Regs, i32imm, imm>;
231-
def I64RT : RegTyInfo<i64, Int64Regs, i64imm, imm>;
232-
233-
def F32RT : RegTyInfo<f32, Float32Regs, f32imm, fpimm>;
234-
def F64RT : RegTyInfo<f64, Float64Regs, f64imm, fpimm>;
235-
def F16RT : RegTyInfo<f16, Int16Regs, f16imm, fpimm, supports_imm = 0>;
236-
def BF16RT : RegTyInfo<bf16, Int16Regs, bf16imm, fpimm, supports_imm = 0>;
229+
def I16RT : RegTyInfo<i16, Int16Regs, i16imm, imm>;
230+
def I32RT : RegTyInfo<i32, Int32Regs, i32imm, imm>;
231+
def I64RT : RegTyInfo<i64, Int64Regs, i64imm, imm>;
232+
233+
def F32RT : RegTyInfo<f32, Float32Regs, f32imm, fpimm>;
234+
def F64RT : RegTyInfo<f64, Float64Regs, f64imm, fpimm>;
235+
def F16RT : RegTyInfo<f16, Int16Regs, f16imm, fpimm, supports_imm = 0>;
236+
def BF16RT : RegTyInfo<bf16, Int16Regs, bf16imm, fpimm, supports_imm = 0>;
237+
238+
def F16X2RT : RegTyInfo<v2f16, Int32Regs, ?, ?, supports_imm = 0>;
239+
def BF16X2RT : RegTyInfo<v2bf16, Int32Regs, ?, ?, supports_imm = 0>;
237240

238241

239242
// Template for instructions which take three int64, int32, or int16 args.

0 commit comments

Comments
 (0)