Skip to content

Commit b60cf84

Browse files
committed
[InstCombine] Add more cases for simplifying (icmp (and/or x, Mask), y)
This cleans up basically all the regressions assosiated from #84688 Proof of all new cases: https://alive2.llvm.org/ce/z/5yYWLb Closes #85445
1 parent 23047df commit b60cf84

File tree

2 files changed

+115
-80
lines changed

2 files changed

+115
-80
lines changed

llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp

Lines changed: 98 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -4177,7 +4177,9 @@ static bool isMaskOrZero(const Value *V, bool Not, const SimplifyQuery &Q,
41774177
/// a check for a lossy truncation.
41784178
/// Folds:
41794179
/// icmp SrcPred (x & Mask), x to icmp DstPred x, Mask
4180+
/// icmp SrcPred (x & ~Mask), ~Mask to icmp DstPred x, ~Mask
41804181
/// icmp eq/ne (x & ~Mask), 0 to icmp DstPred x, Mask
4182+
/// icmp eq/ne (~x | Mask), -1 to icmp DstPred x, Mask
41814183
/// Where Mask is some pattern that produces all-ones in low bits:
41824184
/// (-1 >> y)
41834185
/// ((-1 << y) >> y) <- non-canonical, has extra uses
@@ -4189,82 +4191,126 @@ static bool isMaskOrZero(const Value *V, bool Not, const SimplifyQuery &Q,
41894191
static Value *foldICmpWithLowBitMaskedVal(ICmpInst::Predicate Pred, Value *Op0,
41904192
Value *Op1, const SimplifyQuery &Q,
41914193
InstCombiner &IC) {
4192-
Value *X, *M;
4193-
bool NeedsNot = false;
4194-
4195-
auto CheckMask = [&](Value *V, bool Not) {
4196-
if (ICmpInst::isSigned(Pred) && !match(V, m_ImmConstant()))
4197-
return false;
4198-
return isMaskOrZero(V, Not, Q);
4199-
};
4200-
4201-
if (match(Op0, m_c_And(m_Specific(Op1), m_Value(M))) &&
4202-
CheckMask(M, /*Not*/ false)) {
4203-
X = Op1;
4204-
} else if (match(Op1, m_Zero()) && ICmpInst::isEquality(Pred) &&
4205-
match(Op0, m_OneUse(m_And(m_Value(X), m_Value(M))))) {
4206-
NeedsNot = true;
4207-
if (IC.isFreeToInvert(X, X->hasOneUse()) && CheckMask(X, /*Not*/ true))
4208-
std::swap(X, M);
4209-
else if (!IC.isFreeToInvert(M, M->hasOneUse()) ||
4210-
!CheckMask(M, /*Not*/ true))
4211-
return nullptr;
4212-
} else {
4213-
return nullptr;
4214-
}
42154194

42164195
ICmpInst::Predicate DstPred;
42174196
switch (Pred) {
42184197
case ICmpInst::Predicate::ICMP_EQ:
4219-
// x & (-1 >> y) == x -> x u<= (-1 >> y)
4198+
// x & Mask == x
4199+
// x & ~Mask == 0
4200+
// ~x | Mask == -1
4201+
// -> x u<= Mask
4202+
// x & ~Mask == ~Mask
4203+
// -> ~Mask u<= x
42204204
DstPred = ICmpInst::Predicate::ICMP_ULE;
42214205
break;
42224206
case ICmpInst::Predicate::ICMP_NE:
4223-
// x & (-1 >> y) != x -> x u> (-1 >> y)
4207+
// x & Mask != x
4208+
// x & ~Mask != 0
4209+
// ~x | Mask != -1
4210+
// -> x u> Mask
4211+
// x & ~Mask != ~Mask
4212+
// -> ~Mask u> x
42244213
DstPred = ICmpInst::Predicate::ICMP_UGT;
42254214
break;
42264215
case ICmpInst::Predicate::ICMP_ULT:
4227-
// x & (-1 >> y) u< x -> x u> (-1 >> y)
4228-
// x u> x & (-1 >> y) -> x u> (-1 >> y)
4216+
// x & Mask u< x
4217+
// -> x u> Mask
4218+
// x & ~Mask u< ~Mask
4219+
// -> ~Mask u> x
42294220
DstPred = ICmpInst::Predicate::ICMP_UGT;
42304221
break;
42314222
case ICmpInst::Predicate::ICMP_UGE:
4232-
// x & (-1 >> y) u>= x -> x u<= (-1 >> y)
4233-
// x u<= x & (-1 >> y) -> x u<= (-1 >> y)
4223+
// x & Mask u>= x
4224+
// -> x u<= Mask
4225+
// x & ~Mask u>= ~Mask
4226+
// -> ~Mask u<= x
42344227
DstPred = ICmpInst::Predicate::ICMP_ULE;
42354228
break;
42364229
case ICmpInst::Predicate::ICMP_SLT:
4237-
// x & (-1 >> y) s< x -> x s> (-1 >> y)
4238-
// x s> x & (-1 >> y) -> x s> (-1 >> y)
4239-
if (!match(M, m_Constant())) // Can not do this fold with non-constant.
4240-
return nullptr;
4241-
if (!match(M, m_NonNegative())) // Must not have any -1 vector elements.
4242-
return nullptr;
4230+
// x & Mask s< x [iff Mask s>= 0]
4231+
// -> x s> Mask
4232+
// x & ~Mask s< ~Mask [iff ~Mask != 0]
4233+
// -> ~Mask s> x
42434234
DstPred = ICmpInst::Predicate::ICMP_SGT;
42444235
break;
42454236
case ICmpInst::Predicate::ICMP_SGE:
4246-
// x & (-1 >> y) s>= x -> x s<= (-1 >> y)
4247-
// x s<= x & (-1 >> y) -> x s<= (-1 >> y)
4248-
if (!match(M, m_Constant())) // Can not do this fold with non-constant.
4249-
return nullptr;
4250-
if (!match(M, m_NonNegative())) // Must not have any -1 vector elements.
4251-
return nullptr;
4237+
// x & Mask s>= x [iff Mask s>= 0]
4238+
// -> x s<= Mask
4239+
// x & ~Mask s>= ~Mask [iff ~Mask != 0]
4240+
// -> ~Mask s<= x
42524241
DstPred = ICmpInst::Predicate::ICMP_SLE;
42534242
break;
4254-
case ICmpInst::Predicate::ICMP_SGT:
4255-
case ICmpInst::Predicate::ICMP_SLE:
4256-
return nullptr;
4257-
case ICmpInst::Predicate::ICMP_UGT:
4258-
case ICmpInst::Predicate::ICMP_ULE:
4259-
llvm_unreachable("Instsimplify took care of commut. variant");
4260-
break;
42614243
default:
4262-
llvm_unreachable("All possible folds are handled.");
4244+
// We don't support sgt,sle
4245+
// ult/ugt are simplified to true/false respectively.
4246+
return nullptr;
42634247
}
42644248

4265-
// The mask value may be a vector constant that has undefined elements. But it
4266-
// may not be safe to propagate those undefs into the new compare, so replace
4267-
// those elements by copying an existing, defined, and safe scalar constant.
4249+
Value *X, *M;
4250+
// Put search code in lambda for early positive returns.
4251+
auto IsLowBitMask = [&]() {
4252+
if (match(Op0, m_c_And(m_Specific(Op1), m_Value(M)))) {
4253+
X = Op1;
4254+
// Look for: x & Mask pred x
4255+
if (isMaskOrZero(M, /*Not=*/false, Q)) {
4256+
return !ICmpInst::isSigned(Pred) ||
4257+
(match(M, m_NonNegative()) || isKnownNonNegative(M, Q));
4258+
}
4259+
4260+
// Look for: x & ~Mask pred ~Mask
4261+
if (isMaskOrZero(X, /*Not=*/true, Q)) {
4262+
return !ICmpInst::isSigned(Pred) ||
4263+
isKnownNonZero(X, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT);
4264+
}
4265+
return false;
4266+
}
4267+
if (ICmpInst::isEquality(Pred) && match(Op1, m_AllOnes()) &&
4268+
match(Op0, m_OneUse(m_Or(m_Value(X), m_Value(M))))) {
4269+
4270+
auto Check = [&]() {
4271+
// Look for: ~x | Mask == -1
4272+
if (isMaskOrZero(M, /*Not=*/false, Q)) {
4273+
if (Value *NotX =
4274+
IC.getFreelyInverted(X, X->hasOneUse(), &IC.Builder)) {
4275+
X = NotX;
4276+
return true;
4277+
}
4278+
}
4279+
return false;
4280+
};
4281+
if (Check())
4282+
return true;
4283+
std::swap(X, M);
4284+
return Check();
4285+
}
4286+
if (ICmpInst::isEquality(Pred) && match(Op1, m_Zero()) &&
4287+
match(Op0, m_OneUse(m_And(m_Value(X), m_Value(M))))) {
4288+
auto Check = [&]() {
4289+
// Look for: x & ~Mask == 0
4290+
if (isMaskOrZero(M, /*Not=*/true, Q)) {
4291+
if (Value *NotM =
4292+
IC.getFreelyInverted(M, M->hasOneUse(), &IC.Builder)) {
4293+
M = NotM;
4294+
return true;
4295+
}
4296+
}
4297+
return false;
4298+
};
4299+
if (Check())
4300+
return true;
4301+
std::swap(X, M);
4302+
return Check();
4303+
}
4304+
return false;
4305+
};
4306+
4307+
if (!IsLowBitMask())
4308+
return nullptr;
4309+
4310+
// The mask value may be a vector constant that has undefined elements. But
4311+
// it may not be safe to propagate those undefs into the new compare, so
4312+
// replace those elements by copying an existing, defined, and safe scalar
4313+
// constant.
42684314
Type *OpTy = M->getType();
42694315
auto *VecC = dyn_cast<Constant>(M);
42704316
auto *OpVTy = dyn_cast<FixedVectorType>(OpTy);
@@ -4280,8 +4326,6 @@ static Value *foldICmpWithLowBitMaskedVal(ICmpInst::Predicate Pred, Value *Op0,
42804326
M = Constant::replaceUndefsWith(VecC, SafeReplacementConstant);
42814327
}
42824328

4283-
if (NeedsNot)
4284-
M = IC.Builder.CreateNot(M);
42854329
return IC.Builder.CreateICmp(DstPred, X, M);
42864330
}
42874331

llvm/test/Transforms/InstCombine/icmp-and-lowbit-mask.ll

Lines changed: 17 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -680,8 +680,7 @@ define i1 @src_x_and_mask_slt(i8 %x, i8 %y, i1 %cond) {
680680
; CHECK-NEXT: [[MASK:%.*]] = select i1 [[COND:%.*]], i8 [[MASK0]], i8 0
681681
; CHECK-NEXT: [[MASK_POS:%.*]] = icmp sgt i8 [[MASK]], -1
682682
; CHECK-NEXT: call void @llvm.assume(i1 [[MASK_POS]])
683-
; CHECK-NEXT: [[AND:%.*]] = and i8 [[MASK]], [[X:%.*]]
684-
; CHECK-NEXT: [[R:%.*]] = icmp slt i8 [[AND]], [[X]]
683+
; CHECK-NEXT: [[R:%.*]] = icmp slt i8 [[MASK]], [[X:%.*]]
685684
; CHECK-NEXT: ret i1 [[R]]
686685
;
687686
%mask0 = lshr i8 -1, %y
@@ -699,8 +698,7 @@ define i1 @src_x_and_mask_sge(i8 %x, i8 %y, i1 %cond) {
699698
; CHECK-NEXT: [[MASK:%.*]] = select i1 [[COND:%.*]], i8 [[MASK0]], i8 0
700699
; CHECK-NEXT: [[MASK_POS:%.*]] = icmp sgt i8 [[MASK]], -1
701700
; CHECK-NEXT: call void @llvm.assume(i1 [[MASK_POS]])
702-
; CHECK-NEXT: [[AND:%.*]] = and i8 [[MASK]], [[X:%.*]]
703-
; CHECK-NEXT: [[R:%.*]] = icmp sge i8 [[AND]], [[X]]
701+
; CHECK-NEXT: [[R:%.*]] = icmp sge i8 [[MASK]], [[X:%.*]]
704702
; CHECK-NEXT: ret i1 [[R]]
705703
;
706704
%mask0 = lshr i8 -1, %y
@@ -745,9 +743,9 @@ define i1 @src_x_and_mask_sge_fail_maybe_neg(i8 %x, i8 %y, i1 %cond) {
745743
define i1 @src_x_and_nmask_eq(i8 %x, i8 %y, i1 %cond) {
746744
; CHECK-LABEL: @src_x_and_nmask_eq(
747745
; CHECK-NEXT: [[NOT_MASK0:%.*]] = shl nsw i8 -1, [[Y:%.*]]
748-
; CHECK-NEXT: [[NOT_MASK:%.*]] = select i1 [[COND:%.*]], i8 [[NOT_MASK0]], i8 0
749-
; CHECK-NEXT: [[AND:%.*]] = and i8 [[NOT_MASK]], [[X:%.*]]
750-
; CHECK-NEXT: [[R:%.*]] = icmp eq i8 [[NOT_MASK]], [[AND]]
746+
; CHECK-NEXT: [[R1:%.*]] = icmp ule i8 [[NOT_MASK0]], [[X:%.*]]
747+
; CHECK-NEXT: [[NOT_COND:%.*]] = xor i1 [[COND:%.*]], true
748+
; CHECK-NEXT: [[R:%.*]] = select i1 [[NOT_COND]], i1 true, i1 [[R1]]
751749
; CHECK-NEXT: ret i1 [[R]]
752750
;
753751
%not_mask0 = shl i8 -1, %y
@@ -760,9 +758,8 @@ define i1 @src_x_and_nmask_eq(i8 %x, i8 %y, i1 %cond) {
760758
define i1 @src_x_and_nmask_ne(i8 %x, i8 %y, i1 %cond) {
761759
; CHECK-LABEL: @src_x_and_nmask_ne(
762760
; CHECK-NEXT: [[NOT_MASK0:%.*]] = shl nsw i8 -1, [[Y:%.*]]
763-
; CHECK-NEXT: [[NOT_MASK:%.*]] = select i1 [[COND:%.*]], i8 [[NOT_MASK0]], i8 0
764-
; CHECK-NEXT: [[AND:%.*]] = and i8 [[NOT_MASK]], [[X:%.*]]
765-
; CHECK-NEXT: [[R:%.*]] = icmp ne i8 [[AND]], [[NOT_MASK]]
761+
; CHECK-NEXT: [[R1:%.*]] = icmp ugt i8 [[NOT_MASK0]], [[X:%.*]]
762+
; CHECK-NEXT: [[R:%.*]] = select i1 [[COND:%.*]], i1 [[R1]], i1 false
766763
; CHECK-NEXT: ret i1 [[R]]
767764
;
768765
%not_mask0 = shl i8 -1, %y
@@ -775,9 +772,8 @@ define i1 @src_x_and_nmask_ne(i8 %x, i8 %y, i1 %cond) {
775772
define i1 @src_x_and_nmask_ult(i8 %x, i8 %y, i1 %cond) {
776773
; CHECK-LABEL: @src_x_and_nmask_ult(
777774
; CHECK-NEXT: [[NOT_MASK0:%.*]] = shl nsw i8 -1, [[Y:%.*]]
778-
; CHECK-NEXT: [[NOT_MASK:%.*]] = select i1 [[COND:%.*]], i8 [[NOT_MASK0]], i8 0
779-
; CHECK-NEXT: [[AND:%.*]] = and i8 [[NOT_MASK]], [[X:%.*]]
780-
; CHECK-NEXT: [[R:%.*]] = icmp ne i8 [[AND]], [[NOT_MASK]]
775+
; CHECK-NEXT: [[R1:%.*]] = icmp ugt i8 [[NOT_MASK0]], [[X:%.*]]
776+
; CHECK-NEXT: [[R:%.*]] = select i1 [[COND:%.*]], i1 [[R1]], i1 false
781777
; CHECK-NEXT: ret i1 [[R]]
782778
;
783779
%not_mask0 = shl i8 -1, %y
@@ -790,9 +786,9 @@ define i1 @src_x_and_nmask_ult(i8 %x, i8 %y, i1 %cond) {
790786
define i1 @src_x_and_nmask_uge(i8 %x, i8 %y, i1 %cond) {
791787
; CHECK-LABEL: @src_x_and_nmask_uge(
792788
; CHECK-NEXT: [[NOT_MASK0:%.*]] = shl nsw i8 -1, [[Y:%.*]]
793-
; CHECK-NEXT: [[NOT_MASK:%.*]] = select i1 [[COND:%.*]], i8 [[NOT_MASK0]], i8 0
794-
; CHECK-NEXT: [[AND:%.*]] = and i8 [[NOT_MASK]], [[X:%.*]]
795-
; CHECK-NEXT: [[R:%.*]] = icmp eq i8 [[AND]], [[NOT_MASK]]
789+
; CHECK-NEXT: [[R1:%.*]] = icmp ule i8 [[NOT_MASK0]], [[X:%.*]]
790+
; CHECK-NEXT: [[NOT_COND:%.*]] = xor i1 [[COND:%.*]], true
791+
; CHECK-NEXT: [[R:%.*]] = select i1 [[NOT_COND]], i1 true, i1 [[R1]]
796792
; CHECK-NEXT: ret i1 [[R]]
797793
;
798794
%not_mask0 = shl i8 -1, %y
@@ -805,8 +801,7 @@ define i1 @src_x_and_nmask_uge(i8 %x, i8 %y, i1 %cond) {
805801
define i1 @src_x_and_nmask_slt(i8 %x, i8 %y) {
806802
; CHECK-LABEL: @src_x_and_nmask_slt(
807803
; CHECK-NEXT: [[NOT_MASK:%.*]] = shl nsw i8 -1, [[Y:%.*]]
808-
; CHECK-NEXT: [[AND:%.*]] = and i8 [[NOT_MASK]], [[X:%.*]]
809-
; CHECK-NEXT: [[R:%.*]] = icmp slt i8 [[AND]], [[NOT_MASK]]
804+
; CHECK-NEXT: [[R:%.*]] = icmp sgt i8 [[NOT_MASK]], [[X:%.*]]
810805
; CHECK-NEXT: ret i1 [[R]]
811806
;
812807
%not_mask = shl i8 -1, %y
@@ -818,8 +813,7 @@ define i1 @src_x_and_nmask_slt(i8 %x, i8 %y) {
818813
define i1 @src_x_and_nmask_sge(i8 %x, i8 %y) {
819814
; CHECK-LABEL: @src_x_and_nmask_sge(
820815
; CHECK-NEXT: [[NOT_MASK:%.*]] = shl nsw i8 -1, [[Y:%.*]]
821-
; CHECK-NEXT: [[AND:%.*]] = and i8 [[NOT_MASK]], [[X:%.*]]
822-
; CHECK-NEXT: [[R:%.*]] = icmp sge i8 [[AND]], [[NOT_MASK]]
816+
; CHECK-NEXT: [[R:%.*]] = icmp sle i8 [[NOT_MASK]], [[X:%.*]]
823817
; CHECK-NEXT: ret i1 [[R]]
824818
;
825819
%not_mask = shl i8 -1, %y
@@ -865,9 +859,8 @@ define i1 @src_x_or_mask_eq(i8 %x, i8 %y, i8 %z, i1 %c2, i1 %cond) {
865859
; CHECK-NEXT: [[TMP1:%.*]] = xor i8 [[X:%.*]], -124
866860
; CHECK-NEXT: [[TMP2:%.*]] = select i1 [[C2:%.*]], i8 [[TMP1]], i8 -46
867861
; CHECK-NEXT: [[TMP3:%.*]] = call i8 @llvm.umax.i8(i8 [[Z:%.*]], i8 [[TMP2]])
868-
; CHECK-NEXT: [[NX_CCC:%.*]] = sub i8 11, [[TMP3]]
869-
; CHECK-NEXT: [[OR:%.*]] = or i8 [[NX_CCC]], [[MASK]]
870-
; CHECK-NEXT: [[R:%.*]] = icmp eq i8 [[OR]], -1
862+
; CHECK-NEXT: [[TMP4:%.*]] = add i8 [[TMP3]], -12
863+
; CHECK-NEXT: [[R:%.*]] = icmp ule i8 [[TMP4]], [[MASK]]
871864
; CHECK-NEXT: ret i1 [[R]]
872865
;
873866
%mask0 = lshr i8 -1, %y
@@ -886,9 +879,7 @@ define i1 @src_x_or_mask_ne(i8 %x, i8 %y, i1 %cond) {
886879
; CHECK-LABEL: @src_x_or_mask_ne(
887880
; CHECK-NEXT: [[MASK0:%.*]] = lshr i8 -1, [[Y:%.*]]
888881
; CHECK-NEXT: [[MASK:%.*]] = select i1 [[COND:%.*]], i8 [[MASK0]], i8 0
889-
; CHECK-NEXT: [[NX:%.*]] = xor i8 [[X:%.*]], -1
890-
; CHECK-NEXT: [[OR:%.*]] = or i8 [[MASK]], [[NX]]
891-
; CHECK-NEXT: [[R:%.*]] = icmp ne i8 [[OR]], -1
882+
; CHECK-NEXT: [[R:%.*]] = icmp ult i8 [[MASK]], [[X:%.*]]
892883
; CHECK-NEXT: ret i1 [[R]]
893884
;
894885
%mask0 = lshr i8 -1, %y

0 commit comments

Comments
 (0)