Skip to content

Commit d8a26ca

Browse files
authored
[KnownBits] Make abdu and abds optimal (#89081)
Fixes #84212
1 parent 8a21d59 commit d8a26ca

File tree

3 files changed

+42
-84
lines changed

3 files changed

+42
-84
lines changed

llvm/include/llvm/Support/KnownBits.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,7 @@ struct KnownBits {
394394
static KnownBits abdu(const KnownBits &LHS, const KnownBits &RHS);
395395

396396
/// Compute known bits for abds(LHS, RHS).
397-
static KnownBits abds(const KnownBits &LHS, const KnownBits &RHS);
397+
static KnownBits abds(KnownBits LHS, KnownBits RHS);
398398

399399
/// Compute known bits for shl(LHS, RHS).
400400
/// NOTE: RHS (shift amount) bitwidth doesn't need to be the same as LHS.

llvm/lib/Support/KnownBits.cpp

Lines changed: 39 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -232,41 +232,53 @@ KnownBits KnownBits::smin(const KnownBits &LHS, const KnownBits &RHS) {
232232
}
233233

234234
KnownBits KnownBits::abdu(const KnownBits &LHS, const KnownBits &RHS) {
235-
// abdu(LHS,RHS) = sub(umax(LHS,RHS), umin(LHS,RHS)).
236-
KnownBits UMaxValue = umax(LHS, RHS);
237-
KnownBits UMinValue = umin(LHS, RHS);
238-
KnownBits MinMaxDiff = computeForAddSub(/*Add=*/false, /*NSW=*/false,
239-
/*NUW=*/true, UMaxValue, UMinValue);
235+
// If we know which argument is larger, return (sub LHS, RHS) or
236+
// (sub RHS, LHS) directly.
237+
if (LHS.getMinValue().uge(RHS.getMaxValue()))
238+
return computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/false, LHS,
239+
RHS);
240+
if (RHS.getMinValue().uge(LHS.getMaxValue()))
241+
return computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/false, RHS,
242+
LHS);
240243

241-
// find the common bits between sub(LHS,RHS) and sub(RHS,LHS).
244+
// By construction, the subtraction in abdu never has unsigned overflow.
245+
// Find the common bits between (sub nuw LHS, RHS) and (sub nuw RHS, LHS).
242246
KnownBits Diff0 =
243-
computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/false, LHS, RHS);
247+
computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/true, LHS, RHS);
244248
KnownBits Diff1 =
245-
computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/false, RHS, LHS);
246-
KnownBits SubDiff = Diff0.intersectWith(Diff1);
247-
248-
KnownBits KnownAbsDiff = MinMaxDiff.unionWith(SubDiff);
249-
assert(!KnownAbsDiff.hasConflict() && "Bad Output");
250-
return KnownAbsDiff;
249+
computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/true, RHS, LHS);
250+
return Diff0.intersectWith(Diff1);
251251
}
252252

253-
KnownBits KnownBits::abds(const KnownBits &LHS, const KnownBits &RHS) {
254-
// abds(LHS,RHS) = sub(smax(LHS,RHS), smin(LHS,RHS)).
255-
KnownBits SMaxValue = smax(LHS, RHS);
256-
KnownBits SMinValue = smin(LHS, RHS);
257-
KnownBits MinMaxDiff = computeForAddSub(/*Add=*/false, /*NSW=*/false,
258-
/*NUW=*/false, SMaxValue, SMinValue);
253+
KnownBits KnownBits::abds(KnownBits LHS, KnownBits RHS) {
254+
// If we know which argument is larger, return (sub LHS, RHS) or
255+
// (sub RHS, LHS) directly.
256+
if (LHS.getSignedMinValue().sge(RHS.getSignedMaxValue()))
257+
return computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/false, LHS,
258+
RHS);
259+
if (RHS.getSignedMinValue().sge(LHS.getSignedMaxValue()))
260+
return computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/false, RHS,
261+
LHS);
262+
263+
// Shift both arguments from the signed range to the unsigned range, e.g. from
264+
// [-0x80, 0x7F] to [0, 0xFF]. This allows us to use "sub nuw" below just like
265+
// abdu does.
266+
// Note that we can't just use "sub nsw" instead because abds has signed
267+
// inputs but an unsigned result, which makes the overflow conditions
268+
// different.
269+
unsigned SignBitPosition = LHS.getBitWidth() - 1;
270+
for (auto Arg : {&LHS, &RHS}) {
271+
bool Tmp = Arg->Zero[SignBitPosition];
272+
Arg->Zero.setBitVal(SignBitPosition, Arg->One[SignBitPosition]);
273+
Arg->One.setBitVal(SignBitPosition, Tmp);
274+
}
259275

260-
// find the common bits between sub(LHS,RHS) and sub(RHS,LHS).
276+
// Find the common bits between (sub nuw LHS, RHS) and (sub nuw RHS, LHS).
261277
KnownBits Diff0 =
262-
computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/false, LHS, RHS);
278+
computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/true, LHS, RHS);
263279
KnownBits Diff1 =
264-
computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/false, RHS, LHS);
265-
KnownBits SubDiff = Diff0.intersectWith(Diff1);
266-
267-
KnownBits KnownAbsDiff = MinMaxDiff.unionWith(SubDiff);
268-
assert(!KnownAbsDiff.hasConflict() && "Bad Output");
269-
return KnownAbsDiff;
280+
computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/true, RHS, LHS);
281+
return Diff0.intersectWith(Diff1);
270282
}
271283

272284
static unsigned getMaxShiftAmount(const APInt &MaxValue, unsigned BitWidth) {

llvm/unittests/Support/KnownBitsTest.cpp

Lines changed: 2 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -294,58 +294,6 @@ TEST(KnownBitsTest, SignBitUnknown) {
294294
EXPECT_TRUE(Known.isSignUnknown());
295295
}
296296

297-
TEST(KnownBitsTest, ABDUSpecialCase) {
298-
// There are 2 implementations of abdu - both are currently needed to cover
299-
// extra cases.
300-
KnownBits LHS, RHS, Res;
301-
302-
// abdu(LHS,RHS) = sub(umax(LHS,RHS), umin(LHS,RHS)).
303-
// Actual: false (Inputs = 1011, 101?, Computed = 000?, Exact = 000?)
304-
LHS.One = APInt(4, 0b1011);
305-
RHS.One = APInt(4, 0b1010);
306-
LHS.Zero = APInt(4, 0b0100);
307-
RHS.Zero = APInt(4, 0b0100);
308-
Res = KnownBits::abdu(LHS, RHS);
309-
EXPECT_EQ(0b0000ul, Res.One.getZExtValue());
310-
EXPECT_EQ(0b1110ul, Res.Zero.getZExtValue());
311-
312-
// find the common bits between sub(LHS,RHS) and sub(RHS,LHS).
313-
// Actual: false (Inputs = ???1, 1000, Computed = ???1, Exact = 0??1)
314-
LHS.One = APInt(4, 0b0001);
315-
RHS.One = APInt(4, 0b1000);
316-
LHS.Zero = APInt(4, 0b0000);
317-
RHS.Zero = APInt(4, 0b0111);
318-
Res = KnownBits::abdu(LHS, RHS);
319-
EXPECT_EQ(0b0001ul, Res.One.getZExtValue());
320-
EXPECT_EQ(0b0000ul, Res.Zero.getZExtValue());
321-
}
322-
323-
TEST(KnownBitsTest, ABDSSpecialCase) {
324-
// There are 2 implementations of abds - both are currently needed to cover
325-
// extra cases.
326-
KnownBits LHS, RHS, Res;
327-
328-
// abds(LHS,RHS) = sub(smax(LHS,RHS), smin(LHS,RHS)).
329-
// Actual: false (Inputs = 1011, 10??, Computed = ????, Exact = 00??)
330-
LHS.One = APInt(4, 0b1011);
331-
RHS.One = APInt(4, 0b1000);
332-
LHS.Zero = APInt(4, 0b0100);
333-
RHS.Zero = APInt(4, 0b0100);
334-
Res = KnownBits::abds(LHS, RHS);
335-
EXPECT_EQ(0, Res.One.getSExtValue());
336-
EXPECT_EQ(-4, Res.Zero.getSExtValue());
337-
338-
// find the common bits between sub(LHS,RHS) and sub(RHS,LHS).
339-
// Actual: false (Inputs = ???1, 1000, Computed = ???1, Exact = 0??1)
340-
LHS.One = APInt(4, 0b0001);
341-
RHS.One = APInt(4, 0b1000);
342-
LHS.Zero = APInt(4, 0b0000);
343-
RHS.Zero = APInt(4, 0b0111);
344-
Res = KnownBits::abds(LHS, RHS);
345-
EXPECT_EQ(1, Res.One.getSExtValue());
346-
EXPECT_EQ(0, Res.Zero.getSExtValue());
347-
}
348-
349297
TEST(KnownBitsTest, BinaryExhaustive) {
350298
testBinaryOpExhaustive(
351299
[](const KnownBits &Known1, const KnownBits &Known2) {
@@ -366,10 +314,8 @@ TEST(KnownBitsTest, BinaryExhaustive) {
366314
testBinaryOpExhaustive(KnownBits::umin, APIntOps::umin);
367315
testBinaryOpExhaustive(KnownBits::smax, APIntOps::smax);
368316
testBinaryOpExhaustive(KnownBits::smin, APIntOps::smin);
369-
testBinaryOpExhaustive(KnownBits::abdu, APIntOps::abdu,
370-
checkCorrectnessOnlyBinary);
371-
testBinaryOpExhaustive(KnownBits::abds, APIntOps::abds,
372-
checkCorrectnessOnlyBinary);
317+
testBinaryOpExhaustive(KnownBits::abdu, APIntOps::abdu);
318+
testBinaryOpExhaustive(KnownBits::abds, APIntOps::abds);
373319
testBinaryOpExhaustive(
374320
[](const KnownBits &Known1, const KnownBits &Known2) {
375321
return KnownBits::udiv(Known1, Known2);

0 commit comments

Comments
 (0)