Skip to content

[KnownBits] Implement knownbits lshr/ashr with exact flag #84254

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions llvm/include/llvm/Support/KnownBits.h
Original file line number Diff line number Diff line change
Expand Up @@ -402,12 +402,12 @@ struct KnownBits {
/// Compute known bits for lshr(LHS, RHS).
/// NOTE: RHS (shift amount) bitwidth doesn't need to be the same as LHS.
static KnownBits lshr(const KnownBits &LHS, const KnownBits &RHS,
bool ShAmtNonZero = false);
bool ShAmtNonZero = false, bool Exact = false);

/// Compute known bits for ashr(LHS, RHS).
/// NOTE: RHS (shift amount) bitwidth doesn't need to be the same as LHS.
static KnownBits ashr(const KnownBits &LHS, const KnownBits &RHS,
bool ShAmtNonZero = false);
bool ShAmtNonZero = false, bool Exact = false);

/// Determine if these known bits always give the same ICMP_EQ result.
static std::optional<bool> eq(const KnownBits &LHS, const KnownBits &RHS);
Expand Down
14 changes: 8 additions & 6 deletions llvm/lib/Analysis/ValueTracking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1142,9 +1142,10 @@ static void computeKnownBitsFromOperator(const Operator *I,
break;
}
case Instruction::LShr: {
auto KF = [](const KnownBits &KnownVal, const KnownBits &KnownAmt,
bool ShAmtNonZero) {
return KnownBits::lshr(KnownVal, KnownAmt, ShAmtNonZero);
bool Exact = Q.IIQ.isExact(cast<BinaryOperator>(I));
auto KF = [Exact](const KnownBits &KnownVal, const KnownBits &KnownAmt,
bool ShAmtNonZero) {
return KnownBits::lshr(KnownVal, KnownAmt, ShAmtNonZero, Exact);
};
computeKnownBitsFromShiftOperator(I, DemandedElts, Known, Known2, Depth, Q,
KF);
Expand All @@ -1155,9 +1156,10 @@ static void computeKnownBitsFromOperator(const Operator *I,
break;
}
case Instruction::AShr: {
auto KF = [](const KnownBits &KnownVal, const KnownBits &KnownAmt,
bool ShAmtNonZero) {
return KnownBits::ashr(KnownVal, KnownAmt, ShAmtNonZero);
bool Exact = Q.IIQ.isExact(cast<BinaryOperator>(I));
auto KF = [Exact](const KnownBits &KnownVal, const KnownBits &KnownAmt,
bool ShAmtNonZero) {
return KnownBits::ashr(KnownVal, KnownAmt, ShAmtNonZero, Exact);
};
computeKnownBitsFromShiftOperator(I, DemandedElts, Known, Known2, Depth, Q,
KF);
Expand Down
6 changes: 4 additions & 2 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3485,7 +3485,8 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
case ISD::SRL:
Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
Known2 = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
Known = KnownBits::lshr(Known, Known2);
Known = KnownBits::lshr(Known, Known2, /*ShAmtNonZero=*/false,
Op->getFlags().hasExact());

// Minimum shift high bits are known zero.
if (const APInt *ShMinAmt =
Expand All @@ -3495,7 +3496,8 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
case ISD::SRA:
Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
Known2 = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
Known = KnownBits::ashr(Known, Known2);
Known = KnownBits::ashr(Known, Known2, /*ShAmtNonZero=*/false,
Op->getFlags().hasExact());
break;
case ISD::FSHL:
case ISD::FSHR:
Expand Down
28 changes: 26 additions & 2 deletions llvm/lib/Support/KnownBits.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW,
}

KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS,
bool ShAmtNonZero) {
bool ShAmtNonZero, bool Exact) {
unsigned BitWidth = LHS.getBitWidth();
auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) {
KnownBits Known = LHS;
Expand All @@ -367,6 +367,18 @@ KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS,
// Find the common bits from all possible shifts.
APInt MaxValue = RHS.getMaxValue();
unsigned MaxShiftAmount = getMaxShiftAmount(MaxValue, BitWidth);

// If exact, bound MaxShiftAmount to first known 1 in LHS.
if (Exact) {
unsigned FirstOne = LHS.countMaxTrailingZeros();
if (FirstOne < MinShiftAmount) {
// Always poison. Return zero because we don't like returning conflict.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not your fault, but I really don't like that we work so hard everywhere to avoid returning conflict. We should embrace conflict!

Known.setAllZero();
return Known;
}
MaxShiftAmount = std::min(MaxShiftAmount, FirstOne);
}

unsigned ShiftAmtZeroMask = RHS.Zero.zextOrTrunc(32).getZExtValue();
unsigned ShiftAmtOneMask = RHS.One.zextOrTrunc(32).getZExtValue();
Known.Zero.setAllBits();
Expand All @@ -389,7 +401,7 @@ KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS,
}

KnownBits KnownBits::ashr(const KnownBits &LHS, const KnownBits &RHS,
bool ShAmtNonZero) {
bool ShAmtNonZero, bool Exact) {
unsigned BitWidth = LHS.getBitWidth();
auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) {
KnownBits Known = LHS;
Expand All @@ -415,6 +427,18 @@ KnownBits KnownBits::ashr(const KnownBits &LHS, const KnownBits &RHS,
// Find the common bits from all possible shifts.
APInt MaxValue = RHS.getMaxValue();
unsigned MaxShiftAmount = getMaxShiftAmount(MaxValue, BitWidth);

// If exact, bound MaxShiftAmount to first known 1 in LHS.
if (Exact) {
unsigned FirstOne = LHS.countMaxTrailingZeros();
if (FirstOne < MinShiftAmount) {
// Always poison. Return zero because we don't like returning conflict.
Known.setAllZero();
return Known;
}
MaxShiftAmount = std::min(MaxShiftAmount, FirstOne);
}

unsigned ShiftAmtZeroMask = RHS.Zero.zextOrTrunc(32).getZExtValue();
unsigned ShiftAmtOneMask = RHS.One.zextOrTrunc(32).getZExtValue();
Known.Zero.setAllBits();
Expand Down
20 changes: 20 additions & 0 deletions llvm/test/Analysis/ValueTracking/knownbits-shift.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
; RUN: opt -passes=instcombine -S < %s | FileCheck %s

define i8 @simplify_lshr_with_exact(i8 %x) {
; CHECK-LABEL: @simplify_lshr_with_exact(
; CHECK-NEXT: ret i8 2
;
%shr = lshr exact i8 6, %x
%r = and i8 %shr, 2
ret i8 %r
}

define i8 @simplify_ashr_with_exact(i8 %x) {
; CHECK-LABEL: @simplify_ashr_with_exact(
; CHECK-NEXT: ret i8 2
;
%shr = ashr exact i8 -122, %x
%r = and i8 %shr, 2
ret i8 %r
}
26 changes: 26 additions & 0 deletions llvm/unittests/Support/KnownBitsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,19 @@ TEST(KnownBitsTest, BinaryExhaustive) {
return N1.lshr(N2);
},
checkOptimalityBinary, /* RefinePoisonToZero */ true);
testBinaryOpExhaustive(
[](const KnownBits &Known1, const KnownBits &Known2) {
return KnownBits::lshr(Known1, Known2, /*ShAmtNonZero=*/false,
/*Exact=*/true);
},
[](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
if (N2.uge(N2.getBitWidth()))
return std::nullopt;
if (!N1.extractBits(N2.getZExtValue(), 0).isZero())
return std::nullopt;
return N1.lshr(N2);
},
checkOptimalityBinary, /* RefinePoisonToZero */ true);
testBinaryOpExhaustive(
[](const KnownBits &Known1, const KnownBits &Known2) {
return KnownBits::ashr(Known1, Known2);
Expand All @@ -526,6 +539,19 @@ TEST(KnownBitsTest, BinaryExhaustive) {
return N1.ashr(N2);
},
checkOptimalityBinary, /* RefinePoisonToZero */ true);
testBinaryOpExhaustive(
[](const KnownBits &Known1, const KnownBits &Known2) {
return KnownBits::ashr(Known1, Known2, /*ShAmtNonZero=*/false,
/*Exact=*/true);
},
[](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
if (N2.uge(N2.getBitWidth()))
return std::nullopt;
if (!N1.extractBits(N2.getZExtValue(), 0).isZero())
return std::nullopt;
return N1.ashr(N2);
},
checkOptimalityBinary, /* RefinePoisonToZero */ true);

testBinaryOpExhaustive(
[](const KnownBits &Known1, const KnownBits &Known2) {
Expand Down