Skip to content

Commit 2a5aefd

Browse files
committed
[InstCombine] Improve mask detection in foldICmpWithLowBitMaskedVal
Make recursive matcher that is able to detect a lot more patterns. Proofs for all supported patterns: https://alive2.llvm.org/ce/z/fSQ3nZ Differential Revision: https://reviews.llvm.org/D159058
1 parent b2a1e18 commit 2a5aefd

13 files changed

+185
-67
lines changed

llvm/include/llvm/IR/PatternMatch.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,19 @@ inline api_pred_ty<is_negated_power2> m_NegatedPower2(const APInt *&V) {
564564
return V;
565565
}
566566

567+
struct is_negated_power2_or_zero {
568+
bool isValue(const APInt &C) { return !C || C.isNegatedPowerOf2(); }
569+
};
570+
/// Match a integer or vector negated power-of-2.
571+
/// For vectors, this includes constants with undefined elements.
572+
inline cst_pred_ty<is_negated_power2_or_zero> m_NegatedPower2OrZero() {
573+
return cst_pred_ty<is_negated_power2_or_zero>();
574+
}
575+
inline api_pred_ty<is_negated_power2_or_zero>
576+
m_NegatedPower2OrZero(const APInt *&V) {
577+
return V;
578+
}
579+
567580
struct is_power2_or_zero {
568581
bool isValue(const APInt &C) { return !C || C.isPowerOf2(); }
569582
};
@@ -595,6 +608,18 @@ inline cst_pred_ty<is_lowbit_mask> m_LowBitMask() {
595608
}
596609
inline api_pred_ty<is_lowbit_mask> m_LowBitMask(const APInt *&V) { return V; }
597610

611+
struct is_lowbit_mask_or_zero {
612+
bool isValue(const APInt &C) { return !C || C.isMask(); }
613+
};
614+
/// Match an integer or vector with only the low bit(s) set.
615+
/// For vectors, this includes constants with undefined elements.
616+
inline cst_pred_ty<is_lowbit_mask_or_zero> m_LowBitMaskOrZero() {
617+
return cst_pred_ty<is_lowbit_mask_or_zero>();
618+
}
619+
inline api_pred_ty<is_lowbit_mask_or_zero> m_LowBitMaskOrZero(const APInt *&V) {
620+
return V;
621+
}
622+
598623
struct icmp_pred_with_threshold {
599624
ICmpInst::Predicate Pred;
600625
const APInt *Thr;

llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp

Lines changed: 134 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4068,6 +4068,109 @@ Instruction *InstCombinerImpl::foldSelectICmp(ICmpInst::Predicate Pred,
40684068
return nullptr;
40694069
}
40704070

4071+
// Returns of V is a Mask ((X + 1) & X == 0) or ~Mask (-Pow2OrZero)
4072+
static bool isMaskOrZero(const Value *V, bool Not, const SimplifyQuery &Q,
4073+
unsigned Depth = 0) {
4074+
if (Not ? match(V, m_NegatedPower2OrZero()) : match(V, m_LowBitMaskOrZero()))
4075+
return true;
4076+
if (V->getType()->getScalarSizeInBits() == 1)
4077+
return true;
4078+
if (Depth++ >= MaxAnalysisRecursionDepth)
4079+
return false;
4080+
Value *X;
4081+
if (match(V, m_Not(m_Value(X))))
4082+
return isMaskOrZero(X, !Not, Q, Depth);
4083+
const Operator *I = dyn_cast<Operator>(V);
4084+
if (I == nullptr)
4085+
return false;
4086+
switch (I->getOpcode()) {
4087+
case Instruction::ZExt:
4088+
// ZExt(Mask) is a Mask.
4089+
return !Not && isMaskOrZero(I->getOperand(0), Not, Q, Depth);
4090+
case Instruction::SExt:
4091+
// SExt(Mask) is a Mask.
4092+
// SExt(~Mask) is a ~Mask.
4093+
return isMaskOrZero(I->getOperand(0), Not, Q, Depth);
4094+
case Instruction::And:
4095+
case Instruction::Or:
4096+
// Mask0 | Mask1 is a Mask.
4097+
// Mask0 & Mask1 is a Mask.
4098+
// ~Mask0 | ~Mask1 is a ~Mask.
4099+
// ~Mask0 & ~Mask1 is a ~Mask.
4100+
return isMaskOrZero(I->getOperand(1), Not, Q, Depth) &&
4101+
isMaskOrZero(I->getOperand(0), Not, Q, Depth);
4102+
case Instruction::Xor:
4103+
// (X ^ (X - 1)) is a Mask
4104+
return match(V, m_c_Xor(m_Value(X), m_Add(m_Deferred(X), m_AllOnes())));
4105+
case Instruction::Select:
4106+
// c ? Mask0 : Mask1 is a Mask.
4107+
return isMaskOrZero(I->getOperand(1), Not, Q, Depth) &&
4108+
isMaskOrZero(I->getOperand(2), Not, Q, Depth);
4109+
case Instruction::Shl:
4110+
if (Not) {
4111+
// (-1 >> X) << X is ~Mask
4112+
if (match(I->getOperand(0),
4113+
m_Shr(m_AllOnes(), m_Specific(I->getOperand(1)))))
4114+
return true;
4115+
4116+
// (~Mask) << X is a ~Mask.
4117+
return isMaskOrZero(I->getOperand(0), Not, Q, Depth);
4118+
}
4119+
break;
4120+
case Instruction::LShr:
4121+
if (!Not) {
4122+
// (-1 << X) >> X is a Mask
4123+
if (match(I->getOperand(0),
4124+
m_Shl(m_AllOnes(), m_Specific(I->getOperand(1)))))
4125+
return true;
4126+
// Mask >> X is a Mask.
4127+
return isMaskOrZero(I->getOperand(0), Not, Q, Depth);
4128+
}
4129+
return false;
4130+
case Instruction::AShr:
4131+
// Mask s>> X is a Mask.
4132+
// ~Mask s>> X is a ~Mask.
4133+
return isMaskOrZero(I->getOperand(0), Not, Q, Depth);
4134+
case Instruction::Add:
4135+
// Pow2 - 1 is a Mask.
4136+
if (!Not && match(I->getOperand(1), m_AllOnes()))
4137+
return isKnownToBeAPowerOfTwo(I->getOperand(0), Q.DL, /*OrZero*/ true,
4138+
Depth, Q.AC, Q.CxtI, Q.DT);
4139+
break;
4140+
case Instruction::Sub:
4141+
// -Pow2 is a ~Mask.
4142+
if (Not && match(I->getOperand(0), m_Zero()))
4143+
return isKnownToBeAPowerOfTwo(I->getOperand(1), Q.DL, /*OrZero*/ true,
4144+
Depth, Q.AC, Q.CxtI, Q.DT);
4145+
break;
4146+
case Instruction::Invoke:
4147+
case Instruction::Call: {
4148+
if (auto *II = dyn_cast<IntrinsicInst>(I)) {
4149+
switch (II->getIntrinsicID()) {
4150+
// min/max(Mask0, Mask1) is a Mask.
4151+
// min/max(~Mask0, ~Mask1) is a ~Mask.
4152+
case Intrinsic::umax:
4153+
case Intrinsic::smax:
4154+
case Intrinsic::umin:
4155+
case Intrinsic::smin:
4156+
return isMaskOrZero(II->getArgOperand(1), Not, Q, Depth) &&
4157+
isMaskOrZero(II->getArgOperand(0), Not, Q, Depth);
4158+
4159+
// In the context of masks, bitreverse(Mask) == ~Mask
4160+
case Intrinsic::bitreverse:
4161+
return isMaskOrZero(II->getArgOperand(0), !Not, Q, Depth);
4162+
default:
4163+
break;
4164+
}
4165+
}
4166+
break;
4167+
}
4168+
default:
4169+
break;
4170+
}
4171+
return false;
4172+
}
4173+
40714174
/// Some comparisons can be simplified.
40724175
/// In this case, we are looking for comparisons that look like
40734176
/// a check for a lossy truncation.
@@ -4081,21 +4184,35 @@ Instruction *InstCombinerImpl::foldSelectICmp(ICmpInst::Predicate Pred,
40814184
/// The Mask can be a constant, too.
40824185
/// For some predicates, the operands are commutative.
40834186
/// For others, x can only be on a specific side.
4084-
static Value *foldICmpWithLowBitMaskedVal(ICmpInst &I,
4085-
InstCombiner::BuilderTy &Builder) {
4187+
static Value *foldICmpWithLowBitMaskedVal(ICmpInst &I, const SimplifyQuery &Q,
4188+
InstCombiner &IC) {
4189+
4190+
Value *X, *M;
4191+
ICmpInst::Predicate Pred = I.getPredicate();
40864192
ICmpInst::Predicate SrcPred;
4087-
Value *X, *M, *Y;
4088-
auto m_VariableMask = m_CombineOr(
4089-
m_CombineOr(m_Not(m_Shl(m_AllOnes(), m_Value())),
4090-
m_Add(m_Shl(m_One(), m_Value()), m_AllOnes())),
4091-
m_CombineOr(m_LShr(m_AllOnes(), m_Value()),
4092-
m_LShr(m_Shl(m_AllOnes(), m_Value(Y)), m_Deferred(Y))));
4093-
auto m_Mask = m_CombineOr(m_VariableMask, m_LowBitMask());
4094-
if (!match(&I, m_c_ICmp(SrcPred,
4095-
m_c_And(m_CombineAnd(m_Mask, m_Value(M)), m_Value(X)),
4096-
m_Deferred(X))))
4097-
return nullptr;
4193+
bool NeedsNot = false;
4194+
4195+
auto CheckMask = [&](Value *V, bool Not) {
4196+
if (!ICmpInst::isSigned(Pred))
4197+
return isMaskOrZero(V, Not, Q);
4198+
return Not ? match(V, m_NegatedPower2OrZero())
4199+
: match(V, m_LowBitMaskOrZero());
4200+
};
40984201

4202+
auto TryMatch = [&](unsigned OpNo) {
4203+
SrcPred = Pred;
4204+
if (match(I.getOperand(OpNo),
4205+
m_c_And(m_Specific(I.getOperand(1 - OpNo)), m_Value(M)))) {
4206+
X = I.getOperand(1 - OpNo);
4207+
if (OpNo)
4208+
SrcPred = ICmpInst::getSwappedPredicate(Pred);
4209+
return CheckMask(M, /*Not*/ false);
4210+
}
4211+
return false;
4212+
};
4213+
4214+
if (!TryMatch(0) && !TryMatch(1))
4215+
return nullptr;
40994216
ICmpInst::Predicate DstPred;
41004217
switch (SrcPred) {
41014218
case ICmpInst::Predicate::ICMP_EQ:
@@ -4163,7 +4280,9 @@ static Value *foldICmpWithLowBitMaskedVal(ICmpInst &I,
41634280
M = Constant::replaceUndefsWith(VecC, SafeReplacementConstant);
41644281
}
41654282

4166-
return Builder.CreateICmp(DstPred, X, M);
4283+
if (NeedsNot)
4284+
M = IC.Builder.CreateNot(M);
4285+
return IC.Builder.CreateICmp(DstPred, X, M);
41674286
}
41684287

41694288
/// Some comparisons can be simplified.
@@ -5080,7 +5199,7 @@ Instruction *InstCombinerImpl::foldICmpBinOp(ICmpInst &I,
50805199
if (Value *V = foldMultiplicationOverflowCheck(I))
50815200
return replaceInstUsesWith(I, V);
50825201

5083-
if (Value *V = foldICmpWithLowBitMaskedVal(I, Builder))
5202+
if (Value *V = foldICmpWithLowBitMaskedVal(I, Q, *this))
50845203
return replaceInstUsesWith(I, V);
50855204

50865205
if (Instruction *R = foldICmpAndXX(I, Q, *this))

llvm/test/Transforms/InstCombine/canonicalize-constant-low-bit-mask-and-icmp-eq-to-icmp-ule.ll

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,7 @@ define <2 x i1> @p2_vec_nonsplat(<2 x i8> %x) {
6262

6363
define <2 x i1> @p2_vec_nonsplat_edgecase0(<2 x i8> %x) {
6464
; CHECK-LABEL: @p2_vec_nonsplat_edgecase0(
65-
; CHECK-NEXT: [[TMP1:%.*]] = and <2 x i8> [[X:%.*]], <i8 -4, i8 -1>
66-
; CHECK-NEXT: [[RET:%.*]] = icmp eq <2 x i8> [[TMP1]], zeroinitializer
65+
; CHECK-NEXT: [[RET:%.*]] = icmp ult <2 x i8> [[X:%.*]], <i8 4, i8 1>
6766
; CHECK-NEXT: ret <2 x i1> [[RET]]
6867
;
6968
%tmp0 = and <2 x i8> %x, <i8 3, i8 0>

llvm/test/Transforms/InstCombine/canonicalize-constant-low-bit-mask-and-icmp-ne-to-icmp-ugt.ll

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,7 @@ define <2 x i1> @p2_vec_nonsplat(<2 x i8> %x) {
6262

6363
define <2 x i1> @p2_vec_nonsplat_edgecase0(<2 x i8> %x) {
6464
; CHECK-LABEL: @p2_vec_nonsplat_edgecase0(
65-
; CHECK-NEXT: [[TMP1:%.*]] = and <2 x i8> [[X:%.*]], <i8 -4, i8 -1>
66-
; CHECK-NEXT: [[RET:%.*]] = icmp ne <2 x i8> [[TMP1]], zeroinitializer
65+
; CHECK-NEXT: [[RET:%.*]] = icmp ugt <2 x i8> [[X:%.*]], <i8 3, i8 0>
6766
; CHECK-NEXT: ret <2 x i1> [[RET]]
6867
;
6968
%tmp0 = and <2 x i8> %x, <i8 3, i8 0>

llvm/test/Transforms/InstCombine/canonicalize-constant-low-bit-mask-and-icmp-sge-to-icmp-sle.ll

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,7 @@ define <2 x i1> @p2_vec_nonsplat(<2 x i8> %x) {
5050

5151
define <2 x i1> @p2_vec_nonsplat_edgecase(<2 x i8> %x) {
5252
; CHECK-LABEL: @p2_vec_nonsplat_edgecase(
53-
; CHECK-NEXT: [[TMP0:%.*]] = and <2 x i8> [[X:%.*]], <i8 3, i8 0>
54-
; CHECK-NEXT: [[RET:%.*]] = icmp sge <2 x i8> [[TMP0]], [[X]]
53+
; CHECK-NEXT: [[RET:%.*]] = icmp slt <2 x i8> [[X:%.*]], <i8 4, i8 1>
5554
; CHECK-NEXT: ret <2 x i1> [[RET]]
5655
;
5756
%tmp0 = and <2 x i8> %x, <i8 3, i8 0>

llvm/test/Transforms/InstCombine/canonicalize-constant-low-bit-mask-and-icmp-sgt-to-icmp-sgt.ll

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,7 @@ define <2 x i1> @p2_vec_nonsplat() {
6363
define <2 x i1> @p2_vec_nonsplat_edgecase() {
6464
; CHECK-LABEL: @p2_vec_nonsplat_edgecase(
6565
; CHECK-NEXT: [[X:%.*]] = call <2 x i8> @gen2x8()
66-
; CHECK-NEXT: [[TMP0:%.*]] = and <2 x i8> [[X]], <i8 3, i8 0>
67-
; CHECK-NEXT: [[RET:%.*]] = icmp sgt <2 x i8> [[X]], [[TMP0]]
66+
; CHECK-NEXT: [[RET:%.*]] = icmp sgt <2 x i8> [[X]], <i8 3, i8 0>
6867
; CHECK-NEXT: ret <2 x i1> [[RET]]
6968
;
7069
%x = call <2 x i8> @gen2x8()

llvm/test/Transforms/InstCombine/canonicalize-constant-low-bit-mask-and-icmp-sle-to-icmp-sle.ll

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,7 @@ define <2 x i1> @p2_vec_nonsplat() {
6363
define <2 x i1> @p2_vec_nonsplat_edgecase() {
6464
; CHECK-LABEL: @p2_vec_nonsplat_edgecase(
6565
; CHECK-NEXT: [[X:%.*]] = call <2 x i8> @gen2x8()
66-
; CHECK-NEXT: [[TMP0:%.*]] = and <2 x i8> [[X]], <i8 3, i8 0>
67-
; CHECK-NEXT: [[RET:%.*]] = icmp sle <2 x i8> [[X]], [[TMP0]]
66+
; CHECK-NEXT: [[RET:%.*]] = icmp slt <2 x i8> [[X]], <i8 4, i8 1>
6867
; CHECK-NEXT: ret <2 x i1> [[RET]]
6968
;
7069
%x = call <2 x i8> @gen2x8()

llvm/test/Transforms/InstCombine/canonicalize-constant-low-bit-mask-and-icmp-slt-to-icmp-sgt.ll

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,7 @@ define <2 x i1> @p2_vec_nonsplat(<2 x i8> %x) {
5050

5151
define <2 x i1> @p2_vec_nonsplat_edgecase(<2 x i8> %x) {
5252
; CHECK-LABEL: @p2_vec_nonsplat_edgecase(
53-
; CHECK-NEXT: [[TMP0:%.*]] = and <2 x i8> [[X:%.*]], <i8 3, i8 0>
54-
; CHECK-NEXT: [[RET:%.*]] = icmp slt <2 x i8> [[TMP0]], [[X]]
53+
; CHECK-NEXT: [[RET:%.*]] = icmp sgt <2 x i8> [[X:%.*]], <i8 3, i8 0>
5554
; CHECK-NEXT: ret <2 x i1> [[RET]]
5655
;
5756
%tmp0 = and <2 x i8> %x, <i8 3, i8 0>

llvm/test/Transforms/InstCombine/canonicalize-constant-low-bit-mask-and-icmp-uge-to-icmp-ule.ll

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,7 @@ define <2 x i1> @p2_vec_nonsplat(<2 x i8> %x) {
6262

6363
define <2 x i1> @p2_vec_nonsplat_edgecase0(<2 x i8> %x) {
6464
; CHECK-LABEL: @p2_vec_nonsplat_edgecase0(
65-
; CHECK-NEXT: [[TMP1:%.*]] = and <2 x i8> [[X:%.*]], <i8 -4, i8 -1>
66-
; CHECK-NEXT: [[RET:%.*]] = icmp eq <2 x i8> [[TMP1]], zeroinitializer
65+
; CHECK-NEXT: [[RET:%.*]] = icmp ult <2 x i8> [[X:%.*]], <i8 4, i8 1>
6766
; CHECK-NEXT: ret <2 x i1> [[RET]]
6867
;
6968
%tmp0 = and <2 x i8> %x, <i8 3, i8 0>

llvm/test/Transforms/InstCombine/canonicalize-constant-low-bit-mask-and-icmp-ugt-to-icmp-ugt.ll

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,7 @@ define <2 x i1> @p2_vec_nonsplat() {
7575
define <2 x i1> @p2_vec_nonsplat_edgecase0() {
7676
; CHECK-LABEL: @p2_vec_nonsplat_edgecase0(
7777
; CHECK-NEXT: [[X:%.*]] = call <2 x i8> @gen2x8()
78-
; CHECK-NEXT: [[TMP1:%.*]] = and <2 x i8> [[X]], <i8 -4, i8 -1>
79-
; CHECK-NEXT: [[RET:%.*]] = icmp ne <2 x i8> [[TMP1]], zeroinitializer
78+
; CHECK-NEXT: [[RET:%.*]] = icmp ugt <2 x i8> [[X]], <i8 3, i8 0>
8079
; CHECK-NEXT: ret <2 x i1> [[RET]]
8180
;
8281
%x = call <2 x i8> @gen2x8()

llvm/test/Transforms/InstCombine/canonicalize-constant-low-bit-mask-and-icmp-ule-to-icmp-ule.ll

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,7 @@ define <2 x i1> @p2_vec_nonsplat() {
7575
define <2 x i1> @p2_vec_nonsplat_edgecase0() {
7676
; CHECK-LABEL: @p2_vec_nonsplat_edgecase0(
7777
; CHECK-NEXT: [[X:%.*]] = call <2 x i8> @gen2x8()
78-
; CHECK-NEXT: [[TMP1:%.*]] = and <2 x i8> [[X]], <i8 -4, i8 -1>
79-
; CHECK-NEXT: [[RET:%.*]] = icmp eq <2 x i8> [[TMP1]], zeroinitializer
78+
; CHECK-NEXT: [[RET:%.*]] = icmp ult <2 x i8> [[X]], <i8 4, i8 1>
8079
; CHECK-NEXT: ret <2 x i1> [[RET]]
8180
;
8281
%x = call <2 x i8> @gen2x8()

llvm/test/Transforms/InstCombine/canonicalize-constant-low-bit-mask-and-icmp-ult-to-icmp-ugt.ll

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,7 @@ define <2 x i1> @p2_vec_nonsplat(<2 x i8> %x) {
6262

6363
define <2 x i1> @p2_vec_nonsplat_edgecase0(<2 x i8> %x) {
6464
; CHECK-LABEL: @p2_vec_nonsplat_edgecase0(
65-
; CHECK-NEXT: [[TMP1:%.*]] = and <2 x i8> [[X:%.*]], <i8 -4, i8 -1>
66-
; CHECK-NEXT: [[RET:%.*]] = icmp ne <2 x i8> [[TMP1]], zeroinitializer
65+
; CHECK-NEXT: [[RET:%.*]] = icmp ugt <2 x i8> [[X:%.*]], <i8 3, i8 0>
6766
; CHECK-NEXT: ret <2 x i1> [[RET]]
6867
;
6968
%tmp0 = and <2 x i8> %x, <i8 3, i8 0>

0 commit comments

Comments
 (0)