Skip to content

[DAG] Replace getValid*ShiftAmountConstant helpers with getValid*ShiftAmount helpers to support KnownBits analysis #93182

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 1, 2024

Conversation

RKSimon
Copy link
Collaborator

@RKSimon RKSimon commented May 23, 2024

The getValidShiftAmountConstant/getValidMinimumShiftAmountConstant/getValidMaximumShiftAmountConstant helpers only worked with constant shift amounts, which could be problematic after type legalization (e.g. v2i64 might be split into v4i32 on some targets such as 32-bit x86, Thumb2 MVE).

This patch proposes we generalize these helpers to work with KnownBits if a scalar/buildvector constant isn't available.

Most restrictions are the same - the helper fails if any shift amount is out of bounds, getValidShiftConstant must be a specific constant uniform etc.

However, getValidMinimumShiftAmount/getValidMaximumShiftAmount now can return bounds values that aren't values in the actual data, as they are based off the common KnownBits of every vector element.

This addresses feedback on #92096

@llvmbot
Copy link
Member

llvmbot commented May 23, 2024

@llvm/pr-subscribers-backend-x86

@llvm/pr-subscribers-backend-powerpc

Author: Simon Pilgrim (RKSimon)

Changes

The getValidShiftAmountConstant/getValidMinimumShiftAmountConstant/getValidMaximumShiftAmountConstant helpers only worked with constant shift amounts, which could be problematic after type legalization (e.g. v2i64 might be split into v4i32 on some targets such as 32-bit x86, Thumb2 MVE).

This patch proposes we generalize these helpers to work with KnownBits if a scalar/buildvector constant isn't available.

Most restrictions are the same - the helper fails if any shift amount is out of bounds, getValidShiftConstant must be a specific constant uniform etc.

However, getValidMinimumShiftAmount/getValidMaximumShiftAmount now can return bounds values that aren't values in the actual data, as they are based off the common KnownBits of every vector element.

This addresses feedback on #92096


Full diff: https://github.com/llvm/llvm-project/pull/93182.diff

5 Files Affected:

  • (modified) llvm/include/llvm/CodeGen/SelectionDAG.h (+29-27)
  • (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp (+89-69)
  • (modified) llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp (+13-13)
  • (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+1-1)
  • (modified) llvm/test/CodeGen/PowerPC/pr44183.ll (+3-4)
diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index 96a6270690468..95afbeb5dd6ec 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -2159,36 +2159,38 @@ class SelectionDAG {
   /// splatted value it will return SDValue().
   SDValue getSplatValue(SDValue V, bool LegalTypes = false);
 
-  /// If a SHL/SRA/SRL node \p V has a constant or splat constant shift amount
+  /// If a SHL/SRA/SRL node \p V has an uniform shift amount
   /// that is less than the element bit-width of the shift node, return it.
-  const APInt *getValidShiftAmountConstant(SDValue V,
-                                           const APInt &DemandedElts) const;
+  std::optional<uint64_t> getValidShiftAmount(SDValue V,
+                                              const APInt &DemandedElts,
+                                              unsigned Depth = 0) const;
 
-  /// If a SHL/SRA/SRL node \p V has a constant or splat constant shift amount
+  /// If a SHL/SRA/SRL node \p V has an uniform shift amount
   /// that is less than the element bit-width of the shift node, return it.
-  const APInt *getValidShiftAmountConstant(SDValue V) const;
-
-  /// If a SHL/SRA/SRL node \p V has constant shift amounts that are all less
-  /// than the element bit-width of the shift node, return the minimum value.
-  const APInt *
-  getValidMinimumShiftAmountConstant(SDValue V,
-                                     const APInt &DemandedElts) const;
-
-  /// If a SHL/SRA/SRL node \p V has constant shift amounts that are all less
-  /// than the element bit-width of the shift node, return the minimum value.
-  const APInt *
-  getValidMinimumShiftAmountConstant(SDValue V) const;
-
-  /// If a SHL/SRA/SRL node \p V has constant shift amounts that are all less
-  /// than the element bit-width of the shift node, return the maximum value.
-  const APInt *
-  getValidMaximumShiftAmountConstant(SDValue V,
-                                     const APInt &DemandedElts) const;
-
-  /// If a SHL/SRA/SRL node \p V has constant shift amounts that are all less
-  /// than the element bit-width of the shift node, return the maximum value.
-  const APInt *
-  getValidMaximumShiftAmountConstant(SDValue V) const;
+  std::optional<uint64_t> getValidShiftAmount(SDValue V,
+                                              unsigned Depth = 0) const;
+
+  /// If a SHL/SRA/SRL node \p V has shift amounts that are all less than the
+  /// element bit-width of the shift node, return the minimum possible value.
+  std::optional<uint64_t> getValidMinimumShiftAmount(SDValue V,
+                                                     const APInt &DemandedElts,
+                                                     unsigned Depth = 0) const;
+
+  /// If a SHL/SRA/SRL node \p V has shift amounts that are all less than the
+  /// element bit-width of the shift node, return the minimum possible value.
+  std::optional<uint64_t> getValidMinimumShiftAmount(SDValue V,
+                                                     unsigned Depth = 0) const;
+
+  /// If a SHL/SRA/SRL node \p V has shift amounts that are all less than the
+  /// element bit-width of the shift node, return the maximum possible value.
+  std::optional<uint64_t> getValidMaximumShiftAmount(SDValue V,
+                                                     const APInt &DemandedElts,
+                                                     unsigned Depth = 0) const;
+
+  /// If a SHL/SRA/SRL node \p V has shift amounts that are all less than the
+  /// element bit-width of the shift node, return the maximum possible value.
+  std::optional<uint64_t> getValidMaximumShiftAmount(SDValue V,
+                                                     unsigned Depth = 0) const;
 
   /// Match a binop + shuffle pyramid that represents a horizontal reduction
   /// over the elements of a vector starting from the EXTRACT_VECTOR_ELT node /p
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index b05649c6ce955..b71b496c0aa84 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -3009,9 +3009,9 @@ SDValue SelectionDAG::getSplatValue(SDValue V, bool LegalTypes) {
   return SDValue();
 }
 
-const APInt *
-SelectionDAG::getValidShiftAmountConstant(SDValue V,
-                                          const APInt &DemandedElts) const {
+std::optional<uint64_t>
+SelectionDAG::getValidShiftAmount(SDValue V, const APInt &DemandedElts,
+                                  unsigned Depth) const {
   assert((V.getOpcode() == ISD::SHL || V.getOpcode() == ISD::SRL ||
           V.getOpcode() == ISD::SRA) &&
          "Unknown shift node");
@@ -3020,91 +3020,111 @@ SelectionDAG::getValidShiftAmountConstant(SDValue V,
     // Shifting more than the bitwidth is not valid.
     const APInt &ShAmt = SA->getAPIntValue();
     if (ShAmt.ult(BitWidth))
-      return &ShAmt;
+      return ShAmt.getZExtValue();
+  } else {
+    KnownBits KnownAmt =
+        computeKnownBits(V.getOperand(1), DemandedElts, Depth + 1);
+    if (KnownAmt.isConstant() && KnownAmt.getConstant().ult(BitWidth))
+      return KnownAmt.getConstant().getZExtValue();
   }
-  return nullptr;
+  return std::nullopt;
 }
 
-const APInt *SelectionDAG::getValidShiftAmountConstant(SDValue V) const {
+std::optional<uint64_t>
+SelectionDAG::getValidShiftAmount(SDValue V, unsigned Depth) const {
   EVT VT = V.getValueType();
   APInt DemandedElts = VT.isFixedLengthVector()
                            ? APInt::getAllOnes(VT.getVectorNumElements())
                            : APInt(1, 1);
-  return getValidShiftAmountConstant(V, DemandedElts);
+  return getValidShiftAmount(V, DemandedElts, Depth);
 }
 
-const APInt *SelectionDAG::getValidMinimumShiftAmountConstant(
-    SDValue V, const APInt &DemandedElts) const {
+std::optional<uint64_t>
+SelectionDAG::getValidMinimumShiftAmount(SDValue V, const APInt &DemandedElts,
+                                         unsigned Depth) const {
   assert((V.getOpcode() == ISD::SHL || V.getOpcode() == ISD::SRL ||
           V.getOpcode() == ISD::SRA) &&
          "Unknown shift node");
-  if (const APInt *ValidAmt = getValidShiftAmountConstant(V, DemandedElts))
-    return ValidAmt;
   unsigned BitWidth = V.getScalarValueSizeInBits();
-  auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand(1));
-  if (!BV)
-    return nullptr;
-  const APInt *MinShAmt = nullptr;
-  for (unsigned i = 0, e = BV->getNumOperands(); i != e; ++i) {
-    if (!DemandedElts[i])
-      continue;
-    auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand(i));
-    if (!SA)
-      return nullptr;
-    // Shifting more than the bitwidth is not valid.
-    const APInt &ShAmt = SA->getAPIntValue();
-    if (ShAmt.uge(BitWidth))
-      return nullptr;
-    if (MinShAmt && MinShAmt->ule(ShAmt))
-      continue;
-    MinShAmt = &ShAmt;
+  if (auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand(1))) {
+    const APInt *MinShAmt = nullptr;
+    for (unsigned i = 0, e = BV->getNumOperands(); i != e; ++i) {
+      if (!DemandedElts[i])
+        continue;
+      auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand(i));
+      if (!SA) {
+        MinShAmt = nullptr;
+        break;
+      }
+      // Shifting more than the bitwidth is not valid.
+      const APInt &ShAmt = SA->getAPIntValue();
+      if (ShAmt.uge(BitWidth))
+        return std::nullopt;
+      if (MinShAmt && MinShAmt->ule(ShAmt))
+        continue;
+      MinShAmt = &ShAmt;
+    }
+    if (MinShAmt)
+      return MinShAmt->getZExtValue();
   }
-  return MinShAmt;
+  KnownBits KnownAmt =
+      computeKnownBits(V.getOperand(1), DemandedElts, Depth + 1);
+  if (KnownAmt.getMaxValue().ult(BitWidth))
+    return KnownAmt.getMinValue().getZExtValue();
+  return std::nullopt;
 }
 
-const APInt *SelectionDAG::getValidMinimumShiftAmountConstant(SDValue V) const {
+std::optional<uint64_t>
+SelectionDAG::getValidMinimumShiftAmount(SDValue V, unsigned Depth) const {
   EVT VT = V.getValueType();
   APInt DemandedElts = VT.isFixedLengthVector()
                            ? APInt::getAllOnes(VT.getVectorNumElements())
                            : APInt(1, 1);
-  return getValidMinimumShiftAmountConstant(V, DemandedElts);
+  return getValidMinimumShiftAmount(V, DemandedElts, Depth);
 }
 
-const APInt *SelectionDAG::getValidMaximumShiftAmountConstant(
-    SDValue V, const APInt &DemandedElts) const {
+std::optional<uint64_t>
+SelectionDAG::getValidMaximumShiftAmount(SDValue V, const APInt &DemandedElts,
+                                         unsigned Depth) const {
   assert((V.getOpcode() == ISD::SHL || V.getOpcode() == ISD::SRL ||
           V.getOpcode() == ISD::SRA) &&
          "Unknown shift node");
-  if (const APInt *ValidAmt = getValidShiftAmountConstant(V, DemandedElts))
-    return ValidAmt;
   unsigned BitWidth = V.getScalarValueSizeInBits();
-  auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand(1));
-  if (!BV)
-    return nullptr;
-  const APInt *MaxShAmt = nullptr;
-  for (unsigned i = 0, e = BV->getNumOperands(); i != e; ++i) {
-    if (!DemandedElts[i])
-      continue;
-    auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand(i));
-    if (!SA)
-      return nullptr;
-    // Shifting more than the bitwidth is not valid.
-    const APInt &ShAmt = SA->getAPIntValue();
-    if (ShAmt.uge(BitWidth))
-      return nullptr;
-    if (MaxShAmt && MaxShAmt->uge(ShAmt))
-      continue;
-    MaxShAmt = &ShAmt;
+  if (auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand(1))) {
+    const APInt *MaxShAmt = nullptr;
+    for (unsigned i = 0, e = BV->getNumOperands(); i != e; ++i) {
+      if (!DemandedElts[i])
+        continue;
+      auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand(i));
+      if (!SA) {
+        MaxShAmt = nullptr;
+        break;
+      }
+      // Shifting more than the bitwidth is not valid.
+      const APInt &ShAmt = SA->getAPIntValue();
+      if (ShAmt.uge(BitWidth))
+        return std::nullopt;
+      if (MaxShAmt && MaxShAmt->uge(ShAmt))
+        continue;
+      MaxShAmt = &ShAmt;
+    }
+    if (MaxShAmt)
+      return MaxShAmt->getZExtValue();
   }
-  return MaxShAmt;
+  KnownBits KnownAmt =
+      computeKnownBits(V.getOperand(1), DemandedElts, Depth + 1);
+  if (KnownAmt.getMaxValue().ult(BitWidth))
+    return KnownAmt.getMaxValue().getZExtValue();
+  return std::nullopt;
 }
 
-const APInt *SelectionDAG::getValidMaximumShiftAmountConstant(SDValue V) const {
+std::optional<uint64_t>
+SelectionDAG::getValidMaximumShiftAmount(SDValue V, unsigned Depth) const {
   EVT VT = V.getValueType();
   APInt DemandedElts = VT.isFixedLengthVector()
                            ? APInt::getAllOnes(VT.getVectorNumElements())
                            : APInt(1, 1);
-  return getValidMaximumShiftAmountConstant(V, DemandedElts);
+  return getValidMaximumShiftAmount(V, DemandedElts, Depth);
 }
 
 /// Determine which bits of Op are known to be either zero or one and return
@@ -3569,9 +3589,9 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
     Known = KnownBits::shl(Known, Known2, NUW, NSW, ShAmtNonZero);
 
     // Minimum shift low bits are known zero.
-    if (const APInt *ShMinAmt =
-            getValidMinimumShiftAmountConstant(Op, DemandedElts))
-      Known.Zero.setLowBits(ShMinAmt->getZExtValue());
+    if (std::optional<uint64_t> ShMinAmt =
+            getValidMinimumShiftAmount(Op, DemandedElts, Depth))
+      Known.Zero.setLowBits(*ShMinAmt);
     break;
   }
   case ISD::SRL:
@@ -3581,9 +3601,9 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
                             Op->getFlags().hasExact());
 
     // Minimum shift high bits are known zero.
-    if (const APInt *ShMinAmt =
-            getValidMinimumShiftAmountConstant(Op, DemandedElts))
-      Known.Zero.setHighBits(ShMinAmt->getZExtValue());
+    if (std::optional<uint64_t> ShMinAmt =
+            getValidMinimumShiftAmount(Op, DemandedElts, Depth))
+      Known.Zero.setHighBits(*ShMinAmt);
     break;
   case ISD::SRA:
     Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
@@ -4587,17 +4607,17 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts,
   case ISD::SRA:
     Tmp = ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth + 1);
     // SRA X, C -> adds C sign bits.
-    if (const APInt *ShAmt =
-            getValidMinimumShiftAmountConstant(Op, DemandedElts))
-      Tmp = std::min<uint64_t>(Tmp + ShAmt->getZExtValue(), VTBits);
+    if (std::optional<uint64_t> ShAmt =
+            getValidMinimumShiftAmount(Op, DemandedElts, Depth))
+      Tmp = std::min<uint64_t>(Tmp + *ShAmt, VTBits);
     return Tmp;
   case ISD::SHL:
-    if (const APInt *ShAmt =
-            getValidMaximumShiftAmountConstant(Op, DemandedElts)) {
+    if (std::optional<uint64_t> ShAmt =
+            getValidMaximumShiftAmount(Op, DemandedElts, Depth)) {
       // shl destroys sign bits, ensure it doesn't shift out all sign bits.
       Tmp = ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth + 1);
-      if (ShAmt->ult(Tmp))
-        return Tmp - ShAmt->getZExtValue();
+      if (*ShAmt < Tmp)
+        return Tmp - *ShAmt;
     }
     break;
   case ISD::AND:
@@ -5270,7 +5290,7 @@ bool SelectionDAG::canCreateUndefOrPoison(SDValue Op, const APInt &DemandedElts,
   case ISD::SRL:
   case ISD::SRA:
     // If the max shift amount isn't in range, then the shift can create poison.
-    return !getValidMaximumShiftAmountConstant(Op, DemandedElts);
+    return !getValidMaximumShiftAmount(Op, DemandedElts, Depth);
 
   case ISD::SCALAR_TO_VECTOR:
     // Check if we demand any upper (undef) elements.
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 37c72339fe295..dfcd5439b8a9d 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -796,10 +796,10 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
   case ISD::SHL: {
     // If we are only demanding sign bits then we can use the shift source
     // directly.
-    if (const APInt *MaxSA =
-            DAG.getValidMaximumShiftAmountConstant(Op, DemandedElts)) {
+    if (std::optional<uint64_t> MaxSA =
+            DAG.getValidMaximumShiftAmount(Op, DemandedElts, Depth)) {
       SDValue Op0 = Op.getOperand(0);
-      unsigned ShAmt = MaxSA->getZExtValue();
+      unsigned ShAmt = *MaxSA;
       unsigned NumSignBits =
           DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1);
       unsigned UpperDemandedBits = BitWidth - DemandedBits.countr_zero();
@@ -1789,9 +1789,9 @@ bool TargetLowering::SimplifyDemandedBits(
         // TODO - support non-uniform vector amounts.
         if (InnerOp.getOpcode() == ISD::SRL && Op0.hasOneUse() &&
             InnerOp.hasOneUse()) {
-          if (const APInt *SA2 =
-                  TLO.DAG.getValidShiftAmountConstant(InnerOp, DemandedElts)) {
-            unsigned InnerShAmt = SA2->getZExtValue();
+          if (std::optional<uint64_t> SA2 = TLO.DAG.getValidShiftAmount(
+                  InnerOp, DemandedElts, Depth + 1)) {
+            unsigned InnerShAmt = *SA2;
             if (InnerShAmt < ShAmt && InnerShAmt < InnerBits &&
                 DemandedBits.getActiveBits() <=
                     (InnerBits - InnerShAmt + ShAmt) &&
@@ -1918,9 +1918,9 @@ bool TargetLowering::SimplifyDemandedBits(
 
     // If we are only demanding sign bits then we can use the shift source
     // directly.
-    if (const APInt *MaxSA =
-            TLO.DAG.getValidMaximumShiftAmountConstant(Op, DemandedElts)) {
-      unsigned ShAmt = MaxSA->getZExtValue();
+    if (std::optional<uint64_t> MaxSA =
+            TLO.DAG.getValidMaximumShiftAmount(Op, DemandedElts, Depth)) {
+      unsigned ShAmt = *MaxSA;
       unsigned NumSignBits =
           TLO.DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1);
       unsigned UpperDemandedBits = BitWidth - DemandedBits.countr_zero();
@@ -2598,11 +2598,11 @@ bool TargetLowering::SimplifyDemandedBits(
         break;
 
       if (Src.getNode()->hasOneUse()) {
-        const APInt *ShAmtC =
-            TLO.DAG.getValidShiftAmountConstant(Src, DemandedElts);
-        if (!ShAmtC || ShAmtC->uge(BitWidth))
+        std::optional<uint64_t> ShAmtC =
+            TLO.DAG.getValidShiftAmount(Src, DemandedElts, Depth + 1);
+        if (!ShAmtC || *ShAmtC >= BitWidth)
           break;
-        uint64_t ShVal = ShAmtC->getZExtValue();
+        uint64_t ShVal = *ShAmtC;
 
         APInt HighBits =
             APInt::getHighBitsSet(OperandBitWidth, OperandBitWidth - BitWidth);
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 215cbc308e43d..fd99f0e345d14 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -20490,7 +20490,7 @@ static SDValue matchTruncateWithPACK(unsigned &PackOpcode, EVT DstVT,
   // the truncation then we can use PACKSS by converting the srl to a sra.
   // SimplifyDemandedBits often relaxes sra to srl so we need to reverse it.
   if (In.getOpcode() == ISD::SRL && In->hasOneUse())
-    if (const APInt *ShAmt = DAG.getValidShiftAmountConstant(In)) {
+    if (std::optional<uint64_t> ShAmt = DAG.getValidShiftAmount(In)) {
       if (*ShAmt == MinSignBits) {
         PackOpcode = X86ISD::PACKSS;
         return DAG.getNode(ISD::SRA, DL, SrcVT, In->ops());
diff --git a/llvm/test/CodeGen/PowerPC/pr44183.ll b/llvm/test/CodeGen/PowerPC/pr44183.ll
index 4d2c81c35b7fe..dc3e129922971 100644
--- a/llvm/test/CodeGen/PowerPC/pr44183.ll
+++ b/llvm/test/CodeGen/PowerPC/pr44183.ll
@@ -12,13 +12,12 @@ define void @_ZN1m1nEv(ptr %this) local_unnamed_addr nounwind align 2 {
 ; CHECK-NEXT:    mflr r0
 ; CHECK-NEXT:    std r30, -16(r1) # 8-byte Folded Spill
 ; CHECK-NEXT:    stdu r1, -48(r1)
-; CHECK-NEXT:    std r0, 64(r1)
 ; CHECK-NEXT:    mr r30, r3
-; CHECK-NEXT:    ld r3, 8(r3)
+; CHECK-NEXT:    std r0, 64(r1)
+; CHECK-NEXT:    lwz r3, 8(r3)
 ; CHECK-NEXT:    lwz r4, 36(r30)
-; CHECK-NEXT:    rldicl r3, r3, 60, 4
+; CHECK-NEXT:    rlwinm r3, r3, 27, 0, 0
 ; CHECK-NEXT:    clrlwi r4, r4, 31
-; CHECK-NEXT:    slwi r3, r3, 31
 ; CHECK-NEXT:    rlwimi r4, r3, 0, 0, 0
 ; CHECK-NEXT:    bl _ZN1llsE1d
 ; CHECK-NEXT:    nop

@llvmbot
Copy link
Member

llvmbot commented May 23, 2024

@llvm/pr-subscribers-llvm-selectiondag

Author: Simon Pilgrim (RKSimon)

Changes

The getValidShiftAmountConstant/getValidMinimumShiftAmountConstant/getValidMaximumShiftAmountConstant helpers only worked with constant shift amounts, which could be problematic after type legalization (e.g. v2i64 might be split into v4i32 on some targets such as 32-bit x86, Thumb2 MVE).

This patch proposes we generalize these helpers to work with KnownBits if a scalar/buildvector constant isn't available.

Most restrictions are the same - the helper fails if any shift amount is out of bounds, getValidShiftConstant must be a specific constant uniform etc.

However, getValidMinimumShiftAmount/getValidMaximumShiftAmount now can return bounds values that aren't values in the actual data, as they are based off the common KnownBits of every vector element.

This addresses feedback on #92096


Full diff: https://github.com/llvm/llvm-project/pull/93182.diff

5 Files Affected:

  • (modified) llvm/include/llvm/CodeGen/SelectionDAG.h (+29-27)
  • (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp (+89-69)
  • (modified) llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp (+13-13)
  • (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+1-1)
  • (modified) llvm/test/CodeGen/PowerPC/pr44183.ll (+3-4)
diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index 96a6270690468..95afbeb5dd6ec 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -2159,36 +2159,38 @@ class SelectionDAG {
   /// splatted value it will return SDValue().
   SDValue getSplatValue(SDValue V, bool LegalTypes = false);
 
-  /// If a SHL/SRA/SRL node \p V has a constant or splat constant shift amount
+  /// If a SHL/SRA/SRL node \p V has an uniform shift amount
   /// that is less than the element bit-width of the shift node, return it.
-  const APInt *getValidShiftAmountConstant(SDValue V,
-                                           const APInt &DemandedElts) const;
+  std::optional<uint64_t> getValidShiftAmount(SDValue V,
+                                              const APInt &DemandedElts,
+                                              unsigned Depth = 0) const;
 
-  /// If a SHL/SRA/SRL node \p V has a constant or splat constant shift amount
+  /// If a SHL/SRA/SRL node \p V has an uniform shift amount
   /// that is less than the element bit-width of the shift node, return it.
-  const APInt *getValidShiftAmountConstant(SDValue V) const;
-
-  /// If a SHL/SRA/SRL node \p V has constant shift amounts that are all less
-  /// than the element bit-width of the shift node, return the minimum value.
-  const APInt *
-  getValidMinimumShiftAmountConstant(SDValue V,
-                                     const APInt &DemandedElts) const;
-
-  /// If a SHL/SRA/SRL node \p V has constant shift amounts that are all less
-  /// than the element bit-width of the shift node, return the minimum value.
-  const APInt *
-  getValidMinimumShiftAmountConstant(SDValue V) const;
-
-  /// If a SHL/SRA/SRL node \p V has constant shift amounts that are all less
-  /// than the element bit-width of the shift node, return the maximum value.
-  const APInt *
-  getValidMaximumShiftAmountConstant(SDValue V,
-                                     const APInt &DemandedElts) const;
-
-  /// If a SHL/SRA/SRL node \p V has constant shift amounts that are all less
-  /// than the element bit-width of the shift node, return the maximum value.
-  const APInt *
-  getValidMaximumShiftAmountConstant(SDValue V) const;
+  std::optional<uint64_t> getValidShiftAmount(SDValue V,
+                                              unsigned Depth = 0) const;
+
+  /// If a SHL/SRA/SRL node \p V has shift amounts that are all less than the
+  /// element bit-width of the shift node, return the minimum possible value.
+  std::optional<uint64_t> getValidMinimumShiftAmount(SDValue V,
+                                                     const APInt &DemandedElts,
+                                                     unsigned Depth = 0) const;
+
+  /// If a SHL/SRA/SRL node \p V has shift amounts that are all less than the
+  /// element bit-width of the shift node, return the minimum possible value.
+  std::optional<uint64_t> getValidMinimumShiftAmount(SDValue V,
+                                                     unsigned Depth = 0) const;
+
+  /// If a SHL/SRA/SRL node \p V has shift amounts that are all less than the
+  /// element bit-width of the shift node, return the maximum possible value.
+  std::optional<uint64_t> getValidMaximumShiftAmount(SDValue V,
+                                                     const APInt &DemandedElts,
+                                                     unsigned Depth = 0) const;
+
+  /// If a SHL/SRA/SRL node \p V has shift amounts that are all less than the
+  /// element bit-width of the shift node, return the maximum possible value.
+  std::optional<uint64_t> getValidMaximumShiftAmount(SDValue V,
+                                                     unsigned Depth = 0) const;
 
   /// Match a binop + shuffle pyramid that represents a horizontal reduction
   /// over the elements of a vector starting from the EXTRACT_VECTOR_ELT node /p
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index b05649c6ce955..b71b496c0aa84 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -3009,9 +3009,9 @@ SDValue SelectionDAG::getSplatValue(SDValue V, bool LegalTypes) {
   return SDValue();
 }
 
-const APInt *
-SelectionDAG::getValidShiftAmountConstant(SDValue V,
-                                          const APInt &DemandedElts) const {
+std::optional<uint64_t>
+SelectionDAG::getValidShiftAmount(SDValue V, const APInt &DemandedElts,
+                                  unsigned Depth) const {
   assert((V.getOpcode() == ISD::SHL || V.getOpcode() == ISD::SRL ||
           V.getOpcode() == ISD::SRA) &&
          "Unknown shift node");
@@ -3020,91 +3020,111 @@ SelectionDAG::getValidShiftAmountConstant(SDValue V,
     // Shifting more than the bitwidth is not valid.
     const APInt &ShAmt = SA->getAPIntValue();
     if (ShAmt.ult(BitWidth))
-      return &ShAmt;
+      return ShAmt.getZExtValue();
+  } else {
+    KnownBits KnownAmt =
+        computeKnownBits(V.getOperand(1), DemandedElts, Depth + 1);
+    if (KnownAmt.isConstant() && KnownAmt.getConstant().ult(BitWidth))
+      return KnownAmt.getConstant().getZExtValue();
   }
-  return nullptr;
+  return std::nullopt;
 }
 
-const APInt *SelectionDAG::getValidShiftAmountConstant(SDValue V) const {
+std::optional<uint64_t>
+SelectionDAG::getValidShiftAmount(SDValue V, unsigned Depth) const {
   EVT VT = V.getValueType();
   APInt DemandedElts = VT.isFixedLengthVector()
                            ? APInt::getAllOnes(VT.getVectorNumElements())
                            : APInt(1, 1);
-  return getValidShiftAmountConstant(V, DemandedElts);
+  return getValidShiftAmount(V, DemandedElts, Depth);
 }
 
-const APInt *SelectionDAG::getValidMinimumShiftAmountConstant(
-    SDValue V, const APInt &DemandedElts) const {
+std::optional<uint64_t>
+SelectionDAG::getValidMinimumShiftAmount(SDValue V, const APInt &DemandedElts,
+                                         unsigned Depth) const {
   assert((V.getOpcode() == ISD::SHL || V.getOpcode() == ISD::SRL ||
           V.getOpcode() == ISD::SRA) &&
          "Unknown shift node");
-  if (const APInt *ValidAmt = getValidShiftAmountConstant(V, DemandedElts))
-    return ValidAmt;
   unsigned BitWidth = V.getScalarValueSizeInBits();
-  auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand(1));
-  if (!BV)
-    return nullptr;
-  const APInt *MinShAmt = nullptr;
-  for (unsigned i = 0, e = BV->getNumOperands(); i != e; ++i) {
-    if (!DemandedElts[i])
-      continue;
-    auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand(i));
-    if (!SA)
-      return nullptr;
-    // Shifting more than the bitwidth is not valid.
-    const APInt &ShAmt = SA->getAPIntValue();
-    if (ShAmt.uge(BitWidth))
-      return nullptr;
-    if (MinShAmt && MinShAmt->ule(ShAmt))
-      continue;
-    MinShAmt = &ShAmt;
+  if (auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand(1))) {
+    const APInt *MinShAmt = nullptr;
+    for (unsigned i = 0, e = BV->getNumOperands(); i != e; ++i) {
+      if (!DemandedElts[i])
+        continue;
+      auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand(i));
+      if (!SA) {
+        MinShAmt = nullptr;
+        break;
+      }
+      // Shifting more than the bitwidth is not valid.
+      const APInt &ShAmt = SA->getAPIntValue();
+      if (ShAmt.uge(BitWidth))
+        return std::nullopt;
+      if (MinShAmt && MinShAmt->ule(ShAmt))
+        continue;
+      MinShAmt = &ShAmt;
+    }
+    if (MinShAmt)
+      return MinShAmt->getZExtValue();
   }
-  return MinShAmt;
+  KnownBits KnownAmt =
+      computeKnownBits(V.getOperand(1), DemandedElts, Depth + 1);
+  if (KnownAmt.getMaxValue().ult(BitWidth))
+    return KnownAmt.getMinValue().getZExtValue();
+  return std::nullopt;
 }
 
-const APInt *SelectionDAG::getValidMinimumShiftAmountConstant(SDValue V) const {
+std::optional<uint64_t>
+SelectionDAG::getValidMinimumShiftAmount(SDValue V, unsigned Depth) const {
   EVT VT = V.getValueType();
   APInt DemandedElts = VT.isFixedLengthVector()
                            ? APInt::getAllOnes(VT.getVectorNumElements())
                            : APInt(1, 1);
-  return getValidMinimumShiftAmountConstant(V, DemandedElts);
+  return getValidMinimumShiftAmount(V, DemandedElts, Depth);
 }
 
-const APInt *SelectionDAG::getValidMaximumShiftAmountConstant(
-    SDValue V, const APInt &DemandedElts) const {
+std::optional<uint64_t>
+SelectionDAG::getValidMaximumShiftAmount(SDValue V, const APInt &DemandedElts,
+                                         unsigned Depth) const {
   assert((V.getOpcode() == ISD::SHL || V.getOpcode() == ISD::SRL ||
           V.getOpcode() == ISD::SRA) &&
          "Unknown shift node");
-  if (const APInt *ValidAmt = getValidShiftAmountConstant(V, DemandedElts))
-    return ValidAmt;
   unsigned BitWidth = V.getScalarValueSizeInBits();
-  auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand(1));
-  if (!BV)
-    return nullptr;
-  const APInt *MaxShAmt = nullptr;
-  for (unsigned i = 0, e = BV->getNumOperands(); i != e; ++i) {
-    if (!DemandedElts[i])
-      continue;
-    auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand(i));
-    if (!SA)
-      return nullptr;
-    // Shifting more than the bitwidth is not valid.
-    const APInt &ShAmt = SA->getAPIntValue();
-    if (ShAmt.uge(BitWidth))
-      return nullptr;
-    if (MaxShAmt && MaxShAmt->uge(ShAmt))
-      continue;
-    MaxShAmt = &ShAmt;
+  if (auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand(1))) {
+    const APInt *MaxShAmt = nullptr;
+    for (unsigned i = 0, e = BV->getNumOperands(); i != e; ++i) {
+      if (!DemandedElts[i])
+        continue;
+      auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand(i));
+      if (!SA) {
+        MaxShAmt = nullptr;
+        break;
+      }
+      // Shifting more than the bitwidth is not valid.
+      const APInt &ShAmt = SA->getAPIntValue();
+      if (ShAmt.uge(BitWidth))
+        return std::nullopt;
+      if (MaxShAmt && MaxShAmt->uge(ShAmt))
+        continue;
+      MaxShAmt = &ShAmt;
+    }
+    if (MaxShAmt)
+      return MaxShAmt->getZExtValue();
   }
-  return MaxShAmt;
+  KnownBits KnownAmt =
+      computeKnownBits(V.getOperand(1), DemandedElts, Depth + 1);
+  if (KnownAmt.getMaxValue().ult(BitWidth))
+    return KnownAmt.getMaxValue().getZExtValue();
+  return std::nullopt;
 }
 
-const APInt *SelectionDAG::getValidMaximumShiftAmountConstant(SDValue V) const {
+std::optional<uint64_t>
+SelectionDAG::getValidMaximumShiftAmount(SDValue V, unsigned Depth) const {
   EVT VT = V.getValueType();
   APInt DemandedElts = VT.isFixedLengthVector()
                            ? APInt::getAllOnes(VT.getVectorNumElements())
                            : APInt(1, 1);
-  return getValidMaximumShiftAmountConstant(V, DemandedElts);
+  return getValidMaximumShiftAmount(V, DemandedElts, Depth);
 }
 
 /// Determine which bits of Op are known to be either zero or one and return
@@ -3569,9 +3589,9 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
     Known = KnownBits::shl(Known, Known2, NUW, NSW, ShAmtNonZero);
 
     // Minimum shift low bits are known zero.
-    if (const APInt *ShMinAmt =
-            getValidMinimumShiftAmountConstant(Op, DemandedElts))
-      Known.Zero.setLowBits(ShMinAmt->getZExtValue());
+    if (std::optional<uint64_t> ShMinAmt =
+            getValidMinimumShiftAmount(Op, DemandedElts, Depth))
+      Known.Zero.setLowBits(*ShMinAmt);
     break;
   }
   case ISD::SRL:
@@ -3581,9 +3601,9 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
                             Op->getFlags().hasExact());
 
     // Minimum shift high bits are known zero.
-    if (const APInt *ShMinAmt =
-            getValidMinimumShiftAmountConstant(Op, DemandedElts))
-      Known.Zero.setHighBits(ShMinAmt->getZExtValue());
+    if (std::optional<uint64_t> ShMinAmt =
+            getValidMinimumShiftAmount(Op, DemandedElts, Depth))
+      Known.Zero.setHighBits(*ShMinAmt);
     break;
   case ISD::SRA:
     Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
@@ -4587,17 +4607,17 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts,
   case ISD::SRA:
     Tmp = ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth + 1);
     // SRA X, C -> adds C sign bits.
-    if (const APInt *ShAmt =
-            getValidMinimumShiftAmountConstant(Op, DemandedElts))
-      Tmp = std::min<uint64_t>(Tmp + ShAmt->getZExtValue(), VTBits);
+    if (std::optional<uint64_t> ShAmt =
+            getValidMinimumShiftAmount(Op, DemandedElts, Depth))
+      Tmp = std::min<uint64_t>(Tmp + *ShAmt, VTBits);
     return Tmp;
   case ISD::SHL:
-    if (const APInt *ShAmt =
-            getValidMaximumShiftAmountConstant(Op, DemandedElts)) {
+    if (std::optional<uint64_t> ShAmt =
+            getValidMaximumShiftAmount(Op, DemandedElts, Depth)) {
       // shl destroys sign bits, ensure it doesn't shift out all sign bits.
       Tmp = ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth + 1);
-      if (ShAmt->ult(Tmp))
-        return Tmp - ShAmt->getZExtValue();
+      if (*ShAmt < Tmp)
+        return Tmp - *ShAmt;
     }
     break;
   case ISD::AND:
@@ -5270,7 +5290,7 @@ bool SelectionDAG::canCreateUndefOrPoison(SDValue Op, const APInt &DemandedElts,
   case ISD::SRL:
   case ISD::SRA:
     // If the max shift amount isn't in range, then the shift can create poison.
-    return !getValidMaximumShiftAmountConstant(Op, DemandedElts);
+    return !getValidMaximumShiftAmount(Op, DemandedElts, Depth);
 
   case ISD::SCALAR_TO_VECTOR:
     // Check if we demand any upper (undef) elements.
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 37c72339fe295..dfcd5439b8a9d 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -796,10 +796,10 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
   case ISD::SHL: {
     // If we are only demanding sign bits then we can use the shift source
     // directly.
-    if (const APInt *MaxSA =
-            DAG.getValidMaximumShiftAmountConstant(Op, DemandedElts)) {
+    if (std::optional<uint64_t> MaxSA =
+            DAG.getValidMaximumShiftAmount(Op, DemandedElts, Depth)) {
       SDValue Op0 = Op.getOperand(0);
-      unsigned ShAmt = MaxSA->getZExtValue();
+      unsigned ShAmt = *MaxSA;
       unsigned NumSignBits =
           DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1);
       unsigned UpperDemandedBits = BitWidth - DemandedBits.countr_zero();
@@ -1789,9 +1789,9 @@ bool TargetLowering::SimplifyDemandedBits(
         // TODO - support non-uniform vector amounts.
         if (InnerOp.getOpcode() == ISD::SRL && Op0.hasOneUse() &&
             InnerOp.hasOneUse()) {
-          if (const APInt *SA2 =
-                  TLO.DAG.getValidShiftAmountConstant(InnerOp, DemandedElts)) {
-            unsigned InnerShAmt = SA2->getZExtValue();
+          if (std::optional<uint64_t> SA2 = TLO.DAG.getValidShiftAmount(
+                  InnerOp, DemandedElts, Depth + 1)) {
+            unsigned InnerShAmt = *SA2;
             if (InnerShAmt < ShAmt && InnerShAmt < InnerBits &&
                 DemandedBits.getActiveBits() <=
                     (InnerBits - InnerShAmt + ShAmt) &&
@@ -1918,9 +1918,9 @@ bool TargetLowering::SimplifyDemandedBits(
 
     // If we are only demanding sign bits then we can use the shift source
     // directly.
-    if (const APInt *MaxSA =
-            TLO.DAG.getValidMaximumShiftAmountConstant(Op, DemandedElts)) {
-      unsigned ShAmt = MaxSA->getZExtValue();
+    if (std::optional<uint64_t> MaxSA =
+            TLO.DAG.getValidMaximumShiftAmount(Op, DemandedElts, Depth)) {
+      unsigned ShAmt = *MaxSA;
       unsigned NumSignBits =
           TLO.DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1);
       unsigned UpperDemandedBits = BitWidth - DemandedBits.countr_zero();
@@ -2598,11 +2598,11 @@ bool TargetLowering::SimplifyDemandedBits(
         break;
 
       if (Src.getNode()->hasOneUse()) {
-        const APInt *ShAmtC =
-            TLO.DAG.getValidShiftAmountConstant(Src, DemandedElts);
-        if (!ShAmtC || ShAmtC->uge(BitWidth))
+        std::optional<uint64_t> ShAmtC =
+            TLO.DAG.getValidShiftAmount(Src, DemandedElts, Depth + 1);
+        if (!ShAmtC || *ShAmtC >= BitWidth)
           break;
-        uint64_t ShVal = ShAmtC->getZExtValue();
+        uint64_t ShVal = *ShAmtC;
 
         APInt HighBits =
             APInt::getHighBitsSet(OperandBitWidth, OperandBitWidth - BitWidth);
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 215cbc308e43d..fd99f0e345d14 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -20490,7 +20490,7 @@ static SDValue matchTruncateWithPACK(unsigned &PackOpcode, EVT DstVT,
   // the truncation then we can use PACKSS by converting the srl to a sra.
   // SimplifyDemandedBits often relaxes sra to srl so we need to reverse it.
   if (In.getOpcode() == ISD::SRL && In->hasOneUse())
-    if (const APInt *ShAmt = DAG.getValidShiftAmountConstant(In)) {
+    if (std::optional<uint64_t> ShAmt = DAG.getValidShiftAmount(In)) {
       if (*ShAmt == MinSignBits) {
         PackOpcode = X86ISD::PACKSS;
         return DAG.getNode(ISD::SRA, DL, SrcVT, In->ops());
diff --git a/llvm/test/CodeGen/PowerPC/pr44183.ll b/llvm/test/CodeGen/PowerPC/pr44183.ll
index 4d2c81c35b7fe..dc3e129922971 100644
--- a/llvm/test/CodeGen/PowerPC/pr44183.ll
+++ b/llvm/test/CodeGen/PowerPC/pr44183.ll
@@ -12,13 +12,12 @@ define void @_ZN1m1nEv(ptr %this) local_unnamed_addr nounwind align 2 {
 ; CHECK-NEXT:    mflr r0
 ; CHECK-NEXT:    std r30, -16(r1) # 8-byte Folded Spill
 ; CHECK-NEXT:    stdu r1, -48(r1)
-; CHECK-NEXT:    std r0, 64(r1)
 ; CHECK-NEXT:    mr r30, r3
-; CHECK-NEXT:    ld r3, 8(r3)
+; CHECK-NEXT:    std r0, 64(r1)
+; CHECK-NEXT:    lwz r3, 8(r3)
 ; CHECK-NEXT:    lwz r4, 36(r30)
-; CHECK-NEXT:    rldicl r3, r3, 60, 4
+; CHECK-NEXT:    rlwinm r3, r3, 27, 0, 0
 ; CHECK-NEXT:    clrlwi r4, r4, 31
-; CHECK-NEXT:    slwi r3, r3, 31
 ; CHECK-NEXT:    rlwimi r4, r3, 0, 0, 0
 ; CHECK-NEXT:    bl _ZN1llsE1d
 ; CHECK-NEXT:    nop

} else {
KnownBits KnownAmt =
computeKnownBits(V.getOperand(1), DemandedElts, Depth + 1);
if (KnownAmt.isConstant() && KnownAmt.getConstant().ult(BitWidth))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO if its constant only, Depth should start at MaxDepth - 2, otherwise think there is a lot of wasted compute.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem we're hitting is that in many cases its not just a bitcast of a constant, its often more complex than that and when trying to use value tracking during legalization we only get one chance, we can't wait for combines to simplify things.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, fair enough (maybe mention in a comment? This seems a bit non-intuitive).

@RKSimon RKSimon force-pushed the perf/valid-shiftamt branch from b7c6586 to 9ced5e9 Compare May 24, 2024 10:45
@RKSimon RKSimon force-pushed the perf/valid-shiftamt branch from 9ced5e9 to bbe3c19 Compare May 24, 2024 17:04
@RKSimon
Copy link
Collaborator Author

RKSimon commented May 27, 2024

any more comments?

@RKSimon RKSimon force-pushed the perf/valid-shiftamt branch 2 times, most recently from aaee413 to 57b8519 Compare May 29, 2024 20:59
@RKSimon
Copy link
Collaborator Author

RKSimon commented May 30, 2024

ping - anything else?

// Use computeKnownBits to find a hidden constant (usually type legalized).
// e.g. Hidden behind multiple bitcasts/build_vector/casts etc.
KnownBits KnownAmt =
computeKnownBits(V.getOperand(1), DemandedElts, Depth + 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think Depth + 1 makes sense here. This isn't a recursive function. If its used by a recursive function, the caller should be setting Depth + 1.

computeKnownBits(V.getOperand(1), DemandedElts, Depth + 1);
if (KnownAmt.getMaxValue().ult(BitWidth))
return KnownAmt.getMaxValue().getZExtValue();
return std::nullopt;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there should a single helper for all of the getValid*ShiftAmount functions.
Think it could just be getValidShiftAmount impl but return a KnownBits (the constant case we can just use KnownBits::makeConstant()).

Then getValidShiftAmount returns if the result is constant + less than bitwidth, getValidMinimumShiftAmount returns the getMinValue(), and vice versa for maximum. Think that will save a lot of dup code.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me take a look, I'm wondering if ConstantRange will help us here. KnownBits alone tends to make the min/max values 'fuzzy'.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its a bit of a mixed bag. KnownBits impl is more complete but has to throw away range information that isn't tied to particular bits. ConstantRange has a less complete impl, but obv represents what we want here better.

@RKSimon RKSimon force-pushed the perf/valid-shiftamt branch from 57b8519 to 3879271 Compare May 30, 2024 18:49
// Use computeKnownBits to find a hidden constant/knownbits (usually type
// legalized). e.g. Hidden behind multiple bitcasts/build_vector/casts etc.
KnownBits KnownAmt =
computeKnownBits(V.getOperand(1), DemandedElts, Depth + 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still think this should be Depth and the callers should be in charge of managing Depth + 1 (the callers to getValidShiftAmount/getValidMinimumShiftAmount/getValidMaximumShiftAmount that is).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really mind either way - we do pass in the shift node and not the shift amount specifically so I can see both sides of it. This PR was just to pull out a minor tweak of #92096 and has instead turned into a monster :(

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its my preference, but not a req.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Guess my feeling is most of our recursive functions in DAGCombiner we put the +1 at the initial callsite, so it might be confusing otherwise. Also this prohibits a 0 depth use.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK - the depth = 0 case could be an issue in the matchTruncateWithPACK use in X86ISelLowering.cpp - this is currently the only call from a non-recursive function, but there will probably be more in the future.

@goldsteinn
Copy link
Contributor

LGTM (with preference for changing Depth handling).

Wait a day or so before pushing so others have a chance to take a look.

@RKSimon RKSimon force-pushed the perf/valid-shiftamt branch from 3879271 to ee0341d Compare May 30, 2024 21:39
if (!ShAmtC || ShAmtC->uge(BitWidth))
std::optional<uint64_t> ShAmtC =
TLO.DAG.getValidShiftAmount(Src, DemandedElts, Depth + 2);
if (!ShAmtC || *ShAmtC >= BitWidth)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why Depth + 2? (here and above)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are 'inner' shifts of shifts. The originals should have used Depth+1 but that was missed.

@RKSimon
Copy link
Collaborator Author

RKSimon commented May 31, 2024

Thanks everyone, unless anyone has any more concerns I'm intending to commit this over the weekend.

RKSimon added 2 commits June 1, 2024 16:12
…tAmount helpers to support KnownBits analysis

The getValidShiftAmountConstant/getValidMinimumShiftAmountConstant/getValidMaximumShiftAmountConstant helpers only worked with constant shift amounts, which could be problematic after type legalization (e.g. v2i64 might be split into v4i32 on some targets such as Thumb2 MVE).

This patch proposes we generalize these helpers to work with KnownBits if a scalar/buildvector constant isn't available.

Most restrictions are the same - the helper fails if any shift amount is out of bounds, getValidShiftConstant must be a specific constant uniform etc.

However, getValidMinimumShiftAmount/getValidMaximumShiftAmount now can return bounds values that aren't values in the actual data, as they are based off the common KnownBits of every vector element.

This addresses feedback on llvm#92096
RKSimon added 2 commits June 1, 2024 16:12
I forgot to account for ConstantRange expecting the upper bound to be non-inclusive
@RKSimon RKSimon force-pushed the perf/valid-shiftamt branch from 55fb0ca to 5af5cd1 Compare June 1, 2024 15:12
@RKSimon RKSimon merged commit 2b1dfd2 into llvm:main Jun 1, 2024
5 of 7 checks passed
@RKSimon RKSimon deleted the perf/valid-shiftamt branch June 1, 2024 15:48
@RKSimon
Copy link
Collaborator Author

RKSimon commented Jun 1, 2024

As usual when touching SimplifyDemandedBits et al this has managed to expose a bug somewhere else that only shows up in multi stage build bots - currently looking at this.

RKSimon added a commit that referenced this pull request Jun 1, 2024
…en not poison

Since #93182 we can now call computeKnownBits inside getValidMaximumShiftAmount to determine the bounds of the shift amount ensuring that it wasn't poison, meaning if we did freeze the ahift amount, isGuaranteedNotToBeUndefOrPoison would then fail as we can't call computeKnownBits through FREEZE for potentially poison values.

I'm still reducing a decent test case but wanted to get the buildbot fix ASAP.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants