-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[NVPTX] Cleanup and document nvvm.fabs intrinsics, adding f16 support #135644
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[NVPTX] Cleanup and document nvvm.fabs intrinsics, adding f16 support #135644
Conversation
@llvm/pr-subscribers-llvm-transforms @llvm/pr-subscribers-llvm-ir Author: Alex MacLean (AlexMaclean) ChangesThis change unifies the NVVM intrinsics for floating point absolute value into two new overloaded intrinsics " Patch is 25.74 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/135644.diff 8 Files Affected:
diff --git a/llvm/docs/NVPTXUsage.rst b/llvm/docs/NVPTXUsage.rst
index 621879fc5648b..fbb7122b9b42d 100644
--- a/llvm/docs/NVPTXUsage.rst
+++ b/llvm/docs/NVPTXUsage.rst
@@ -309,6 +309,58 @@ space casted to this space), 1 is returned, otherwise 0 is returned.
Arithmetic Intrinsics
---------------------
+'``llvm.nvvm.fabs.*``' Intrinsic
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Syntax:
+"""""""
+
+.. code-block:: llvm
+
+ declare float @llvm.nvvm.fabs.f32(float %a)
+ declare double @llvm.nvvm.fabs.f64(double %a)
+ declare half @llvm.nvvm.fabs.f16(half %a)
+ declare <2 x bfloat> @llvm.nvvm.fabs.v2bf16(<2 x bfloat> %a)
+
+Overview:
+"""""""""
+
+The '``llvm.nvvm.fabs.*``' intrinsics return the absolute value of the operand.
+
+Semantics:
+""""""""""
+
+Unlike, '``llvm.fabs.*``', these intrinsics do not perfectly preserve NaN
+values. Instead, a NaN input yeilds an unspecified NaN output. The exception to
+this rule is the double precision variant, for which NaN is preserved.
+
+
+'``llvm.nvvm.fabs.ftz.*``' Intrinsic
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Syntax:
+"""""""
+
+.. code-block:: llvm
+
+ declare float @llvm.nvvm.fabs.ftz.f32(float %a)
+ declare half @llvm.nvvm.fabs.ftz.f16(half %a)
+ declare <2 x half> @llvm.nvvm.fabs.ftz.v2f16(<2 x half> %a)
+
+Overview:
+"""""""""
+
+The '``llvm.nvvm.fabs.ftz.*``' intrinsics return the absolute value of the
+operand, flushing subnormals to sign preserving zero.
+
+Semantics:
+""""""""""
+
+Before the absolute value is taken, the input is flushed to sign preserving
+zero if it is a subnormal. In addtion, unlike '``llvm.fabs.*``', a NaN input
+yields an unspecified NaN output.
+
+
'``llvm.nvvm.idp2a.[us].[us]``' Intrinsics
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 4aeb1d8a2779e..2cea00c640a02 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -1039,18 +1039,18 @@ let TargetPrefix = "nvvm" in {
// Abs
//
- def int_nvvm_fabs_ftz_f : ClangBuiltin<"__nvvm_fabs_ftz_f">,
- DefaultAttrsIntrinsic<[llvm_float_ty], [llvm_float_ty], [IntrNoMem, IntrSpeculatable]>;
- def int_nvvm_fabs_f : ClangBuiltin<"__nvvm_fabs_f">,
- DefaultAttrsIntrinsic<[llvm_float_ty], [llvm_float_ty], [IntrNoMem, IntrSpeculatable]>;
- def int_nvvm_fabs_d : ClangBuiltin<"__nvvm_fabs_d">,
- DefaultAttrsIntrinsic<[llvm_double_ty], [llvm_double_ty], [IntrNoMem, IntrSpeculatable]>;
+ def int_nvvm_fabs_ftz :
+ DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>],
+ [IntrNoMem, IntrSpeculatable]>;
+ def int_nvvm_fabs :
+ DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>],
+ [IntrNoMem, IntrSpeculatable]>;
//
// Abs, Neg bf16, bf16x2
//
- foreach unary = ["abs", "neg"] in {
+ foreach unary = ["neg"] in {
def int_nvvm_ # unary # _bf16 :
ClangBuiltin<!strconcat("__nvvm_", unary, "_bf16")>,
DefaultAttrsIntrinsic<[llvm_bfloat_ty], [llvm_bfloat_ty], [IntrNoMem]>;
diff --git a/llvm/lib/IR/AutoUpgrade.cpp b/llvm/lib/IR/AutoUpgrade.cpp
index 0b329d91c3c7c..79709ae1b1a1d 100644
--- a/llvm/lib/IR/AutoUpgrade.cpp
+++ b/llvm/lib/IR/AutoUpgrade.cpp
@@ -939,12 +939,6 @@ static bool upgradeArmOrAarch64IntrinsicFunction(bool IsArm, Function *F,
}
static Intrinsic::ID shouldUpgradeNVPTXBF16Intrinsic(StringRef Name) {
- if (Name.consume_front("abs."))
- return StringSwitch<Intrinsic::ID>(Name)
- .Case("bf16", Intrinsic::nvvm_abs_bf16)
- .Case("bf16x2", Intrinsic::nvvm_abs_bf16x2)
- .Default(Intrinsic::not_intrinsic);
-
if (Name.consume_front("fma.rn."))
return StringSwitch<Intrinsic::ID>(Name)
.Case("bf16", Intrinsic::nvvm_fma_rn_bf16)
@@ -1291,7 +1285,7 @@ static bool upgradeIntrinsicFunction1(Function *F, Function *&NewFn,
bool Expand = false;
if (Name.consume_front("abs."))
// nvvm.abs.{i,ii}
- Expand = Name == "i" || Name == "ll";
+ Expand = Name == "i" || Name == "ll" || Name == "bf16" || Name == "bf16x2";
else if (Name == "clz.ll" || Name == "popc.ll" || Name == "h2f" ||
Name == "swap.lo.hi.b64")
Expand = true;
@@ -2311,6 +2305,11 @@ static Value *upgradeNVVMIntrinsicCall(StringRef Name, CallBase *CI,
Value *Cmp = Builder.CreateICmpSGE(
Arg, llvm::Constant::getNullValue(Arg->getType()), "abs.cond");
Rep = Builder.CreateSelect(Cmp, Arg, Neg, "abs");
+ } else if (Name == "abs.bf16" || Name == "abs.bf16x2") {
+ Type *Ty = (Name == "abs.bf16") ? Builder.getBFloatTy() : FixedVectorType::get(Builder.getBFloatTy(), 2);
+ Value *Arg = Builder.CreateBitCast(CI->getArgOperand(0), Ty);
+ Value *Abs = Builder.CreateUnaryIntrinsic(Intrinsic::nvvm_fabs, Arg);
+ Rep = Builder.CreateBitCast(Abs, CI->getType());
} else if (Name.starts_with("atomic.load.add.f32.p") ||
Name.starts_with("atomic.load.add.f64.p")) {
Value *Ptr = CI->getArgOperand(0);
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index aa0eedb1b7446..62fdb6d2523f4 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -226,14 +226,17 @@ class RegTyInfo<ValueType ty, NVPTXRegClass rc, Operand imm, SDNode imm_node,
int Size = ty.Size;
}
-def I16RT : RegTyInfo<i16, Int16Regs, i16imm, imm>;
-def I32RT : RegTyInfo<i32, Int32Regs, i32imm, imm>;
-def I64RT : RegTyInfo<i64, Int64Regs, i64imm, imm>;
-
-def F32RT : RegTyInfo<f32, Float32Regs, f32imm, fpimm>;
-def F64RT : RegTyInfo<f64, Float64Regs, f64imm, fpimm>;
-def F16RT : RegTyInfo<f16, Int16Regs, f16imm, fpimm, supports_imm = 0>;
-def BF16RT : RegTyInfo<bf16, Int16Regs, bf16imm, fpimm, supports_imm = 0>;
+def I16RT : RegTyInfo<i16, Int16Regs, i16imm, imm>;
+def I32RT : RegTyInfo<i32, Int32Regs, i32imm, imm>;
+def I64RT : RegTyInfo<i64, Int64Regs, i64imm, imm>;
+
+def F32RT : RegTyInfo<f32, Float32Regs, f32imm, fpimm>;
+def F64RT : RegTyInfo<f64, Float64Regs, f64imm, fpimm>;
+def F16RT : RegTyInfo<f16, Int16Regs, f16imm, fpimm, supports_imm = 0>;
+def BF16RT : RegTyInfo<bf16, Int16Regs, bf16imm, fpimm, supports_imm = 0>;
+
+def F16X2RT : RegTyInfo<v2f16, Int32Regs, ?, ?, supports_imm = 0>;
+def BF16X2RT : RegTyInfo<v2bf16, Int32Regs, ?, ?, supports_imm = 0>;
// Template for instructions which take three int64, int32, or int16 args.
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 8528ff702f236..6f6601555f6e4 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -983,12 +983,13 @@ def : Pat<(int_nvvm_fmin_d
// We need a full string for OpcStr here because we need to deal with case like
// INT_PTX_RECIP.
-class F_MATH_1<string OpcStr, NVPTXRegClass target_regclass,
- NVPTXRegClass src_regclass, Intrinsic IntOP, list<Predicate> Preds = []>
- : NVPTXInst<(outs target_regclass:$dst), (ins src_regclass:$src0),
- OpcStr,
- [(set target_regclass:$dst, (IntOP src_regclass:$src0))]>,
- Requires<Preds>;
+class F_MATH_1<string OpcStr, RegTyInfo dst, RegTyInfo src, Intrinsic IntOP,
+ list<Predicate> Preds = []>
+ : NVPTXInst<(outs dst.RC:$dst),
+ (ins src.RC:$src0),
+ OpcStr,
+ [(set dst.Ty:$dst, (IntOP src.Ty:$src0))]>,
+ Requires<Preds>;
// We need a full string for OpcStr here because we need to deal with the case
// like INT_PTX_NATIVE_POWR_F.
@@ -1307,13 +1308,20 @@ def : Pat<(int_nvvm_ceil_d f64:$a),
// Abs
//
-def INT_NVVM_FABS_FTZ_F : F_MATH_1<"abs.ftz.f32 \t$dst, $src0;", Float32Regs,
- Float32Regs, int_nvvm_fabs_ftz_f>;
-def INT_NVVM_FABS_F : F_MATH_1<"abs.f32 \t$dst, $src0;", Float32Regs,
- Float32Regs, int_nvvm_fabs_f>;
+multiclass F_ABS<string suffix, RegTyInfo RT, bit support_ftz, list<Predicate> preds = []> {
+ def "" : F_MATH_1<"abs." # suffix # " \t$dst, $src0;", RT, RT, int_nvvm_fabs, preds>;
+ if support_ftz then
+ def _FTZ : F_MATH_1<"abs.ftz." # suffix # " \t$dst, $src0;", RT, RT, int_nvvm_fabs_ftz, preds>;
+}
+
+defm ABS_F16 : F_ABS<"f16", F16RT, support_ftz = true, preds = [hasPTX<65>, hasSM<53>]>;
+defm ABS_F16X2 : F_ABS<"f16x2", F16X2RT, support_ftz = true, preds = [hasPTX<65>, hasSM<53>]>;
+
+defm ABS_BF16 : F_ABS<"bf16", BF16RT, support_ftz = false, preds = [hasPTX<70>, hasSM<80>]>;
+defm ABS_BF16X2 : F_ABS<"bf16x2", BF16X2RT, support_ftz = false, preds = [hasPTX<70>, hasSM<80>]>;
-def INT_NVVM_FABS_D : F_MATH_1<"abs.f64 \t$dst, $src0;", Float64Regs,
- Float64Regs, int_nvvm_fabs_d>;
+defm ABS_F32 : F_ABS<"f32", F32RT, support_ftz = true>;
+defm ABS_F64 : F_ABS<"f64", F64RT, support_ftz = false>;
//
// copysign
@@ -1332,17 +1340,13 @@ def COPYSIGN_D :
[(set f64:$dst, (fcopysign_nvptx f64:$src1, f64:$src0))]>;
//
-// Abs, Neg bf16, bf16x2
+// Neg bf16, bf16x2
//
-def INT_NVVM_ABS_BF16 : F_MATH_1<"abs.bf16 \t$dst, $src0;", Int16Regs,
- Int16Regs, int_nvvm_abs_bf16, [hasPTX<70>, hasSM<80>]>;
-def INT_NVVM_ABS_BF16X2 : F_MATH_1<"abs.bf16x2 \t$dst, $src0;", Int32Regs,
- Int32Regs, int_nvvm_abs_bf16x2, [hasPTX<70>, hasSM<80>]>;
-def INT_NVVM_NEG_BF16 : F_MATH_1<"neg.bf16 \t$dst, $src0;", Int16Regs,
- Int16Regs, int_nvvm_neg_bf16, [hasPTX<70>, hasSM<80>]>;
-def INT_NVVM_NEG_BF16X2 : F_MATH_1<"neg.bf16x2 \t$dst, $src0;", Int32Regs,
- Int32Regs, int_nvvm_neg_bf16x2, [hasPTX<70>, hasSM<80>]>;
+def INT_NVVM_NEG_BF16 : F_MATH_1<"neg.bf16 \t$dst, $src0;", BF16RT,
+ BF16RT, int_nvvm_neg_bf16, [hasPTX<70>, hasSM<80>]>;
+def INT_NVVM_NEG_BF16X2 : F_MATH_1<"neg.bf16x2 \t$dst, $src0;", BF16X2RT,
+ BF16X2RT, int_nvvm_neg_bf16x2, [hasPTX<70>, hasSM<80>]>;
//
// Round
@@ -1382,16 +1386,16 @@ def : Pat<(int_nvvm_saturate_d f64:$a),
//
def INT_NVVM_EX2_APPROX_FTZ_F : F_MATH_1<"ex2.approx.ftz.f32 \t$dst, $src0;",
- Float32Regs, Float32Regs, int_nvvm_ex2_approx_ftz_f>;
+ F32RT, F32RT, int_nvvm_ex2_approx_ftz_f>;
def INT_NVVM_EX2_APPROX_F : F_MATH_1<"ex2.approx.f32 \t$dst, $src0;",
- Float32Regs, Float32Regs, int_nvvm_ex2_approx_f>;
+ F32RT, F32RT, int_nvvm_ex2_approx_f>;
def INT_NVVM_EX2_APPROX_D : F_MATH_1<"ex2.approx.f64 \t$dst, $src0;",
- Float64Regs, Float64Regs, int_nvvm_ex2_approx_d>;
+ F64RT, F64RT, int_nvvm_ex2_approx_d>;
def INT_NVVM_EX2_APPROX_F16 : F_MATH_1<"ex2.approx.f16 \t$dst, $src0;",
- Int16Regs, Int16Regs, int_nvvm_ex2_approx_f16, [hasPTX<70>, hasSM<75>]>;
+ F16RT, F16RT, int_nvvm_ex2_approx_f16, [hasPTX<70>, hasSM<75>]>;
def INT_NVVM_EX2_APPROX_F16X2 : F_MATH_1<"ex2.approx.f16x2 \t$dst, $src0;",
- Int32Regs, Int32Regs, int_nvvm_ex2_approx_f16x2, [hasPTX<70>, hasSM<75>]>;
+ F16X2RT, F16X2RT, int_nvvm_ex2_approx_f16x2, [hasPTX<70>, hasSM<75>]>;
def : Pat<(fexp2 f32:$a),
(INT_NVVM_EX2_APPROX_FTZ_F $a)>, Requires<[doF32FTZ]>;
@@ -1403,11 +1407,11 @@ def : Pat<(fexp2 v2f16:$a),
(INT_NVVM_EX2_APPROX_F16X2 $a)>, Requires<[useFP16Math]>;
def INT_NVVM_LG2_APPROX_FTZ_F : F_MATH_1<"lg2.approx.ftz.f32 \t$dst, $src0;",
- Float32Regs, Float32Regs, int_nvvm_lg2_approx_ftz_f>;
+ F32RT, F32RT, int_nvvm_lg2_approx_ftz_f>;
def INT_NVVM_LG2_APPROX_F : F_MATH_1<"lg2.approx.f32 \t$dst, $src0;",
- Float32Regs, Float32Regs, int_nvvm_lg2_approx_f>;
+ F32RT, F32RT, int_nvvm_lg2_approx_f>;
def INT_NVVM_LG2_APPROX_D : F_MATH_1<"lg2.approx.f64 \t$dst, $src0;",
- Float64Regs, Float64Regs, int_nvvm_lg2_approx_d>;
+ F64RT, F64RT, int_nvvm_lg2_approx_d>;
def : Pat<(flog2 f32:$a), (INT_NVVM_LG2_APPROX_FTZ_F $a)>,
Requires<[doF32FTZ]>;
@@ -1419,14 +1423,14 @@ def : Pat<(flog2 f32:$a), (INT_NVVM_LG2_APPROX_F $a)>,
//
def INT_NVVM_SIN_APPROX_FTZ_F : F_MATH_1<"sin.approx.ftz.f32 \t$dst, $src0;",
- Float32Regs, Float32Regs, int_nvvm_sin_approx_ftz_f>;
+ F32RT, F32RT, int_nvvm_sin_approx_ftz_f>;
def INT_NVVM_SIN_APPROX_F : F_MATH_1<"sin.approx.f32 \t$dst, $src0;",
- Float32Regs, Float32Regs, int_nvvm_sin_approx_f>;
+ F32RT, F32RT, int_nvvm_sin_approx_f>;
def INT_NVVM_COS_APPROX_FTZ_F : F_MATH_1<"cos.approx.ftz.f32 \t$dst, $src0;",
- Float32Regs, Float32Regs, int_nvvm_cos_approx_ftz_f>;
+ F32RT, F32RT, int_nvvm_cos_approx_ftz_f>;
def INT_NVVM_COS_APPROX_F : F_MATH_1<"cos.approx.f32 \t$dst, $src0;",
- Float32Regs, Float32Regs, int_nvvm_cos_approx_f>;
+ F32RT, F32RT, int_nvvm_cos_approx_f>;
//
// Fma
@@ -1511,69 +1515,69 @@ defm INT_NVVM_FMA : FMA_INST;
//
def INT_NVVM_RCP_RN_FTZ_F : F_MATH_1<"rcp.rn.ftz.f32 \t$dst, $src0;",
- Float32Regs, Float32Regs, int_nvvm_rcp_rn_ftz_f>;
+ F32RT, F32RT, int_nvvm_rcp_rn_ftz_f>;
def INT_NVVM_RCP_RN_F : F_MATH_1<"rcp.rn.f32 \t$dst, $src0;",
- Float32Regs, Float32Regs, int_nvvm_rcp_rn_f>;
+ F32RT, F32RT, int_nvvm_rcp_rn_f>;
def INT_NVVM_RCP_RZ_FTZ_F : F_MATH_1<"rcp.rz.ftz.f32 \t$dst, $src0;",
- Float32Regs, Float32Regs, int_nvvm_rcp_rz_ftz_f>;
+ F32RT, F32RT, int_nvvm_rcp_rz_ftz_f>;
def INT_NVVM_RCP_RZ_F : F_MATH_1<"rcp.rz.f32 \t$dst, $src0;",
- Float32Regs, Float32Regs, int_nvvm_rcp_rz_f>;
+ F32RT, F32RT, int_nvvm_rcp_rz_f>;
def INT_NVVM_RCP_RM_FTZ_F : F_MATH_1<"rcp.rm.ftz.f32 \t$dst, $src0;",
- Float32Regs, Float32Regs, int_nvvm_rcp_rm_ftz_f>;
+ F32RT, F32RT, int_nvvm_rcp_rm_ftz_f>;
def INT_NVVM_RCP_RM_F : F_MATH_1<"rcp.rm.f32 \t$dst, $src0;",
- Float32Regs, Float32Regs, int_nvvm_rcp_rm_f>;
+ F32RT, F32RT, int_nvvm_rcp_rm_f>;
def INT_NVVM_RCP_RP_FTZ_F : F_MATH_1<"rcp.rp.ftz.f32 \t$dst, $src0;",
- Float32Regs, Float32Regs, int_nvvm_rcp_rp_ftz_f>;
+ F32RT, F32RT, int_nvvm_rcp_rp_ftz_f>;
def INT_NVVM_RCP_RP_F : F_MATH_1<"rcp.rp.f32 \t$dst, $src0;",
- Float32Regs, Float32Regs, int_nvvm_rcp_rp_f>;
+ F32RT, F32RT, int_nvvm_rcp_rp_f>;
-def INT_NVVM_RCP_RN_D : F_MATH_1<"rcp.rn.f64 \t$dst, $src0;", Float64Regs,
- Float64Regs, int_nvvm_rcp_rn_d>;
-def INT_NVVM_RCP_RZ_D : F_MATH_1<"rcp.rz.f64 \t$dst, $src0;", Float64Regs,
- Float64Regs, int_nvvm_rcp_rz_d>;
-def INT_NVVM_RCP_RM_D : F_MATH_1<"rcp.rm.f64 \t$dst, $src0;", Float64Regs,
- Float64Regs, int_nvvm_rcp_rm_d>;
-def INT_NVVM_RCP_RP_D : F_MATH_1<"rcp.rp.f64 \t$dst, $src0;", Float64Regs,
- Float64Regs, int_nvvm_rcp_rp_d>;
+def INT_NVVM_RCP_RN_D : F_MATH_1<"rcp.rn.f64 \t$dst, $src0;", F64RT,
+ F64RT, int_nvvm_rcp_rn_d>;
+def INT_NVVM_RCP_RZ_D : F_MATH_1<"rcp.rz.f64 \t$dst, $src0;", F64RT,
+ F64RT, int_nvvm_rcp_rz_d>;
+def INT_NVVM_RCP_RM_D : F_MATH_1<"rcp.rm.f64 \t$dst, $src0;", F64RT,
+ F64RT, int_nvvm_rcp_rm_d>;
+def INT_NVVM_RCP_RP_D : F_MATH_1<"rcp.rp.f64 \t$dst, $src0;", F64RT,
+ F64RT, int_nvvm_rcp_rp_d>;
def INT_NVVM_RCP_APPROX_FTZ_F : F_MATH_1<"rcp.approx.ftz.f32 \t$dst, $src0;",
- Float32Regs, Float32Regs, int_nvvm_rcp_approx_ftz_f>;
+ F32RT, F32RT, int_nvvm_rcp_approx_ftz_f>;
def INT_NVVM_RCP_APPROX_FTZ_D : F_MATH_1<"rcp.approx.ftz.f64 \t$dst, $src0;",
- Float64Regs, Float64Regs, int_nvvm_rcp_approx_ftz_d>;
+ F64RT, F64RT, int_nvvm_rcp_approx_ftz_d>;
//
// Sqrt
//
def INT_NVVM_SQRT_RN_FTZ_F : F_MATH_1<"sqrt.rn.ftz.f32 \t$dst, $src0;",
- Float32Regs, Float32Regs, int_nvvm_sqrt_rn_ftz_f>;
-def INT_NVVM_SQRT_RN_F : F_MATH_1<"sqrt.rn.f32 \t$dst, $src0;", Float32Regs,
- Float32Regs, int_nvvm_sqrt_rn_f>;
+ F32RT, F32RT, int_nvvm_sqrt_rn_ftz_f>;
+def INT_NVVM_SQRT_RN_F : F_MATH_1<"sqrt.rn.f32 \t$dst, $src0;", F32RT,
+ F32RT, int_nvvm_sqrt_rn_f>;
def INT_NVVM_SQRT_RZ_FTZ_F : F_MATH_1<"sqrt.rz.ftz.f32 \t$dst, $src0;",
- Float32Regs, Float32Regs, int_nvvm_sqrt_rz_ftz_f>;
-def INT_NVVM_SQRT_RZ_F : F_MATH_1<"sqrt.rz.f32 \t$dst, $src0;", Float32Regs,
- Float32Regs, int_nvvm_sqrt_rz_f>;
+ F32RT, F32RT, int_nvvm_sqrt_rz_ftz_f>;
+def INT_NVVM_SQRT_RZ_F : F_MATH_1<"sqrt.rz.f32 \t$dst, $src0;", F32RT,
+ F32RT, int_nvvm_sqrt_rz_f>;
def INT_NVVM_SQRT_RM_FTZ_F : F_MATH_1<"sqrt.rm.ftz.f32 \t$dst, $src0;",
- Float32Regs, Float32Regs, int_nvvm_sqrt_rm_ftz_f>;
-def INT_NVVM_SQRT_RM_F : F_MATH_1<"sqrt.rm.f32 \t$dst, $src0;", Float32Regs,
- Float32Regs, int_nvvm_sqrt_rm_f>;
+ F32RT, F32RT, int_nvvm_sqrt_rm_ftz_f>;
+def INT_NVVM_SQRT_RM_F : F_MATH_1<"sqrt.rm.f32 \t$dst, $src0;", F32RT,
+ F32RT, int_nvvm_sqrt_rm_f>;
def INT_NVVM_SQRT_RP_FTZ_F : F_MATH_1<"sqrt.rp.ftz.f32 \t$dst, $src0;",
- Float32Regs, Float32Regs, int_nvvm_sqrt_rp_ftz_f>;
-def INT_NVVM_SQRT_RP_F : F_MATH_1<"sqrt.rp.f32 \t$dst, $src0;", Float32Regs,
- Float32Regs, int_nvvm_sqrt_rp_f>;
+ F32RT, F32RT, int_nvvm_sqrt_rp_ftz_f>;
+def INT_NVVM_SQRT_RP_F : F_MATH_1<"sqrt.rp.f32 \t$dst, $src0;", F32RT,
+ F32RT, int_nvvm_sqrt_rp_f>;
def INT_NVVM_SQRT_APPROX_FTZ_F : F_MATH_1<"sqrt.approx.ftz.f32 \t$dst, $src0;",
- Float32Regs, Float32Regs, int_nvvm_sqrt_approx_ftz_f>;
+ F32RT, F32RT, int_nvvm_sqrt_approx_ftz_f>;
def INT_NVVM_SQRT_APPROX_F : F_MATH_1<"sqrt.approx.f32 \t$dst, $src0;",
- Float32Regs, Float32Regs, int_nvvm_sqrt_approx_f>;
+ F32RT, F32RT, int_nvvm_sqrt_approx_f>;
-def INT_NVVM_SQRT_RN_D : F_MATH_1<"sqrt.rn.f64 \t$dst, $src0;", Float64Regs,
- Float64Regs, int_nvvm_sqrt_rn_d>;
-def INT_NVVM_SQRT_RZ_D : F_MATH_1<"sqrt.rz.f64 \t$dst, $src0;", Float64Regs,
- Float64Regs, int_nvvm_sqrt_rz_d>;
-def INT_NVVM_SQRT_RM_D : F_MATH_1<"sqrt.rm.f64 \t$dst, $src0;", Float64Regs,
- Float64Regs, int_nvvm_sqrt_rm_d>;
-def INT_NVVM_SQRT_RP_D : F_MATH_1<"sqrt.rp.f64 \t$dst, $src0;", Float64Regs,
- Float64Regs, int_nvvm_sqrt_rp_d>;
+def INT_NVVM_SQRT_RN_D : F_MATH_1<"sqrt.rn.f64 \t$dst, $src0;", F64RT,
+ F64RT, int_nvvm_sqrt_rn_d>;
+def INT_NVVM_SQRT_RZ_D : F_MATH_1<"sqrt.rz.f64 \t$dst, $src0;", F64RT,
+ F64RT, int_nvvm_sqrt_rz_d>;
+def INT_NVVM_SQRT_RM_D : F_MATH_1<"sqrt.rm.f64 \t$dst, $src0;", F64RT,
+ F64RT, int_nvvm_sqrt_rm_d>;
+def INT_NVVM_SQRT_RP_D : F_MATH_1<"sqrt.rp.f64 \t$dst, $src0;", F64RT,
+ F64RT, int_nvvm_sqrt_rp_d>;
// nvvm_sqrt intrinsic
def : Pat<(int_nvvm_sqrt_f f32:$a),
@@ -1590,16 +1594,16 @@ def : Pat<(int_nvvm_sqrt_f f32:$a),
//
def INT_NVVM_RSQRT_APPROX_FTZ_F
- : F_MATH_1<"rsqrt.approx.ftz.f32 \t$dst, $src0;", Float32Regs, Float32Regs,
+ : F_MATH_1<"rsqrt.approx.ftz.f32 \t$dst, $src0;", F32RT, F32RT,
int_nvvm_rsqrt_approx_ftz_f>;
def INT_NVVM_RSQRT_APPROX_FTZ_D
- : F_MATH_1<"rsqrt.approx.ftz.f64 \t$dst, $src0;", Float64Regs, Float64Regs,
+ : F_MATH_1<"rsqrt.approx.ftz.f64 \t$dst, $src0;", F64RT, F64RT,
int_nvvm_rsqrt_approx_ftz_d>;
def INT_NVVM_RSQRT_APPROX_F : F_MATH_1<"rsqrt.approx.f32 \t$dst, $src0;",
- Float32Regs, Float32Regs, int_nvvm_rsqrt_approx_f>;
+ F32RT, F32RT, int_nvvm_rsqrt_approx_f>;
def INT_NVVM_RSQRT_APPROX_D : F_MATH_1<"rsqrt.approx.f64 \t$dst, $src0;",
- Float64Regs, Float64Regs, int_nvvm_rsqrt_approx_d>;
+ F64RT, F64RT, int_nvvm_rsqrt_approx_d>;
// 1.0f / sqrt_approx -> rsqrt_approx
def: Pat<(fdiv FloatConst1, (int_nvvm_sqrt_approx_f f32:$a)),
@@ -1815,13 +1819,13 @@ def INT_NVVM_D2I_LO : F_MATH_1<
".reg .b32 %temp; \n\t",
"mov.b64 \t{$dst, %temp}, $src0;\n\t",
"}}"),
- Int32Regs, Float64Regs, int_nvvm_d2i_lo>;
+ I32RT, F64RT, int_nvvm_d2i_lo>;
def INT_NVVM_D2I_HI : F_MATH_1<
!strconcat("{{\n\t",
".reg .b32 %temp; \n\t",
"mov.b64 \t{%temp, $dst}, $src0;\n\t",
"}}"),
- Int32Regs, Float64Regs, int_nvvm_d2i_hi>;
+ I32RT, F64RT, int_nvvm_d2i_hi>;
def : Pat<(int_nvvm_f2ll_rn_ftz f32:$a),
(CVT_s64_f32 $a, CvtRNI_FTZ)>;
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
index 81ad01bea8867..15695c287d52d 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
@@ -141,6 +141,7 @@ static Instruction *convertNvvmIntrinsicToLlvm(InstCombiner &IC,
enum SpecialCase {
SPC_Reciprocal,
SCP_FunnelShiftClamp,
+ SPC_Fabs,
};
// SimplifyAction is a poor-man's variant (plus an additional flag) that
@@ -185,8 +186,8 @@ static Instruction *convertNvvmIntrinsicToLlvm(InstCombiner &IC,
return {Intrinsic::ceil, FTZ_MustBeOff};
case Intrinsic::nvvm_ceil_ftz_f:
return {Intrinsic::ceil, FTZ_MustBeOn};
- case Intrinsic::nvvm_fabs_d:
- return {Intrinsic::fabs, FTZ_Any};
+ case Intrinsic::nvvm_fabs:
+ return {SPC_Fabs, FTZ_Any};
case Intrinsic::nvvm_floor_d:
return {Intrinsic::floor, FTZ_Any};
case Intrinsic::nvvm_floor_f:
@@ -411,6 +412,11 @@ static Instruction *conve...
[truncated]
|
@llvm/pr-subscribers-backend-nvptx Author: Alex MacLean (AlexMaclean) ChangesThis change unifies the NVVM intrinsics for floating point absolute value into two new overloaded intrinsics " Patch is 25.74 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/135644.diff 8 Files Affected:
diff --git a/llvm/docs/NVPTXUsage.rst b/llvm/docs/NVPTXUsage.rst
index 621879fc5648b..fbb7122b9b42d 100644
--- a/llvm/docs/NVPTXUsage.rst
+++ b/llvm/docs/NVPTXUsage.rst
@@ -309,6 +309,58 @@ space casted to this space), 1 is returned, otherwise 0 is returned.
Arithmetic Intrinsics
---------------------
+'``llvm.nvvm.fabs.*``' Intrinsic
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Syntax:
+"""""""
+
+.. code-block:: llvm
+
+ declare float @llvm.nvvm.fabs.f32(float %a)
+ declare double @llvm.nvvm.fabs.f64(double %a)
+ declare half @llvm.nvvm.fabs.f16(half %a)
+ declare <2 x bfloat> @llvm.nvvm.fabs.v2bf16(<2 x bfloat> %a)
+
+Overview:
+"""""""""
+
+The '``llvm.nvvm.fabs.*``' intrinsics return the absolute value of the operand.
+
+Semantics:
+""""""""""
+
+Unlike, '``llvm.fabs.*``', these intrinsics do not perfectly preserve NaN
+values. Instead, a NaN input yeilds an unspecified NaN output. The exception to
+this rule is the double precision variant, for which NaN is preserved.
+
+
+'``llvm.nvvm.fabs.ftz.*``' Intrinsic
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Syntax:
+"""""""
+
+.. code-block:: llvm
+
+ declare float @llvm.nvvm.fabs.ftz.f32(float %a)
+ declare half @llvm.nvvm.fabs.ftz.f16(half %a)
+ declare <2 x half> @llvm.nvvm.fabs.ftz.v2f16(<2 x half> %a)
+
+Overview:
+"""""""""
+
+The '``llvm.nvvm.fabs.ftz.*``' intrinsics return the absolute value of the
+operand, flushing subnormals to sign preserving zero.
+
+Semantics:
+""""""""""
+
+Before the absolute value is taken, the input is flushed to sign preserving
+zero if it is a subnormal. In addtion, unlike '``llvm.fabs.*``', a NaN input
+yields an unspecified NaN output.
+
+
'``llvm.nvvm.idp2a.[us].[us]``' Intrinsics
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 4aeb1d8a2779e..2cea00c640a02 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -1039,18 +1039,18 @@ let TargetPrefix = "nvvm" in {
// Abs
//
- def int_nvvm_fabs_ftz_f : ClangBuiltin<"__nvvm_fabs_ftz_f">,
- DefaultAttrsIntrinsic<[llvm_float_ty], [llvm_float_ty], [IntrNoMem, IntrSpeculatable]>;
- def int_nvvm_fabs_f : ClangBuiltin<"__nvvm_fabs_f">,
- DefaultAttrsIntrinsic<[llvm_float_ty], [llvm_float_ty], [IntrNoMem, IntrSpeculatable]>;
- def int_nvvm_fabs_d : ClangBuiltin<"__nvvm_fabs_d">,
- DefaultAttrsIntrinsic<[llvm_double_ty], [llvm_double_ty], [IntrNoMem, IntrSpeculatable]>;
+ def int_nvvm_fabs_ftz :
+ DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>],
+ [IntrNoMem, IntrSpeculatable]>;
+ def int_nvvm_fabs :
+ DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>],
+ [IntrNoMem, IntrSpeculatable]>;
//
// Abs, Neg bf16, bf16x2
//
- foreach unary = ["abs", "neg"] in {
+ foreach unary = ["neg"] in {
def int_nvvm_ # unary # _bf16 :
ClangBuiltin<!strconcat("__nvvm_", unary, "_bf16")>,
DefaultAttrsIntrinsic<[llvm_bfloat_ty], [llvm_bfloat_ty], [IntrNoMem]>;
diff --git a/llvm/lib/IR/AutoUpgrade.cpp b/llvm/lib/IR/AutoUpgrade.cpp
index 0b329d91c3c7c..79709ae1b1a1d 100644
--- a/llvm/lib/IR/AutoUpgrade.cpp
+++ b/llvm/lib/IR/AutoUpgrade.cpp
@@ -939,12 +939,6 @@ static bool upgradeArmOrAarch64IntrinsicFunction(bool IsArm, Function *F,
}
static Intrinsic::ID shouldUpgradeNVPTXBF16Intrinsic(StringRef Name) {
- if (Name.consume_front("abs."))
- return StringSwitch<Intrinsic::ID>(Name)
- .Case("bf16", Intrinsic::nvvm_abs_bf16)
- .Case("bf16x2", Intrinsic::nvvm_abs_bf16x2)
- .Default(Intrinsic::not_intrinsic);
-
if (Name.consume_front("fma.rn."))
return StringSwitch<Intrinsic::ID>(Name)
.Case("bf16", Intrinsic::nvvm_fma_rn_bf16)
@@ -1291,7 +1285,7 @@ static bool upgradeIntrinsicFunction1(Function *F, Function *&NewFn,
bool Expand = false;
if (Name.consume_front("abs."))
// nvvm.abs.{i,ii}
- Expand = Name == "i" || Name == "ll";
+ Expand = Name == "i" || Name == "ll" || Name == "bf16" || Name == "bf16x2";
else if (Name == "clz.ll" || Name == "popc.ll" || Name == "h2f" ||
Name == "swap.lo.hi.b64")
Expand = true;
@@ -2311,6 +2305,11 @@ static Value *upgradeNVVMIntrinsicCall(StringRef Name, CallBase *CI,
Value *Cmp = Builder.CreateICmpSGE(
Arg, llvm::Constant::getNullValue(Arg->getType()), "abs.cond");
Rep = Builder.CreateSelect(Cmp, Arg, Neg, "abs");
+ } else if (Name == "abs.bf16" || Name == "abs.bf16x2") {
+ Type *Ty = (Name == "abs.bf16") ? Builder.getBFloatTy() : FixedVectorType::get(Builder.getBFloatTy(), 2);
+ Value *Arg = Builder.CreateBitCast(CI->getArgOperand(0), Ty);
+ Value *Abs = Builder.CreateUnaryIntrinsic(Intrinsic::nvvm_fabs, Arg);
+ Rep = Builder.CreateBitCast(Abs, CI->getType());
} else if (Name.starts_with("atomic.load.add.f32.p") ||
Name.starts_with("atomic.load.add.f64.p")) {
Value *Ptr = CI->getArgOperand(0);
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index aa0eedb1b7446..62fdb6d2523f4 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -226,14 +226,17 @@ class RegTyInfo<ValueType ty, NVPTXRegClass rc, Operand imm, SDNode imm_node,
int Size = ty.Size;
}
-def I16RT : RegTyInfo<i16, Int16Regs, i16imm, imm>;
-def I32RT : RegTyInfo<i32, Int32Regs, i32imm, imm>;
-def I64RT : RegTyInfo<i64, Int64Regs, i64imm, imm>;
-
-def F32RT : RegTyInfo<f32, Float32Regs, f32imm, fpimm>;
-def F64RT : RegTyInfo<f64, Float64Regs, f64imm, fpimm>;
-def F16RT : RegTyInfo<f16, Int16Regs, f16imm, fpimm, supports_imm = 0>;
-def BF16RT : RegTyInfo<bf16, Int16Regs, bf16imm, fpimm, supports_imm = 0>;
+def I16RT : RegTyInfo<i16, Int16Regs, i16imm, imm>;
+def I32RT : RegTyInfo<i32, Int32Regs, i32imm, imm>;
+def I64RT : RegTyInfo<i64, Int64Regs, i64imm, imm>;
+
+def F32RT : RegTyInfo<f32, Float32Regs, f32imm, fpimm>;
+def F64RT : RegTyInfo<f64, Float64Regs, f64imm, fpimm>;
+def F16RT : RegTyInfo<f16, Int16Regs, f16imm, fpimm, supports_imm = 0>;
+def BF16RT : RegTyInfo<bf16, Int16Regs, bf16imm, fpimm, supports_imm = 0>;
+
+def F16X2RT : RegTyInfo<v2f16, Int32Regs, ?, ?, supports_imm = 0>;
+def BF16X2RT : RegTyInfo<v2bf16, Int32Regs, ?, ?, supports_imm = 0>;
// Template for instructions which take three int64, int32, or int16 args.
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 8528ff702f236..6f6601555f6e4 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -983,12 +983,13 @@ def : Pat<(int_nvvm_fmin_d
// We need a full string for OpcStr here because we need to deal with case like
// INT_PTX_RECIP.
-class F_MATH_1<string OpcStr, NVPTXRegClass target_regclass,
- NVPTXRegClass src_regclass, Intrinsic IntOP, list<Predicate> Preds = []>
- : NVPTXInst<(outs target_regclass:$dst), (ins src_regclass:$src0),
- OpcStr,
- [(set target_regclass:$dst, (IntOP src_regclass:$src0))]>,
- Requires<Preds>;
+class F_MATH_1<string OpcStr, RegTyInfo dst, RegTyInfo src, Intrinsic IntOP,
+ list<Predicate> Preds = []>
+ : NVPTXInst<(outs dst.RC:$dst),
+ (ins src.RC:$src0),
+ OpcStr,
+ [(set dst.Ty:$dst, (IntOP src.Ty:$src0))]>,
+ Requires<Preds>;
// We need a full string for OpcStr here because we need to deal with the case
// like INT_PTX_NATIVE_POWR_F.
@@ -1307,13 +1308,20 @@ def : Pat<(int_nvvm_ceil_d f64:$a),
// Abs
//
-def INT_NVVM_FABS_FTZ_F : F_MATH_1<"abs.ftz.f32 \t$dst, $src0;", Float32Regs,
- Float32Regs, int_nvvm_fabs_ftz_f>;
-def INT_NVVM_FABS_F : F_MATH_1<"abs.f32 \t$dst, $src0;", Float32Regs,
- Float32Regs, int_nvvm_fabs_f>;
+multiclass F_ABS<string suffix, RegTyInfo RT, bit support_ftz, list<Predicate> preds = []> {
+ def "" : F_MATH_1<"abs." # suffix # " \t$dst, $src0;", RT, RT, int_nvvm_fabs, preds>;
+ if support_ftz then
+ def _FTZ : F_MATH_1<"abs.ftz." # suffix # " \t$dst, $src0;", RT, RT, int_nvvm_fabs_ftz, preds>;
+}
+
+defm ABS_F16 : F_ABS<"f16", F16RT, support_ftz = true, preds = [hasPTX<65>, hasSM<53>]>;
+defm ABS_F16X2 : F_ABS<"f16x2", F16X2RT, support_ftz = true, preds = [hasPTX<65>, hasSM<53>]>;
+
+defm ABS_BF16 : F_ABS<"bf16", BF16RT, support_ftz = false, preds = [hasPTX<70>, hasSM<80>]>;
+defm ABS_BF16X2 : F_ABS<"bf16x2", BF16X2RT, support_ftz = false, preds = [hasPTX<70>, hasSM<80>]>;
-def INT_NVVM_FABS_D : F_MATH_1<"abs.f64 \t$dst, $src0;", Float64Regs,
- Float64Regs, int_nvvm_fabs_d>;
+defm ABS_F32 : F_ABS<"f32", F32RT, support_ftz = true>;
+defm ABS_F64 : F_ABS<"f64", F64RT, support_ftz = false>;
//
// copysign
@@ -1332,17 +1340,13 @@ def COPYSIGN_D :
[(set f64:$dst, (fcopysign_nvptx f64:$src1, f64:$src0))]>;
//
-// Abs, Neg bf16, bf16x2
+// Neg bf16, bf16x2
//
-def INT_NVVM_ABS_BF16 : F_MATH_1<"abs.bf16 \t$dst, $src0;", Int16Regs,
- Int16Regs, int_nvvm_abs_bf16, [hasPTX<70>, hasSM<80>]>;
-def INT_NVVM_ABS_BF16X2 : F_MATH_1<"abs.bf16x2 \t$dst, $src0;", Int32Regs,
- Int32Regs, int_nvvm_abs_bf16x2, [hasPTX<70>, hasSM<80>]>;
-def INT_NVVM_NEG_BF16 : F_MATH_1<"neg.bf16 \t$dst, $src0;", Int16Regs,
- Int16Regs, int_nvvm_neg_bf16, [hasPTX<70>, hasSM<80>]>;
-def INT_NVVM_NEG_BF16X2 : F_MATH_1<"neg.bf16x2 \t$dst, $src0;", Int32Regs,
- Int32Regs, int_nvvm_neg_bf16x2, [hasPTX<70>, hasSM<80>]>;
+def INT_NVVM_NEG_BF16 : F_MATH_1<"neg.bf16 \t$dst, $src0;", BF16RT,
+ BF16RT, int_nvvm_neg_bf16, [hasPTX<70>, hasSM<80>]>;
+def INT_NVVM_NEG_BF16X2 : F_MATH_1<"neg.bf16x2 \t$dst, $src0;", BF16X2RT,
+ BF16X2RT, int_nvvm_neg_bf16x2, [hasPTX<70>, hasSM<80>]>;
//
// Round
@@ -1382,16 +1386,16 @@ def : Pat<(int_nvvm_saturate_d f64:$a),
//
def INT_NVVM_EX2_APPROX_FTZ_F : F_MATH_1<"ex2.approx.ftz.f32 \t$dst, $src0;",
- Float32Regs, Float32Regs, int_nvvm_ex2_approx_ftz_f>;
+ F32RT, F32RT, int_nvvm_ex2_approx_ftz_f>;
def INT_NVVM_EX2_APPROX_F : F_MATH_1<"ex2.approx.f32 \t$dst, $src0;",
- Float32Regs, Float32Regs, int_nvvm_ex2_approx_f>;
+ F32RT, F32RT, int_nvvm_ex2_approx_f>;
def INT_NVVM_EX2_APPROX_D : F_MATH_1<"ex2.approx.f64 \t$dst, $src0;",
- Float64Regs, Float64Regs, int_nvvm_ex2_approx_d>;
+ F64RT, F64RT, int_nvvm_ex2_approx_d>;
def INT_NVVM_EX2_APPROX_F16 : F_MATH_1<"ex2.approx.f16 \t$dst, $src0;",
- Int16Regs, Int16Regs, int_nvvm_ex2_approx_f16, [hasPTX<70>, hasSM<75>]>;
+ F16RT, F16RT, int_nvvm_ex2_approx_f16, [hasPTX<70>, hasSM<75>]>;
def INT_NVVM_EX2_APPROX_F16X2 : F_MATH_1<"ex2.approx.f16x2 \t$dst, $src0;",
- Int32Regs, Int32Regs, int_nvvm_ex2_approx_f16x2, [hasPTX<70>, hasSM<75>]>;
+ F16X2RT, F16X2RT, int_nvvm_ex2_approx_f16x2, [hasPTX<70>, hasSM<75>]>;
def : Pat<(fexp2 f32:$a),
(INT_NVVM_EX2_APPROX_FTZ_F $a)>, Requires<[doF32FTZ]>;
@@ -1403,11 +1407,11 @@ def : Pat<(fexp2 v2f16:$a),
(INT_NVVM_EX2_APPROX_F16X2 $a)>, Requires<[useFP16Math]>;
def INT_NVVM_LG2_APPROX_FTZ_F : F_MATH_1<"lg2.approx.ftz.f32 \t$dst, $src0;",
- Float32Regs, Float32Regs, int_nvvm_lg2_approx_ftz_f>;
+ F32RT, F32RT, int_nvvm_lg2_approx_ftz_f>;
def INT_NVVM_LG2_APPROX_F : F_MATH_1<"lg2.approx.f32 \t$dst, $src0;",
- Float32Regs, Float32Regs, int_nvvm_lg2_approx_f>;
+ F32RT, F32RT, int_nvvm_lg2_approx_f>;
def INT_NVVM_LG2_APPROX_D : F_MATH_1<"lg2.approx.f64 \t$dst, $src0;",
- Float64Regs, Float64Regs, int_nvvm_lg2_approx_d>;
+ F64RT, F64RT, int_nvvm_lg2_approx_d>;
def : Pat<(flog2 f32:$a), (INT_NVVM_LG2_APPROX_FTZ_F $a)>,
Requires<[doF32FTZ]>;
@@ -1419,14 +1423,14 @@ def : Pat<(flog2 f32:$a), (INT_NVVM_LG2_APPROX_F $a)>,
//
def INT_NVVM_SIN_APPROX_FTZ_F : F_MATH_1<"sin.approx.ftz.f32 \t$dst, $src0;",
- Float32Regs, Float32Regs, int_nvvm_sin_approx_ftz_f>;
+ F32RT, F32RT, int_nvvm_sin_approx_ftz_f>;
def INT_NVVM_SIN_APPROX_F : F_MATH_1<"sin.approx.f32 \t$dst, $src0;",
- Float32Regs, Float32Regs, int_nvvm_sin_approx_f>;
+ F32RT, F32RT, int_nvvm_sin_approx_f>;
def INT_NVVM_COS_APPROX_FTZ_F : F_MATH_1<"cos.approx.ftz.f32 \t$dst, $src0;",
- Float32Regs, Float32Regs, int_nvvm_cos_approx_ftz_f>;
+ F32RT, F32RT, int_nvvm_cos_approx_ftz_f>;
def INT_NVVM_COS_APPROX_F : F_MATH_1<"cos.approx.f32 \t$dst, $src0;",
- Float32Regs, Float32Regs, int_nvvm_cos_approx_f>;
+ F32RT, F32RT, int_nvvm_cos_approx_f>;
//
// Fma
@@ -1511,69 +1515,69 @@ defm INT_NVVM_FMA : FMA_INST;
//
def INT_NVVM_RCP_RN_FTZ_F : F_MATH_1<"rcp.rn.ftz.f32 \t$dst, $src0;",
- Float32Regs, Float32Regs, int_nvvm_rcp_rn_ftz_f>;
+ F32RT, F32RT, int_nvvm_rcp_rn_ftz_f>;
def INT_NVVM_RCP_RN_F : F_MATH_1<"rcp.rn.f32 \t$dst, $src0;",
- Float32Regs, Float32Regs, int_nvvm_rcp_rn_f>;
+ F32RT, F32RT, int_nvvm_rcp_rn_f>;
def INT_NVVM_RCP_RZ_FTZ_F : F_MATH_1<"rcp.rz.ftz.f32 \t$dst, $src0;",
- Float32Regs, Float32Regs, int_nvvm_rcp_rz_ftz_f>;
+ F32RT, F32RT, int_nvvm_rcp_rz_ftz_f>;
def INT_NVVM_RCP_RZ_F : F_MATH_1<"rcp.rz.f32 \t$dst, $src0;",
- Float32Regs, Float32Regs, int_nvvm_rcp_rz_f>;
+ F32RT, F32RT, int_nvvm_rcp_rz_f>;
def INT_NVVM_RCP_RM_FTZ_F : F_MATH_1<"rcp.rm.ftz.f32 \t$dst, $src0;",
- Float32Regs, Float32Regs, int_nvvm_rcp_rm_ftz_f>;
+ F32RT, F32RT, int_nvvm_rcp_rm_ftz_f>;
def INT_NVVM_RCP_RM_F : F_MATH_1<"rcp.rm.f32 \t$dst, $src0;",
- Float32Regs, Float32Regs, int_nvvm_rcp_rm_f>;
+ F32RT, F32RT, int_nvvm_rcp_rm_f>;
def INT_NVVM_RCP_RP_FTZ_F : F_MATH_1<"rcp.rp.ftz.f32 \t$dst, $src0;",
- Float32Regs, Float32Regs, int_nvvm_rcp_rp_ftz_f>;
+ F32RT, F32RT, int_nvvm_rcp_rp_ftz_f>;
def INT_NVVM_RCP_RP_F : F_MATH_1<"rcp.rp.f32 \t$dst, $src0;",
- Float32Regs, Float32Regs, int_nvvm_rcp_rp_f>;
+ F32RT, F32RT, int_nvvm_rcp_rp_f>;
-def INT_NVVM_RCP_RN_D : F_MATH_1<"rcp.rn.f64 \t$dst, $src0;", Float64Regs,
- Float64Regs, int_nvvm_rcp_rn_d>;
-def INT_NVVM_RCP_RZ_D : F_MATH_1<"rcp.rz.f64 \t$dst, $src0;", Float64Regs,
- Float64Regs, int_nvvm_rcp_rz_d>;
-def INT_NVVM_RCP_RM_D : F_MATH_1<"rcp.rm.f64 \t$dst, $src0;", Float64Regs,
- Float64Regs, int_nvvm_rcp_rm_d>;
-def INT_NVVM_RCP_RP_D : F_MATH_1<"rcp.rp.f64 \t$dst, $src0;", Float64Regs,
- Float64Regs, int_nvvm_rcp_rp_d>;
+def INT_NVVM_RCP_RN_D : F_MATH_1<"rcp.rn.f64 \t$dst, $src0;", F64RT,
+ F64RT, int_nvvm_rcp_rn_d>;
+def INT_NVVM_RCP_RZ_D : F_MATH_1<"rcp.rz.f64 \t$dst, $src0;", F64RT,
+ F64RT, int_nvvm_rcp_rz_d>;
+def INT_NVVM_RCP_RM_D : F_MATH_1<"rcp.rm.f64 \t$dst, $src0;", F64RT,
+ F64RT, int_nvvm_rcp_rm_d>;
+def INT_NVVM_RCP_RP_D : F_MATH_1<"rcp.rp.f64 \t$dst, $src0;", F64RT,
+ F64RT, int_nvvm_rcp_rp_d>;
def INT_NVVM_RCP_APPROX_FTZ_F : F_MATH_1<"rcp.approx.ftz.f32 \t$dst, $src0;",
- Float32Regs, Float32Regs, int_nvvm_rcp_approx_ftz_f>;
+ F32RT, F32RT, int_nvvm_rcp_approx_ftz_f>;
def INT_NVVM_RCP_APPROX_FTZ_D : F_MATH_1<"rcp.approx.ftz.f64 \t$dst, $src0;",
- Float64Regs, Float64Regs, int_nvvm_rcp_approx_ftz_d>;
+ F64RT, F64RT, int_nvvm_rcp_approx_ftz_d>;
//
// Sqrt
//
def INT_NVVM_SQRT_RN_FTZ_F : F_MATH_1<"sqrt.rn.ftz.f32 \t$dst, $src0;",
- Float32Regs, Float32Regs, int_nvvm_sqrt_rn_ftz_f>;
-def INT_NVVM_SQRT_RN_F : F_MATH_1<"sqrt.rn.f32 \t$dst, $src0;", Float32Regs,
- Float32Regs, int_nvvm_sqrt_rn_f>;
+ F32RT, F32RT, int_nvvm_sqrt_rn_ftz_f>;
+def INT_NVVM_SQRT_RN_F : F_MATH_1<"sqrt.rn.f32 \t$dst, $src0;", F32RT,
+ F32RT, int_nvvm_sqrt_rn_f>;
def INT_NVVM_SQRT_RZ_FTZ_F : F_MATH_1<"sqrt.rz.ftz.f32 \t$dst, $src0;",
- Float32Regs, Float32Regs, int_nvvm_sqrt_rz_ftz_f>;
-def INT_NVVM_SQRT_RZ_F : F_MATH_1<"sqrt.rz.f32 \t$dst, $src0;", Float32Regs,
- Float32Regs, int_nvvm_sqrt_rz_f>;
+ F32RT, F32RT, int_nvvm_sqrt_rz_ftz_f>;
+def INT_NVVM_SQRT_RZ_F : F_MATH_1<"sqrt.rz.f32 \t$dst, $src0;", F32RT,
+ F32RT, int_nvvm_sqrt_rz_f>;
def INT_NVVM_SQRT_RM_FTZ_F : F_MATH_1<"sqrt.rm.ftz.f32 \t$dst, $src0;",
- Float32Regs, Float32Regs, int_nvvm_sqrt_rm_ftz_f>;
-def INT_NVVM_SQRT_RM_F : F_MATH_1<"sqrt.rm.f32 \t$dst, $src0;", Float32Regs,
- Float32Regs, int_nvvm_sqrt_rm_f>;
+ F32RT, F32RT, int_nvvm_sqrt_rm_ftz_f>;
+def INT_NVVM_SQRT_RM_F : F_MATH_1<"sqrt.rm.f32 \t$dst, $src0;", F32RT,
+ F32RT, int_nvvm_sqrt_rm_f>;
def INT_NVVM_SQRT_RP_FTZ_F : F_MATH_1<"sqrt.rp.ftz.f32 \t$dst, $src0;",
- Float32Regs, Float32Regs, int_nvvm_sqrt_rp_ftz_f>;
-def INT_NVVM_SQRT_RP_F : F_MATH_1<"sqrt.rp.f32 \t$dst, $src0;", Float32Regs,
- Float32Regs, int_nvvm_sqrt_rp_f>;
+ F32RT, F32RT, int_nvvm_sqrt_rp_ftz_f>;
+def INT_NVVM_SQRT_RP_F : F_MATH_1<"sqrt.rp.f32 \t$dst, $src0;", F32RT,
+ F32RT, int_nvvm_sqrt_rp_f>;
def INT_NVVM_SQRT_APPROX_FTZ_F : F_MATH_1<"sqrt.approx.ftz.f32 \t$dst, $src0;",
- Float32Regs, Float32Regs, int_nvvm_sqrt_approx_ftz_f>;
+ F32RT, F32RT, int_nvvm_sqrt_approx_ftz_f>;
def INT_NVVM_SQRT_APPROX_F : F_MATH_1<"sqrt.approx.f32 \t$dst, $src0;",
- Float32Regs, Float32Regs, int_nvvm_sqrt_approx_f>;
+ F32RT, F32RT, int_nvvm_sqrt_approx_f>;
-def INT_NVVM_SQRT_RN_D : F_MATH_1<"sqrt.rn.f64 \t$dst, $src0;", Float64Regs,
- Float64Regs, int_nvvm_sqrt_rn_d>;
-def INT_NVVM_SQRT_RZ_D : F_MATH_1<"sqrt.rz.f64 \t$dst, $src0;", Float64Regs,
- Float64Regs, int_nvvm_sqrt_rz_d>;
-def INT_NVVM_SQRT_RM_D : F_MATH_1<"sqrt.rm.f64 \t$dst, $src0;", Float64Regs,
- Float64Regs, int_nvvm_sqrt_rm_d>;
-def INT_NVVM_SQRT_RP_D : F_MATH_1<"sqrt.rp.f64 \t$dst, $src0;", Float64Regs,
- Float64Regs, int_nvvm_sqrt_rp_d>;
+def INT_NVVM_SQRT_RN_D : F_MATH_1<"sqrt.rn.f64 \t$dst, $src0;", F64RT,
+ F64RT, int_nvvm_sqrt_rn_d>;
+def INT_NVVM_SQRT_RZ_D : F_MATH_1<"sqrt.rz.f64 \t$dst, $src0;", F64RT,
+ F64RT, int_nvvm_sqrt_rz_d>;
+def INT_NVVM_SQRT_RM_D : F_MATH_1<"sqrt.rm.f64 \t$dst, $src0;", F64RT,
+ F64RT, int_nvvm_sqrt_rm_d>;
+def INT_NVVM_SQRT_RP_D : F_MATH_1<"sqrt.rp.f64 \t$dst, $src0;", F64RT,
+ F64RT, int_nvvm_sqrt_rp_d>;
// nvvm_sqrt intrinsic
def : Pat<(int_nvvm_sqrt_f f32:$a),
@@ -1590,16 +1594,16 @@ def : Pat<(int_nvvm_sqrt_f f32:$a),
//
def INT_NVVM_RSQRT_APPROX_FTZ_F
- : F_MATH_1<"rsqrt.approx.ftz.f32 \t$dst, $src0;", Float32Regs, Float32Regs,
+ : F_MATH_1<"rsqrt.approx.ftz.f32 \t$dst, $src0;", F32RT, F32RT,
int_nvvm_rsqrt_approx_ftz_f>;
def INT_NVVM_RSQRT_APPROX_FTZ_D
- : F_MATH_1<"rsqrt.approx.ftz.f64 \t$dst, $src0;", Float64Regs, Float64Regs,
+ : F_MATH_1<"rsqrt.approx.ftz.f64 \t$dst, $src0;", F64RT, F64RT,
int_nvvm_rsqrt_approx_ftz_d>;
def INT_NVVM_RSQRT_APPROX_F : F_MATH_1<"rsqrt.approx.f32 \t$dst, $src0;",
- Float32Regs, Float32Regs, int_nvvm_rsqrt_approx_f>;
+ F32RT, F32RT, int_nvvm_rsqrt_approx_f>;
def INT_NVVM_RSQRT_APPROX_D : F_MATH_1<"rsqrt.approx.f64 \t$dst, $src0;",
- Float64Regs, Float64Regs, int_nvvm_rsqrt_approx_d>;
+ F64RT, F64RT, int_nvvm_rsqrt_approx_d>;
// 1.0f / sqrt_approx -> rsqrt_approx
def: Pat<(fdiv FloatConst1, (int_nvvm_sqrt_approx_f f32:$a)),
@@ -1815,13 +1819,13 @@ def INT_NVVM_D2I_LO : F_MATH_1<
".reg .b32 %temp; \n\t",
"mov.b64 \t{$dst, %temp}, $src0;\n\t",
"}}"),
- Int32Regs, Float64Regs, int_nvvm_d2i_lo>;
+ I32RT, F64RT, int_nvvm_d2i_lo>;
def INT_NVVM_D2I_HI : F_MATH_1<
!strconcat("{{\n\t",
".reg .b32 %temp; \n\t",
"mov.b64 \t{%temp, $dst}, $src0;\n\t",
"}}"),
- Int32Regs, Float64Regs, int_nvvm_d2i_hi>;
+ I32RT, F64RT, int_nvvm_d2i_hi>;
def : Pat<(int_nvvm_f2ll_rn_ftz f32:$a),
(CVT_s64_f32 $a, CvtRNI_FTZ)>;
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
index 81ad01bea8867..15695c287d52d 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
@@ -141,6 +141,7 @@ static Instruction *convertNvvmIntrinsicToLlvm(InstCombiner &IC,
enum SpecialCase {
SPC_Reciprocal,
SCP_FunnelShiftClamp,
+ SPC_Fabs,
};
// SimplifyAction is a poor-man's variant (plus an additional flag) that
@@ -185,8 +186,8 @@ static Instruction *convertNvvmIntrinsicToLlvm(InstCombiner &IC,
return {Intrinsic::ceil, FTZ_MustBeOff};
case Intrinsic::nvvm_ceil_ftz_f:
return {Intrinsic::ceil, FTZ_MustBeOn};
- case Intrinsic::nvvm_fabs_d:
- return {Intrinsic::fabs, FTZ_Any};
+ case Intrinsic::nvvm_fabs:
+ return {SPC_Fabs, FTZ_Any};
case Intrinsic::nvvm_floor_d:
return {Intrinsic::floor, FTZ_Any};
case Intrinsic::nvvm_floor_f:
@@ -411,6 +412,11 @@ static Instruction *conve...
[truncated]
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
785e7d0
to
fd11c2b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM w/ a doc nit.
def F16X2RT : RegTyInfo<v2f16, Int32Regs, ?, ?, supports_imm = 0>; | ||
def BF16X2RT : RegTyInfo<v2bf16, Int32Regs, ?, ?, supports_imm = 0>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting. I never thought of passing ?
as an argument. That can indeed be convenient in some cases.
llvm/docs/NVPTXUsage.rst
Outdated
|
||
Unlike, '``llvm.fabs.*``', these intrinsics do not perfectly preserve NaN | ||
values. Instead, a NaN input yeilds an unspecified NaN output. The exception to | ||
this rule is the double precision variant, for which NaN is preserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The "exception" here seems like it's letting too much of the implementation leak through... it's hard to understand the semantics, and it'll make it very hard to transition if we decide to turn this into a target-independent intrinsic. Maybe just say it's always unspecified, and juts use the regular llvm.fabs.f64 where the nan handling matters?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea, this is a good point. I've updated the docs to remove the exception. A NaN will now always yield an unspecified NaN output.
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/160/builds/16485 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/180/builds/16474 Here is the relevant piece of the build log for the reference
|
These auto-upgrade rules are required after these intrinsics were removed in #135644
…llvm#135644) This change unifies the NVVM intrinsics for floating point absolute value into two new overloaded intrinsics "llvm.nvvm.fabs.*" and "llvm.nvvm.fabs.ftz.*". Documentation has been added specifying the semantics of these intrinsics to clarify how they differ from "llvm.fabs.*". In addition, support for these new intrinsics is extended to cover the f16 variants.
These auto-upgrade rules are required after these intrinsics were removed in llvm#135644
…llvm#135644) This change unifies the NVVM intrinsics for floating point absolute value into two new overloaded intrinsics "llvm.nvvm.fabs.*" and "llvm.nvvm.fabs.ftz.*". Documentation has been added specifying the semantics of these intrinsics to clarify how they differ from "llvm.fabs.*". In addition, support for these new intrinsics is extended to cover the f16 variants.
These auto-upgrade rules are required after these intrinsics were removed in llvm#135644
This change unifies the NVVM intrinsics for floating point absolute value into two new overloaded intrinsics "
llvm.nvvm.fabs.*
" and "llvm.nvvm.fabs.ftz.*
". Documentation has been added specifying the semantics of these intrinsics to clarify how they differ from "llvm.fabs.*
". In addition, support for these new intrinsics is extended to cover thef16
variants.