Skip to content

Commit d77eb9e

Browse files
committed
[InstCombine] Improve mask detection in foldICmpWithLowBitMaskedVal
Make recursive matcher that is able to detect a lot more patterns. Proofs for all supported patterns: https://alive2.llvm.org/ce/z/fSQ3nZ Differential Revision: https://reviews.llvm.org/D159058
1 parent f89e4e3 commit d77eb9e

14 files changed

+184
-68
lines changed

llvm/include/llvm/IR/PatternMatch.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,19 @@ inline api_pred_ty<is_negated_power2> m_NegatedPower2(const APInt *&V) {
564564
return V;
565565
}
566566

567+
struct is_negated_power2_or_zero {
568+
bool isValue(const APInt &C) { return !C || C.isNegatedPowerOf2(); }
569+
};
570+
/// Match a integer or vector negated power-of-2.
571+
/// For vectors, this includes constants with undefined elements.
572+
inline cst_pred_ty<is_negated_power2_or_zero> m_NegatedPower2OrZero() {
573+
return cst_pred_ty<is_negated_power2_or_zero>();
574+
}
575+
inline api_pred_ty<is_negated_power2_or_zero>
576+
m_NegatedPower2OrZero(const APInt *&V) {
577+
return V;
578+
}
579+
567580
struct is_power2_or_zero {
568581
bool isValue(const APInt &C) { return !C || C.isPowerOf2(); }
569582
};
@@ -595,6 +608,18 @@ inline cst_pred_ty<is_lowbit_mask> m_LowBitMask() {
595608
}
596609
inline api_pred_ty<is_lowbit_mask> m_LowBitMask(const APInt *&V) { return V; }
597610

611+
struct is_lowbit_mask_or_zero {
612+
bool isValue(const APInt &C) { return !C || C.isMask(); }
613+
};
614+
/// Match an integer or vector with only the low bit(s) set.
615+
/// For vectors, this includes constants with undefined elements.
616+
inline cst_pred_ty<is_lowbit_mask_or_zero> m_LowBitMaskOrZero() {
617+
return cst_pred_ty<is_lowbit_mask_or_zero>();
618+
}
619+
inline api_pred_ty<is_lowbit_mask_or_zero> m_LowBitMaskOrZero(const APInt *&V) {
620+
return V;
621+
}
622+
598623
struct icmp_pred_with_threshold {
599624
ICmpInst::Predicate Pred;
600625
const APInt *Thr;

llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp

Lines changed: 109 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4069,6 +4069,95 @@ Instruction *InstCombinerImpl::foldSelectICmp(ICmpInst::Predicate Pred,
40694069
return nullptr;
40704070
}
40714071

4072+
// Returns whether V is a Mask ((X + 1) & X == 0) or ~Mask (-Pow2OrZero)
4073+
static bool isMaskOrZero(const Value *V, bool Not, const SimplifyQuery &Q,
4074+
unsigned Depth = 0) {
4075+
if (Not ? match(V, m_NegatedPower2OrZero()) : match(V, m_LowBitMaskOrZero()))
4076+
return true;
4077+
if (V->getType()->getScalarSizeInBits() == 1)
4078+
return true;
4079+
if (Depth++ >= MaxAnalysisRecursionDepth)
4080+
return false;
4081+
Value *X;
4082+
const Instruction *I = dyn_cast<Instruction>(V);
4083+
if (!I)
4084+
return false;
4085+
switch (I->getOpcode()) {
4086+
case Instruction::ZExt:
4087+
// ZExt(Mask) is a Mask.
4088+
return !Not && isMaskOrZero(I->getOperand(0), Not, Q, Depth);
4089+
case Instruction::SExt:
4090+
// SExt(Mask) is a Mask.
4091+
// SExt(~Mask) is a ~Mask.
4092+
return isMaskOrZero(I->getOperand(0), Not, Q, Depth);
4093+
case Instruction::And:
4094+
case Instruction::Or:
4095+
// Mask0 | Mask1 is a Mask.
4096+
// Mask0 & Mask1 is a Mask.
4097+
// ~Mask0 | ~Mask1 is a ~Mask.
4098+
// ~Mask0 & ~Mask1 is a ~Mask.
4099+
return isMaskOrZero(I->getOperand(1), Not, Q, Depth) &&
4100+
isMaskOrZero(I->getOperand(0), Not, Q, Depth);
4101+
case Instruction::Xor:
4102+
if (match(V, m_Not(m_Value(X))))
4103+
return isMaskOrZero(X, !Not, Q, Depth);
4104+
4105+
// (X ^ (X - 1)) is a Mask
4106+
return !Not &&
4107+
match(V, m_c_Xor(m_Value(X), m_Add(m_Deferred(X), m_AllOnes())));
4108+
case Instruction::Select:
4109+
// c ? Mask0 : Mask1 is a Mask.
4110+
return isMaskOrZero(I->getOperand(1), Not, Q, Depth) &&
4111+
isMaskOrZero(I->getOperand(2), Not, Q, Depth);
4112+
case Instruction::Shl:
4113+
// (~Mask) << X is a ~Mask.
4114+
return Not && isMaskOrZero(I->getOperand(0), Not, Q, Depth);
4115+
case Instruction::LShr:
4116+
// Mask >> X is a Mask.
4117+
return !Not && isMaskOrZero(I->getOperand(0), Not, Q, Depth);
4118+
case Instruction::AShr:
4119+
// Mask s>> X is a Mask.
4120+
// ~Mask s>> X is a ~Mask.
4121+
return isMaskOrZero(I->getOperand(0), Not, Q, Depth);
4122+
case Instruction::Add:
4123+
// Pow2 - 1 is a Mask.
4124+
if (!Not && match(I->getOperand(1), m_AllOnes()))
4125+
return isKnownToBeAPowerOfTwo(I->getOperand(0), Q.DL, /*OrZero*/ true,
4126+
Depth, Q.AC, Q.CxtI, Q.DT);
4127+
break;
4128+
case Instruction::Sub:
4129+
// -Pow2 is a ~Mask.
4130+
if (Not && match(I->getOperand(0), m_Zero()))
4131+
return isKnownToBeAPowerOfTwo(I->getOperand(1), Q.DL, /*OrZero*/ true,
4132+
Depth, Q.AC, Q.CxtI, Q.DT);
4133+
break;
4134+
case Instruction::Call: {
4135+
if (auto *II = dyn_cast<IntrinsicInst>(I)) {
4136+
switch (II->getIntrinsicID()) {
4137+
// min/max(Mask0, Mask1) is a Mask.
4138+
// min/max(~Mask0, ~Mask1) is a ~Mask.
4139+
case Intrinsic::umax:
4140+
case Intrinsic::smax:
4141+
case Intrinsic::umin:
4142+
case Intrinsic::smin:
4143+
return isMaskOrZero(II->getArgOperand(1), Not, Q, Depth) &&
4144+
isMaskOrZero(II->getArgOperand(0), Not, Q, Depth);
4145+
4146+
// In the context of masks, bitreverse(Mask) == ~Mask
4147+
case Intrinsic::bitreverse:
4148+
return isMaskOrZero(II->getArgOperand(0), !Not, Q, Depth);
4149+
default:
4150+
break;
4151+
}
4152+
}
4153+
break;
4154+
}
4155+
default:
4156+
break;
4157+
}
4158+
return false;
4159+
}
4160+
40724161
/// Some comparisons can be simplified.
40734162
/// In this case, we are looking for comparisons that look like
40744163
/// a check for a lossy truncation.
@@ -4083,17 +4172,21 @@ Instruction *InstCombinerImpl::foldSelectICmp(ICmpInst::Predicate Pred,
40834172
/// For some predicates, the operands are commutative.
40844173
/// For others, x can only be on a specific side.
40854174
static Value *foldICmpWithLowBitMaskedVal(ICmpInst::Predicate Pred, Value *Op0,
4086-
Value *Op1,
4087-
InstCombiner::BuilderTy &Builder) {
4088-
Value *M, *Y;
4089-
auto m_VariableMask = m_CombineOr(
4090-
m_CombineOr(m_Not(m_Shl(m_AllOnes(), m_Value())),
4091-
m_Add(m_Shl(m_One(), m_Value()), m_AllOnes())),
4092-
m_CombineOr(m_LShr(m_AllOnes(), m_Value()),
4093-
m_LShr(m_Shl(m_AllOnes(), m_Value(Y)), m_Deferred(Y))));
4094-
auto m_Mask = m_CombineOr(m_VariableMask, m_LowBitMask());
4095-
4096-
if (!match(Op0, m_c_And(m_CombineAnd(m_Mask, m_Value(M)), m_Specific(Op1))))
4175+
Value *Op1, const SimplifyQuery &Q,
4176+
InstCombiner &IC) {
4177+
Value *M;
4178+
bool NeedsNot = false;
4179+
4180+
auto CheckMask = [&](Value *V, bool Not) {
4181+
if (ICmpInst::isSigned(Pred) && !match(V, m_ImmConstant()))
4182+
return false;
4183+
return isMaskOrZero(V, Not, Q);
4184+
};
4185+
4186+
if (!match(Op0, m_c_And(m_Specific(Op1), m_Value(M))))
4187+
return nullptr;
4188+
4189+
if (!CheckMask(M, /*Not*/ false))
40974190
return nullptr;
40984191

40994192
ICmpInst::Predicate DstPred;
@@ -4163,7 +4256,9 @@ static Value *foldICmpWithLowBitMaskedVal(ICmpInst::Predicate Pred, Value *Op0,
41634256
M = Constant::replaceUndefsWith(VecC, SafeReplacementConstant);
41644257
}
41654258

4166-
return Builder.CreateICmp(DstPred, Op1, M);
4259+
if (NeedsNot)
4260+
M = IC.Builder.CreateNot(M);
4261+
return IC.Builder.CreateICmp(DstPred, Op1, M);
41674262
}
41684263

41694264
/// Some comparisons can be simplified.
@@ -6980,7 +7075,8 @@ Instruction *InstCombinerImpl::foldICmpCommutative(ICmpInst::Predicate Pred,
69807075
}
69817076
}
69827077

6983-
if (Value *V = foldICmpWithLowBitMaskedVal(Pred, Op0, Op1, Builder))
7078+
const SimplifyQuery Q = SQ.getWithInstruction(&CxtI);
7079+
if (Value *V = foldICmpWithLowBitMaskedVal(Pred, Op0, Op1, Q, *this))
69847080
return replaceInstUsesWith(CxtI, V);
69857081

69867082
return nullptr;

llvm/test/Transforms/InstCombine/canonicalize-constant-low-bit-mask-and-icmp-eq-to-icmp-ule.ll

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,7 @@ define <2 x i1> @p2_vec_nonsplat(<2 x i8> %x) {
6262

6363
define <2 x i1> @p2_vec_nonsplat_edgecase0(<2 x i8> %x) {
6464
; CHECK-LABEL: @p2_vec_nonsplat_edgecase0(
65-
; CHECK-NEXT: [[TMP1:%.*]] = and <2 x i8> [[X:%.*]], <i8 -4, i8 -1>
66-
; CHECK-NEXT: [[RET:%.*]] = icmp eq <2 x i8> [[TMP1]], zeroinitializer
65+
; CHECK-NEXT: [[RET:%.*]] = icmp ult <2 x i8> [[X:%.*]], <i8 4, i8 1>
6766
; CHECK-NEXT: ret <2 x i1> [[RET]]
6867
;
6968
%tmp0 = and <2 x i8> %x, <i8 3, i8 0>

llvm/test/Transforms/InstCombine/canonicalize-constant-low-bit-mask-and-icmp-ne-to-icmp-ugt.ll

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,7 @@ define <2 x i1> @p2_vec_nonsplat(<2 x i8> %x) {
6262

6363
define <2 x i1> @p2_vec_nonsplat_edgecase0(<2 x i8> %x) {
6464
; CHECK-LABEL: @p2_vec_nonsplat_edgecase0(
65-
; CHECK-NEXT: [[TMP1:%.*]] = and <2 x i8> [[X:%.*]], <i8 -4, i8 -1>
66-
; CHECK-NEXT: [[RET:%.*]] = icmp ne <2 x i8> [[TMP1]], zeroinitializer
65+
; CHECK-NEXT: [[RET:%.*]] = icmp ugt <2 x i8> [[X:%.*]], <i8 3, i8 0>
6766
; CHECK-NEXT: ret <2 x i1> [[RET]]
6867
;
6968
%tmp0 = and <2 x i8> %x, <i8 3, i8 0>

llvm/test/Transforms/InstCombine/canonicalize-constant-low-bit-mask-and-icmp-sge-to-icmp-sle.ll

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,7 @@ define <2 x i1> @p2_vec_nonsplat(<2 x i8> %x) {
5050

5151
define <2 x i1> @p2_vec_nonsplat_edgecase(<2 x i8> %x) {
5252
; CHECK-LABEL: @p2_vec_nonsplat_edgecase(
53-
; CHECK-NEXT: [[TMP0:%.*]] = and <2 x i8> [[X:%.*]], <i8 3, i8 0>
54-
; CHECK-NEXT: [[RET:%.*]] = icmp sge <2 x i8> [[TMP0]], [[X]]
53+
; CHECK-NEXT: [[RET:%.*]] = icmp slt <2 x i8> [[X:%.*]], <i8 4, i8 1>
5554
; CHECK-NEXT: ret <2 x i1> [[RET]]
5655
;
5756
%tmp0 = and <2 x i8> %x, <i8 3, i8 0>

llvm/test/Transforms/InstCombine/canonicalize-constant-low-bit-mask-and-icmp-sgt-to-icmp-sgt.ll

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,7 @@ define <2 x i1> @p2_vec_nonsplat() {
6363
define <2 x i1> @p2_vec_nonsplat_edgecase() {
6464
; CHECK-LABEL: @p2_vec_nonsplat_edgecase(
6565
; CHECK-NEXT: [[X:%.*]] = call <2 x i8> @gen2x8()
66-
; CHECK-NEXT: [[TMP0:%.*]] = and <2 x i8> [[X]], <i8 3, i8 0>
67-
; CHECK-NEXT: [[RET:%.*]] = icmp sgt <2 x i8> [[X]], [[TMP0]]
66+
; CHECK-NEXT: [[RET:%.*]] = icmp sgt <2 x i8> [[X]], <i8 3, i8 0>
6867
; CHECK-NEXT: ret <2 x i1> [[RET]]
6968
;
7069
%x = call <2 x i8> @gen2x8()

llvm/test/Transforms/InstCombine/canonicalize-constant-low-bit-mask-and-icmp-sle-to-icmp-sle.ll

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,7 @@ define <2 x i1> @p2_vec_nonsplat() {
6363
define <2 x i1> @p2_vec_nonsplat_edgecase() {
6464
; CHECK-LABEL: @p2_vec_nonsplat_edgecase(
6565
; CHECK-NEXT: [[X:%.*]] = call <2 x i8> @gen2x8()
66-
; CHECK-NEXT: [[TMP0:%.*]] = and <2 x i8> [[X]], <i8 3, i8 0>
67-
; CHECK-NEXT: [[RET:%.*]] = icmp sle <2 x i8> [[X]], [[TMP0]]
66+
; CHECK-NEXT: [[RET:%.*]] = icmp slt <2 x i8> [[X]], <i8 4, i8 1>
6867
; CHECK-NEXT: ret <2 x i1> [[RET]]
6968
;
7069
%x = call <2 x i8> @gen2x8()

llvm/test/Transforms/InstCombine/canonicalize-constant-low-bit-mask-and-icmp-slt-to-icmp-sgt.ll

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,7 @@ define <2 x i1> @p2_vec_nonsplat(<2 x i8> %x) {
5050

5151
define <2 x i1> @p2_vec_nonsplat_edgecase(<2 x i8> %x) {
5252
; CHECK-LABEL: @p2_vec_nonsplat_edgecase(
53-
; CHECK-NEXT: [[TMP0:%.*]] = and <2 x i8> [[X:%.*]], <i8 3, i8 0>
54-
; CHECK-NEXT: [[RET:%.*]] = icmp slt <2 x i8> [[TMP0]], [[X]]
53+
; CHECK-NEXT: [[RET:%.*]] = icmp sgt <2 x i8> [[X:%.*]], <i8 3, i8 0>
5554
; CHECK-NEXT: ret <2 x i1> [[RET]]
5655
;
5756
%tmp0 = and <2 x i8> %x, <i8 3, i8 0>

llvm/test/Transforms/InstCombine/canonicalize-constant-low-bit-mask-and-icmp-uge-to-icmp-ule.ll

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,7 @@ define <2 x i1> @p2_vec_nonsplat(<2 x i8> %x) {
6262

6363
define <2 x i1> @p2_vec_nonsplat_edgecase0(<2 x i8> %x) {
6464
; CHECK-LABEL: @p2_vec_nonsplat_edgecase0(
65-
; CHECK-NEXT: [[TMP1:%.*]] = and <2 x i8> [[X:%.*]], <i8 -4, i8 -1>
66-
; CHECK-NEXT: [[RET:%.*]] = icmp eq <2 x i8> [[TMP1]], zeroinitializer
65+
; CHECK-NEXT: [[RET:%.*]] = icmp ult <2 x i8> [[X:%.*]], <i8 4, i8 1>
6766
; CHECK-NEXT: ret <2 x i1> [[RET]]
6867
;
6968
%tmp0 = and <2 x i8> %x, <i8 3, i8 0>

llvm/test/Transforms/InstCombine/canonicalize-constant-low-bit-mask-and-icmp-ugt-to-icmp-ugt.ll

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,7 @@ define <2 x i1> @p2_vec_nonsplat() {
7575
define <2 x i1> @p2_vec_nonsplat_edgecase0() {
7676
; CHECK-LABEL: @p2_vec_nonsplat_edgecase0(
7777
; CHECK-NEXT: [[X:%.*]] = call <2 x i8> @gen2x8()
78-
; CHECK-NEXT: [[TMP1:%.*]] = and <2 x i8> [[X]], <i8 -4, i8 -1>
79-
; CHECK-NEXT: [[RET:%.*]] = icmp ne <2 x i8> [[TMP1]], zeroinitializer
78+
; CHECK-NEXT: [[RET:%.*]] = icmp ugt <2 x i8> [[X]], <i8 3, i8 0>
8079
; CHECK-NEXT: ret <2 x i1> [[RET]]
8180
;
8281
%x = call <2 x i8> @gen2x8()

llvm/test/Transforms/InstCombine/canonicalize-constant-low-bit-mask-and-icmp-ule-to-icmp-ule.ll

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,7 @@ define <2 x i1> @p2_vec_nonsplat() {
7575
define <2 x i1> @p2_vec_nonsplat_edgecase0() {
7676
; CHECK-LABEL: @p2_vec_nonsplat_edgecase0(
7777
; CHECK-NEXT: [[X:%.*]] = call <2 x i8> @gen2x8()
78-
; CHECK-NEXT: [[TMP1:%.*]] = and <2 x i8> [[X]], <i8 -4, i8 -1>
79-
; CHECK-NEXT: [[RET:%.*]] = icmp eq <2 x i8> [[TMP1]], zeroinitializer
78+
; CHECK-NEXT: [[RET:%.*]] = icmp ult <2 x i8> [[X]], <i8 4, i8 1>
8079
; CHECK-NEXT: ret <2 x i1> [[RET]]
8180
;
8281
%x = call <2 x i8> @gen2x8()

llvm/test/Transforms/InstCombine/canonicalize-constant-low-bit-mask-and-icmp-ult-to-icmp-ugt.ll

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,7 @@ define <2 x i1> @p2_vec_nonsplat(<2 x i8> %x) {
6262

6363
define <2 x i1> @p2_vec_nonsplat_edgecase0(<2 x i8> %x) {
6464
; CHECK-LABEL: @p2_vec_nonsplat_edgecase0(
65-
; CHECK-NEXT: [[TMP1:%.*]] = and <2 x i8> [[X:%.*]], <i8 -4, i8 -1>
66-
; CHECK-NEXT: [[RET:%.*]] = icmp ne <2 x i8> [[TMP1]], zeroinitializer
65+
; CHECK-NEXT: [[RET:%.*]] = icmp ugt <2 x i8> [[X:%.*]], <i8 3, i8 0>
6766
; CHECK-NEXT: ret <2 x i1> [[RET]]
6867
;
6968
%tmp0 = and <2 x i8> %x, <i8 3, i8 0>

0 commit comments

Comments
 (0)