Skip to content

[SLP]Improve minbitwidth analysis for shifts. #84356

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

alexey-bataev
Copy link
Member

Adds improved bitwidth analysis for shl/ashr/lshr instructions. The
analysis is based on similar version in InstCombiner.

Created using spr 1.3.5
@llvmbot
Copy link
Member

llvmbot commented Mar 7, 2024

@llvm/pr-subscribers-llvm-transforms

Author: Alexey Bataev (alexey-bataev)

Changes

Adds improved bitwidth analysis for shl/ashr/lshr instructions. The
analysis is based on similar version in InstCombiner.


Full diff: https://github.com/llvm/llvm-project/pull/84356.diff

3 Files Affected:

  • (modified) llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp (+97-15)
  • (modified) llvm/test/Transforms/SLPVectorizer/X86/reorder-possible-strided-node.ll (+2-4)
  • (modified) llvm/test/Transforms/SLPVectorizer/X86/reorder_diamond_match.ll (+13-12)
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 1889bc09e85028..3364f34d0148cc 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -10094,16 +10094,6 @@ BoUpSLP::isGatherShuffledSingleRegisterEntry(
         BitWidth = UserIt->second.second;
     }
   }
-  auto CheckBitwidth = [&](const TreeEntry &TE) {
-    Type *ScalarTy = TE.Scalars.front()->getType();
-    if (!ScalarTy->isIntegerTy())
-      return true;
-    unsigned TEBitWidth = DL->getTypeStoreSize(ScalarTy);
-    auto UserIt = MinBWs.find(TEUseEI.UserTE);
-    if (UserIt != MinBWs.end())
-      TEBitWidth = UserIt->second.second;
-    return BitWidth == TEBitWidth;
-  };
   SmallVector<SmallPtrSet<const TreeEntry *, 4>> UsedTEs;
   DenseMap<Value *, int> UsedValuesEntry;
   for (Value *V : VL) {
@@ -10138,8 +10128,6 @@ BoUpSLP::isGatherShuffledSingleRegisterEntry(
           continue;
       }
 
-      if (!CheckBitwidth(*TEPtr))
-        continue;
       // Check if the user node of the TE comes after user node of TEPtr,
       // otherwise TEPtr depends on TE.
       if ((TEInsertBlock != InsertPt->getParent() ||
@@ -10157,7 +10145,7 @@ BoUpSLP::isGatherShuffledSingleRegisterEntry(
           VTE = *It->getSecond().begin();
           // Iterate through all vectorized nodes.
           auto *MIt = find_if(It->getSecond(), [&](const TreeEntry *MTE) {
-            return MTE->State == TreeEntry::Vectorize && CheckBitwidth(*MTE);
+            return MTE->State == TreeEntry::Vectorize;
           });
           if (MIt == It->getSecond().end())
             continue;
@@ -10167,8 +10155,6 @@ BoUpSLP::isGatherShuffledSingleRegisterEntry(
       Instruction &LastBundleInst = getLastInstructionInBundle(VTE);
       if (&LastBundleInst == TEInsertPt || !CheckOrdering(&LastBundleInst))
         continue;
-      if (!CheckBitwidth(*VTE))
-        continue;
       VToTEs.insert(VTE);
     }
     if (VToTEs.empty())
@@ -10216,6 +10202,45 @@ BoUpSLP::isGatherShuffledSingleRegisterEntry(
     return std::nullopt;
   }
 
+  if (BitWidth > 0) {
+    // Check if the used TEs supposed to be resized and choose the best
+    // candidates.
+    unsigned NodesBitWidth = 0;
+    auto CheckBitwidth = [&](const TreeEntry &TE) {
+      unsigned TEBitWidth = BitWidth;
+      auto UserIt = MinBWs.find(TEUseEI.UserTE);
+      if (UserIt != MinBWs.end())
+        TEBitWidth = UserIt->second.second;
+      if (BitWidth <= TEBitWidth) {
+        if (NodesBitWidth == 0)
+          NodesBitWidth = TEBitWidth;
+        return NodesBitWidth == TEBitWidth;
+      }
+      return false;
+    };
+    for (auto [Idx, Set] : enumerate(UsedTEs)) {
+      DenseSet<const TreeEntry *> ForRemoval;
+      for (const TreeEntry *TE : Set) {
+        if (!CheckBitwidth(*TE))
+          ForRemoval.insert(TE);
+      }
+      // All elements must be removed - remove the whole container.
+      if (ForRemoval.size() == Set.size()) {
+        Set.clear();
+        continue;
+      }
+      for (const TreeEntry *TE : ForRemoval)
+        Set.erase(TE);
+    }
+    for (auto *It = UsedTEs.begin(); It != UsedTEs.end();) {
+      if (It->empty()) {
+        UsedTEs.erase(It);
+        continue;
+      }
+      std::advance(It, 1);
+    }
+  }
+
   unsigned VF = 0;
   if (UsedTEs.size() == 1) {
     // Keep the order to avoid non-determinism.
@@ -13946,6 +13971,63 @@ bool BoUpSLP::collectValuesToDemote(
     MaxDepthLevel = std::max(Level1, Level2);
     break;
   }
+  case Instruction::Shl: {
+    // If we are truncating the result of this SHL, and if it's a shift of an
+    // inrange amount, we can always perform a SHL in a smaller type.
+    unsigned Level1, Level2;
+    KnownBits AmtKnownBits = computeKnownBits(I->getOperand(1), *DL);
+    if (AmtKnownBits.getMaxValue().uge(BitWidth) ||
+        !collectValuesToDemote(I->getOperand(0), IsProfitableToDemoteRoot,
+                               BitWidth, ToDemote, DemotedConsts, Visited,
+                               Level1, IsProfitableToDemote) ||
+        !collectValuesToDemote(I->getOperand(1), IsProfitableToDemoteRoot,
+                               BitWidth, ToDemote, DemotedConsts, Visited,
+                               Level2, IsProfitableToDemote))
+      return false;
+    MaxDepthLevel = std::max(Level1, Level2);
+    break;
+  }
+  case Instruction::LShr: {
+    // If this is a truncate of a logical shr, we can truncate it to a smaller
+    // lshr iff we know that the bits we would otherwise be shifting in are
+    // already zeros.
+    uint32_t OrigBitWidth = DL->getTypeSizeInBits(V->getType());
+    unsigned Level1, Level2;
+    KnownBits AmtKnownBits = computeKnownBits(I->getOperand(1), *DL);
+    APInt ShiftedBits = APInt::getBitsSetFrom(OrigBitWidth, BitWidth);
+    if (AmtKnownBits.getMaxValue().uge(BitWidth) ||
+        !MaskedValueIsZero(I->getOperand(0), ShiftedBits, SimplifyQuery(*DL)) ||
+        !collectValuesToDemote(I->getOperand(0), IsProfitableToDemoteRoot,
+                               BitWidth, ToDemote, DemotedConsts, Visited,
+                               Level1, IsProfitableToDemote) ||
+        !collectValuesToDemote(I->getOperand(1), IsProfitableToDemoteRoot,
+                               BitWidth, ToDemote, DemotedConsts, Visited,
+                               Level2, IsProfitableToDemote))
+      return false;
+    MaxDepthLevel = std::max(Level1, Level2);
+    break;
+  }
+  case Instruction::AShr: {
+    // If this is a truncate of an arithmetic shr, we can truncate it to a
+    // smaller ashr iff we know that all the bits from the sign bit of the
+    // original type and the sign bit of the truncate type are similar.
+    uint32_t OrigBitWidth = DL->getTypeSizeInBits(V->getType());
+    unsigned Level1, Level2;
+    KnownBits AmtKnownBits = computeKnownBits(I->getOperand(1), *DL);
+    unsigned ShiftedBits = OrigBitWidth - BitWidth;
+    if (AmtKnownBits.getMaxValue().uge(BitWidth) ||
+        ShiftedBits >=
+            ComputeNumSignBits(I->getOperand(0), *DL, 0, AC, nullptr, DT) ||
+        !collectValuesToDemote(I->getOperand(0), IsProfitableToDemoteRoot,
+                               BitWidth, ToDemote, DemotedConsts, Visited,
+                               Level1, IsProfitableToDemote) ||
+        !collectValuesToDemote(I->getOperand(1), IsProfitableToDemoteRoot,
+                               BitWidth, ToDemote, DemotedConsts, Visited,
+                               Level2, IsProfitableToDemote))
+      return false;
+    MaxDepthLevel = std::max(Level1, Level2);
+    break;
+  }
 
   // We can demote selects if we can demote their true and false values.
   case Instruction::Select: {
diff --git a/llvm/test/Transforms/SLPVectorizer/X86/reorder-possible-strided-node.ll b/llvm/test/Transforms/SLPVectorizer/X86/reorder-possible-strided-node.ll
index 6f5d3d3785e0c8..6378f696b470d4 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/reorder-possible-strided-node.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/reorder-possible-strided-node.ll
@@ -10,10 +10,8 @@ define void @test() {
 ; CHECK-NEXT:    [[TMP1:%.*]] = load <4 x i32>, ptr [[ARRAYIDX22]], align 4
 ; CHECK-NEXT:    [[TMP2:%.*]] = shufflevector <4 x i32> [[TMP1]], <4 x i32> poison, <4 x i32> <i32 3, i32 2, i32 1, i32 0>
 ; CHECK-NEXT:    [[TMP3:%.*]] = mul <4 x i32> [[TMP2]], [[TMP0]]
-; CHECK-NEXT:    [[TMP4:%.*]] = sext <4 x i32> [[TMP3]] to <4 x i64>
-; CHECK-NEXT:    [[TMP5:%.*]] = ashr <4 x i64> [[TMP4]], zeroinitializer
-; CHECK-NEXT:    [[TMP6:%.*]] = trunc <4 x i64> [[TMP5]] to <4 x i32>
-; CHECK-NEXT:    store <4 x i32> [[TMP6]], ptr getelementptr inbounds ([4 x i32], ptr null, i64 8, i64 0), align 16
+; CHECK-NEXT:    [[TMP4:%.*]] = ashr <4 x i32> [[TMP3]], zeroinitializer
+; CHECK-NEXT:    store <4 x i32> [[TMP4]], ptr getelementptr inbounds ([4 x i32], ptr null, i64 8, i64 0), align 16
 ; CHECK-NEXT:    ret void
 ;
 entry:
diff --git a/llvm/test/Transforms/SLPVectorizer/X86/reorder_diamond_match.ll b/llvm/test/Transforms/SLPVectorizer/X86/reorder_diamond_match.ll
index 86b1e1a801e32f..91ee4dba07009f 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/reorder_diamond_match.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/reorder_diamond_match.ll
@@ -5,18 +5,19 @@ define void @test() {
 ; CHECK-LABEL: @test(
 ; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr inbounds i8, ptr undef, i64 4
 ; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr inbounds [4 x [4 x i32]], ptr undef, i64 0, i64 1, i64 0
-; CHECK-NEXT:    [[TMP4:%.*]] = load <4 x i8>, ptr [[TMP1]], align 1
-; CHECK-NEXT:    [[TMP5:%.*]] = zext <4 x i8> [[TMP4]] to <4 x i32>
-; CHECK-NEXT:    [[TMP6:%.*]] = sub nsw <4 x i32> zeroinitializer, [[TMP5]]
-; CHECK-NEXT:    [[TMP7:%.*]] = shl nsw <4 x i32> [[TMP6]], zeroinitializer
-; CHECK-NEXT:    [[TMP8:%.*]] = add nsw <4 x i32> [[TMP7]], zeroinitializer
-; CHECK-NEXT:    [[TMP9:%.*]] = shufflevector <4 x i32> [[TMP8]], <4 x i32> poison, <4 x i32> <i32 1, i32 0, i32 3, i32 2>
-; CHECK-NEXT:    [[TMP10:%.*]] = add nsw <4 x i32> [[TMP8]], [[TMP9]]
-; CHECK-NEXT:    [[TMP11:%.*]] = sub nsw <4 x i32> [[TMP8]], [[TMP9]]
-; CHECK-NEXT:    [[TMP12:%.*]] = shufflevector <4 x i32> [[TMP10]], <4 x i32> [[TMP11]], <4 x i32> <i32 1, i32 4, i32 3, i32 6>
-; CHECK-NEXT:    [[TMP13:%.*]] = add nsw <4 x i32> zeroinitializer, [[TMP12]]
-; CHECK-NEXT:    [[TMP14:%.*]] = sub nsw <4 x i32> zeroinitializer, [[TMP12]]
-; CHECK-NEXT:    [[TMP15:%.*]] = shufflevector <4 x i32> [[TMP13]], <4 x i32> [[TMP14]], <4 x i32> <i32 0, i32 1, i32 6, i32 7>
+; CHECK-NEXT:    [[TMP3:%.*]] = load <4 x i8>, ptr [[TMP1]], align 1
+; CHECK-NEXT:    [[TMP4:%.*]] = zext <4 x i8> [[TMP3]] to <4 x i16>
+; CHECK-NEXT:    [[TMP5:%.*]] = sub <4 x i16> zeroinitializer, [[TMP4]]
+; CHECK-NEXT:    [[TMP6:%.*]] = shl <4 x i16> [[TMP5]], zeroinitializer
+; CHECK-NEXT:    [[TMP7:%.*]] = add <4 x i16> [[TMP6]], zeroinitializer
+; CHECK-NEXT:    [[TMP8:%.*]] = shufflevector <4 x i16> [[TMP7]], <4 x i16> poison, <4 x i32> <i32 1, i32 0, i32 3, i32 2>
+; CHECK-NEXT:    [[TMP9:%.*]] = add nsw <4 x i16> [[TMP7]], [[TMP8]]
+; CHECK-NEXT:    [[TMP10:%.*]] = sub nsw <4 x i16> [[TMP7]], [[TMP8]]
+; CHECK-NEXT:    [[TMP11:%.*]] = shufflevector <4 x i16> [[TMP9]], <4 x i16> [[TMP10]], <4 x i32> <i32 1, i32 4, i32 3, i32 6>
+; CHECK-NEXT:    [[TMP12:%.*]] = add nsw <4 x i16> zeroinitializer, [[TMP11]]
+; CHECK-NEXT:    [[TMP13:%.*]] = sub nsw <4 x i16> zeroinitializer, [[TMP11]]
+; CHECK-NEXT:    [[TMP14:%.*]] = shufflevector <4 x i16> [[TMP12]], <4 x i16> [[TMP13]], <4 x i32> <i32 0, i32 1, i32 6, i32 7>
+; CHECK-NEXT:    [[TMP15:%.*]] = zext <4 x i16> [[TMP14]] to <4 x i32>
 ; CHECK-NEXT:    store <4 x i32> [[TMP15]], ptr [[TMP2]], align 16
 ; CHECK-NEXT:    ret void
 ;

Created using spr 1.3.5
Created using spr 1.3.5
if (AmtKnownBits.getMaxValue().uge(BitWidth) ||
ShiftedBits >=
ComputeNumSignBits(I->getOperand(0), *DL, 0, AC, nullptr, DT))
return IsProfitableToDemote && IsPotentiallyTruncated(I, BitWidth);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this section of each shift instruction type was pulled out into a small helper, then the 3 opcodes could share the rest of the code and we'd get rid of a lot of duplication - is it worth it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will try to reduce this code

Created using spr 1.3.5
Copy link

github-actions bot commented Mar 14, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

Created using spr 1.3.5
Created using spr 1.3.5
@alexey-bataev
Copy link
Member Author

Ping!

Copy link
Collaborator

@RKSimon RKSimon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@alexey-bataev alexey-bataev merged commit 6c1d445 into main Mar 20, 2024
@alexey-bataev alexey-bataev deleted the users/alexey-bataev/spr/slpimprove-minbitwidth-analysis-for-shifts branch March 20, 2024 13:07
chencha3 pushed a commit to chencha3/llvm-project that referenced this pull request Mar 23, 2024
Adds improved bitwidth analysis for shl/ashr/lshr instructions. The
analysis is based on similar version in InstCombiner.

Reviewers: RKSimon

Reviewed By: RKSimon

Pull Request: llvm#84356
@mikaelholmen
Copy link
Collaborator

mikaelholmen commented Mar 26, 2024

Hi @alexey-bataev ,

I think we get a miscompile with this patch
Reproduce with:

opt bbi-93743_2.ll -mtriple=aarch64 -passes=slp-vectorizer -S -o - -slp-threshold=-100

The problem occurs if both inputs to foo are 0xffffffffffffffff.

Then the

  %5 = shufflevector <2 x i128> %3, <2 x i128> %4, <2 x i32> <i32 0, i32 3>

in the slp-vectorizer output will result in

<0x10000000000000000,poison>

due to the "nsw" on the "shl", and that poison then turns the return value from the function to be poison as well.

bbi-93743_2.ll.gz

@alexey-bataev
Copy link
Member Author

Thanks, will fix it ASAP

@alexey-bataev
Copy link
Member Author

alexey-bataev commented Mar 26, 2024

Hi @alexey-bataev ,

I think we get a miscompile with this patch Reproduce with:

opt bbi-93743_2.ll -mtriple=aarch64 -passes=slp-vectorizer -S -o - -slp-threshold=-100

The problem occurs if both inputs to foo are 0xffffffffffffffff.

Then the

  %5 = shufflevector <2 x i128> %3, <2 x i128> %4, <2 x i32> <i32 0, i32 3>

in the slp-vectorizer output will result in

<0x10000000000000000,poison>

due to the "nsw" on the "shl", and that poison then turns the return value from the function to be poison as well.

bbi-93743_2.ll.gz

Ok, it is not an issue of this patch, but a long-standing issue in vectorizer. I have a fix for this in one of the extra patches, will commit it separately with your reproducer.

@alexey-bataev
Copy link
Member Author

Hi @alexey-bataev ,

I think we get a miscompile with this patch Reproduce with:

opt bbi-93743_2.ll -mtriple=aarch64 -passes=slp-vectorizer -S -o - -slp-threshold=-100

The problem occurs if both inputs to foo are 0xffffffffffffffff.

Then the

  %5 = shufflevector <2 x i128> %3, <2 x i128> %4, <2 x i32> <i32 0, i32 3>

in the slp-vectorizer output will result in

<0x10000000000000000,poison>

due to the "nsw" on the "shl", and that poison then turns the return value from the function to be poison as well.

bbi-93743_2.ll.gz

Hi @alexey-bataev ,

I think we get a miscompile with this patch Reproduce with:

opt bbi-93743_2.ll -mtriple=aarch64 -passes=slp-vectorizer -S -o - -slp-threshold=-100

The problem occurs if both inputs to foo are 0xffffffffffffffff.

Then the

  %5 = shufflevector <2 x i128> %3, <2 x i128> %4, <2 x i32> <i32 0, i32 3>

in the slp-vectorizer output will result in

<0x10000000000000000,poison>

due to the "nsw" on the "shl", and that poison then turns the return value from the function to be poison as well.

bbi-93743_2.ll.gz

Fixed in 26dd128

@mikaelholmen
Copy link
Collaborator

Fixed in 26dd128

I verified that the fix solves the problem I saw. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants