-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[SLP]Improve minbitwidth analysis for shifts. #84356
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SLP]Improve minbitwidth analysis for shifts. #84356
Conversation
Created using spr 1.3.5
@llvm/pr-subscribers-llvm-transforms Author: Alexey Bataev (alexey-bataev) ChangesAdds improved bitwidth analysis for shl/ashr/lshr instructions. The Full diff: https://github.com/llvm/llvm-project/pull/84356.diff 3 Files Affected:
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 1889bc09e85028..3364f34d0148cc 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -10094,16 +10094,6 @@ BoUpSLP::isGatherShuffledSingleRegisterEntry(
BitWidth = UserIt->second.second;
}
}
- auto CheckBitwidth = [&](const TreeEntry &TE) {
- Type *ScalarTy = TE.Scalars.front()->getType();
- if (!ScalarTy->isIntegerTy())
- return true;
- unsigned TEBitWidth = DL->getTypeStoreSize(ScalarTy);
- auto UserIt = MinBWs.find(TEUseEI.UserTE);
- if (UserIt != MinBWs.end())
- TEBitWidth = UserIt->second.second;
- return BitWidth == TEBitWidth;
- };
SmallVector<SmallPtrSet<const TreeEntry *, 4>> UsedTEs;
DenseMap<Value *, int> UsedValuesEntry;
for (Value *V : VL) {
@@ -10138,8 +10128,6 @@ BoUpSLP::isGatherShuffledSingleRegisterEntry(
continue;
}
- if (!CheckBitwidth(*TEPtr))
- continue;
// Check if the user node of the TE comes after user node of TEPtr,
// otherwise TEPtr depends on TE.
if ((TEInsertBlock != InsertPt->getParent() ||
@@ -10157,7 +10145,7 @@ BoUpSLP::isGatherShuffledSingleRegisterEntry(
VTE = *It->getSecond().begin();
// Iterate through all vectorized nodes.
auto *MIt = find_if(It->getSecond(), [&](const TreeEntry *MTE) {
- return MTE->State == TreeEntry::Vectorize && CheckBitwidth(*MTE);
+ return MTE->State == TreeEntry::Vectorize;
});
if (MIt == It->getSecond().end())
continue;
@@ -10167,8 +10155,6 @@ BoUpSLP::isGatherShuffledSingleRegisterEntry(
Instruction &LastBundleInst = getLastInstructionInBundle(VTE);
if (&LastBundleInst == TEInsertPt || !CheckOrdering(&LastBundleInst))
continue;
- if (!CheckBitwidth(*VTE))
- continue;
VToTEs.insert(VTE);
}
if (VToTEs.empty())
@@ -10216,6 +10202,45 @@ BoUpSLP::isGatherShuffledSingleRegisterEntry(
return std::nullopt;
}
+ if (BitWidth > 0) {
+ // Check if the used TEs supposed to be resized and choose the best
+ // candidates.
+ unsigned NodesBitWidth = 0;
+ auto CheckBitwidth = [&](const TreeEntry &TE) {
+ unsigned TEBitWidth = BitWidth;
+ auto UserIt = MinBWs.find(TEUseEI.UserTE);
+ if (UserIt != MinBWs.end())
+ TEBitWidth = UserIt->second.second;
+ if (BitWidth <= TEBitWidth) {
+ if (NodesBitWidth == 0)
+ NodesBitWidth = TEBitWidth;
+ return NodesBitWidth == TEBitWidth;
+ }
+ return false;
+ };
+ for (auto [Idx, Set] : enumerate(UsedTEs)) {
+ DenseSet<const TreeEntry *> ForRemoval;
+ for (const TreeEntry *TE : Set) {
+ if (!CheckBitwidth(*TE))
+ ForRemoval.insert(TE);
+ }
+ // All elements must be removed - remove the whole container.
+ if (ForRemoval.size() == Set.size()) {
+ Set.clear();
+ continue;
+ }
+ for (const TreeEntry *TE : ForRemoval)
+ Set.erase(TE);
+ }
+ for (auto *It = UsedTEs.begin(); It != UsedTEs.end();) {
+ if (It->empty()) {
+ UsedTEs.erase(It);
+ continue;
+ }
+ std::advance(It, 1);
+ }
+ }
+
unsigned VF = 0;
if (UsedTEs.size() == 1) {
// Keep the order to avoid non-determinism.
@@ -13946,6 +13971,63 @@ bool BoUpSLP::collectValuesToDemote(
MaxDepthLevel = std::max(Level1, Level2);
break;
}
+ case Instruction::Shl: {
+ // If we are truncating the result of this SHL, and if it's a shift of an
+ // inrange amount, we can always perform a SHL in a smaller type.
+ unsigned Level1, Level2;
+ KnownBits AmtKnownBits = computeKnownBits(I->getOperand(1), *DL);
+ if (AmtKnownBits.getMaxValue().uge(BitWidth) ||
+ !collectValuesToDemote(I->getOperand(0), IsProfitableToDemoteRoot,
+ BitWidth, ToDemote, DemotedConsts, Visited,
+ Level1, IsProfitableToDemote) ||
+ !collectValuesToDemote(I->getOperand(1), IsProfitableToDemoteRoot,
+ BitWidth, ToDemote, DemotedConsts, Visited,
+ Level2, IsProfitableToDemote))
+ return false;
+ MaxDepthLevel = std::max(Level1, Level2);
+ break;
+ }
+ case Instruction::LShr: {
+ // If this is a truncate of a logical shr, we can truncate it to a smaller
+ // lshr iff we know that the bits we would otherwise be shifting in are
+ // already zeros.
+ uint32_t OrigBitWidth = DL->getTypeSizeInBits(V->getType());
+ unsigned Level1, Level2;
+ KnownBits AmtKnownBits = computeKnownBits(I->getOperand(1), *DL);
+ APInt ShiftedBits = APInt::getBitsSetFrom(OrigBitWidth, BitWidth);
+ if (AmtKnownBits.getMaxValue().uge(BitWidth) ||
+ !MaskedValueIsZero(I->getOperand(0), ShiftedBits, SimplifyQuery(*DL)) ||
+ !collectValuesToDemote(I->getOperand(0), IsProfitableToDemoteRoot,
+ BitWidth, ToDemote, DemotedConsts, Visited,
+ Level1, IsProfitableToDemote) ||
+ !collectValuesToDemote(I->getOperand(1), IsProfitableToDemoteRoot,
+ BitWidth, ToDemote, DemotedConsts, Visited,
+ Level2, IsProfitableToDemote))
+ return false;
+ MaxDepthLevel = std::max(Level1, Level2);
+ break;
+ }
+ case Instruction::AShr: {
+ // If this is a truncate of an arithmetic shr, we can truncate it to a
+ // smaller ashr iff we know that all the bits from the sign bit of the
+ // original type and the sign bit of the truncate type are similar.
+ uint32_t OrigBitWidth = DL->getTypeSizeInBits(V->getType());
+ unsigned Level1, Level2;
+ KnownBits AmtKnownBits = computeKnownBits(I->getOperand(1), *DL);
+ unsigned ShiftedBits = OrigBitWidth - BitWidth;
+ if (AmtKnownBits.getMaxValue().uge(BitWidth) ||
+ ShiftedBits >=
+ ComputeNumSignBits(I->getOperand(0), *DL, 0, AC, nullptr, DT) ||
+ !collectValuesToDemote(I->getOperand(0), IsProfitableToDemoteRoot,
+ BitWidth, ToDemote, DemotedConsts, Visited,
+ Level1, IsProfitableToDemote) ||
+ !collectValuesToDemote(I->getOperand(1), IsProfitableToDemoteRoot,
+ BitWidth, ToDemote, DemotedConsts, Visited,
+ Level2, IsProfitableToDemote))
+ return false;
+ MaxDepthLevel = std::max(Level1, Level2);
+ break;
+ }
// We can demote selects if we can demote their true and false values.
case Instruction::Select: {
diff --git a/llvm/test/Transforms/SLPVectorizer/X86/reorder-possible-strided-node.ll b/llvm/test/Transforms/SLPVectorizer/X86/reorder-possible-strided-node.ll
index 6f5d3d3785e0c8..6378f696b470d4 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/reorder-possible-strided-node.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/reorder-possible-strided-node.ll
@@ -10,10 +10,8 @@ define void @test() {
; CHECK-NEXT: [[TMP1:%.*]] = load <4 x i32>, ptr [[ARRAYIDX22]], align 4
; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <4 x i32> [[TMP1]], <4 x i32> poison, <4 x i32> <i32 3, i32 2, i32 1, i32 0>
; CHECK-NEXT: [[TMP3:%.*]] = mul <4 x i32> [[TMP2]], [[TMP0]]
-; CHECK-NEXT: [[TMP4:%.*]] = sext <4 x i32> [[TMP3]] to <4 x i64>
-; CHECK-NEXT: [[TMP5:%.*]] = ashr <4 x i64> [[TMP4]], zeroinitializer
-; CHECK-NEXT: [[TMP6:%.*]] = trunc <4 x i64> [[TMP5]] to <4 x i32>
-; CHECK-NEXT: store <4 x i32> [[TMP6]], ptr getelementptr inbounds ([4 x i32], ptr null, i64 8, i64 0), align 16
+; CHECK-NEXT: [[TMP4:%.*]] = ashr <4 x i32> [[TMP3]], zeroinitializer
+; CHECK-NEXT: store <4 x i32> [[TMP4]], ptr getelementptr inbounds ([4 x i32], ptr null, i64 8, i64 0), align 16
; CHECK-NEXT: ret void
;
entry:
diff --git a/llvm/test/Transforms/SLPVectorizer/X86/reorder_diamond_match.ll b/llvm/test/Transforms/SLPVectorizer/X86/reorder_diamond_match.ll
index 86b1e1a801e32f..91ee4dba07009f 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/reorder_diamond_match.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/reorder_diamond_match.ll
@@ -5,18 +5,19 @@ define void @test() {
; CHECK-LABEL: @test(
; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds i8, ptr undef, i64 4
; CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds [4 x [4 x i32]], ptr undef, i64 0, i64 1, i64 0
-; CHECK-NEXT: [[TMP4:%.*]] = load <4 x i8>, ptr [[TMP1]], align 1
-; CHECK-NEXT: [[TMP5:%.*]] = zext <4 x i8> [[TMP4]] to <4 x i32>
-; CHECK-NEXT: [[TMP6:%.*]] = sub nsw <4 x i32> zeroinitializer, [[TMP5]]
-; CHECK-NEXT: [[TMP7:%.*]] = shl nsw <4 x i32> [[TMP6]], zeroinitializer
-; CHECK-NEXT: [[TMP8:%.*]] = add nsw <4 x i32> [[TMP7]], zeroinitializer
-; CHECK-NEXT: [[TMP9:%.*]] = shufflevector <4 x i32> [[TMP8]], <4 x i32> poison, <4 x i32> <i32 1, i32 0, i32 3, i32 2>
-; CHECK-NEXT: [[TMP10:%.*]] = add nsw <4 x i32> [[TMP8]], [[TMP9]]
-; CHECK-NEXT: [[TMP11:%.*]] = sub nsw <4 x i32> [[TMP8]], [[TMP9]]
-; CHECK-NEXT: [[TMP12:%.*]] = shufflevector <4 x i32> [[TMP10]], <4 x i32> [[TMP11]], <4 x i32> <i32 1, i32 4, i32 3, i32 6>
-; CHECK-NEXT: [[TMP13:%.*]] = add nsw <4 x i32> zeroinitializer, [[TMP12]]
-; CHECK-NEXT: [[TMP14:%.*]] = sub nsw <4 x i32> zeroinitializer, [[TMP12]]
-; CHECK-NEXT: [[TMP15:%.*]] = shufflevector <4 x i32> [[TMP13]], <4 x i32> [[TMP14]], <4 x i32> <i32 0, i32 1, i32 6, i32 7>
+; CHECK-NEXT: [[TMP3:%.*]] = load <4 x i8>, ptr [[TMP1]], align 1
+; CHECK-NEXT: [[TMP4:%.*]] = zext <4 x i8> [[TMP3]] to <4 x i16>
+; CHECK-NEXT: [[TMP5:%.*]] = sub <4 x i16> zeroinitializer, [[TMP4]]
+; CHECK-NEXT: [[TMP6:%.*]] = shl <4 x i16> [[TMP5]], zeroinitializer
+; CHECK-NEXT: [[TMP7:%.*]] = add <4 x i16> [[TMP6]], zeroinitializer
+; CHECK-NEXT: [[TMP8:%.*]] = shufflevector <4 x i16> [[TMP7]], <4 x i16> poison, <4 x i32> <i32 1, i32 0, i32 3, i32 2>
+; CHECK-NEXT: [[TMP9:%.*]] = add nsw <4 x i16> [[TMP7]], [[TMP8]]
+; CHECK-NEXT: [[TMP10:%.*]] = sub nsw <4 x i16> [[TMP7]], [[TMP8]]
+; CHECK-NEXT: [[TMP11:%.*]] = shufflevector <4 x i16> [[TMP9]], <4 x i16> [[TMP10]], <4 x i32> <i32 1, i32 4, i32 3, i32 6>
+; CHECK-NEXT: [[TMP12:%.*]] = add nsw <4 x i16> zeroinitializer, [[TMP11]]
+; CHECK-NEXT: [[TMP13:%.*]] = sub nsw <4 x i16> zeroinitializer, [[TMP11]]
+; CHECK-NEXT: [[TMP14:%.*]] = shufflevector <4 x i16> [[TMP12]], <4 x i16> [[TMP13]], <4 x i32> <i32 0, i32 1, i32 6, i32 7>
+; CHECK-NEXT: [[TMP15:%.*]] = zext <4 x i16> [[TMP14]] to <4 x i32>
; CHECK-NEXT: store <4 x i32> [[TMP15]], ptr [[TMP2]], align 16
; CHECK-NEXT: ret void
;
|
if (AmtKnownBits.getMaxValue().uge(BitWidth) || | ||
ShiftedBits >= | ||
ComputeNumSignBits(I->getOperand(0), *DL, 0, AC, nullptr, DT)) | ||
return IsProfitableToDemote && IsPotentiallyTruncated(I, BitWidth); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this section of each shift instruction type was pulled out into a small helper, then the 3 opcodes could share the rest of the code and we'd get rid of a lot of duplication - is it worth it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will try to reduce this code
Created using spr 1.3.5
✅ With the latest revision this PR passed the C/C++ code formatter. |
Created using spr 1.3.5
Ping! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Adds improved bitwidth analysis for shl/ashr/lshr instructions. The analysis is based on similar version in InstCombiner. Reviewers: RKSimon Reviewed By: RKSimon Pull Request: llvm#84356
Hi @alexey-bataev , I think we get a miscompile with this patch
The problem occurs if both inputs to foo are 0xffffffffffffffff. Then the
in the slp-vectorizer output will result in
due to the "nsw" on the "shl", and that poison then turns the return value from the function to be poison as well. |
Thanks, will fix it ASAP |
Ok, it is not an issue of this patch, but a long-standing issue in vectorizer. I have a fix for this in one of the extra patches, will commit it separately with your reproducer. |
Fixed in 26dd128 |
I verified that the fix solves the problem I saw. Thanks! |
Adds improved bitwidth analysis for shl/ashr/lshr instructions. The
analysis is based on similar version in InstCombiner.