Skip to content

Commit d81db0e

Browse files
committed
[KnownBits] Implement knownbits lshr/ashr with exact flag
The exact flag basically allows us to set an upper bound on shift amount when we have a known 1 in `LHS`. Typically we deduce exact using knownbits (on non-exact incoming shifts), so this is particularly impactful, but may be useful in some circumstances. Closes #84254
1 parent a9d913e commit d81db0e

File tree

3 files changed

+54
-8
lines changed

3 files changed

+54
-8
lines changed

llvm/lib/Support/KnownBits.cpp

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW,
343343
}
344344

345345
KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS,
346-
bool ShAmtNonZero, bool /*Exact*/) {
346+
bool ShAmtNonZero, bool Exact) {
347347
unsigned BitWidth = LHS.getBitWidth();
348348
auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) {
349349
KnownBits Known = LHS;
@@ -367,6 +367,18 @@ KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS,
367367
// Find the common bits from all possible shifts.
368368
APInt MaxValue = RHS.getMaxValue();
369369
unsigned MaxShiftAmount = getMaxShiftAmount(MaxValue, BitWidth);
370+
371+
// If exact, bound MaxShiftAmount to first known 1 in LHS.
372+
if (Exact) {
373+
unsigned FirstOne = LHS.countMaxTrailingZeros();
374+
if (FirstOne < MinShiftAmount) {
375+
// Always poison. Return zero because we don't like returning conflict.
376+
Known.setAllZero();
377+
return Known;
378+
}
379+
MaxShiftAmount = std::min(MaxShiftAmount, FirstOne);
380+
}
381+
370382
unsigned ShiftAmtZeroMask = RHS.Zero.zextOrTrunc(32).getZExtValue();
371383
unsigned ShiftAmtOneMask = RHS.One.zextOrTrunc(32).getZExtValue();
372384
Known.Zero.setAllBits();
@@ -389,7 +401,7 @@ KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS,
389401
}
390402

391403
KnownBits KnownBits::ashr(const KnownBits &LHS, const KnownBits &RHS,
392-
bool ShAmtNonZero, bool /*Exact*/) {
404+
bool ShAmtNonZero, bool Exact) {
393405
unsigned BitWidth = LHS.getBitWidth();
394406
auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) {
395407
KnownBits Known = LHS;
@@ -415,6 +427,18 @@ KnownBits KnownBits::ashr(const KnownBits &LHS, const KnownBits &RHS,
415427
// Find the common bits from all possible shifts.
416428
APInt MaxValue = RHS.getMaxValue();
417429
unsigned MaxShiftAmount = getMaxShiftAmount(MaxValue, BitWidth);
430+
431+
// If exact, bound MaxShiftAmount to first known 1 in LHS.
432+
if (Exact) {
433+
unsigned FirstOne = LHS.countMaxTrailingZeros();
434+
if (FirstOne < MinShiftAmount) {
435+
// Always poison. Return zero because we don't like returning conflict.
436+
Known.setAllZero();
437+
return Known;
438+
}
439+
MaxShiftAmount = std::min(MaxShiftAmount, FirstOne);
440+
}
441+
418442
unsigned ShiftAmtZeroMask = RHS.Zero.zextOrTrunc(32).getZExtValue();
419443
unsigned ShiftAmtOneMask = RHS.One.zextOrTrunc(32).getZExtValue();
420444
Known.Zero.setAllBits();

llvm/test/Analysis/ValueTracking/knownbits-shift.ll

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,7 @@
33

44
define i8 @simplify_lshr_with_exact(i8 %x) {
55
; CHECK-LABEL: @simplify_lshr_with_exact(
6-
; CHECK-NEXT: [[SHR:%.*]] = lshr exact i8 6, [[X:%.*]]
7-
; CHECK-NEXT: [[R:%.*]] = and i8 [[SHR]], 2
8-
; CHECK-NEXT: ret i8 [[R]]
6+
; CHECK-NEXT: ret i8 2
97
;
108
%shr = lshr exact i8 6, %x
119
%r = and i8 %shr, 2
@@ -14,9 +12,7 @@ define i8 @simplify_lshr_with_exact(i8 %x) {
1412

1513
define i8 @simplify_ashr_with_exact(i8 %x) {
1614
; CHECK-LABEL: @simplify_ashr_with_exact(
17-
; CHECK-NEXT: [[SHR:%.*]] = ashr exact i8 -122, [[X:%.*]]
18-
; CHECK-NEXT: [[R:%.*]] = and i8 [[SHR]], 2
19-
; CHECK-NEXT: ret i8 [[R]]
15+
; CHECK-NEXT: ret i8 2
2016
;
2117
%shr = ashr exact i8 -122, %x
2218
%r = and i8 %shr, 2

llvm/unittests/Support/KnownBitsTest.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,19 @@ TEST(KnownBitsTest, BinaryExhaustive) {
516516
return N1.lshr(N2);
517517
},
518518
checkOptimalityBinary, /* RefinePoisonToZero */ true);
519+
testBinaryOpExhaustive(
520+
[](const KnownBits &Known1, const KnownBits &Known2) {
521+
return KnownBits::lshr(Known1, Known2, /*ShAmtNonZero=*/false,
522+
/*Exact=*/true);
523+
},
524+
[](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
525+
if (N2.uge(N2.getBitWidth()))
526+
return std::nullopt;
527+
if (!N1.extractBits(N2.getZExtValue(), 0).isZero())
528+
return std::nullopt;
529+
return N1.lshr(N2);
530+
},
531+
checkOptimalityBinary, /* RefinePoisonToZero */ true);
519532
testBinaryOpExhaustive(
520533
[](const KnownBits &Known1, const KnownBits &Known2) {
521534
return KnownBits::ashr(Known1, Known2);
@@ -526,6 +539,19 @@ TEST(KnownBitsTest, BinaryExhaustive) {
526539
return N1.ashr(N2);
527540
},
528541
checkOptimalityBinary, /* RefinePoisonToZero */ true);
542+
testBinaryOpExhaustive(
543+
[](const KnownBits &Known1, const KnownBits &Known2) {
544+
return KnownBits::ashr(Known1, Known2, /*ShAmtNonZero=*/false,
545+
/*Exact=*/true);
546+
},
547+
[](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
548+
if (N2.uge(N2.getBitWidth()))
549+
return std::nullopt;
550+
if (!N1.extractBits(N2.getZExtValue(), 0).isZero())
551+
return std::nullopt;
552+
return N1.ashr(N2);
553+
},
554+
checkOptimalityBinary, /* RefinePoisonToZero */ true);
529555

530556
testBinaryOpExhaustive(
531557
[](const KnownBits &Known1, const KnownBits &Known2) {

0 commit comments

Comments
 (0)