Skip to content

Reland "[NVPTX] deprecate nvvm.rotate.* intrinsics, cleanup funnel-shift handling" #110025

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 27, 2024

Conversation

AlexMaclean
Copy link
Member

@AlexMaclean AlexMaclean commented Sep 25, 2024

This change deprecates the following intrinsics which can be trivially
converted to llvm funnel-shift intrinsics:

  • @llvm.nvvm.rotate.b32
  • @llvm.nvvm.rotate.right.b64
  • @llvm.nvvm.rotate.b64

This fixes a bug in the previous version (#107655) which flipped the order of the operands to the PTX funnel shift instruction. In LLVM IR the high bits are the first arg and the low bits are the second arg, while in PTX this is reversed.

@AlexMaclean AlexMaclean self-assigned this Sep 25, 2024
@AlexMaclean AlexMaclean changed the title Reland "[NVPTX] deprecate nvvm.rotate.* intrinsics, cleanup funnel-shift handling (#107655)" Reland "[NVPTX] deprecate nvvm.rotate.* intrinsics, cleanup funnel-shift handling" Sep 25, 2024
@llvmbot
Copy link
Member

llvmbot commented Sep 25, 2024

@llvm/pr-subscribers-llvm-ir

Author: Alex MacLean (AlexMaclean)

Changes

This change deprecates the following intrinsics which can be trivially
converted to llvm funnel-shift intrinsics:

  • @llvm.nvvm.rotate.b32
  • @llvm.nvvm.rotate.right.b64
  • @llvm.nvvm.rotate.b64

This fixes a bug in the previous version which flipped the order of the operands to the PTX funnel shift instruction. In LLVM IR the high bits are the first arg and the low bits are the second arg, while in PTX this is reversed.


Patch is 52.58 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/110025.diff

10 Files Affected:

  • (modified) llvm/docs/ReleaseNotes.rst (+6)
  • (modified) llvm/include/llvm/IR/IntrinsicsNVVM.td (-16)
  • (modified) llvm/lib/IR/AutoUpgrade.cpp (+106-78)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+13-20)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.h (+2-2)
  • (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.td (+36-161)
  • (modified) llvm/lib/Target/NVPTX/NVPTXIntrinsics.td (+2-127)
  • (modified) llvm/test/Assembler/auto_upgrade_nvvm_intrinsics.ll (+17-1)
  • (modified) llvm/test/CodeGen/NVPTX/rotate.ll (+266-167)
  • (modified) llvm/test/CodeGen/NVPTX/rotate_64.ll (+23-10)
diff --git a/llvm/docs/ReleaseNotes.rst b/llvm/docs/ReleaseNotes.rst
index 05f5bd65fc5f6d..0784d93f18da8f 100644
--- a/llvm/docs/ReleaseNotes.rst
+++ b/llvm/docs/ReleaseNotes.rst
@@ -63,6 +63,12 @@ Changes to the LLVM IR
   * ``llvm.nvvm.bitcast.d2ll``
   * ``llvm.nvvm.bitcast.ll2d``
 
+* Remove the following intrinsics which can be replaced with a funnel-shift:
+
+  * ``llvm.nvvm.rotate.b32``
+  * ``llvm.nvvm.rotate.right.b64``
+  * ``llvm.nvvm.rotate.b64``
+
 Changes to LLVM infrastructure
 ------------------------------
 
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 737dd6092e2183..aa5294f5f9c909 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -4479,22 +4479,6 @@ def int_nvvm_sust_p_3d_v4i32_trap
               "llvm.nvvm.sust.p.3d.v4i32.trap">,
     ClangBuiltin<"__nvvm_sust_p_3d_v4i32_trap">;
 
-
-def int_nvvm_rotate_b32
-  : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty],
-              [IntrNoMem, IntrSpeculatable], "llvm.nvvm.rotate.b32">,
-              ClangBuiltin<"__nvvm_rotate_b32">;
-
-def int_nvvm_rotate_b64
-  : DefaultAttrsIntrinsic<[llvm_i64_ty], [llvm_i64_ty, llvm_i32_ty],
-             [IntrNoMem, IntrSpeculatable], "llvm.nvvm.rotate.b64">,
-             ClangBuiltin<"__nvvm_rotate_b64">;
-
-def int_nvvm_rotate_right_b64
-  : DefaultAttrsIntrinsic<[llvm_i64_ty], [llvm_i64_ty, llvm_i32_ty],
-              [IntrNoMem, IntrSpeculatable], "llvm.nvvm.rotate.right.b64">,
-              ClangBuiltin<"__nvvm_rotate_right_b64">;
-
 def int_nvvm_swap_lo_hi_b64
   : DefaultAttrsIntrinsic<[llvm_i64_ty], [llvm_i64_ty],
               [IntrNoMem, IntrSpeculatable], "llvm.nvvm.swap.lo.hi.b64">,
diff --git a/llvm/lib/IR/AutoUpgrade.cpp b/llvm/lib/IR/AutoUpgrade.cpp
index 02d1d9d9f78984..3390d651d6c693 100644
--- a/llvm/lib/IR/AutoUpgrade.cpp
+++ b/llvm/lib/IR/AutoUpgrade.cpp
@@ -1272,6 +1272,9 @@ static bool upgradeIntrinsicFunction1(Function *F, Function *&NewFn,
         // nvvm.bitcast.{f2i,i2f,ll2d,d2ll}
         Expand =
             Name == "f2i" || Name == "i2f" || Name == "ll2d" || Name == "d2ll";
+      else if (Name.consume_front("rotate."))
+        // nvvm.rotate.{b32,b64,right.b64}
+        Expand = Name == "b32" || Name == "b64" || Name == "right.b64";
       else
         Expand = false;
 
@@ -2258,6 +2261,108 @@ void llvm::UpgradeInlineAsmString(std::string *AsmStr) {
   }
 }
 
+static Value *upgradeNVVMIntrinsicCall(StringRef Name, CallBase *CI,
+                                       Function *F, IRBuilder<> &Builder) {
+  Value *Rep = nullptr;
+
+  if (Name == "abs.i" || Name == "abs.ll") {
+    Value *Arg = CI->getArgOperand(0);
+    Value *Neg = Builder.CreateNeg(Arg, "neg");
+    Value *Cmp = Builder.CreateICmpSGE(
+        Arg, llvm::Constant::getNullValue(Arg->getType()), "abs.cond");
+    Rep = Builder.CreateSelect(Cmp, Arg, Neg, "abs");
+  } else if (Name.starts_with("atomic.load.add.f32.p") ||
+             Name.starts_with("atomic.load.add.f64.p")) {
+    Value *Ptr = CI->getArgOperand(0);
+    Value *Val = CI->getArgOperand(1);
+    Rep = Builder.CreateAtomicRMW(AtomicRMWInst::FAdd, Ptr, Val, MaybeAlign(),
+                                  AtomicOrdering::SequentiallyConsistent);
+  } else if (Name.consume_front("max.") &&
+             (Name == "s" || Name == "i" || Name == "ll" || Name == "us" ||
+              Name == "ui" || Name == "ull")) {
+    Value *Arg0 = CI->getArgOperand(0);
+    Value *Arg1 = CI->getArgOperand(1);
+    Value *Cmp = Name.starts_with("u")
+                     ? Builder.CreateICmpUGE(Arg0, Arg1, "max.cond")
+                     : Builder.CreateICmpSGE(Arg0, Arg1, "max.cond");
+    Rep = Builder.CreateSelect(Cmp, Arg0, Arg1, "max");
+  } else if (Name.consume_front("min.") &&
+             (Name == "s" || Name == "i" || Name == "ll" || Name == "us" ||
+              Name == "ui" || Name == "ull")) {
+    Value *Arg0 = CI->getArgOperand(0);
+    Value *Arg1 = CI->getArgOperand(1);
+    Value *Cmp = Name.starts_with("u")
+                     ? Builder.CreateICmpULE(Arg0, Arg1, "min.cond")
+                     : Builder.CreateICmpSLE(Arg0, Arg1, "min.cond");
+    Rep = Builder.CreateSelect(Cmp, Arg0, Arg1, "min");
+  } else if (Name == "clz.ll") {
+    // llvm.nvvm.clz.ll returns an i32, but llvm.ctlz.i64 returns an i64.
+    Value *Arg = CI->getArgOperand(0);
+    Value *Ctlz = Builder.CreateCall(
+        Intrinsic::getDeclaration(F->getParent(), Intrinsic::ctlz,
+                                  {Arg->getType()}),
+        {Arg, Builder.getFalse()}, "ctlz");
+    Rep = Builder.CreateTrunc(Ctlz, Builder.getInt32Ty(), "ctlz.trunc");
+  } else if (Name == "popc.ll") {
+    // llvm.nvvm.popc.ll returns an i32, but llvm.ctpop.i64 returns an
+    // i64.
+    Value *Arg = CI->getArgOperand(0);
+    Value *Popc = Builder.CreateCall(
+        Intrinsic::getDeclaration(F->getParent(), Intrinsic::ctpop,
+                                  {Arg->getType()}),
+        Arg, "ctpop");
+    Rep = Builder.CreateTrunc(Popc, Builder.getInt32Ty(), "ctpop.trunc");
+  } else if (Name == "h2f") {
+    Rep = Builder.CreateCall(
+        Intrinsic::getDeclaration(F->getParent(), Intrinsic::convert_from_fp16,
+                                  {Builder.getFloatTy()}),
+        CI->getArgOperand(0), "h2f");
+  } else if (Name.consume_front("bitcast.") &&
+             (Name == "f2i" || Name == "i2f" || Name == "ll2d" ||
+              Name == "d2ll")) {
+    Rep = Builder.CreateBitCast(CI->getArgOperand(0), CI->getType());
+  } else if (Name == "rotate.b32") {
+    Value *Arg = CI->getOperand(0);
+    Value *ShiftAmt = CI->getOperand(1);
+    Rep = Builder.CreateIntrinsic(Builder.getInt32Ty(), Intrinsic::fshl,
+                                  {Arg, Arg, ShiftAmt});
+  } else if (Name == "rotate.b64") {
+    Type *Int64Ty = Builder.getInt64Ty();
+    Value *Arg = CI->getOperand(0);
+    Value *ZExtShiftAmt = Builder.CreateZExt(CI->getOperand(1), Int64Ty);
+    Rep = Builder.CreateIntrinsic(Int64Ty, Intrinsic::fshl,
+                                  {Arg, Arg, ZExtShiftAmt});
+  } else if (Name == "rotate.right.b64") {
+    Type *Int64Ty = Builder.getInt64Ty();
+    Value *Arg = CI->getOperand(0);
+    Value *ZExtShiftAmt = Builder.CreateZExt(CI->getOperand(1), Int64Ty);
+    Rep = Builder.CreateIntrinsic(Int64Ty, Intrinsic::fshr,
+                                  {Arg, Arg, ZExtShiftAmt});
+  } else {
+    Intrinsic::ID IID = shouldUpgradeNVPTXBF16Intrinsic(Name);
+    if (IID != Intrinsic::not_intrinsic &&
+        !F->getReturnType()->getScalarType()->isBFloatTy()) {
+      rename(F);
+      Function *NewFn = Intrinsic::getDeclaration(F->getParent(), IID);
+      SmallVector<Value *, 2> Args;
+      for (size_t I = 0; I < NewFn->arg_size(); ++I) {
+        Value *Arg = CI->getArgOperand(I);
+        Type *OldType = Arg->getType();
+        Type *NewType = NewFn->getArg(I)->getType();
+        Args.push_back(
+            (OldType->isIntegerTy() && NewType->getScalarType()->isBFloatTy())
+                ? Builder.CreateBitCast(Arg, NewType)
+                : Arg);
+      }
+      Rep = Builder.CreateCall(NewFn, Args);
+      if (F->getReturnType()->isIntegerTy())
+        Rep = Builder.CreateBitCast(Rep, F->getReturnType());
+    }
+  }
+
+  return Rep;
+}
+
 static Value *upgradeX86IntrinsicCall(StringRef Name, CallBase *CI, Function *F,
                                       IRBuilder<> &Builder) {
   LLVMContext &C = F->getContext();
@@ -4208,85 +4313,8 @@ void llvm::UpgradeIntrinsicCall(CallBase *CI, Function *NewFn) {
 
     if (!IsX86 && Name == "stackprotectorcheck") {
       Rep = nullptr;
-    } else if (IsNVVM && (Name == "abs.i" || Name == "abs.ll")) {
-      Value *Arg = CI->getArgOperand(0);
-      Value *Neg = Builder.CreateNeg(Arg, "neg");
-      Value *Cmp = Builder.CreateICmpSGE(
-          Arg, llvm::Constant::getNullValue(Arg->getType()), "abs.cond");
-      Rep = Builder.CreateSelect(Cmp, Arg, Neg, "abs");
-    } else if (IsNVVM && (Name.starts_with("atomic.load.add.f32.p") ||
-                          Name.starts_with("atomic.load.add.f64.p"))) {
-      Value *Ptr = CI->getArgOperand(0);
-      Value *Val = CI->getArgOperand(1);
-      Rep = Builder.CreateAtomicRMW(AtomicRMWInst::FAdd, Ptr, Val, MaybeAlign(),
-                                    AtomicOrdering::SequentiallyConsistent);
-    } else if (IsNVVM && Name.consume_front("max.") &&
-               (Name == "s" || Name == "i" || Name == "ll" || Name == "us" ||
-                Name == "ui" || Name == "ull")) {
-      Value *Arg0 = CI->getArgOperand(0);
-      Value *Arg1 = CI->getArgOperand(1);
-      Value *Cmp = Name.starts_with("u")
-                       ? Builder.CreateICmpUGE(Arg0, Arg1, "max.cond")
-                       : Builder.CreateICmpSGE(Arg0, Arg1, "max.cond");
-      Rep = Builder.CreateSelect(Cmp, Arg0, Arg1, "max");
-    } else if (IsNVVM && Name.consume_front("min.") &&
-               (Name == "s" || Name == "i" || Name == "ll" || Name == "us" ||
-                Name == "ui" || Name == "ull")) {
-      Value *Arg0 = CI->getArgOperand(0);
-      Value *Arg1 = CI->getArgOperand(1);
-      Value *Cmp = Name.starts_with("u")
-                       ? Builder.CreateICmpULE(Arg0, Arg1, "min.cond")
-                       : Builder.CreateICmpSLE(Arg0, Arg1, "min.cond");
-      Rep = Builder.CreateSelect(Cmp, Arg0, Arg1, "min");
-    } else if (IsNVVM && Name == "clz.ll") {
-      // llvm.nvvm.clz.ll returns an i32, but llvm.ctlz.i64 returns an i64.
-      Value *Arg = CI->getArgOperand(0);
-      Value *Ctlz = Builder.CreateCall(
-          Intrinsic::getDeclaration(F->getParent(), Intrinsic::ctlz,
-                                    {Arg->getType()}),
-          {Arg, Builder.getFalse()}, "ctlz");
-      Rep = Builder.CreateTrunc(Ctlz, Builder.getInt32Ty(), "ctlz.trunc");
-    } else if (IsNVVM && Name == "popc.ll") {
-      // llvm.nvvm.popc.ll returns an i32, but llvm.ctpop.i64 returns an
-      // i64.
-      Value *Arg = CI->getArgOperand(0);
-      Value *Popc = Builder.CreateCall(
-          Intrinsic::getDeclaration(F->getParent(), Intrinsic::ctpop,
-                                    {Arg->getType()}),
-          Arg, "ctpop");
-      Rep = Builder.CreateTrunc(Popc, Builder.getInt32Ty(), "ctpop.trunc");
     } else if (IsNVVM) {
-      if (Name == "h2f") {
-        Rep =
-            Builder.CreateCall(Intrinsic::getDeclaration(
-                                   F->getParent(), Intrinsic::convert_from_fp16,
-                                   {Builder.getFloatTy()}),
-                               CI->getArgOperand(0), "h2f");
-      } else if (Name.consume_front("bitcast.") &&
-                 (Name == "f2i" || Name == "i2f" || Name == "ll2d" ||
-                  Name == "d2ll")) {
-        Rep = Builder.CreateBitCast(CI->getArgOperand(0), CI->getType());
-      } else {
-        Intrinsic::ID IID = shouldUpgradeNVPTXBF16Intrinsic(Name);
-        if (IID != Intrinsic::not_intrinsic &&
-            !F->getReturnType()->getScalarType()->isBFloatTy()) {
-          rename(F);
-          NewFn = Intrinsic::getDeclaration(F->getParent(), IID);
-          SmallVector<Value *, 2> Args;
-          for (size_t I = 0; I < NewFn->arg_size(); ++I) {
-            Value *Arg = CI->getArgOperand(I);
-            Type *OldType = Arg->getType();
-            Type *NewType = NewFn->getArg(I)->getType();
-            Args.push_back((OldType->isIntegerTy() &&
-                            NewType->getScalarType()->isBFloatTy())
-                               ? Builder.CreateBitCast(Arg, NewType)
-                               : Arg);
-          }
-          Rep = Builder.CreateCall(NewFn, Args);
-          if (F->getReturnType()->isIntegerTy())
-            Rep = Builder.CreateBitCast(Rep, F->getReturnType());
-        }
-      }
+      Rep = upgradeNVVMIntrinsicCall(Name, CI, F, Builder);
     } else if (IsX86) {
       Rep = upgradeX86IntrinsicCall(Name, CI, F, Builder);
     } else if (IsARM) {
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 26888342210918..8718b7890bf58a 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -594,20 +594,13 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
   setOperationAction(ISD::BITREVERSE, MVT::i32, Legal);
   setOperationAction(ISD::BITREVERSE, MVT::i64, Legal);
 
-  // TODO: we may consider expanding ROTL/ROTR on older GPUs.  Currently on GPUs
-  // that don't have h/w rotation we lower them to multi-instruction assembly.
-  // See ROT*_sw in NVPTXIntrInfo.td
-  setOperationAction(ISD::ROTL, MVT::i64, Legal);
-  setOperationAction(ISD::ROTR, MVT::i64, Legal);
-  setOperationAction(ISD::ROTL, MVT::i32, Legal);
-  setOperationAction(ISD::ROTR, MVT::i32, Legal);
-
-  setOperationAction(ISD::ROTL, MVT::i16, Expand);
-  setOperationAction(ISD::ROTL, MVT::v2i16, Expand);
-  setOperationAction(ISD::ROTR, MVT::i16, Expand);
-  setOperationAction(ISD::ROTR, MVT::v2i16, Expand);
-  setOperationAction(ISD::ROTL, MVT::i8, Expand);
-  setOperationAction(ISD::ROTR, MVT::i8, Expand);
+  setOperationAction({ISD::ROTL, ISD::ROTR},
+                     {MVT::i8, MVT::i16, MVT::v2i16, MVT::i32, MVT::i64},
+                     Expand);
+
+  if (STI.hasHWROT32())
+    setOperationAction({ISD::FSHL, ISD::FSHR}, MVT::i32, Legal);
+
   setOperationAction(ISD::BSWAP, MVT::i16, Expand);
 
   setOperationAction(ISD::BR_JT, MVT::Other, Custom);
@@ -958,8 +951,8 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
     MAKE_CASE(NVPTXISD::LDUV4)
     MAKE_CASE(NVPTXISD::StoreV2)
     MAKE_CASE(NVPTXISD::StoreV4)
-    MAKE_CASE(NVPTXISD::FUN_SHFL_CLAMP)
-    MAKE_CASE(NVPTXISD::FUN_SHFR_CLAMP)
+    MAKE_CASE(NVPTXISD::FSHL_CLAMP)
+    MAKE_CASE(NVPTXISD::FSHR_CLAMP)
     MAKE_CASE(NVPTXISD::IMAD)
     MAKE_CASE(NVPTXISD::BFE)
     MAKE_CASE(NVPTXISD::BFI)
@@ -2490,8 +2483,8 @@ SDValue NVPTXTargetLowering::LowerShiftRightParts(SDValue Op,
     //   dLo = shf.r.clamp aLo, aHi, Amt
 
     SDValue Hi = DAG.getNode(Opc, dl, VT, ShOpHi, ShAmt);
-    SDValue Lo = DAG.getNode(NVPTXISD::FUN_SHFR_CLAMP, dl, VT, ShOpLo, ShOpHi,
-                             ShAmt);
+    SDValue Lo =
+        DAG.getNode(NVPTXISD::FSHR_CLAMP, dl, VT, ShOpHi, ShOpLo, ShAmt);
 
     SDValue Ops[2] = { Lo, Hi };
     return DAG.getMergeValues(Ops, dl);
@@ -2549,8 +2542,8 @@ SDValue NVPTXTargetLowering::LowerShiftLeftParts(SDValue Op,
     //   dHi = shf.l.clamp aLo, aHi, Amt
     //   dLo = aLo << Amt
 
-    SDValue Hi = DAG.getNode(NVPTXISD::FUN_SHFL_CLAMP, dl, VT, ShOpLo, ShOpHi,
-                             ShAmt);
+    SDValue Hi =
+        DAG.getNode(NVPTXISD::FSHL_CLAMP, dl, VT, ShOpHi, ShOpLo, ShAmt);
     SDValue Lo = DAG.getNode(ISD::SHL, dl, VT, ShOpLo, ShAmt);
 
     SDValue Ops[2] = { Lo, Hi };
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index 70e16eee346aa2..8c3a597ce0b085 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -51,8 +51,8 @@ enum NodeType : unsigned {
   CallSeqEnd,
   CallPrototype,
   ProxyReg,
-  FUN_SHFL_CLAMP,
-  FUN_SHFR_CLAMP,
+  FSHL_CLAMP,
+  FSHR_CLAMP,
   MUL_WIDE_SIGNED,
   MUL_WIDE_UNSIGNED,
   IMAD,
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 510e4b81003119..0d9dd1b8ee70ac 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -1665,167 +1665,6 @@ def BREV64 :
              "brev.b64 \t$dst, $a;",
              [(set Int64Regs:$dst, (bitreverse Int64Regs:$a))]>;
 
-//
-// Rotate: Use ptx shf instruction if available.
-//
-
-// 32 bit r2 = rotl r1, n
-//    =>
-//        r2 = shf.l r1, r1, n
-def ROTL32imm_hw :
-  NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, i32imm:$amt),
-            "shf.l.wrap.b32 \t$dst, $src, $src, $amt;",
-            [(set Int32Regs:$dst, (rotl (i32 Int32Regs:$src), (i32 imm:$amt)))]>,
-           Requires<[hasHWROT32]>;
-
-def ROTL32reg_hw :
-  NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, Int32Regs:$amt),
-            "shf.l.wrap.b32 \t$dst, $src, $src, $amt;",
-            [(set Int32Regs:$dst, (rotl (i32 Int32Regs:$src), (i32 Int32Regs:$amt)))]>,
-           Requires<[hasHWROT32]>;
-
-// 32 bit r2 = rotr r1, n
-//    =>
-//        r2 = shf.r r1, r1, n
-def ROTR32imm_hw :
-  NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, i32imm:$amt),
-            "shf.r.wrap.b32 \t$dst, $src, $src, $amt;",
-            [(set Int32Regs:$dst, (rotr (i32 Int32Regs:$src), (i32 imm:$amt)))]>,
-           Requires<[hasHWROT32]>;
-
-def ROTR32reg_hw :
-  NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, Int32Regs:$amt),
-            "shf.r.wrap.b32 \t$dst, $src, $src, $amt;",
-            [(set Int32Regs:$dst, (rotr (i32 Int32Regs:$src), (i32 Int32Regs:$amt)))]>,
-           Requires<[hasHWROT32]>;
-
-// 32-bit software rotate by immediate.  $amt2 should equal 32 - $amt1.
-def ROT32imm_sw :
-  NVPTXInst<(outs Int32Regs:$dst),
-            (ins Int32Regs:$src, i32imm:$amt1, i32imm:$amt2),
-            "{{\n\t"
-            ".reg .b32 %lhs;\n\t"
-            ".reg .b32 %rhs;\n\t"
-            "shl.b32 \t%lhs, $src, $amt1;\n\t"
-            "shr.b32 \t%rhs, $src, $amt2;\n\t"
-            "add.u32 \t$dst, %lhs, %rhs;\n\t"
-            "}}",
-            []>;
-
-def SUB_FRM_32 : SDNodeXForm<imm, [{
-  return CurDAG->getTargetConstant(32 - N->getZExtValue(), SDLoc(N), MVT::i32);
-}]>;
-
-def : Pat<(rotl (i32 Int32Regs:$src), (i32 imm:$amt)),
-          (ROT32imm_sw Int32Regs:$src, imm:$amt, (SUB_FRM_32 node:$amt))>,
-      Requires<[noHWROT32]>;
-def : Pat<(rotr (i32 Int32Regs:$src), (i32 imm:$amt)),
-          (ROT32imm_sw Int32Regs:$src, (SUB_FRM_32 node:$amt), imm:$amt)>,
-      Requires<[noHWROT32]>;
-
-// 32-bit software rotate left by register.
-def ROTL32reg_sw :
-  NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, Int32Regs:$amt),
-            "{{\n\t"
-            ".reg .b32 %lhs;\n\t"
-            ".reg .b32 %rhs;\n\t"
-            ".reg .b32 %amt2;\n\t"
-            "shl.b32 \t%lhs, $src, $amt;\n\t"
-            "sub.s32 \t%amt2, 32, $amt;\n\t"
-            "shr.b32 \t%rhs, $src, %amt2;\n\t"
-            "add.u32 \t$dst, %lhs, %rhs;\n\t"
-            "}}",
-            [(set Int32Regs:$dst, (rotl (i32 Int32Regs:$src), (i32 Int32Regs:$amt)))]>,
-           Requires<[noHWROT32]>;
-
-// 32-bit software rotate right by register.
-def ROTR32reg_sw :
-  NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, Int32Regs:$amt),
-            "{{\n\t"
-            ".reg .b32 %lhs;\n\t"
-            ".reg .b32 %rhs;\n\t"
-            ".reg .b32 %amt2;\n\t"
-            "shr.b32 \t%lhs, $src, $amt;\n\t"
-            "sub.s32 \t%amt2, 32, $amt;\n\t"
-            "shl.b32 \t%rhs, $src, %amt2;\n\t"
-            "add.u32 \t$dst, %lhs, %rhs;\n\t"
-            "}}",
-            [(set Int32Regs:$dst, (rotr (i32 Int32Regs:$src), (i32 Int32Regs:$amt)))]>,
-           Requires<[noHWROT32]>;
-
-// 64-bit software rotate by immediate.  $amt2 should equal 64 - $amt1.
-def ROT64imm_sw :
-  NVPTXInst<(outs Int64Regs:$dst),
-            (ins Int64Regs:$src, i32imm:$amt1, i32imm:$amt2),
-            "{{\n\t"
-            ".reg .b64 %lhs;\n\t"
-            ".reg .b64 %rhs;\n\t"
-            "shl.b64 \t%lhs, $src, $amt1;\n\t"
-            "shr.b64 \t%rhs, $src, $amt2;\n\t"
-            "add.u64 \t$dst, %lhs, %rhs;\n\t"
-            "}}",
-            []>;
-
-def SUB_FRM_64 : SDNodeXForm<imm, [{
-    return CurDAG->getTargetConstant(64-N->getZExtValue(), SDLoc(N), MVT::i32);
-}]>;
-
-def : Pat<(rotl Int64Regs:$src, (i32 imm:$amt)),
-          (ROT64imm_sw Int64Regs:$src, imm:$amt, (SUB_FRM_64 node:$amt))>;
-def : Pat<(rotr Int64Regs:$src, (i32 imm:$amt)),
-          (ROT64imm_sw Int64Regs:$src, (SUB_FRM_64 node:$amt), imm:$amt)>;
-
-// 64-bit software rotate left by register.
-def ROTL64reg_sw :
-  NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$src, Int32Regs:$amt),
-            "{{\n\t"
-            ".reg .b64 %lhs;\n\t"
-            ".reg .b64 %rhs;\n\t"
-            ".reg .u32 %amt2;\n\t"
-            "and.b32 \t%amt2, $amt, 63;\n\t"
-            "shl.b64 \t%lhs, $src, %amt2;\n\t"
-            "sub.u32 \t%amt2, 64, %a...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Sep 25, 2024

@llvm/pr-subscribers-backend-nvptx

Author: Alex MacLean (AlexMaclean)

Changes

This change deprecates the following intrinsics which can be trivially
converted to llvm funnel-shift intrinsics:

  • @llvm.nvvm.rotate.b32
  • @llvm.nvvm.rotate.right.b64
  • @llvm.nvvm.rotate.b64

This fixes a bug in the previous version which flipped the order of the operands to the PTX funnel shift instruction. In LLVM IR the high bits are the first arg and the low bits are the second arg, while in PTX this is reversed.


Patch is 52.58 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/110025.diff

10 Files Affected:

  • (modified) llvm/docs/ReleaseNotes.rst (+6)
  • (modified) llvm/include/llvm/IR/IntrinsicsNVVM.td (-16)
  • (modified) llvm/lib/IR/AutoUpgrade.cpp (+106-78)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+13-20)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.h (+2-2)
  • (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.td (+36-161)
  • (modified) llvm/lib/Target/NVPTX/NVPTXIntrinsics.td (+2-127)
  • (modified) llvm/test/Assembler/auto_upgrade_nvvm_intrinsics.ll (+17-1)
  • (modified) llvm/test/CodeGen/NVPTX/rotate.ll (+266-167)
  • (modified) llvm/test/CodeGen/NVPTX/rotate_64.ll (+23-10)
diff --git a/llvm/docs/ReleaseNotes.rst b/llvm/docs/ReleaseNotes.rst
index 05f5bd65fc5f6d..0784d93f18da8f 100644
--- a/llvm/docs/ReleaseNotes.rst
+++ b/llvm/docs/ReleaseNotes.rst
@@ -63,6 +63,12 @@ Changes to the LLVM IR
   * ``llvm.nvvm.bitcast.d2ll``
   * ``llvm.nvvm.bitcast.ll2d``
 
+* Remove the following intrinsics which can be replaced with a funnel-shift:
+
+  * ``llvm.nvvm.rotate.b32``
+  * ``llvm.nvvm.rotate.right.b64``
+  * ``llvm.nvvm.rotate.b64``
+
 Changes to LLVM infrastructure
 ------------------------------
 
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 737dd6092e2183..aa5294f5f9c909 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -4479,22 +4479,6 @@ def int_nvvm_sust_p_3d_v4i32_trap
               "llvm.nvvm.sust.p.3d.v4i32.trap">,
     ClangBuiltin<"__nvvm_sust_p_3d_v4i32_trap">;
 
-
-def int_nvvm_rotate_b32
-  : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty],
-              [IntrNoMem, IntrSpeculatable], "llvm.nvvm.rotate.b32">,
-              ClangBuiltin<"__nvvm_rotate_b32">;
-
-def int_nvvm_rotate_b64
-  : DefaultAttrsIntrinsic<[llvm_i64_ty], [llvm_i64_ty, llvm_i32_ty],
-             [IntrNoMem, IntrSpeculatable], "llvm.nvvm.rotate.b64">,
-             ClangBuiltin<"__nvvm_rotate_b64">;
-
-def int_nvvm_rotate_right_b64
-  : DefaultAttrsIntrinsic<[llvm_i64_ty], [llvm_i64_ty, llvm_i32_ty],
-              [IntrNoMem, IntrSpeculatable], "llvm.nvvm.rotate.right.b64">,
-              ClangBuiltin<"__nvvm_rotate_right_b64">;
-
 def int_nvvm_swap_lo_hi_b64
   : DefaultAttrsIntrinsic<[llvm_i64_ty], [llvm_i64_ty],
               [IntrNoMem, IntrSpeculatable], "llvm.nvvm.swap.lo.hi.b64">,
diff --git a/llvm/lib/IR/AutoUpgrade.cpp b/llvm/lib/IR/AutoUpgrade.cpp
index 02d1d9d9f78984..3390d651d6c693 100644
--- a/llvm/lib/IR/AutoUpgrade.cpp
+++ b/llvm/lib/IR/AutoUpgrade.cpp
@@ -1272,6 +1272,9 @@ static bool upgradeIntrinsicFunction1(Function *F, Function *&NewFn,
         // nvvm.bitcast.{f2i,i2f,ll2d,d2ll}
         Expand =
             Name == "f2i" || Name == "i2f" || Name == "ll2d" || Name == "d2ll";
+      else if (Name.consume_front("rotate."))
+        // nvvm.rotate.{b32,b64,right.b64}
+        Expand = Name == "b32" || Name == "b64" || Name == "right.b64";
       else
         Expand = false;
 
@@ -2258,6 +2261,108 @@ void llvm::UpgradeInlineAsmString(std::string *AsmStr) {
   }
 }
 
+static Value *upgradeNVVMIntrinsicCall(StringRef Name, CallBase *CI,
+                                       Function *F, IRBuilder<> &Builder) {
+  Value *Rep = nullptr;
+
+  if (Name == "abs.i" || Name == "abs.ll") {
+    Value *Arg = CI->getArgOperand(0);
+    Value *Neg = Builder.CreateNeg(Arg, "neg");
+    Value *Cmp = Builder.CreateICmpSGE(
+        Arg, llvm::Constant::getNullValue(Arg->getType()), "abs.cond");
+    Rep = Builder.CreateSelect(Cmp, Arg, Neg, "abs");
+  } else if (Name.starts_with("atomic.load.add.f32.p") ||
+             Name.starts_with("atomic.load.add.f64.p")) {
+    Value *Ptr = CI->getArgOperand(0);
+    Value *Val = CI->getArgOperand(1);
+    Rep = Builder.CreateAtomicRMW(AtomicRMWInst::FAdd, Ptr, Val, MaybeAlign(),
+                                  AtomicOrdering::SequentiallyConsistent);
+  } else if (Name.consume_front("max.") &&
+             (Name == "s" || Name == "i" || Name == "ll" || Name == "us" ||
+              Name == "ui" || Name == "ull")) {
+    Value *Arg0 = CI->getArgOperand(0);
+    Value *Arg1 = CI->getArgOperand(1);
+    Value *Cmp = Name.starts_with("u")
+                     ? Builder.CreateICmpUGE(Arg0, Arg1, "max.cond")
+                     : Builder.CreateICmpSGE(Arg0, Arg1, "max.cond");
+    Rep = Builder.CreateSelect(Cmp, Arg0, Arg1, "max");
+  } else if (Name.consume_front("min.") &&
+             (Name == "s" || Name == "i" || Name == "ll" || Name == "us" ||
+              Name == "ui" || Name == "ull")) {
+    Value *Arg0 = CI->getArgOperand(0);
+    Value *Arg1 = CI->getArgOperand(1);
+    Value *Cmp = Name.starts_with("u")
+                     ? Builder.CreateICmpULE(Arg0, Arg1, "min.cond")
+                     : Builder.CreateICmpSLE(Arg0, Arg1, "min.cond");
+    Rep = Builder.CreateSelect(Cmp, Arg0, Arg1, "min");
+  } else if (Name == "clz.ll") {
+    // llvm.nvvm.clz.ll returns an i32, but llvm.ctlz.i64 returns an i64.
+    Value *Arg = CI->getArgOperand(0);
+    Value *Ctlz = Builder.CreateCall(
+        Intrinsic::getDeclaration(F->getParent(), Intrinsic::ctlz,
+                                  {Arg->getType()}),
+        {Arg, Builder.getFalse()}, "ctlz");
+    Rep = Builder.CreateTrunc(Ctlz, Builder.getInt32Ty(), "ctlz.trunc");
+  } else if (Name == "popc.ll") {
+    // llvm.nvvm.popc.ll returns an i32, but llvm.ctpop.i64 returns an
+    // i64.
+    Value *Arg = CI->getArgOperand(0);
+    Value *Popc = Builder.CreateCall(
+        Intrinsic::getDeclaration(F->getParent(), Intrinsic::ctpop,
+                                  {Arg->getType()}),
+        Arg, "ctpop");
+    Rep = Builder.CreateTrunc(Popc, Builder.getInt32Ty(), "ctpop.trunc");
+  } else if (Name == "h2f") {
+    Rep = Builder.CreateCall(
+        Intrinsic::getDeclaration(F->getParent(), Intrinsic::convert_from_fp16,
+                                  {Builder.getFloatTy()}),
+        CI->getArgOperand(0), "h2f");
+  } else if (Name.consume_front("bitcast.") &&
+             (Name == "f2i" || Name == "i2f" || Name == "ll2d" ||
+              Name == "d2ll")) {
+    Rep = Builder.CreateBitCast(CI->getArgOperand(0), CI->getType());
+  } else if (Name == "rotate.b32") {
+    Value *Arg = CI->getOperand(0);
+    Value *ShiftAmt = CI->getOperand(1);
+    Rep = Builder.CreateIntrinsic(Builder.getInt32Ty(), Intrinsic::fshl,
+                                  {Arg, Arg, ShiftAmt});
+  } else if (Name == "rotate.b64") {
+    Type *Int64Ty = Builder.getInt64Ty();
+    Value *Arg = CI->getOperand(0);
+    Value *ZExtShiftAmt = Builder.CreateZExt(CI->getOperand(1), Int64Ty);
+    Rep = Builder.CreateIntrinsic(Int64Ty, Intrinsic::fshl,
+                                  {Arg, Arg, ZExtShiftAmt});
+  } else if (Name == "rotate.right.b64") {
+    Type *Int64Ty = Builder.getInt64Ty();
+    Value *Arg = CI->getOperand(0);
+    Value *ZExtShiftAmt = Builder.CreateZExt(CI->getOperand(1), Int64Ty);
+    Rep = Builder.CreateIntrinsic(Int64Ty, Intrinsic::fshr,
+                                  {Arg, Arg, ZExtShiftAmt});
+  } else {
+    Intrinsic::ID IID = shouldUpgradeNVPTXBF16Intrinsic(Name);
+    if (IID != Intrinsic::not_intrinsic &&
+        !F->getReturnType()->getScalarType()->isBFloatTy()) {
+      rename(F);
+      Function *NewFn = Intrinsic::getDeclaration(F->getParent(), IID);
+      SmallVector<Value *, 2> Args;
+      for (size_t I = 0; I < NewFn->arg_size(); ++I) {
+        Value *Arg = CI->getArgOperand(I);
+        Type *OldType = Arg->getType();
+        Type *NewType = NewFn->getArg(I)->getType();
+        Args.push_back(
+            (OldType->isIntegerTy() && NewType->getScalarType()->isBFloatTy())
+                ? Builder.CreateBitCast(Arg, NewType)
+                : Arg);
+      }
+      Rep = Builder.CreateCall(NewFn, Args);
+      if (F->getReturnType()->isIntegerTy())
+        Rep = Builder.CreateBitCast(Rep, F->getReturnType());
+    }
+  }
+
+  return Rep;
+}
+
 static Value *upgradeX86IntrinsicCall(StringRef Name, CallBase *CI, Function *F,
                                       IRBuilder<> &Builder) {
   LLVMContext &C = F->getContext();
@@ -4208,85 +4313,8 @@ void llvm::UpgradeIntrinsicCall(CallBase *CI, Function *NewFn) {
 
     if (!IsX86 && Name == "stackprotectorcheck") {
       Rep = nullptr;
-    } else if (IsNVVM && (Name == "abs.i" || Name == "abs.ll")) {
-      Value *Arg = CI->getArgOperand(0);
-      Value *Neg = Builder.CreateNeg(Arg, "neg");
-      Value *Cmp = Builder.CreateICmpSGE(
-          Arg, llvm::Constant::getNullValue(Arg->getType()), "abs.cond");
-      Rep = Builder.CreateSelect(Cmp, Arg, Neg, "abs");
-    } else if (IsNVVM && (Name.starts_with("atomic.load.add.f32.p") ||
-                          Name.starts_with("atomic.load.add.f64.p"))) {
-      Value *Ptr = CI->getArgOperand(0);
-      Value *Val = CI->getArgOperand(1);
-      Rep = Builder.CreateAtomicRMW(AtomicRMWInst::FAdd, Ptr, Val, MaybeAlign(),
-                                    AtomicOrdering::SequentiallyConsistent);
-    } else if (IsNVVM && Name.consume_front("max.") &&
-               (Name == "s" || Name == "i" || Name == "ll" || Name == "us" ||
-                Name == "ui" || Name == "ull")) {
-      Value *Arg0 = CI->getArgOperand(0);
-      Value *Arg1 = CI->getArgOperand(1);
-      Value *Cmp = Name.starts_with("u")
-                       ? Builder.CreateICmpUGE(Arg0, Arg1, "max.cond")
-                       : Builder.CreateICmpSGE(Arg0, Arg1, "max.cond");
-      Rep = Builder.CreateSelect(Cmp, Arg0, Arg1, "max");
-    } else if (IsNVVM && Name.consume_front("min.") &&
-               (Name == "s" || Name == "i" || Name == "ll" || Name == "us" ||
-                Name == "ui" || Name == "ull")) {
-      Value *Arg0 = CI->getArgOperand(0);
-      Value *Arg1 = CI->getArgOperand(1);
-      Value *Cmp = Name.starts_with("u")
-                       ? Builder.CreateICmpULE(Arg0, Arg1, "min.cond")
-                       : Builder.CreateICmpSLE(Arg0, Arg1, "min.cond");
-      Rep = Builder.CreateSelect(Cmp, Arg0, Arg1, "min");
-    } else if (IsNVVM && Name == "clz.ll") {
-      // llvm.nvvm.clz.ll returns an i32, but llvm.ctlz.i64 returns an i64.
-      Value *Arg = CI->getArgOperand(0);
-      Value *Ctlz = Builder.CreateCall(
-          Intrinsic::getDeclaration(F->getParent(), Intrinsic::ctlz,
-                                    {Arg->getType()}),
-          {Arg, Builder.getFalse()}, "ctlz");
-      Rep = Builder.CreateTrunc(Ctlz, Builder.getInt32Ty(), "ctlz.trunc");
-    } else if (IsNVVM && Name == "popc.ll") {
-      // llvm.nvvm.popc.ll returns an i32, but llvm.ctpop.i64 returns an
-      // i64.
-      Value *Arg = CI->getArgOperand(0);
-      Value *Popc = Builder.CreateCall(
-          Intrinsic::getDeclaration(F->getParent(), Intrinsic::ctpop,
-                                    {Arg->getType()}),
-          Arg, "ctpop");
-      Rep = Builder.CreateTrunc(Popc, Builder.getInt32Ty(), "ctpop.trunc");
     } else if (IsNVVM) {
-      if (Name == "h2f") {
-        Rep =
-            Builder.CreateCall(Intrinsic::getDeclaration(
-                                   F->getParent(), Intrinsic::convert_from_fp16,
-                                   {Builder.getFloatTy()}),
-                               CI->getArgOperand(0), "h2f");
-      } else if (Name.consume_front("bitcast.") &&
-                 (Name == "f2i" || Name == "i2f" || Name == "ll2d" ||
-                  Name == "d2ll")) {
-        Rep = Builder.CreateBitCast(CI->getArgOperand(0), CI->getType());
-      } else {
-        Intrinsic::ID IID = shouldUpgradeNVPTXBF16Intrinsic(Name);
-        if (IID != Intrinsic::not_intrinsic &&
-            !F->getReturnType()->getScalarType()->isBFloatTy()) {
-          rename(F);
-          NewFn = Intrinsic::getDeclaration(F->getParent(), IID);
-          SmallVector<Value *, 2> Args;
-          for (size_t I = 0; I < NewFn->arg_size(); ++I) {
-            Value *Arg = CI->getArgOperand(I);
-            Type *OldType = Arg->getType();
-            Type *NewType = NewFn->getArg(I)->getType();
-            Args.push_back((OldType->isIntegerTy() &&
-                            NewType->getScalarType()->isBFloatTy())
-                               ? Builder.CreateBitCast(Arg, NewType)
-                               : Arg);
-          }
-          Rep = Builder.CreateCall(NewFn, Args);
-          if (F->getReturnType()->isIntegerTy())
-            Rep = Builder.CreateBitCast(Rep, F->getReturnType());
-        }
-      }
+      Rep = upgradeNVVMIntrinsicCall(Name, CI, F, Builder);
     } else if (IsX86) {
       Rep = upgradeX86IntrinsicCall(Name, CI, F, Builder);
     } else if (IsARM) {
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 26888342210918..8718b7890bf58a 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -594,20 +594,13 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
   setOperationAction(ISD::BITREVERSE, MVT::i32, Legal);
   setOperationAction(ISD::BITREVERSE, MVT::i64, Legal);
 
-  // TODO: we may consider expanding ROTL/ROTR on older GPUs.  Currently on GPUs
-  // that don't have h/w rotation we lower them to multi-instruction assembly.
-  // See ROT*_sw in NVPTXIntrInfo.td
-  setOperationAction(ISD::ROTL, MVT::i64, Legal);
-  setOperationAction(ISD::ROTR, MVT::i64, Legal);
-  setOperationAction(ISD::ROTL, MVT::i32, Legal);
-  setOperationAction(ISD::ROTR, MVT::i32, Legal);
-
-  setOperationAction(ISD::ROTL, MVT::i16, Expand);
-  setOperationAction(ISD::ROTL, MVT::v2i16, Expand);
-  setOperationAction(ISD::ROTR, MVT::i16, Expand);
-  setOperationAction(ISD::ROTR, MVT::v2i16, Expand);
-  setOperationAction(ISD::ROTL, MVT::i8, Expand);
-  setOperationAction(ISD::ROTR, MVT::i8, Expand);
+  setOperationAction({ISD::ROTL, ISD::ROTR},
+                     {MVT::i8, MVT::i16, MVT::v2i16, MVT::i32, MVT::i64},
+                     Expand);
+
+  if (STI.hasHWROT32())
+    setOperationAction({ISD::FSHL, ISD::FSHR}, MVT::i32, Legal);
+
   setOperationAction(ISD::BSWAP, MVT::i16, Expand);
 
   setOperationAction(ISD::BR_JT, MVT::Other, Custom);
@@ -958,8 +951,8 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
     MAKE_CASE(NVPTXISD::LDUV4)
     MAKE_CASE(NVPTXISD::StoreV2)
     MAKE_CASE(NVPTXISD::StoreV4)
-    MAKE_CASE(NVPTXISD::FUN_SHFL_CLAMP)
-    MAKE_CASE(NVPTXISD::FUN_SHFR_CLAMP)
+    MAKE_CASE(NVPTXISD::FSHL_CLAMP)
+    MAKE_CASE(NVPTXISD::FSHR_CLAMP)
     MAKE_CASE(NVPTXISD::IMAD)
     MAKE_CASE(NVPTXISD::BFE)
     MAKE_CASE(NVPTXISD::BFI)
@@ -2490,8 +2483,8 @@ SDValue NVPTXTargetLowering::LowerShiftRightParts(SDValue Op,
     //   dLo = shf.r.clamp aLo, aHi, Amt
 
     SDValue Hi = DAG.getNode(Opc, dl, VT, ShOpHi, ShAmt);
-    SDValue Lo = DAG.getNode(NVPTXISD::FUN_SHFR_CLAMP, dl, VT, ShOpLo, ShOpHi,
-                             ShAmt);
+    SDValue Lo =
+        DAG.getNode(NVPTXISD::FSHR_CLAMP, dl, VT, ShOpHi, ShOpLo, ShAmt);
 
     SDValue Ops[2] = { Lo, Hi };
     return DAG.getMergeValues(Ops, dl);
@@ -2549,8 +2542,8 @@ SDValue NVPTXTargetLowering::LowerShiftLeftParts(SDValue Op,
     //   dHi = shf.l.clamp aLo, aHi, Amt
     //   dLo = aLo << Amt
 
-    SDValue Hi = DAG.getNode(NVPTXISD::FUN_SHFL_CLAMP, dl, VT, ShOpLo, ShOpHi,
-                             ShAmt);
+    SDValue Hi =
+        DAG.getNode(NVPTXISD::FSHL_CLAMP, dl, VT, ShOpHi, ShOpLo, ShAmt);
     SDValue Lo = DAG.getNode(ISD::SHL, dl, VT, ShOpLo, ShAmt);
 
     SDValue Ops[2] = { Lo, Hi };
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index 70e16eee346aa2..8c3a597ce0b085 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -51,8 +51,8 @@ enum NodeType : unsigned {
   CallSeqEnd,
   CallPrototype,
   ProxyReg,
-  FUN_SHFL_CLAMP,
-  FUN_SHFR_CLAMP,
+  FSHL_CLAMP,
+  FSHR_CLAMP,
   MUL_WIDE_SIGNED,
   MUL_WIDE_UNSIGNED,
   IMAD,
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 510e4b81003119..0d9dd1b8ee70ac 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -1665,167 +1665,6 @@ def BREV64 :
              "brev.b64 \t$dst, $a;",
              [(set Int64Regs:$dst, (bitreverse Int64Regs:$a))]>;
 
-//
-// Rotate: Use ptx shf instruction if available.
-//
-
-// 32 bit r2 = rotl r1, n
-//    =>
-//        r2 = shf.l r1, r1, n
-def ROTL32imm_hw :
-  NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, i32imm:$amt),
-            "shf.l.wrap.b32 \t$dst, $src, $src, $amt;",
-            [(set Int32Regs:$dst, (rotl (i32 Int32Regs:$src), (i32 imm:$amt)))]>,
-           Requires<[hasHWROT32]>;
-
-def ROTL32reg_hw :
-  NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, Int32Regs:$amt),
-            "shf.l.wrap.b32 \t$dst, $src, $src, $amt;",
-            [(set Int32Regs:$dst, (rotl (i32 Int32Regs:$src), (i32 Int32Regs:$amt)))]>,
-           Requires<[hasHWROT32]>;
-
-// 32 bit r2 = rotr r1, n
-//    =>
-//        r2 = shf.r r1, r1, n
-def ROTR32imm_hw :
-  NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, i32imm:$amt),
-            "shf.r.wrap.b32 \t$dst, $src, $src, $amt;",
-            [(set Int32Regs:$dst, (rotr (i32 Int32Regs:$src), (i32 imm:$amt)))]>,
-           Requires<[hasHWROT32]>;
-
-def ROTR32reg_hw :
-  NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, Int32Regs:$amt),
-            "shf.r.wrap.b32 \t$dst, $src, $src, $amt;",
-            [(set Int32Regs:$dst, (rotr (i32 Int32Regs:$src), (i32 Int32Regs:$amt)))]>,
-           Requires<[hasHWROT32]>;
-
-// 32-bit software rotate by immediate.  $amt2 should equal 32 - $amt1.
-def ROT32imm_sw :
-  NVPTXInst<(outs Int32Regs:$dst),
-            (ins Int32Regs:$src, i32imm:$amt1, i32imm:$amt2),
-            "{{\n\t"
-            ".reg .b32 %lhs;\n\t"
-            ".reg .b32 %rhs;\n\t"
-            "shl.b32 \t%lhs, $src, $amt1;\n\t"
-            "shr.b32 \t%rhs, $src, $amt2;\n\t"
-            "add.u32 \t$dst, %lhs, %rhs;\n\t"
-            "}}",
-            []>;
-
-def SUB_FRM_32 : SDNodeXForm<imm, [{
-  return CurDAG->getTargetConstant(32 - N->getZExtValue(), SDLoc(N), MVT::i32);
-}]>;
-
-def : Pat<(rotl (i32 Int32Regs:$src), (i32 imm:$amt)),
-          (ROT32imm_sw Int32Regs:$src, imm:$amt, (SUB_FRM_32 node:$amt))>,
-      Requires<[noHWROT32]>;
-def : Pat<(rotr (i32 Int32Regs:$src), (i32 imm:$amt)),
-          (ROT32imm_sw Int32Regs:$src, (SUB_FRM_32 node:$amt), imm:$amt)>,
-      Requires<[noHWROT32]>;
-
-// 32-bit software rotate left by register.
-def ROTL32reg_sw :
-  NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, Int32Regs:$amt),
-            "{{\n\t"
-            ".reg .b32 %lhs;\n\t"
-            ".reg .b32 %rhs;\n\t"
-            ".reg .b32 %amt2;\n\t"
-            "shl.b32 \t%lhs, $src, $amt;\n\t"
-            "sub.s32 \t%amt2, 32, $amt;\n\t"
-            "shr.b32 \t%rhs, $src, %amt2;\n\t"
-            "add.u32 \t$dst, %lhs, %rhs;\n\t"
-            "}}",
-            [(set Int32Regs:$dst, (rotl (i32 Int32Regs:$src), (i32 Int32Regs:$amt)))]>,
-           Requires<[noHWROT32]>;
-
-// 32-bit software rotate right by register.
-def ROTR32reg_sw :
-  NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, Int32Regs:$amt),
-            "{{\n\t"
-            ".reg .b32 %lhs;\n\t"
-            ".reg .b32 %rhs;\n\t"
-            ".reg .b32 %amt2;\n\t"
-            "shr.b32 \t%lhs, $src, $amt;\n\t"
-            "sub.s32 \t%amt2, 32, $amt;\n\t"
-            "shl.b32 \t%rhs, $src, %amt2;\n\t"
-            "add.u32 \t$dst, %lhs, %rhs;\n\t"
-            "}}",
-            [(set Int32Regs:$dst, (rotr (i32 Int32Regs:$src), (i32 Int32Regs:$amt)))]>,
-           Requires<[noHWROT32]>;
-
-// 64-bit software rotate by immediate.  $amt2 should equal 64 - $amt1.
-def ROT64imm_sw :
-  NVPTXInst<(outs Int64Regs:$dst),
-            (ins Int64Regs:$src, i32imm:$amt1, i32imm:$amt2),
-            "{{\n\t"
-            ".reg .b64 %lhs;\n\t"
-            ".reg .b64 %rhs;\n\t"
-            "shl.b64 \t%lhs, $src, $amt1;\n\t"
-            "shr.b64 \t%rhs, $src, $amt2;\n\t"
-            "add.u64 \t$dst, %lhs, %rhs;\n\t"
-            "}}",
-            []>;
-
-def SUB_FRM_64 : SDNodeXForm<imm, [{
-    return CurDAG->getTargetConstant(64-N->getZExtValue(), SDLoc(N), MVT::i32);
-}]>;
-
-def : Pat<(rotl Int64Regs:$src, (i32 imm:$amt)),
-          (ROT64imm_sw Int64Regs:$src, imm:$amt, (SUB_FRM_64 node:$amt))>;
-def : Pat<(rotr Int64Regs:$src, (i32 imm:$amt)),
-          (ROT64imm_sw Int64Regs:$src, (SUB_FRM_64 node:$amt), imm:$amt)>;
-
-// 64-bit software rotate left by register.
-def ROTL64reg_sw :
-  NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$src, Int32Regs:$amt),
-            "{{\n\t"
-            ".reg .b64 %lhs;\n\t"
-            ".reg .b64 %rhs;\n\t"
-            ".reg .u32 %amt2;\n\t"
-            "and.b32 \t%amt2, $amt, 63;\n\t"
-            "shl.b64 \t%lhs, $src, %amt2;\n\t"
-            "sub.u32 \t%amt2, 64, %a...
[truncated]

@Artem-B
Copy link
Member

Artem-B commented Sep 25, 2024

In LLVM IR the high bits are the first arg and the low bits are the second arg, while in PTX this is reversed.

Oh. OH. Ouch!

Indeed: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-shf

shf.l.mode.b32 d, a, b, c; // left shift
...
Operand b holds bits 63:32 and operand a holds bits 31:0 of the 64-bit source value.

@AlexMaclean AlexMaclean force-pushed the dev/amaclean/reland-rotate branch from 327b29b to 9ebcfbf Compare September 25, 2024 18:51
@Artem-B
Copy link
Member

Artem-B commented Sep 25, 2024

Would it be possible for you to cherry-pick the original commits into the pull request branch, and then add a commit with the fixes, so it's easy to see what you had to change to fix the issue?

@AlexMaclean
Copy link
Member Author

Would it be possible for you to cherry-pick the original commits into the pull request branch, and then add a commit with the fixes, so it's easy to see what you had to change to fix the issue?

The current version has the a cherry-pick of the original change on main that came from landing the previous MR, followed by a commit with the fixes. Does that work? Or do you want me to cherry-pick each commit from the previous branch for the first MR?

@Artem-B
Copy link
Member

Artem-B commented Sep 25, 2024

Never mind. PR has exactly what I wanted. I was looking at the github's 'Compare' link on your last update, which only touched the tests. Commit 9ebcfbf is exactly what I wanted to see. Sorry about the noise.

@AlexMaclean
Copy link
Member Author

@steelannelida Could you please confirm this version does not produce the numerical discrepancies observed previously? Thanks!

…ling (llvm#107655)

This change deprecates the following intrinsics which can be trivially
converted to llvm funnel-shift intrinsics:

- @llvm.nvvm.rotate.b32
- @llvm.nvvm.rotate.right.b64
- @llvm.nvvm.rotate.b64
@AlexMaclean AlexMaclean force-pushed the dev/amaclean/reland-rotate branch from 9ebcfbf to d359c8f Compare September 26, 2024 17:46
@steelannelida
Copy link

@steelannelida Could you please confirm this version does not produce the numerical discrepancies observed previously? Thanks!

Hi! Thank you for asking. I tried reproducing and I see no failures with this version.

@AlexMaclean AlexMaclean merged commit a131fbf into llvm:main Sep 27, 2024
6 of 8 checks passed
Sterling-Augustine pushed a commit to Sterling-Augustine/llvm-project that referenced this pull request Sep 27, 2024
…ift handling" (llvm#110025)

This change deprecates the following intrinsics which can be trivially
converted to llvm funnel-shift intrinsics:

- @llvm.nvvm.rotate.b32
- @llvm.nvvm.rotate.right.b64
- @llvm.nvvm.rotate.b64

This fixes a bug in the previous version (llvm#107655) which flipped the
order of the operands to the PTX funnel shift instruction. In LLVM IR
the high bits are the first arg and the low bits are the second arg,
while in PTX this is reversed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants