Skip to content

Commit 631fd95

Browse files
committed
[InstCombine] Fold (op x, ({z,s}ext (icmp eq x, C))) to select
`(op x, ({z,s}ext (icmp eq x, C)))` is either `(op C, ({z,s}ext 1))` or `(op x, 0)`. If both possibilities simplify (i.e constant fold for the former and either constant fold or converted to just `x` in the latter), fold to: `(select (icmp eq x, C), (op C, ({z,s}ext 1)), (op x, 0)`. Which is easier to analyze and should get roughly the same or better codegen (in most cases something like `({z,s}ext (icmp))` lowers simliarly to a `select`).
1 parent 86ebad9 commit 631fd95

File tree

9 files changed

+123
-25
lines changed

9 files changed

+123
-25
lines changed

llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1487,6 +1487,9 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
14871487
if (Instruction *Phi = foldBinopWithPhiOperands(I))
14881488
return Phi;
14891489

1490+
if (Value *R = foldOpOfXWithXEqC(&I, SQ.getWithInstruction(&I)))
1491+
return replaceInstUsesWith(I, R);
1492+
14901493
// (A*B)+(A*C) -> A*(B+C) etc
14911494
if (Value *V = foldUsingDistributiveLaws(I))
14921495
return replaceInstUsesWith(I, V);
@@ -2092,6 +2095,9 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) {
20922095
if (Instruction *Phi = foldBinopWithPhiOperands(I))
20932096
return Phi;
20942097

2098+
if (Value *R = foldOpOfXWithXEqC(&I, SQ.getWithInstruction(&I)))
2099+
return replaceInstUsesWith(I, R);
2100+
20952101
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
20962102

20972103
// If this is a 'B = x-(-A)', change to B = x+A.

llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2275,6 +2275,9 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) {
22752275
if (Instruction *Phi = foldBinopWithPhiOperands(I))
22762276
return Phi;
22772277

2278+
if (Value *R = foldOpOfXWithXEqC(&I, SQ.getWithInstruction(&I)))
2279+
return replaceInstUsesWith(I, R);
2280+
22782281
// See if we can simplify any instructions used by the instruction whose sole
22792282
// purpose is to compute bits we don't care about.
22802283
if (SimplifyDemandedInstructionBits(I))
@@ -3438,6 +3441,9 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
34383441
if (Instruction *Phi = foldBinopWithPhiOperands(I))
34393442
return Phi;
34403443

3444+
if (Value *R = foldOpOfXWithXEqC(&I, SQ.getWithInstruction(&I)))
3445+
return replaceInstUsesWith(I, R);
3446+
34413447
// See if we can simplify any instructions used by the instruction whose sole
34423448
// purpose is to compute bits we don't care about.
34433449
if (SimplifyDemandedInstructionBits(I))
@@ -4571,6 +4577,9 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) {
45714577
if (Instruction *NewXor = foldXorToXor(I, Builder))
45724578
return NewXor;
45734579

4580+
if (Value *R = foldOpOfXWithXEqC(&I, SQ.getWithInstruction(&I)))
4581+
return replaceInstUsesWith(I, R);
4582+
45744583
// (A&B)^(A&C) -> A&(B^C) etc
45754584
if (Value *V = foldUsingDistributiveLaws(I))
45764585
return replaceInstUsesWith(I, V);

llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1491,6 +1491,9 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
14911491
IntrinsicInst *II = dyn_cast<IntrinsicInst>(&CI);
14921492
if (!II) return visitCallBase(CI);
14931493

1494+
if (Value *R = foldOpOfXWithXEqC(II, SQ.getWithInstruction(&CI)))
1495+
return replaceInstUsesWith(CI, R);
1496+
14941497
// For atomic unordered mem intrinsics if len is not a positive or
14951498
// not a multiple of element size then behavior is undefined.
14961499
if (auto *AMI = dyn_cast<AtomicMemIntrinsic>(II))

llvm/lib/Transforms/InstCombine/InstCombineInternal.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -755,6 +755,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
755755

756756
Value *EvaluateInDifferentType(Value *V, Type *Ty, bool isSigned);
757757

758+
Value *foldOpOfXWithXEqC(Value *Op, const SimplifyQuery &SQ);
758759
bool tryToSinkInstruction(Instruction *I, BasicBlock *DestBlock);
759760
void tryToSinkInstructionDbgValues(
760761
Instruction *I, BasicBlock::iterator InsertPos, BasicBlock *SrcBlock,

llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,9 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) {
204204
if (Instruction *Phi = foldBinopWithPhiOperands(I))
205205
return Phi;
206206

207+
if (Value *R = foldOpOfXWithXEqC(&I, SQ.getWithInstruction(&I)))
208+
return replaceInstUsesWith(I, R);
209+
207210
if (Value *V = foldUsingDistributiveLaws(I))
208211
return replaceInstUsesWith(I, V);
209212

llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1020,6 +1020,9 @@ Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) {
10201020
if (Instruction *V = dropRedundantMaskingOfLeftShiftInput(&I, Q, Builder))
10211021
return V;
10221022

1023+
if (Value *R = foldOpOfXWithXEqC(&I, Q))
1024+
return replaceInstUsesWith(I, R);
1025+
10231026
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
10241027
Type *Ty = I.getType();
10251028
unsigned BitWidth = Ty->getScalarSizeInBits();
@@ -1252,6 +1255,9 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) {
12521255
if (Instruction *R = commonShiftTransforms(I))
12531256
return R;
12541257

1258+
if (Value *R = foldOpOfXWithXEqC(&I, SQ.getWithInstruction(&I)))
1259+
return replaceInstUsesWith(I, R);
1260+
12551261
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
12561262
Type *Ty = I.getType();
12571263
Value *X;
@@ -1625,6 +1631,9 @@ Instruction *InstCombinerImpl::visitAShr(BinaryOperator &I) {
16251631
if (Instruction *R = commonShiftTransforms(I))
16261632
return R;
16271633

1634+
if (Value *R = foldOpOfXWithXEqC(&I, SQ.getWithInstruction(&I)))
1635+
return replaceInstUsesWith(I, R);
1636+
16281637
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
16291638
Type *Ty = I.getType();
16301639
unsigned BitWidth = Ty->getScalarSizeInBits();

llvm/lib/Transforms/InstCombine/InstructionCombining.cpp

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4829,6 +4829,89 @@ void InstCombinerImpl::tryToSinkInstructionDbgValues(
48294829
}
48304830
}
48314831

4832+
// If we have:
4833+
// `(op X, (zext/sext (icmp eq X, C)))`
4834+
// We can transform it to:
4835+
// `(select (icmp eq X, C), (op C, (zext/sext 1)), (op X, 0))`
4836+
// We do so if the `zext/sext` is one use and `(op X, 0)` simplifies.
4837+
Value *InstCombinerImpl::foldOpOfXWithXEqC(Value *Op, const SimplifyQuery &SQ) {
4838+
Value *Cond;
4839+
Constant *C, *ExtC;
4840+
4841+
// match `(op X, (zext/sext (icmp eq X, C)))` and see if `(op X, 0)`
4842+
// simplifies.
4843+
// If we match and simplify, store the `icmp` in `Cond`, `(zext/sext C)` in
4844+
// `ExtC`.
4845+
auto MatchXWithXEqC = [&](Value *Op0, Value *Op1) -> Value * {
4846+
if (match(Op0, m_OneUse(m_ZExtOrSExt(m_Value(Cond))))) {
4847+
ICmpInst::Predicate Pred;
4848+
if (!match(Cond, m_ICmp(Pred, m_Specific(Op1), m_ImmConstant(C))) ||
4849+
Pred != ICmpInst::ICMP_EQ)
4850+
return nullptr;
4851+
4852+
ExtC = isa<SExtInst>(Op0) ? ConstantInt::getAllOnesValue(C->getType())
4853+
: ConstantInt::get(C->getType(), 1);
4854+
return simplifyWithOpReplaced(Op, Op0,
4855+
Constant::getNullValue(Op1->getType()), SQ,
4856+
/*AllowRefinement=*/true);
4857+
}
4858+
return nullptr;
4859+
};
4860+
4861+
Value *SimpleOp = nullptr, *ConstOp = nullptr;
4862+
if (auto *BO = dyn_cast<BinaryOperator>(Op)) {
4863+
switch (BO->getOpcode()) {
4864+
// Potential TODO: For all of these, if Op1 is the compare, the compare
4865+
// must be true and we could replace Op0 with C (otherwise immediate UB).
4866+
case Instruction::UDiv:
4867+
case Instruction::SDiv:
4868+
case Instruction::URem:
4869+
case Instruction::SRem:
4870+
return nullptr;
4871+
default:
4872+
break;
4873+
}
4874+
4875+
// Try X is Op0
4876+
if ((SimpleOp = MatchXWithXEqC(BO->getOperand(0), BO->getOperand(1))))
4877+
ConstOp = Builder.CreateBinOp(BO->getOpcode(), ExtC, C);
4878+
// Try X is Op1
4879+
else if ((SimpleOp = MatchXWithXEqC(BO->getOperand(1), BO->getOperand(0))))
4880+
ConstOp = Builder.CreateBinOp(BO->getOpcode(), C, ExtC);
4881+
} else if (auto *II = dyn_cast<IntrinsicInst>(Op)) {
4882+
switch (II->getIntrinsicID()) {
4883+
default:
4884+
return nullptr;
4885+
case Intrinsic::sshl_sat:
4886+
case Intrinsic::ushl_sat:
4887+
case Intrinsic::umax:
4888+
case Intrinsic::umin:
4889+
case Intrinsic::smax:
4890+
case Intrinsic::smin:
4891+
case Intrinsic::uadd_sat:
4892+
case Intrinsic::usub_sat:
4893+
case Intrinsic::sadd_sat:
4894+
case Intrinsic::ssub_sat:
4895+
// Try X is Op0
4896+
if ((SimpleOp =
4897+
MatchXWithXEqC(II->getArgOperand(0), II->getArgOperand(1))))
4898+
ConstOp = Builder.CreateBinaryIntrinsic(II->getIntrinsicID(), ExtC, C);
4899+
// Try X is Op1
4900+
else if ((SimpleOp =
4901+
MatchXWithXEqC(II->getArgOperand(1), II->getArgOperand(0))))
4902+
ConstOp = Builder.CreateBinaryIntrinsic(II->getIntrinsicID(), C, ExtC);
4903+
break;
4904+
}
4905+
}
4906+
4907+
assert((SimpleOp == nullptr) == (ConstOp == nullptr) &&
4908+
"Simplfied Op and Constant Op are de-synced!");
4909+
if (SimpleOp == nullptr)
4910+
return nullptr;
4911+
4912+
return Builder.CreateSelect(Cond, ConstOp, SimpleOp);
4913+
}
4914+
48324915
void InstCombinerImpl::tryToSinkInstructionDbgVariableRecords(
48334916
Instruction *I, BasicBlock::iterator InsertPos, BasicBlock *SrcBlock,
48344917
BasicBlock *DestBlock,

llvm/test/Transforms/InstCombine/apint-shift.ll

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -564,14 +564,7 @@ define i40 @test26(i40 %A) {
564564
; https://bugs.chromium.org/p/oss-fuzz/issues/detail?id=9880
565565
define i177 @ossfuzz_9880(i177 %X) {
566566
; CHECK-LABEL: @ossfuzz_9880(
567-
; CHECK-NEXT: [[A:%.*]] = alloca i177, align 8
568-
; CHECK-NEXT: [[L1:%.*]] = load i177, ptr [[A]], align 4
569-
; CHECK-NEXT: [[TMP1:%.*]] = icmp eq i177 [[L1]], -1
570-
; CHECK-NEXT: [[B5_NEG:%.*]] = sext i1 [[TMP1]] to i177
571-
; CHECK-NEXT: [[B14:%.*]] = add i177 [[L1]], [[B5_NEG]]
572-
; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i177 [[B14]], -1
573-
; CHECK-NEXT: [[B1:%.*]] = zext i1 [[TMP2]] to i177
574-
; CHECK-NEXT: ret i177 [[B1]]
567+
; CHECK-NEXT: ret i177 0
575568
;
576569
%A = alloca i177
577570
%L1 = load i177, ptr %A

llvm/test/Transforms/InstCombine/fold-ext-eq-c-with-op.ll

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,7 @@
44
declare void @use.i8(i8)
55
define i8 @fold_add_zext_eq_0(i8 %x) {
66
; CHECK-LABEL: @fold_add_zext_eq_0(
7-
; CHECK-NEXT: [[X_EQ:%.*]] = icmp eq i8 [[X:%.*]], 0
8-
; CHECK-NEXT: [[X_EQ_EXT:%.*]] = zext i1 [[X_EQ]] to i8
9-
; CHECK-NEXT: [[R:%.*]] = add i8 [[X_EQ_EXT]], [[X]]
7+
; CHECK-NEXT: [[R:%.*]] = call i8 @llvm.umax.i8(i8 [[X:%.*]], i8 1)
108
; CHECK-NEXT: ret i8 [[R]]
119
;
1210
%x_eq = icmp eq i8 %x, 0
@@ -18,8 +16,7 @@ define i8 @fold_add_zext_eq_0(i8 %x) {
1816
define i8 @fold_add_sext_eq_0(i8 %x) {
1917
; CHECK-LABEL: @fold_add_sext_eq_0(
2018
; CHECK-NEXT: [[X_EQ:%.*]] = icmp eq i8 [[X:%.*]], 0
21-
; CHECK-NEXT: [[X_EQ_EXT:%.*]] = sext i1 [[X_EQ]] to i8
22-
; CHECK-NEXT: [[R:%.*]] = add i8 [[X_EQ_EXT]], [[X]]
19+
; CHECK-NEXT: [[R:%.*]] = select i1 [[X_EQ]], i8 -1, i8 [[X]]
2320
; CHECK-NEXT: ret i8 [[R]]
2421
;
2522
%x_eq = icmp eq i8 %x, 0
@@ -73,8 +70,7 @@ define i8 @fold_mul_sext_eq_12_fail_multiuse(i8 %x) {
7370
define i8 @fold_shl_zext_eq_3_rhs(i8 %x) {
7471
; CHECK-LABEL: @fold_shl_zext_eq_3_rhs(
7572
; CHECK-NEXT: [[X_EQ:%.*]] = icmp eq i8 [[X:%.*]], 3
76-
; CHECK-NEXT: [[X_EQ_EXT:%.*]] = zext i1 [[X_EQ]] to i8
77-
; CHECK-NEXT: [[R:%.*]] = shl i8 [[X]], [[X_EQ_EXT]]
73+
; CHECK-NEXT: [[R:%.*]] = select i1 [[X_EQ]], i8 6, i8 [[X]]
7874
; CHECK-NEXT: ret i8 [[R]]
7975
;
8076
%x_eq = icmp eq i8 %x, 3
@@ -86,8 +82,7 @@ define i8 @fold_shl_zext_eq_3_rhs(i8 %x) {
8682
define i8 @fold_shl_zext_eq_3_lhs(i8 %x) {
8783
; CHECK-LABEL: @fold_shl_zext_eq_3_lhs(
8884
; CHECK-NEXT: [[X_EQ:%.*]] = icmp eq i8 [[X:%.*]], 3
89-
; CHECK-NEXT: [[X_EQ_EXT:%.*]] = zext i1 [[X_EQ]] to i8
90-
; CHECK-NEXT: [[R:%.*]] = shl nuw i8 [[X_EQ_EXT]], [[X]]
85+
; CHECK-NEXT: [[R:%.*]] = select i1 [[X_EQ]], i8 8, i8 0
9186
; CHECK-NEXT: ret i8 [[R]]
9287
;
9388
%x_eq = icmp eq i8 %x, 3
@@ -99,8 +94,7 @@ define i8 @fold_shl_zext_eq_3_lhs(i8 %x) {
9994
define <2 x i8> @fold_lshr_sext_eq_15_5_lhs(<2 x i8> %x) {
10095
; CHECK-LABEL: @fold_lshr_sext_eq_15_5_lhs(
10196
; CHECK-NEXT: [[X_EQ:%.*]] = icmp eq <2 x i8> [[X:%.*]], <i8 15, i8 5>
102-
; CHECK-NEXT: [[X_EQ_EXT:%.*]] = sext <2 x i1> [[X_EQ]] to <2 x i8>
103-
; CHECK-NEXT: [[R:%.*]] = lshr <2 x i8> [[X_EQ_EXT]], [[X]]
97+
; CHECK-NEXT: [[R:%.*]] = select <2 x i1> [[X_EQ]], <2 x i8> <i8 poison, i8 7>, <2 x i8> zeroinitializer
10498
; CHECK-NEXT: ret <2 x i8> [[R]]
10599
;
106100
%x_eq = icmp eq <2 x i8> %x, <i8 15, i8 5>
@@ -122,8 +116,7 @@ define <2 x i8> @fold_lshr_sext_eq_15_poison_rhs(<2 x i8> %x) {
122116
define i8 @fold_umax_zext_eq_9(i8 %x) {
123117
; CHECK-LABEL: @fold_umax_zext_eq_9(
124118
; CHECK-NEXT: [[X_EQ:%.*]] = icmp eq i8 [[X:%.*]], 9
125-
; CHECK-NEXT: [[X_EQ_EXT:%.*]] = sext i1 [[X_EQ]] to i8
126-
; CHECK-NEXT: [[R:%.*]] = call i8 @llvm.umax.i8(i8 [[X]], i8 [[X_EQ_EXT]])
119+
; CHECK-NEXT: [[R:%.*]] = select i1 [[X_EQ]], i8 -1, i8 [[X]]
127120
; CHECK-NEXT: ret i8 [[R]]
128121
;
129122
%x_eq = icmp eq i8 %x, 9
@@ -161,8 +154,7 @@ define i8 @fold_ushl_sat_zext_eq_3_lhs(i8 %x) {
161154
define i8 @fold_uadd_sat_zext_eq_3_rhs(i8 %x) {
162155
; CHECK-LABEL: @fold_uadd_sat_zext_eq_3_rhs(
163156
; CHECK-NEXT: [[X_EQ:%.*]] = icmp eq i8 [[X:%.*]], 3
164-
; CHECK-NEXT: [[X_EQ_EXT:%.*]] = zext i1 [[X_EQ]] to i8
165-
; CHECK-NEXT: [[R:%.*]] = call i8 @llvm.uadd.sat.i8(i8 [[X]], i8 [[X_EQ_EXT]])
157+
; CHECK-NEXT: [[R:%.*]] = select i1 [[X_EQ]], i8 4, i8 [[X]]
166158
; CHECK-NEXT: ret i8 [[R]]
167159
;
168160
%x_eq = icmp eq i8 %x, 3
@@ -187,8 +179,7 @@ define i8 @fold_ssub_sat_sext_eq_99_lhs_fail(i8 %x) {
187179
define i8 @fold_ssub_sat_zext_eq_99_rhs(i8 %x) {
188180
; CHECK-LABEL: @fold_ssub_sat_zext_eq_99_rhs(
189181
; CHECK-NEXT: [[X_EQ:%.*]] = icmp eq i8 [[X:%.*]], 99
190-
; CHECK-NEXT: [[X_EQ_EXT:%.*]] = zext i1 [[X_EQ]] to i8
191-
; CHECK-NEXT: [[R:%.*]] = call i8 @llvm.ssub.sat.i8(i8 [[X]], i8 [[X_EQ_EXT]])
182+
; CHECK-NEXT: [[R:%.*]] = select i1 [[X_EQ]], i8 98, i8 [[X]]
192183
; CHECK-NEXT: ret i8 [[R]]
193184
;
194185
%x_eq = icmp eq i8 %x, 99

0 commit comments

Comments
 (0)