Skip to content

Commit 877cb9a

Browse files
authored
[KnownBits] Make {s,u}{add,sub}_sat optimal (#113096)
Changes are: 1) Make signed-overflow detection optimal 2) For signed-overflow, try to rule out direction even if we can't totally rule out overflow. 3) Intersect add/sub assuming no overflow with possible overflow clamping values as opposed to add/sub without the assumption.
1 parent 15c7e9f commit 877cb9a

File tree

3 files changed

+82
-81
lines changed

3 files changed

+82
-81
lines changed

llvm/lib/Support/KnownBits.cpp

Lines changed: 77 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -610,28 +610,82 @@ static KnownBits computeForSatAddSub(bool Add, bool Signed,
610610
const KnownBits &RHS) {
611611
// We don't see NSW even for sadd/ssub as we want to check if the result has
612612
// signed overflow.
613-
KnownBits Res =
614-
KnownBits::computeForAddSub(Add, /*NSW=*/false, /*NUW=*/false, LHS, RHS);
615-
unsigned BitWidth = Res.getBitWidth();
616-
auto SignBitKnown = [&](const KnownBits &K) {
617-
return K.Zero[BitWidth - 1] || K.One[BitWidth - 1];
618-
};
619-
std::optional<bool> Overflow;
613+
unsigned BitWidth = LHS.getBitWidth();
620614

615+
std::optional<bool> Overflow;
616+
// Even if we can't entirely rule out overflow, we may be able to rule out
617+
// overflow in one direction. This allows us to potentially keep some of the
618+
// add/sub bits. I.e if we can't overflow in the positive direction we won't
619+
// clamp to INT_MAX so we can keep low 0s from the add/sub result.
620+
bool MayNegClamp = true;
621+
bool MayPosClamp = true;
621622
if (Signed) {
622-
// If we can actually detect overflow do so. Otherwise leave Overflow as
623-
// nullopt (we assume it may have happened).
624-
if (SignBitKnown(LHS) && SignBitKnown(RHS) && SignBitKnown(Res)) {
623+
// Easy cases we can rule out any overflow.
624+
if (Add && ((LHS.isNegative() && RHS.isNonNegative()) ||
625+
(LHS.isNonNegative() && RHS.isNegative())))
626+
Overflow = false;
627+
else if (!Add && (((LHS.isNegative() && RHS.isNegative()) ||
628+
(LHS.isNonNegative() && RHS.isNonNegative()))))
629+
Overflow = false;
630+
else {
631+
// Check if we may overflow. If we can't rule out overflow then check if
632+
// we can rule out a direction at least.
633+
KnownBits UnsignedLHS = LHS;
634+
KnownBits UnsignedRHS = RHS;
635+
// Get version of LHS/RHS with clearer signbit. This allows us to detect
636+
// how the addition/subtraction might overflow into the signbit. Then
637+
// using the actual known signbits of LHS/RHS, we can figure out which
638+
// overflows are/aren't possible.
639+
UnsignedLHS.One.clearSignBit();
640+
UnsignedLHS.Zero.setSignBit();
641+
UnsignedRHS.One.clearSignBit();
642+
UnsignedRHS.Zero.setSignBit();
643+
KnownBits Res =
644+
KnownBits::computeForAddSub(Add, /*NSW=*/false,
645+
/*NUW=*/false, UnsignedLHS, UnsignedRHS);
625646
if (Add) {
626-
// sadd.sat
627-
Overflow = (LHS.isNonNegative() == RHS.isNonNegative() &&
628-
Res.isNonNegative() != LHS.isNonNegative());
647+
if (Res.isNegative()) {
648+
// Only overflow scenario is Pos + Pos.
649+
MayNegClamp = false;
650+
// Pos + Pos will overflow with extra signbit.
651+
if (LHS.isNonNegative() && RHS.isNonNegative())
652+
Overflow = true;
653+
} else if (Res.isNonNegative()) {
654+
// Only overflow scenario is Neg + Neg
655+
MayPosClamp = false;
656+
// Neg + Neg will overflow without extra signbit.
657+
if (LHS.isNegative() && RHS.isNegative())
658+
Overflow = true;
659+
}
660+
// We will never clamp to the opposite sign of N-bit result.
661+
if (LHS.isNegative() || RHS.isNegative())
662+
MayPosClamp = false;
663+
if (LHS.isNonNegative() || RHS.isNonNegative())
664+
MayNegClamp = false;
629665
} else {
630-
// ssub.sat
631-
Overflow = (LHS.isNonNegative() != RHS.isNonNegative() &&
632-
Res.isNonNegative() != LHS.isNonNegative());
666+
if (Res.isNegative()) {
667+
// Only overflow scenario is Neg - Pos.
668+
MayPosClamp = false;
669+
// Neg - Pos will overflow with extra signbit.
670+
if (LHS.isNegative() && RHS.isNonNegative())
671+
Overflow = true;
672+
} else if (Res.isNonNegative()) {
673+
// Only overflow scenario is Pos - Neg.
674+
MayNegClamp = false;
675+
// Pos - Neg will overflow without extra signbit.
676+
if (LHS.isNonNegative() && RHS.isNegative())
677+
Overflow = true;
678+
}
679+
// We will never clamp to the opposite sign of N-bit result.
680+
if (LHS.isNegative() || RHS.isNonNegative())
681+
MayPosClamp = false;
682+
if (LHS.isNonNegative() || RHS.isNegative())
683+
MayNegClamp = false;
633684
}
634685
}
686+
// If we have ruled out all clamping, we will never overflow.
687+
if (!MayNegClamp && !MayPosClamp)
688+
Overflow = false;
635689
} else if (Add) {
636690
// uadd.sat
637691
bool Of;
@@ -656,52 +710,8 @@ static KnownBits computeForSatAddSub(bool Add, bool Signed,
656710
}
657711
}
658712

659-
if (Signed) {
660-
if (Add) {
661-
if (LHS.isNonNegative() && RHS.isNonNegative()) {
662-
// Pos + Pos -> Pos
663-
Res.One.clearSignBit();
664-
Res.Zero.setSignBit();
665-
}
666-
if (LHS.isNegative() && RHS.isNegative()) {
667-
// Neg + Neg -> Neg
668-
Res.One.setSignBit();
669-
Res.Zero.clearSignBit();
670-
}
671-
} else {
672-
if (LHS.isNegative() && RHS.isNonNegative()) {
673-
// Neg - Pos -> Neg
674-
Res.One.setSignBit();
675-
Res.Zero.clearSignBit();
676-
} else if (LHS.isNonNegative() && RHS.isNegative()) {
677-
// Pos - Neg -> Pos
678-
Res.One.clearSignBit();
679-
Res.Zero.setSignBit();
680-
}
681-
}
682-
} else {
683-
// Add: Leading ones of either operand are preserved.
684-
// Sub: Leading zeros of LHS and leading ones of RHS are preserved
685-
// as leading zeros in the result.
686-
unsigned LeadingKnown;
687-
if (Add)
688-
LeadingKnown =
689-
std::max(LHS.countMinLeadingOnes(), RHS.countMinLeadingOnes());
690-
else
691-
LeadingKnown =
692-
std::max(LHS.countMinLeadingZeros(), RHS.countMinLeadingOnes());
693-
694-
// We select between the operation result and all-ones/zero
695-
// respectively, so we can preserve known ones/zeros.
696-
APInt Mask = APInt::getHighBitsSet(BitWidth, LeadingKnown);
697-
if (Add) {
698-
Res.One |= Mask;
699-
Res.Zero &= ~Mask;
700-
} else {
701-
Res.Zero |= Mask;
702-
Res.One &= ~Mask;
703-
}
704-
}
713+
KnownBits Res = KnownBits::computeForAddSub(Add, /*NSW=*/Signed,
714+
/*NUW=*/!Signed, LHS, RHS);
705715

706716
if (Overflow) {
707717
// We know whether or not we overflowed.
@@ -714,7 +724,7 @@ static KnownBits computeForSatAddSub(bool Add, bool Signed,
714724
APInt C;
715725
if (Signed) {
716726
// sadd.sat / ssub.sat
717-
assert(SignBitKnown(LHS) &&
727+
assert(!LHS.isSignUnknown() &&
718728
"We somehow know overflow without knowing input sign");
719729
C = LHS.isNegative() ? APInt::getSignedMinValue(BitWidth)
720730
: APInt::getSignedMaxValue(BitWidth);
@@ -735,8 +745,10 @@ static KnownBits computeForSatAddSub(bool Add, bool Signed,
735745
if (Signed) {
736746
// sadd.sat/ssub.sat
737747
// We can keep our information about the sign bits.
738-
Res.Zero.clearLowBits(BitWidth - 1);
739-
Res.One.clearLowBits(BitWidth - 1);
748+
if (MayPosClamp)
749+
Res.Zero.clearLowBits(BitWidth - 1);
750+
if (MayNegClamp)
751+
Res.One.clearLowBits(BitWidth - 1);
740752
} else if (Add) {
741753
// uadd.sat
742754
// We need to clear all the known zeros as we can only use the leading ones.

llvm/test/Analysis/ValueTracking/knownbits-sat-addsub.ll

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -142,14 +142,7 @@ define i1 @ssub_sat_low_bits(i8 %x, i8 %y) {
142142

143143
define i1 @ssub_sat_fail_may_overflow(i8 %x, i8 %y) {
144144
; CHECK-LABEL: @ssub_sat_fail_may_overflow(
145-
; CHECK-NEXT: [[XX:%.*]] = and i8 [[X:%.*]], 15
146-
; CHECK-NEXT: [[YY:%.*]] = and i8 [[Y:%.*]], 15
147-
; CHECK-NEXT: [[LHS:%.*]] = or i8 [[XX]], 1
148-
; CHECK-NEXT: [[RHS:%.*]] = and i8 [[YY]], -2
149-
; CHECK-NEXT: [[EXP:%.*]] = call i8 @llvm.ssub.sat.i8(i8 [[LHS]], i8 [[RHS]])
150-
; CHECK-NEXT: [[AND:%.*]] = and i8 [[EXP]], 1
151-
; CHECK-NEXT: [[R:%.*]] = icmp eq i8 [[AND]], 0
152-
; CHECK-NEXT: ret i1 [[R]]
145+
; CHECK-NEXT: ret i1 false
153146
;
154147
%xx = and i8 %x, 15
155148
%yy = and i8 %y, 15

llvm/unittests/Support/KnownBitsTest.cpp

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -383,26 +383,22 @@ TEST(KnownBitsTest, BinaryExhaustive) {
383383
"sadd_sat", KnownBits::sadd_sat,
384384
[](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
385385
return N1.sadd_sat(N2);
386-
},
387-
/*CheckOptimality=*/false);
386+
});
388387
testBinaryOpExhaustive(
389388
"uadd_sat", KnownBits::uadd_sat,
390389
[](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
391390
return N1.uadd_sat(N2);
392-
},
393-
/*CheckOptimality=*/false);
391+
});
394392
testBinaryOpExhaustive(
395393
"ssub_sat", KnownBits::ssub_sat,
396394
[](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
397395
return N1.ssub_sat(N2);
398-
},
399-
/*CheckOptimality=*/false);
396+
});
400397
testBinaryOpExhaustive(
401398
"usub_sat", KnownBits::usub_sat,
402399
[](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
403400
return N1.usub_sat(N2);
404-
},
405-
/*CheckOptimality=*/false);
401+
});
406402
testBinaryOpExhaustive(
407403
"shl",
408404
[](const KnownBits &Known1, const KnownBits &Known2) {

0 commit comments

Comments
 (0)