Skip to content

[SLP]Improve minbitwidth analysis for shifts. #84356

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 94 additions & 3 deletions llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13991,9 +13991,11 @@ bool BoUpSLP::collectValuesToDemote(
if (MultiNodeScalars.contains(V))
return false;
uint32_t OrigBitWidth = DL->getTypeSizeInBits(V->getType());
APInt Mask = APInt::getBitsSetFrom(OrigBitWidth, BitWidth);
if (MaskedValueIsZero(V, Mask, SimplifyQuery(*DL)))
return true;
if (OrigBitWidth < BitWidth) {
APInt Mask = APInt::getBitsSetFrom(OrigBitWidth, BitWidth);
if (MaskedValueIsZero(V, Mask, SimplifyQuery(*DL)))
return true;
}
auto NumSignBits = ComputeNumSignBits(V, *DL, 0, AC, nullptr, DT);
unsigned BitWidth1 = OrigBitWidth - NumSignBits;
if (!isKnownNonNegative(V, SimplifyQuery(*DL)))
Expand Down Expand Up @@ -14038,6 +14040,30 @@ bool BoUpSLP::collectValuesToDemote(
}
return true;
};
auto AttemptCheckBitwidth =
[&](function_ref<bool(unsigned, unsigned)> Checker, bool &NeedToExit) {
// Try all bitwidth < OrigBitWidth.
NeedToExit = false;
uint32_t OrigBitWidth = DL->getTypeSizeInBits(I->getType());
unsigned BestFailBitwidth = 0;
for (; BitWidth < OrigBitWidth; BitWidth *= 2) {
if (Checker(BitWidth, OrigBitWidth))
return true;
if (BestFailBitwidth == 0 && FinalAnalysis())
BestFailBitwidth = BitWidth;
}
if (BitWidth >= OrigBitWidth) {
if (BestFailBitwidth == 0) {
BitWidth = OrigBitWidth;
return false;
}
MaxDepthLevel = 1;
BitWidth = BestFailBitwidth;
NeedToExit = true;
return true;
}
return false;
};
bool NeedToExit = false;
switch (I->getOpcode()) {

Expand Down Expand Up @@ -14070,6 +14096,71 @@ bool BoUpSLP::collectValuesToDemote(
return false;
break;
}
case Instruction::Shl: {
// Several vectorized uses? Check if we can truncate it, otherwise - exit.
if (ITE->UserTreeIndices.size() > 1 && !IsPotentiallyTruncated(I, BitWidth))
return false;
// If we are truncating the result of this SHL, and if it's a shift of an
// inrange amount, we can always perform a SHL in a smaller type.
if (!AttemptCheckBitwidth(
[&](unsigned BitWidth, unsigned) {
KnownBits AmtKnownBits = computeKnownBits(I->getOperand(1), *DL);
return AmtKnownBits.getMaxValue().ult(BitWidth);
},
NeedToExit))
return false;
if (NeedToExit)
return true;
if (!ProcessOperands({I->getOperand(0), I->getOperand(1)}, NeedToExit))
return false;
break;
}
case Instruction::LShr: {
// Several vectorized uses? Check if we can truncate it, otherwise - exit.
if (ITE->UserTreeIndices.size() > 1 && !IsPotentiallyTruncated(I, BitWidth))
return false;
// If this is a truncate of a logical shr, we can truncate it to a smaller
// lshr iff we know that the bits we would otherwise be shifting in are
// already zeros.
if (!AttemptCheckBitwidth(
[&](unsigned BitWidth, unsigned OrigBitWidth) {
KnownBits AmtKnownBits = computeKnownBits(I->getOperand(1), *DL);
APInt ShiftedBits = APInt::getBitsSetFrom(OrigBitWidth, BitWidth);
return AmtKnownBits.getMaxValue().ult(BitWidth) &&
MaskedValueIsZero(I->getOperand(0), ShiftedBits,
SimplifyQuery(*DL));
},
NeedToExit))
return false;
if (NeedToExit)
return true;
if (!ProcessOperands({I->getOperand(0), I->getOperand(1)}, NeedToExit))
return false;
break;
}
case Instruction::AShr: {
// Several vectorized uses? Check if we can truncate it, otherwise - exit.
if (ITE->UserTreeIndices.size() > 1 && !IsPotentiallyTruncated(I, BitWidth))
return false;
// If this is a truncate of an arithmetic shr, we can truncate it to a
// smaller ashr iff we know that all the bits from the sign bit of the
// original type and the sign bit of the truncate type are similar.
if (!AttemptCheckBitwidth(
[&](unsigned BitWidth, unsigned OrigBitWidth) {
KnownBits AmtKnownBits = computeKnownBits(I->getOperand(1), *DL);
unsigned ShiftedBits = OrigBitWidth - BitWidth;
return AmtKnownBits.getMaxValue().ult(BitWidth) &&
ShiftedBits < ComputeNumSignBits(I->getOperand(0), *DL, 0,
AC, nullptr, DT);
},
NeedToExit))
return false;
if (NeedToExit)
return true;
if (!ProcessOperands({I->getOperand(0), I->getOperand(1)}, NeedToExit))
return false;
break;
}

// We can demote selects if we can demote their true and false values.
case Instruction::Select: {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,8 @@ define void @test() {
; CHECK-NEXT: [[TMP1:%.*]] = load <4 x i32>, ptr [[ARRAYIDX22]], align 4
; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <4 x i32> [[TMP1]], <4 x i32> poison, <4 x i32> <i32 3, i32 2, i32 1, i32 0>
; CHECK-NEXT: [[TMP3:%.*]] = mul <4 x i32> [[TMP2]], [[TMP0]]
; CHECK-NEXT: [[TMP4:%.*]] = sext <4 x i32> [[TMP3]] to <4 x i64>
; CHECK-NEXT: [[TMP5:%.*]] = ashr <4 x i64> [[TMP4]], zeroinitializer
; CHECK-NEXT: [[TMP6:%.*]] = trunc <4 x i64> [[TMP5]] to <4 x i32>
; CHECK-NEXT: store <4 x i32> [[TMP6]], ptr getelementptr inbounds ([4 x i32], ptr null, i64 8, i64 0), align 16
; CHECK-NEXT: [[TMP4:%.*]] = ashr <4 x i32> [[TMP3]], zeroinitializer
; CHECK-NEXT: store <4 x i32> [[TMP4]], ptr getelementptr inbounds ([4 x i32], ptr null, i64 8, i64 0), align 16
; CHECK-NEXT: ret void
;
entry:
Expand Down
25 changes: 13 additions & 12 deletions llvm/test/Transforms/SLPVectorizer/X86/reorder_diamond_match.ll
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,19 @@ define void @test() {
; CHECK-LABEL: @test(
; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds i8, ptr undef, i64 4
; CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds [4 x [4 x i32]], ptr undef, i64 0, i64 1, i64 0
; CHECK-NEXT: [[TMP4:%.*]] = load <4 x i8>, ptr [[TMP1]], align 1
; CHECK-NEXT: [[TMP5:%.*]] = zext <4 x i8> [[TMP4]] to <4 x i32>
; CHECK-NEXT: [[TMP6:%.*]] = sub nsw <4 x i32> zeroinitializer, [[TMP5]]
; CHECK-NEXT: [[TMP7:%.*]] = shl nsw <4 x i32> [[TMP6]], zeroinitializer
; CHECK-NEXT: [[TMP8:%.*]] = add nsw <4 x i32> [[TMP7]], zeroinitializer
; CHECK-NEXT: [[TMP9:%.*]] = shufflevector <4 x i32> [[TMP8]], <4 x i32> poison, <4 x i32> <i32 1, i32 0, i32 3, i32 2>
; CHECK-NEXT: [[TMP10:%.*]] = add nsw <4 x i32> [[TMP8]], [[TMP9]]
; CHECK-NEXT: [[TMP11:%.*]] = sub nsw <4 x i32> [[TMP8]], [[TMP9]]
; CHECK-NEXT: [[TMP12:%.*]] = shufflevector <4 x i32> [[TMP10]], <4 x i32> [[TMP11]], <4 x i32> <i32 1, i32 4, i32 3, i32 6>
; CHECK-NEXT: [[TMP13:%.*]] = add nsw <4 x i32> zeroinitializer, [[TMP12]]
; CHECK-NEXT: [[TMP14:%.*]] = sub nsw <4 x i32> zeroinitializer, [[TMP12]]
; CHECK-NEXT: [[TMP15:%.*]] = shufflevector <4 x i32> [[TMP13]], <4 x i32> [[TMP14]], <4 x i32> <i32 0, i32 1, i32 6, i32 7>
; CHECK-NEXT: [[TMP3:%.*]] = load <4 x i8>, ptr [[TMP1]], align 1
; CHECK-NEXT: [[TMP4:%.*]] = zext <4 x i8> [[TMP3]] to <4 x i16>
; CHECK-NEXT: [[TMP5:%.*]] = sub <4 x i16> zeroinitializer, [[TMP4]]
; CHECK-NEXT: [[TMP6:%.*]] = shl <4 x i16> [[TMP5]], zeroinitializer
; CHECK-NEXT: [[TMP7:%.*]] = add <4 x i16> [[TMP6]], zeroinitializer
; CHECK-NEXT: [[TMP8:%.*]] = shufflevector <4 x i16> [[TMP7]], <4 x i16> poison, <4 x i32> <i32 1, i32 0, i32 3, i32 2>
; CHECK-NEXT: [[TMP9:%.*]] = add nsw <4 x i16> [[TMP7]], [[TMP8]]
; CHECK-NEXT: [[TMP10:%.*]] = sub nsw <4 x i16> [[TMP7]], [[TMP8]]
; CHECK-NEXT: [[TMP11:%.*]] = shufflevector <4 x i16> [[TMP9]], <4 x i16> [[TMP10]], <4 x i32> <i32 1, i32 4, i32 3, i32 6>
; CHECK-NEXT: [[TMP12:%.*]] = add nsw <4 x i16> zeroinitializer, [[TMP11]]
; CHECK-NEXT: [[TMP13:%.*]] = sub nsw <4 x i16> zeroinitializer, [[TMP11]]
; CHECK-NEXT: [[TMP14:%.*]] = shufflevector <4 x i16> [[TMP12]], <4 x i16> [[TMP13]], <4 x i32> <i32 0, i32 1, i32 6, i32 7>
; CHECK-NEXT: [[TMP15:%.*]] = sext <4 x i16> [[TMP14]] to <4 x i32>
; CHECK-NEXT: store <4 x i32> [[TMP15]], ptr [[TMP2]], align 16
; CHECK-NEXT: ret void
;
Expand Down