Skip to content

Commit d182569

Browse files
committed
[InstCombine] Improve mask detection in foldICmpWithLowBitMaskedVal
Make recursive matcher that is able to detect a lot more patterns. Proofs for all supported patterns: https://alive2.llvm.org/ce/z/fSQ3nZ Differential Revision: https://reviews.llvm.org/D159058
1 parent cbc7ab7 commit d182569

14 files changed

+193
-67
lines changed

llvm/include/llvm/IR/PatternMatch.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,19 @@ inline api_pred_ty<is_negated_power2> m_NegatedPower2(const APInt *&V) {
564564
return V;
565565
}
566566

567+
struct is_negated_power2_or_zero {
568+
bool isValue(const APInt &C) { return !C || C.isNegatedPowerOf2(); }
569+
};
570+
/// Match a integer or vector negated power-of-2.
571+
/// For vectors, this includes constants with undefined elements.
572+
inline cst_pred_ty<is_negated_power2_or_zero> m_NegatedPower2OrZero() {
573+
return cst_pred_ty<is_negated_power2_or_zero>();
574+
}
575+
inline api_pred_ty<is_negated_power2_or_zero>
576+
m_NegatedPower2OrZero(const APInt *&V) {
577+
return V;
578+
}
579+
567580
struct is_power2_or_zero {
568581
bool isValue(const APInt &C) { return !C || C.isPowerOf2(); }
569582
};
@@ -595,6 +608,18 @@ inline cst_pred_ty<is_lowbit_mask> m_LowBitMask() {
595608
}
596609
inline api_pred_ty<is_lowbit_mask> m_LowBitMask(const APInt *&V) { return V; }
597610

611+
struct is_lowbit_mask_or_zero {
612+
bool isValue(const APInt &C) { return !C || C.isMask(); }
613+
};
614+
/// Match an integer or vector with only the low bit(s) set.
615+
/// For vectors, this includes constants with undefined elements.
616+
inline cst_pred_ty<is_lowbit_mask_or_zero> m_LowBitMaskOrZero() {
617+
return cst_pred_ty<is_lowbit_mask_or_zero>();
618+
}
619+
inline api_pred_ty<is_lowbit_mask_or_zero> m_LowBitMaskOrZero(const APInt *&V) {
620+
return V;
621+
}
622+
598623
struct icmp_pred_with_threshold {
599624
ICmpInst::Predicate Pred;
600625
const APInt *Thr;

llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp

Lines changed: 120 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4068,6 +4068,95 @@ Instruction *InstCombinerImpl::foldSelectICmp(ICmpInst::Predicate Pred,
40684068
return nullptr;
40694069
}
40704070

4071+
// Returns whether V is a Mask ((X + 1) & X == 0) or ~Mask (-Pow2OrZero)
4072+
static bool isMaskOrZero(const Value *V, bool Not, const SimplifyQuery &Q,
4073+
unsigned Depth = 0) {
4074+
if (Not ? match(V, m_NegatedPower2OrZero()) : match(V, m_LowBitMaskOrZero()))
4075+
return true;
4076+
if (V->getType()->getScalarSizeInBits() == 1)
4077+
return true;
4078+
if (Depth++ >= MaxAnalysisRecursionDepth)
4079+
return false;
4080+
Value *X;
4081+
const Operator *I = dyn_cast<Operator>(V);
4082+
if (I == nullptr)
4083+
return false;
4084+
switch (I->getOpcode()) {
4085+
case Instruction::ZExt:
4086+
// ZExt(Mask) is a Mask.
4087+
return !Not && isMaskOrZero(I->getOperand(0), Not, Q, Depth);
4088+
case Instruction::SExt:
4089+
// SExt(Mask) is a Mask.
4090+
// SExt(~Mask) is a ~Mask.
4091+
return isMaskOrZero(I->getOperand(0), Not, Q, Depth);
4092+
case Instruction::And:
4093+
case Instruction::Or:
4094+
// Mask0 | Mask1 is a Mask.
4095+
// Mask0 & Mask1 is a Mask.
4096+
// ~Mask0 | ~Mask1 is a ~Mask.
4097+
// ~Mask0 & ~Mask1 is a ~Mask.
4098+
return isMaskOrZero(I->getOperand(1), Not, Q, Depth) &&
4099+
isMaskOrZero(I->getOperand(0), Not, Q, Depth);
4100+
case Instruction::Xor:
4101+
if (match(V, m_Not(m_Value(X))))
4102+
return isMaskOrZero(X, !Not, Q, Depth);
4103+
4104+
// (X ^ (X - 1)) is a Mask
4105+
return !Not &&
4106+
match(V, m_c_Xor(m_Value(X), m_Add(m_Deferred(X), m_AllOnes())));
4107+
case Instruction::Select:
4108+
// c ? Mask0 : Mask1 is a Mask.
4109+
return isMaskOrZero(I->getOperand(1), Not, Q, Depth) &&
4110+
isMaskOrZero(I->getOperand(2), Not, Q, Depth);
4111+
case Instruction::Shl:
4112+
// (~Mask) << X is a ~Mask.
4113+
return Not && isMaskOrZero(I->getOperand(0), Not, Q, Depth);
4114+
case Instruction::LShr:
4115+
// Mask >> X is a Mask.
4116+
return !Not && isMaskOrZero(I->getOperand(0), Not, Q, Depth);
4117+
case Instruction::AShr:
4118+
// Mask s>> X is a Mask.
4119+
// ~Mask s>> X is a ~Mask.
4120+
return isMaskOrZero(I->getOperand(0), Not, Q, Depth);
4121+
case Instruction::Add:
4122+
// Pow2 - 1 is a Mask.
4123+
if (!Not && match(I->getOperand(1), m_AllOnes()))
4124+
return isKnownToBeAPowerOfTwo(I->getOperand(0), Q.DL, /*OrZero*/ true,
4125+
Depth, Q.AC, Q.CxtI, Q.DT);
4126+
break;
4127+
case Instruction::Sub:
4128+
// -Pow2 is a ~Mask.
4129+
if (Not && match(I->getOperand(0), m_Zero()))
4130+
return isKnownToBeAPowerOfTwo(I->getOperand(1), Q.DL, /*OrZero*/ true,
4131+
Depth, Q.AC, Q.CxtI, Q.DT);
4132+
break;
4133+
case Instruction::Call: {
4134+
if (auto *II = dyn_cast<IntrinsicInst>(I)) {
4135+
switch (II->getIntrinsicID()) {
4136+
// min/max(Mask0, Mask1) is a Mask.
4137+
// min/max(~Mask0, ~Mask1) is a ~Mask.
4138+
case Intrinsic::umax:
4139+
case Intrinsic::smax:
4140+
case Intrinsic::umin:
4141+
case Intrinsic::smin:
4142+
return isMaskOrZero(II->getArgOperand(1), Not, Q, Depth) &&
4143+
isMaskOrZero(II->getArgOperand(0), Not, Q, Depth);
4144+
4145+
// In the context of masks, bitreverse(Mask) == ~Mask
4146+
case Intrinsic::bitreverse:
4147+
return isMaskOrZero(II->getArgOperand(0), !Not, Q, Depth);
4148+
default:
4149+
break;
4150+
}
4151+
}
4152+
break;
4153+
}
4154+
default:
4155+
break;
4156+
}
4157+
return false;
4158+
}
4159+
40714160
/// Some comparisons can be simplified.
40724161
/// In this case, we are looking for comparisons that look like
40734162
/// a check for a lossy truncation.
@@ -4081,21 +4170,35 @@ Instruction *InstCombinerImpl::foldSelectICmp(ICmpInst::Predicate Pred,
40814170
/// The Mask can be a constant, too.
40824171
/// For some predicates, the operands are commutative.
40834172
/// For others, x can only be on a specific side.
4084-
static Value *foldICmpWithLowBitMaskedVal(ICmpInst &I,
4085-
InstCombiner::BuilderTy &Builder) {
4173+
static Value *foldICmpWithLowBitMaskedVal(ICmpInst &I, const SimplifyQuery &Q,
4174+
InstCombiner &IC) {
4175+
4176+
Value *X, *M;
4177+
ICmpInst::Predicate Pred = I.getPredicate();
40864178
ICmpInst::Predicate SrcPred;
4087-
Value *X, *M, *Y;
4088-
auto m_VariableMask = m_CombineOr(
4089-
m_CombineOr(m_Not(m_Shl(m_AllOnes(), m_Value())),
4090-
m_Add(m_Shl(m_One(), m_Value()), m_AllOnes())),
4091-
m_CombineOr(m_LShr(m_AllOnes(), m_Value()),
4092-
m_LShr(m_Shl(m_AllOnes(), m_Value(Y)), m_Deferred(Y))));
4093-
auto m_Mask = m_CombineOr(m_VariableMask, m_LowBitMask());
4094-
if (!match(&I, m_c_ICmp(SrcPred,
4095-
m_c_And(m_CombineAnd(m_Mask, m_Value(M)), m_Value(X)),
4096-
m_Deferred(X))))
4097-
return nullptr;
4179+
bool NeedsNot = false;
40984180

4181+
auto CheckMask = [&](Value *V, bool Not) {
4182+
if (!ICmpInst::isSigned(Pred))
4183+
return isMaskOrZero(V, Not, Q);
4184+
return Not ? match(V, m_NegatedPower2OrZero())
4185+
: match(V, m_LowBitMaskOrZero());
4186+
};
4187+
4188+
auto TryMatch = [&](unsigned OpNo) {
4189+
SrcPred = Pred;
4190+
if (match(I.getOperand(OpNo),
4191+
m_c_And(m_Specific(I.getOperand(1 - OpNo)), m_Value(M)))) {
4192+
X = I.getOperand(1 - OpNo);
4193+
if (OpNo)
4194+
SrcPred = ICmpInst::getSwappedPredicate(Pred);
4195+
return CheckMask(M, /*Not*/ false);
4196+
}
4197+
return false;
4198+
};
4199+
4200+
if (!TryMatch(0) && !TryMatch(1))
4201+
return nullptr;
40994202
ICmpInst::Predicate DstPred;
41004203
switch (SrcPred) {
41014204
case ICmpInst::Predicate::ICMP_EQ:
@@ -4163,7 +4266,9 @@ static Value *foldICmpWithLowBitMaskedVal(ICmpInst &I,
41634266
M = Constant::replaceUndefsWith(VecC, SafeReplacementConstant);
41644267
}
41654268

4166-
return Builder.CreateICmp(DstPred, X, M);
4269+
if (NeedsNot)
4270+
M = IC.Builder.CreateNot(M);
4271+
return IC.Builder.CreateICmp(DstPred, X, M);
41674272
}
41684273

41694274
/// Some comparisons can be simplified.
@@ -5080,7 +5185,7 @@ Instruction *InstCombinerImpl::foldICmpBinOp(ICmpInst &I,
50805185
if (Value *V = foldMultiplicationOverflowCheck(I))
50815186
return replaceInstUsesWith(I, V);
50825187

5083-
if (Value *V = foldICmpWithLowBitMaskedVal(I, Builder))
5188+
if (Value *V = foldICmpWithLowBitMaskedVal(I, Q, *this))
50845189
return replaceInstUsesWith(I, V);
50855190

50865191
if (Instruction *R = foldICmpAndXX(I, Q, *this))

llvm/test/Transforms/InstCombine/canonicalize-constant-low-bit-mask-and-icmp-eq-to-icmp-ule.ll

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,7 @@ define <2 x i1> @p2_vec_nonsplat(<2 x i8> %x) {
6262

6363
define <2 x i1> @p2_vec_nonsplat_edgecase0(<2 x i8> %x) {
6464
; CHECK-LABEL: @p2_vec_nonsplat_edgecase0(
65-
; CHECK-NEXT: [[TMP1:%.*]] = and <2 x i8> [[X:%.*]], <i8 -4, i8 -1>
66-
; CHECK-NEXT: [[RET:%.*]] = icmp eq <2 x i8> [[TMP1]], zeroinitializer
65+
; CHECK-NEXT: [[RET:%.*]] = icmp ult <2 x i8> [[X:%.*]], <i8 4, i8 1>
6766
; CHECK-NEXT: ret <2 x i1> [[RET]]
6867
;
6968
%tmp0 = and <2 x i8> %x, <i8 3, i8 0>

llvm/test/Transforms/InstCombine/canonicalize-constant-low-bit-mask-and-icmp-ne-to-icmp-ugt.ll

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,7 @@ define <2 x i1> @p2_vec_nonsplat(<2 x i8> %x) {
6262

6363
define <2 x i1> @p2_vec_nonsplat_edgecase0(<2 x i8> %x) {
6464
; CHECK-LABEL: @p2_vec_nonsplat_edgecase0(
65-
; CHECK-NEXT: [[TMP1:%.*]] = and <2 x i8> [[X:%.*]], <i8 -4, i8 -1>
66-
; CHECK-NEXT: [[RET:%.*]] = icmp ne <2 x i8> [[TMP1]], zeroinitializer
65+
; CHECK-NEXT: [[RET:%.*]] = icmp ugt <2 x i8> [[X:%.*]], <i8 3, i8 0>
6766
; CHECK-NEXT: ret <2 x i1> [[RET]]
6867
;
6968
%tmp0 = and <2 x i8> %x, <i8 3, i8 0>

llvm/test/Transforms/InstCombine/canonicalize-constant-low-bit-mask-and-icmp-sge-to-icmp-sle.ll

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,7 @@ define <2 x i1> @p2_vec_nonsplat(<2 x i8> %x) {
5050

5151
define <2 x i1> @p2_vec_nonsplat_edgecase(<2 x i8> %x) {
5252
; CHECK-LABEL: @p2_vec_nonsplat_edgecase(
53-
; CHECK-NEXT: [[TMP0:%.*]] = and <2 x i8> [[X:%.*]], <i8 3, i8 0>
54-
; CHECK-NEXT: [[RET:%.*]] = icmp sge <2 x i8> [[TMP0]], [[X]]
53+
; CHECK-NEXT: [[RET:%.*]] = icmp slt <2 x i8> [[X:%.*]], <i8 4, i8 1>
5554
; CHECK-NEXT: ret <2 x i1> [[RET]]
5655
;
5756
%tmp0 = and <2 x i8> %x, <i8 3, i8 0>

llvm/test/Transforms/InstCombine/canonicalize-constant-low-bit-mask-and-icmp-sgt-to-icmp-sgt.ll

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,7 @@ define <2 x i1> @p2_vec_nonsplat() {
6363
define <2 x i1> @p2_vec_nonsplat_edgecase() {
6464
; CHECK-LABEL: @p2_vec_nonsplat_edgecase(
6565
; CHECK-NEXT: [[X:%.*]] = call <2 x i8> @gen2x8()
66-
; CHECK-NEXT: [[TMP0:%.*]] = and <2 x i8> [[X]], <i8 3, i8 0>
67-
; CHECK-NEXT: [[RET:%.*]] = icmp sgt <2 x i8> [[X]], [[TMP0]]
66+
; CHECK-NEXT: [[RET:%.*]] = icmp sgt <2 x i8> [[X]], <i8 3, i8 0>
6867
; CHECK-NEXT: ret <2 x i1> [[RET]]
6968
;
7069
%x = call <2 x i8> @gen2x8()

llvm/test/Transforms/InstCombine/canonicalize-constant-low-bit-mask-and-icmp-sle-to-icmp-sle.ll

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,7 @@ define <2 x i1> @p2_vec_nonsplat() {
6363
define <2 x i1> @p2_vec_nonsplat_edgecase() {
6464
; CHECK-LABEL: @p2_vec_nonsplat_edgecase(
6565
; CHECK-NEXT: [[X:%.*]] = call <2 x i8> @gen2x8()
66-
; CHECK-NEXT: [[TMP0:%.*]] = and <2 x i8> [[X]], <i8 3, i8 0>
67-
; CHECK-NEXT: [[RET:%.*]] = icmp sle <2 x i8> [[X]], [[TMP0]]
66+
; CHECK-NEXT: [[RET:%.*]] = icmp slt <2 x i8> [[X]], <i8 4, i8 1>
6867
; CHECK-NEXT: ret <2 x i1> [[RET]]
6968
;
7069
%x = call <2 x i8> @gen2x8()

llvm/test/Transforms/InstCombine/canonicalize-constant-low-bit-mask-and-icmp-slt-to-icmp-sgt.ll

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,7 @@ define <2 x i1> @p2_vec_nonsplat(<2 x i8> %x) {
5050

5151
define <2 x i1> @p2_vec_nonsplat_edgecase(<2 x i8> %x) {
5252
; CHECK-LABEL: @p2_vec_nonsplat_edgecase(
53-
; CHECK-NEXT: [[TMP0:%.*]] = and <2 x i8> [[X:%.*]], <i8 3, i8 0>
54-
; CHECK-NEXT: [[RET:%.*]] = icmp slt <2 x i8> [[TMP0]], [[X]]
53+
; CHECK-NEXT: [[RET:%.*]] = icmp sgt <2 x i8> [[X:%.*]], <i8 3, i8 0>
5554
; CHECK-NEXT: ret <2 x i1> [[RET]]
5655
;
5756
%tmp0 = and <2 x i8> %x, <i8 3, i8 0>

llvm/test/Transforms/InstCombine/canonicalize-constant-low-bit-mask-and-icmp-uge-to-icmp-ule.ll

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,7 @@ define <2 x i1> @p2_vec_nonsplat(<2 x i8> %x) {
6262

6363
define <2 x i1> @p2_vec_nonsplat_edgecase0(<2 x i8> %x) {
6464
; CHECK-LABEL: @p2_vec_nonsplat_edgecase0(
65-
; CHECK-NEXT: [[TMP1:%.*]] = and <2 x i8> [[X:%.*]], <i8 -4, i8 -1>
66-
; CHECK-NEXT: [[RET:%.*]] = icmp eq <2 x i8> [[TMP1]], zeroinitializer
65+
; CHECK-NEXT: [[RET:%.*]] = icmp ult <2 x i8> [[X:%.*]], <i8 4, i8 1>
6766
; CHECK-NEXT: ret <2 x i1> [[RET]]
6867
;
6968
%tmp0 = and <2 x i8> %x, <i8 3, i8 0>

llvm/test/Transforms/InstCombine/canonicalize-constant-low-bit-mask-and-icmp-ugt-to-icmp-ugt.ll

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,7 @@ define <2 x i1> @p2_vec_nonsplat() {
7575
define <2 x i1> @p2_vec_nonsplat_edgecase0() {
7676
; CHECK-LABEL: @p2_vec_nonsplat_edgecase0(
7777
; CHECK-NEXT: [[X:%.*]] = call <2 x i8> @gen2x8()
78-
; CHECK-NEXT: [[TMP1:%.*]] = and <2 x i8> [[X]], <i8 -4, i8 -1>
79-
; CHECK-NEXT: [[RET:%.*]] = icmp ne <2 x i8> [[TMP1]], zeroinitializer
78+
; CHECK-NEXT: [[RET:%.*]] = icmp ugt <2 x i8> [[X]], <i8 3, i8 0>
8079
; CHECK-NEXT: ret <2 x i1> [[RET]]
8180
;
8281
%x = call <2 x i8> @gen2x8()

llvm/test/Transforms/InstCombine/canonicalize-constant-low-bit-mask-and-icmp-ule-to-icmp-ule.ll

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,7 @@ define <2 x i1> @p2_vec_nonsplat() {
7575
define <2 x i1> @p2_vec_nonsplat_edgecase0() {
7676
; CHECK-LABEL: @p2_vec_nonsplat_edgecase0(
7777
; CHECK-NEXT: [[X:%.*]] = call <2 x i8> @gen2x8()
78-
; CHECK-NEXT: [[TMP1:%.*]] = and <2 x i8> [[X]], <i8 -4, i8 -1>
79-
; CHECK-NEXT: [[RET:%.*]] = icmp eq <2 x i8> [[TMP1]], zeroinitializer
78+
; CHECK-NEXT: [[RET:%.*]] = icmp ult <2 x i8> [[X]], <i8 4, i8 1>
8079
; CHECK-NEXT: ret <2 x i1> [[RET]]
8180
;
8281
%x = call <2 x i8> @gen2x8()

llvm/test/Transforms/InstCombine/canonicalize-constant-low-bit-mask-and-icmp-ult-to-icmp-ugt.ll

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,7 @@ define <2 x i1> @p2_vec_nonsplat(<2 x i8> %x) {
6262

6363
define <2 x i1> @p2_vec_nonsplat_edgecase0(<2 x i8> %x) {
6464
; CHECK-LABEL: @p2_vec_nonsplat_edgecase0(
65-
; CHECK-NEXT: [[TMP1:%.*]] = and <2 x i8> [[X:%.*]], <i8 -4, i8 -1>
66-
; CHECK-NEXT: [[RET:%.*]] = icmp ne <2 x i8> [[TMP1]], zeroinitializer
65+
; CHECK-NEXT: [[RET:%.*]] = icmp ugt <2 x i8> [[X:%.*]], <i8 3, i8 0>
6766
; CHECK-NEXT: ret <2 x i1> [[RET]]
6867
;
6968
%tmp0 = and <2 x i8> %x, <i8 3, i8 0>

0 commit comments

Comments
 (0)