Skip to content

Commit 9a4b687

Browse files
committed
[KnownBits] Make {s,u}{add,sub}_sat optimal
Changes are: 1) Make signed-overflow detection optimal 2) For signed-overflow, try to rule out direction even if we can't totally rule out overflow. 3) Intersect add/sub assuming no overflow with possible overflow clamping values as opposed to add/sub without the assumption.
1 parent 8673d0e commit 9a4b687

File tree

3 files changed

+78
-81
lines changed

3 files changed

+78
-81
lines changed

llvm/lib/Support/KnownBits.cpp

Lines changed: 73 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -610,28 +610,78 @@ static KnownBits computeForSatAddSub(bool Add, bool Signed,
610610
const KnownBits &RHS) {
611611
// We don't see NSW even for sadd/ssub as we want to check if the result has
612612
// signed overflow.
613-
KnownBits Res =
614-
KnownBits::computeForAddSub(Add, /*NSW=*/false, /*NUW=*/false, LHS, RHS);
615-
unsigned BitWidth = Res.getBitWidth();
616-
auto SignBitKnown = [&](const KnownBits &K) {
617-
return K.Zero[BitWidth - 1] || K.One[BitWidth - 1];
618-
};
619-
std::optional<bool> Overflow;
613+
unsigned BitWidth = LHS.getBitWidth();
620614

615+
std::optional<bool> Overflow;
616+
// Even if we can't entirely rule out overflow, we may be able to rule out
617+
// overflow in one direction. This allows us to potentially keep some of the
618+
// add/sub bits. I.e if we can't overflow in the positive direction we won't
619+
// clamp to INT_MAX so we can keep low 0s from the add/sub result.
620+
bool MayNegClamp = true;
621+
bool MayPosClamp = true;
621622
if (Signed) {
622-
// If we can actually detect overflow do so. Otherwise leave Overflow as
623-
// nullopt (we assume it may have happened).
624-
if (SignBitKnown(LHS) && SignBitKnown(RHS) && SignBitKnown(Res)) {
623+
// Easy cases we can rule out any overflow.
624+
if (Add && ((LHS.isNegative() && RHS.isNonNegative()) ||
625+
(LHS.isNonNegative() && RHS.isNegative())))
626+
Overflow = false;
627+
else if (!Add && (((LHS.isNegative() && RHS.isNegative()) ||
628+
(LHS.isNonNegative() && RHS.isNonNegative()))))
629+
Overflow = false;
630+
else {
631+
// Check if we may overflow. If we can't rule out overflow then check if
632+
// we can rule out a direction at least.
633+
KnownBits UnsignedLHS = LHS;
634+
KnownBits UnsignedRHS = RHS;
635+
UnsignedLHS.One.clearSignBit();
636+
UnsignedLHS.Zero.setSignBit();
637+
UnsignedRHS.One.clearSignBit();
638+
UnsignedRHS.Zero.setSignBit();
639+
KnownBits Res =
640+
KnownBits::computeForAddSub(Add, /*NSW=*/false,
641+
/*NUW=*/false, UnsignedLHS, UnsignedRHS);
625642
if (Add) {
626-
// sadd.sat
627-
Overflow = (LHS.isNonNegative() == RHS.isNonNegative() &&
628-
Res.isNonNegative() != LHS.isNonNegative());
643+
if (Res.isNegative()) {
644+
// Only overflow scenario is Pos + Pos.
645+
MayNegClamp = false;
646+
// Pos + Pos will overflow with extra signbit.
647+
if (LHS.isNonNegative() && RHS.isNonNegative())
648+
Overflow = true;
649+
} else if (Res.isNonNegative()) {
650+
// Only overflow scenario is Neg + Neg
651+
MayPosClamp = false;
652+
// Neg + Neg will overflow without extra signbit.
653+
if (LHS.isNegative() && RHS.isNegative())
654+
Overflow = true;
655+
}
656+
// We will never clamp to the opposite sign of N-bit result.
657+
if (LHS.isNegative() || RHS.isNegative())
658+
MayPosClamp = false;
659+
if (LHS.isNonNegative() || RHS.isNonNegative())
660+
MayNegClamp = false;
629661
} else {
630-
// ssub.sat
631-
Overflow = (LHS.isNonNegative() != RHS.isNonNegative() &&
632-
Res.isNonNegative() != LHS.isNonNegative());
662+
if (Res.isNegative()) {
663+
// Only overflow scenario is Neg - Pos.
664+
MayPosClamp = false;
665+
// Neg - Pos will overflow with extra signbit.
666+
if (LHS.isNegative() && RHS.isNonNegative())
667+
Overflow = true;
668+
} else if (Res.isNonNegative()) {
669+
// Only overflow scenario is Pos - Neg.
670+
MayNegClamp = false;
671+
// Pos - Neg will overflow without extra signbit.
672+
if (LHS.isNonNegative() && RHS.isNegative())
673+
Overflow = true;
674+
}
675+
// We will never clamp to the opposite sign of N-bit result.
676+
if (LHS.isNegative() || RHS.isNonNegative())
677+
MayPosClamp = false;
678+
if (LHS.isNonNegative() || RHS.isNegative())
679+
MayNegClamp = false;
633680
}
634681
}
682+
// If we have ruled out all clamping, we will never overflow.
683+
if (!MayNegClamp && !MayPosClamp)
684+
Overflow = false;
635685
} else if (Add) {
636686
// uadd.sat
637687
bool Of;
@@ -656,52 +706,8 @@ static KnownBits computeForSatAddSub(bool Add, bool Signed,
656706
}
657707
}
658708

659-
if (Signed) {
660-
if (Add) {
661-
if (LHS.isNonNegative() && RHS.isNonNegative()) {
662-
// Pos + Pos -> Pos
663-
Res.One.clearSignBit();
664-
Res.Zero.setSignBit();
665-
}
666-
if (LHS.isNegative() && RHS.isNegative()) {
667-
// Neg + Neg -> Neg
668-
Res.One.setSignBit();
669-
Res.Zero.clearSignBit();
670-
}
671-
} else {
672-
if (LHS.isNegative() && RHS.isNonNegative()) {
673-
// Neg - Pos -> Neg
674-
Res.One.setSignBit();
675-
Res.Zero.clearSignBit();
676-
} else if (LHS.isNonNegative() && RHS.isNegative()) {
677-
// Pos - Neg -> Pos
678-
Res.One.clearSignBit();
679-
Res.Zero.setSignBit();
680-
}
681-
}
682-
} else {
683-
// Add: Leading ones of either operand are preserved.
684-
// Sub: Leading zeros of LHS and leading ones of RHS are preserved
685-
// as leading zeros in the result.
686-
unsigned LeadingKnown;
687-
if (Add)
688-
LeadingKnown =
689-
std::max(LHS.countMinLeadingOnes(), RHS.countMinLeadingOnes());
690-
else
691-
LeadingKnown =
692-
std::max(LHS.countMinLeadingZeros(), RHS.countMinLeadingOnes());
693-
694-
// We select between the operation result and all-ones/zero
695-
// respectively, so we can preserve known ones/zeros.
696-
APInt Mask = APInt::getHighBitsSet(BitWidth, LeadingKnown);
697-
if (Add) {
698-
Res.One |= Mask;
699-
Res.Zero &= ~Mask;
700-
} else {
701-
Res.Zero |= Mask;
702-
Res.One &= ~Mask;
703-
}
704-
}
709+
KnownBits Res = KnownBits::computeForAddSub(Add, /*NSW=*/Signed,
710+
/*NUW=*/!Signed, LHS, RHS);
705711

706712
if (Overflow) {
707713
// We know whether or not we overflowed.
@@ -714,7 +720,7 @@ static KnownBits computeForSatAddSub(bool Add, bool Signed,
714720
APInt C;
715721
if (Signed) {
716722
// sadd.sat / ssub.sat
717-
assert(SignBitKnown(LHS) &&
723+
assert(!LHS.isSignUnknown() &&
718724
"We somehow know overflow without knowing input sign");
719725
C = LHS.isNegative() ? APInt::getSignedMinValue(BitWidth)
720726
: APInt::getSignedMaxValue(BitWidth);
@@ -735,8 +741,10 @@ static KnownBits computeForSatAddSub(bool Add, bool Signed,
735741
if (Signed) {
736742
// sadd.sat/ssub.sat
737743
// We can keep our information about the sign bits.
738-
Res.Zero.clearLowBits(BitWidth - 1);
739-
Res.One.clearLowBits(BitWidth - 1);
744+
if (MayPosClamp)
745+
Res.Zero.clearLowBits(BitWidth - 1);
746+
if (MayNegClamp)
747+
Res.One.clearLowBits(BitWidth - 1);
740748
} else if (Add) {
741749
// uadd.sat
742750
// We need to clear all the known zeros as we can only use the leading ones.

llvm/test/Analysis/ValueTracking/knownbits-sat-addsub.ll

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -142,14 +142,7 @@ define i1 @ssub_sat_low_bits(i8 %x, i8 %y) {
142142

143143
define i1 @ssub_sat_fail_may_overflow(i8 %x, i8 %y) {
144144
; CHECK-LABEL: @ssub_sat_fail_may_overflow(
145-
; CHECK-NEXT: [[XX:%.*]] = and i8 [[X:%.*]], 15
146-
; CHECK-NEXT: [[YY:%.*]] = and i8 [[Y:%.*]], 15
147-
; CHECK-NEXT: [[LHS:%.*]] = or i8 [[XX]], 1
148-
; CHECK-NEXT: [[RHS:%.*]] = and i8 [[YY]], -2
149-
; CHECK-NEXT: [[EXP:%.*]] = call i8 @llvm.ssub.sat.i8(i8 [[LHS]], i8 [[RHS]])
150-
; CHECK-NEXT: [[AND:%.*]] = and i8 [[EXP]], 1
151-
; CHECK-NEXT: [[R:%.*]] = icmp eq i8 [[AND]], 0
152-
; CHECK-NEXT: ret i1 [[R]]
145+
; CHECK-NEXT: ret i1 false
153146
;
154147
%xx = and i8 %x, 15
155148
%yy = and i8 %y, 15

llvm/unittests/Support/KnownBitsTest.cpp

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -383,26 +383,22 @@ TEST(KnownBitsTest, BinaryExhaustive) {
383383
"sadd_sat", KnownBits::sadd_sat,
384384
[](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
385385
return N1.sadd_sat(N2);
386-
},
387-
/*CheckOptimality=*/false);
386+
});
388387
testBinaryOpExhaustive(
389388
"uadd_sat", KnownBits::uadd_sat,
390389
[](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
391390
return N1.uadd_sat(N2);
392-
},
393-
/*CheckOptimality=*/false);
391+
});
394392
testBinaryOpExhaustive(
395393
"ssub_sat", KnownBits::ssub_sat,
396394
[](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
397395
return N1.ssub_sat(N2);
398-
},
399-
/*CheckOptimality=*/false);
396+
});
400397
testBinaryOpExhaustive(
401398
"usub_sat", KnownBits::usub_sat,
402399
[](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
403400
return N1.usub_sat(N2);
404-
},
405-
/*CheckOptimality=*/false);
401+
});
406402
testBinaryOpExhaustive(
407403
"shl",
408404
[](const KnownBits &Known1, const KnownBits &Known2) {

0 commit comments

Comments
 (0)