Skip to content

Commit 5350e1b

Browse files
committed
[KnownBits] Implement accurate unsigned and signed max and min
Use the new implementation in ValueTracking, SelectionDAG and GlobalISel. Differential Revision: https://reviews.llvm.org/D87034
1 parent 04ea680 commit 5350e1b

File tree

7 files changed

+169
-90
lines changed

7 files changed

+169
-90
lines changed

llvm/include/llvm/Support/KnownBits.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,10 @@ struct KnownBits {
173173
One.extractBits(NumBits, BitPosition));
174174
}
175175

176+
/// Return KnownBits based on this, but updated given that the underlying
177+
/// value is known to be greater than or equal to Val.
178+
KnownBits makeGE(const APInt &Val) const;
179+
176180
/// Returns the minimum number of trailing zero bits.
177181
unsigned countMinTrailingZeros() const {
178182
return Zero.countTrailingOnes();
@@ -241,6 +245,18 @@ struct KnownBits {
241245
static KnownBits computeForAddSub(bool Add, bool NSW, const KnownBits &LHS,
242246
KnownBits RHS);
243247

248+
/// Compute known bits for umax(LHS, RHS).
249+
static KnownBits umax(const KnownBits &LHS, const KnownBits &RHS);
250+
251+
/// Compute known bits for umin(LHS, RHS).
252+
static KnownBits umin(const KnownBits &LHS, const KnownBits &RHS);
253+
254+
/// Compute known bits for smax(LHS, RHS).
255+
static KnownBits smax(const KnownBits &LHS, const KnownBits &RHS);
256+
257+
/// Compute known bits for smin(LHS, RHS).
258+
static KnownBits smin(const KnownBits &LHS, const KnownBits &RHS);
259+
244260
/// Insert the bits from a smaller known bits starting at bitPosition.
245261
void insertBits(const KnownBits &SubBits, unsigned BitPosition) {
246262
Zero.insertBits(SubBits.Zero, BitPosition);

llvm/lib/Analysis/ValueTracking.cpp

Lines changed: 26 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1212,59 +1212,41 @@ static void computeKnownBitsFromOperator(const Operator *I,
12121212
if (SelectPatternResult::isMinOrMax(SPF)) {
12131213
computeKnownBits(RHS, Known, Depth + 1, Q);
12141214
computeKnownBits(LHS, Known2, Depth + 1, Q);
1215-
} else {
1216-
computeKnownBits(I->getOperand(2), Known, Depth + 1, Q);
1217-
computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q);
1215+
switch (SPF) {
1216+
default:
1217+
llvm_unreachable("Unhandled select pattern flavor!");
1218+
case SPF_SMAX:
1219+
Known = KnownBits::smax(Known, Known2);
1220+
break;
1221+
case SPF_SMIN:
1222+
Known = KnownBits::smin(Known, Known2);
1223+
break;
1224+
case SPF_UMAX:
1225+
Known = KnownBits::umax(Known, Known2);
1226+
break;
1227+
case SPF_UMIN:
1228+
Known = KnownBits::umin(Known, Known2);
1229+
break;
1230+
}
1231+
break;
12181232
}
12191233

1220-
unsigned MaxHighOnes = 0;
1221-
unsigned MaxHighZeros = 0;
1222-
if (SPF == SPF_SMAX) {
1223-
// If both sides are negative, the result is negative.
1224-
if (Known.isNegative() && Known2.isNegative())
1225-
// We can derive a lower bound on the result by taking the max of the
1226-
// leading one bits.
1227-
MaxHighOnes =
1228-
std::max(Known.countMinLeadingOnes(), Known2.countMinLeadingOnes());
1229-
// If either side is non-negative, the result is non-negative.
1230-
else if (Known.isNonNegative() || Known2.isNonNegative())
1231-
MaxHighZeros = 1;
1232-
} else if (SPF == SPF_SMIN) {
1233-
// If both sides are non-negative, the result is non-negative.
1234-
if (Known.isNonNegative() && Known2.isNonNegative())
1235-
// We can derive an upper bound on the result by taking the max of the
1236-
// leading zero bits.
1237-
MaxHighZeros = std::max(Known.countMinLeadingZeros(),
1238-
Known2.countMinLeadingZeros());
1239-
// If either side is negative, the result is negative.
1240-
else if (Known.isNegative() || Known2.isNegative())
1241-
MaxHighOnes = 1;
1242-
} else if (SPF == SPF_UMAX) {
1243-
// We can derive a lower bound on the result by taking the max of the
1244-
// leading one bits.
1245-
MaxHighOnes =
1246-
std::max(Known.countMinLeadingOnes(), Known2.countMinLeadingOnes());
1247-
} else if (SPF == SPF_UMIN) {
1248-
// We can derive an upper bound on the result by taking the max of the
1249-
// leading zero bits.
1250-
MaxHighZeros =
1251-
std::max(Known.countMinLeadingZeros(), Known2.countMinLeadingZeros());
1252-
} else if (SPF == SPF_ABS) {
1234+
computeKnownBits(I->getOperand(2), Known, Depth + 1, Q);
1235+
computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q);
1236+
1237+
// Only known if known in both the LHS and RHS.
1238+
Known.One &= Known2.One;
1239+
Known.Zero &= Known2.Zero;
1240+
1241+
if (SPF == SPF_ABS) {
12531242
// RHS from matchSelectPattern returns the negation part of abs pattern.
12541243
// If the negate has an NSW flag we can assume the sign bit of the result
12551244
// will be 0 because that makes abs(INT_MIN) undefined.
12561245
if (match(RHS, m_Neg(m_Specific(LHS))) &&
12571246
Q.IIQ.hasNoSignedWrap(cast<Instruction>(RHS)))
1258-
MaxHighZeros = 1;
1247+
Known.Zero.setSignBit();
12591248
}
12601249

1261-
// Only known if known in both the LHS and RHS.
1262-
Known.One &= Known2.One;
1263-
Known.Zero &= Known2.Zero;
1264-
if (MaxHighOnes > 0)
1265-
Known.One.setHighBits(MaxHighOnes);
1266-
if (MaxHighZeros > 0)
1267-
Known.Zero.setHighBits(MaxHighZeros);
12681250
break;
12691251
}
12701252
case Instruction::FPTrunc:

llvm/lib/CodeGen/GlobalISel/GISelKnownBits.cpp

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -308,11 +308,24 @@ void GISelKnownBits::computeKnownBitsImpl(Register R, KnownBits &Known,
308308
Known, DemandedElts, Depth + 1);
309309
break;
310310
}
311-
case TargetOpcode::G_SMIN:
311+
case TargetOpcode::G_SMIN: {
312+
// TODO: Handle clamp pattern with number of sign bits
313+
KnownBits KnownRHS;
314+
computeKnownBitsImpl(MI.getOperand(1).getReg(), Known, DemandedElts,
315+
Depth + 1);
316+
computeKnownBitsImpl(MI.getOperand(2).getReg(), KnownRHS, DemandedElts,
317+
Depth + 1);
318+
Known = KnownBits::smin(Known, KnownRHS);
319+
break;
320+
}
312321
case TargetOpcode::G_SMAX: {
313322
// TODO: Handle clamp pattern with number of sign bits
314-
computeKnownBitsMin(MI.getOperand(1).getReg(), MI.getOperand(2).getReg(),
315-
Known, DemandedElts, Depth + 1);
323+
KnownBits KnownRHS;
324+
computeKnownBitsImpl(MI.getOperand(1).getReg(), Known, DemandedElts,
325+
Depth + 1);
326+
computeKnownBitsImpl(MI.getOperand(2).getReg(), KnownRHS, DemandedElts,
327+
Depth + 1);
328+
Known = KnownBits::smax(Known, KnownRHS);
316329
break;
317330
}
318331
case TargetOpcode::G_UMIN: {
@@ -321,13 +334,7 @@ void GISelKnownBits::computeKnownBitsImpl(Register R, KnownBits &Known,
321334
DemandedElts, Depth + 1);
322335
computeKnownBitsImpl(MI.getOperand(2).getReg(), KnownRHS,
323336
DemandedElts, Depth + 1);
324-
325-
// UMIN - we know that the result will have the maximum of the
326-
// known zero leading bits of the inputs.
327-
unsigned LeadZero = Known.countMinLeadingZeros();
328-
LeadZero = std::max(LeadZero, KnownRHS.countMinLeadingZeros());
329-
Known &= KnownRHS;
330-
Known.Zero.setHighBits(LeadZero);
337+
Known = KnownBits::umin(Known, KnownRHS);
331338
break;
332339
}
333340
case TargetOpcode::G_UMAX: {
@@ -336,14 +343,7 @@ void GISelKnownBits::computeKnownBitsImpl(Register R, KnownBits &Known,
336343
DemandedElts, Depth + 1);
337344
computeKnownBitsImpl(MI.getOperand(2).getReg(), KnownRHS,
338345
DemandedElts, Depth + 1);
339-
340-
// UMAX - we know that the result will have the maximum of the
341-
// known one leading bits of the inputs.
342-
unsigned LeadOne = Known.countMinLeadingOnes();
343-
LeadOne = std::max(LeadOne, KnownRHS.countMinLeadingOnes());
344-
Known.Zero &= KnownRHS.Zero;
345-
Known.One &= KnownRHS.One;
346-
Known.One.setHighBits(LeadOne);
346+
Known = KnownBits::umax(Known, KnownRHS);
347347
break;
348348
}
349349
case TargetOpcode::G_FCMP:

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3390,29 +3390,13 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
33903390
case ISD::UMIN: {
33913391
Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
33923392
Known2 = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
3393-
3394-
// UMIN - we know that the result will have the maximum of the
3395-
// known zero leading bits of the inputs.
3396-
unsigned LeadZero = Known.countMinLeadingZeros();
3397-
LeadZero = std::max(LeadZero, Known2.countMinLeadingZeros());
3398-
3399-
Known.Zero &= Known2.Zero;
3400-
Known.One &= Known2.One;
3401-
Known.Zero.setHighBits(LeadZero);
3393+
Known = KnownBits::umin(Known, Known2);
34023394
break;
34033395
}
34043396
case ISD::UMAX: {
34053397
Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
34063398
Known2 = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
3407-
3408-
// UMAX - we know that the result will have the maximum of the
3409-
// known one leading bits of the inputs.
3410-
unsigned LeadOne = Known.countMinLeadingOnes();
3411-
LeadOne = std::max(LeadOne, Known2.countMinLeadingOnes());
3412-
3413-
Known.Zero &= Known2.Zero;
3414-
Known.One &= Known2.One;
3415-
Known.One.setHighBits(LeadOne);
3399+
Known = KnownBits::umax(Known, Known2);
34163400
break;
34173401
}
34183402
case ISD::SMIN:
@@ -3446,12 +3430,13 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
34463430
}
34473431
}
34483432

3449-
// Fallback - just get the shared known bits of the operands.
34503433
Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
34513434
if (Known.isUnknown()) break; // Early-out
34523435
Known2 = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
3453-
Known.Zero &= Known2.Zero;
3454-
Known.One &= Known2.One;
3436+
if (IsMax)
3437+
Known = KnownBits::smax(Known, Known2);
3438+
else
3439+
Known = KnownBits::smin(Known, Known2);
34553440
break;
34563441
}
34573442
case ISD::FrameIndex:

llvm/lib/Support/KnownBits.cpp

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,68 @@ KnownBits KnownBits::computeForAddSub(bool Add, bool NSW,
8383
return KnownOut;
8484
}
8585

86+
KnownBits KnownBits::makeGE(const APInt &Val) const {
87+
// Count the number of leading bit positions where our underlying value is
88+
// known to be less than or equal to Val.
89+
unsigned N = (Zero | Val).countLeadingOnes();
90+
91+
// For each of those bit positions, if Val has a 1 in that bit then our
92+
// underlying value must also have a 1.
93+
APInt MaskedVal(Val);
94+
MaskedVal.clearLowBits(getBitWidth() - N);
95+
return KnownBits(Zero, One | MaskedVal);
96+
}
97+
98+
KnownBits KnownBits::umax(const KnownBits &LHS, const KnownBits &RHS) {
99+
// If we can prove that LHS >= RHS then use LHS as the result. Likewise for
100+
// RHS. Ideally our caller would already have spotted these cases and
101+
// optimized away the umax operation, but we handle them here for
102+
// completeness.
103+
if (LHS.getMinValue().uge(RHS.getMaxValue()))
104+
return LHS;
105+
if (RHS.getMinValue().uge(LHS.getMaxValue()))
106+
return RHS;
107+
108+
// If the result of the umax is LHS then it must be greater than or equal to
109+
// the minimum possible value of RHS. Likewise for RHS. Any known bits that
110+
// are common to these two values are also known in the result.
111+
KnownBits L = LHS.makeGE(RHS.getMinValue());
112+
KnownBits R = RHS.makeGE(LHS.getMinValue());
113+
return KnownBits(L.Zero & R.Zero, L.One & R.One);
114+
}
115+
116+
KnownBits KnownBits::umin(const KnownBits &LHS, const KnownBits &RHS) {
117+
// Flip the range of values: [0, 0xFFFFFFFF] <-> [0xFFFFFFFF, 0]
118+
auto Flip = [](KnownBits Val) { return KnownBits(Val.One, Val.Zero); };
119+
return Flip(umax(Flip(LHS), Flip(RHS)));
120+
}
121+
122+
KnownBits KnownBits::smax(const KnownBits &LHS, const KnownBits &RHS) {
123+
// Flip the range of values: [-0x80000000, 0x7FFFFFFF] <-> [0, 0xFFFFFFFF]
124+
auto Flip = [](KnownBits Val) {
125+
unsigned SignBitPosition = Val.getBitWidth() - 1;
126+
APInt Zero = Val.Zero;
127+
APInt One = Val.One;
128+
Zero.setBitVal(SignBitPosition, Val.One[SignBitPosition]);
129+
One.setBitVal(SignBitPosition, Val.Zero[SignBitPosition]);
130+
return KnownBits(Zero, One);
131+
};
132+
return Flip(umax(Flip(LHS), Flip(RHS)));
133+
}
134+
135+
KnownBits KnownBits::smin(const KnownBits &LHS, const KnownBits &RHS) {
136+
// Flip the range of values: [-0x80000000, 0x7FFFFFFF] <-> [0xFFFFFFFF, 0]
137+
auto Flip = [](KnownBits Val) {
138+
unsigned SignBitPosition = Val.getBitWidth() - 1;
139+
APInt Zero = Val.One;
140+
APInt One = Val.Zero;
141+
Zero.setBitVal(SignBitPosition, Val.Zero[SignBitPosition]);
142+
One.setBitVal(SignBitPosition, Val.One[SignBitPosition]);
143+
return KnownBits(Zero, One);
144+
};
145+
return Flip(umax(Flip(LHS), Flip(RHS)));
146+
}
147+
86148
KnownBits &KnownBits::operator&=(const KnownBits &RHS) {
87149
// Result bit is 0 if either operand bit is 0.
88150
Zero |= RHS.Zero;

llvm/unittests/CodeGen/GlobalISel/KnownBitsTest.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -719,9 +719,9 @@ TEST_F(AArch64GISelMITest, TestKnownBitsUMax) {
719719

720720
KnownBits KnownUmax = Info.getKnownBits(CopyUMax);
721721
EXPECT_EQ(64u, KnownUmax.getBitWidth());
722-
EXPECT_EQ(0u, KnownUmax.Zero.getZExtValue());
722+
EXPECT_EQ(0xffu, KnownUmax.Zero.getZExtValue());
723723
EXPECT_EQ(0xffffffffffffff00, KnownUmax.One.getZExtValue());
724724

725-
EXPECT_EQ(0u, KnownUmax.Zero.getZExtValue());
725+
EXPECT_EQ(0xffu, KnownUmax.Zero.getZExtValue());
726726
EXPECT_EQ(0xffffffffffffff00, KnownUmax.One.getZExtValue());
727727
}

llvm/unittests/Support/KnownBitsTest.cpp

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -103,13 +103,15 @@ TEST(KnownBitsTest, BinaryExhaustive) {
103103
unsigned Bits = 4;
104104
ForeachKnownBits(Bits, [&](const KnownBits &Known1) {
105105
ForeachKnownBits(Bits, [&](const KnownBits &Known2) {
106-
KnownBits KnownAnd(Bits), KnownOr(Bits), KnownXor(Bits);
106+
KnownBits KnownAnd(Bits);
107107
KnownAnd.Zero.setAllBits();
108108
KnownAnd.One.setAllBits();
109-
KnownOr.Zero.setAllBits();
110-
KnownOr.One.setAllBits();
111-
KnownXor.Zero.setAllBits();
112-
KnownXor.One.setAllBits();
109+
KnownBits KnownOr(KnownAnd);
110+
KnownBits KnownXor(KnownAnd);
111+
KnownBits KnownUMax(KnownAnd);
112+
KnownBits KnownUMin(KnownAnd);
113+
KnownBits KnownSMax(KnownAnd);
114+
KnownBits KnownSMin(KnownAnd);
113115

114116
ForeachNumInKnownBits(Known1, [&](const APInt &N1) {
115117
ForeachNumInKnownBits(Known2, [&](const APInt &N2) {
@@ -126,6 +128,22 @@ TEST(KnownBitsTest, BinaryExhaustive) {
126128
Res = N1 ^ N2;
127129
KnownXor.One &= Res;
128130
KnownXor.Zero &= ~Res;
131+
132+
Res = APIntOps::umax(N1, N2);
133+
KnownUMax.One &= Res;
134+
KnownUMax.Zero &= ~Res;
135+
136+
Res = APIntOps::umin(N1, N2);
137+
KnownUMin.One &= Res;
138+
KnownUMin.Zero &= ~Res;
139+
140+
Res = APIntOps::smax(N1, N2);
141+
KnownSMax.One &= Res;
142+
KnownSMax.Zero &= ~Res;
143+
144+
Res = APIntOps::smin(N1, N2);
145+
KnownSMin.One &= Res;
146+
KnownSMin.Zero &= ~Res;
129147
});
130148
});
131149

@@ -140,6 +158,22 @@ TEST(KnownBitsTest, BinaryExhaustive) {
140158
KnownBits ComputedXor = Known1 ^ Known2;
141159
EXPECT_EQ(KnownXor.Zero, ComputedXor.Zero);
142160
EXPECT_EQ(KnownXor.One, ComputedXor.One);
161+
162+
KnownBits ComputedUMax = KnownBits::umax(Known1, Known2);
163+
EXPECT_EQ(KnownUMax.Zero, ComputedUMax.Zero);
164+
EXPECT_EQ(KnownUMax.One, ComputedUMax.One);
165+
166+
KnownBits ComputedUMin = KnownBits::umin(Known1, Known2);
167+
EXPECT_EQ(KnownUMin.Zero, ComputedUMin.Zero);
168+
EXPECT_EQ(KnownUMin.One, ComputedUMin.One);
169+
170+
KnownBits ComputedSMax = KnownBits::smax(Known1, Known2);
171+
EXPECT_EQ(KnownSMax.Zero, ComputedSMax.Zero);
172+
EXPECT_EQ(KnownSMax.One, ComputedSMax.One);
173+
174+
KnownBits ComputedSMin = KnownBits::smin(Known1, Known2);
175+
EXPECT_EQ(KnownSMin.Zero, ComputedSMin.Zero);
176+
EXPECT_EQ(KnownSMin.One, ComputedSMin.One);
143177
});
144178
});
145179
}

0 commit comments

Comments
 (0)