Skip to content

Commit 8967469

Browse files
committed
[InstCombine] Fold adds + shifts with nsw and nuw flags
I also added mul nsw/nuw 3, div 2 since this was the canonical version of ((x << 1) + x) / 2, which is a specific expression which canonicalization causes the InstCombine to miss it.
1 parent aaf0fe9 commit 8967469

File tree

2 files changed

+62
-15
lines changed

2 files changed

+62
-15
lines changed

llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1267,6 +1267,19 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) {
12671267
match(Op1, m_SpecificIntAllowUndef(BitWidth - 1)))
12681268
return new ZExtInst(Builder.CreateIsNotNeg(X, "isnotneg"), Ty);
12691269

1270+
// Special Case:
1271+
// if both the add and the shift are nuw, we can omit the AND entirely
1272+
// ((X << Y) nuw + Z nuw) >>u Z --> (X + (Y >>u Z))
1273+
Value *Y;
1274+
if (match(Op0, m_OneUse(m_c_NUWAdd((m_NUWShl(m_Value(X), m_Specific(Op1))),
1275+
m_Value(Y))))) {
1276+
Value *NewLshr = Builder.CreateLShr(Y, Op1, "", I.isExact());
1277+
auto *newAdd = BinaryOperator::CreateNUWAdd(NewLshr, X);
1278+
if (auto *Op0Bin = cast<BinaryOperator>(Op0))
1279+
newAdd->setHasNoSignedWrap(Op0Bin->hasNoSignedWrap());
1280+
return newAdd;
1281+
}
1282+
12701283
if (match(Op1, m_APInt(C))) {
12711284
unsigned ShAmtC = C->getZExtValue();
12721285
auto *II = dyn_cast<IntrinsicInst>(Op0);
@@ -1283,7 +1296,6 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) {
12831296
return new ZExtInst(Cmp, Ty);
12841297
}
12851298

1286-
Value *X;
12871299
const APInt *C1;
12881300
if (match(Op0, m_Shl(m_Value(X), m_APInt(C1))) && C1->ult(BitWidth)) {
12891301
if (C1->ult(ShAmtC)) {
@@ -1328,7 +1340,7 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) {
13281340
// ((X << C) + Y) >>u C --> (X + (Y >>u C)) & (-1 >>u C)
13291341
// TODO: Consolidate with the more general transform that starts from shl
13301342
// (the shifts are in the opposite order).
1331-
Value *Y;
1343+
13321344
if (match(Op0,
13331345
m_OneUse(m_c_Add(m_OneUse(m_Shl(m_Value(X), m_Specific(Op1))),
13341346
m_Value(Y))))) {
@@ -1450,9 +1462,25 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) {
14501462
NewMul->setHasNoSignedWrap(true);
14511463
return NewMul;
14521464
}
1465+
1466+
// Special case:
1467+
// lshr nuw (mul (X, 3), 1) -> add nuw nsw (X, lshr(X, 1)
1468+
if (ShAmtC == 1 && MulC->getZExtValue() == 3) {
1469+
auto *NewAdd = BinaryOperator::CreateNUWAdd(
1470+
X,
1471+
Builder.CreateLShr(X, ConstantInt::get(Ty, 1), "", I.isExact()));
1472+
NewAdd->setHasNoSignedWrap(true);
1473+
return NewAdd;
1474+
}
14531475
}
14541476
}
14551477

1478+
// // lshr nsw (mul (X, 3), 1) -> add nsw (X, lshr(X, 1)
1479+
if (match(Op0, m_OneUse(m_NSWMul(m_Value(X), m_SpecificInt(3)))) &&
1480+
ShAmtC == 1)
1481+
return BinaryOperator::CreateNSWAdd(
1482+
X, Builder.CreateLShr(X, ConstantInt::get(Ty, 1), "", I.isExact()));
1483+
14561484
// Try to narrow bswap.
14571485
// In the case where the shift amount equals the bitwidth difference, the
14581486
// shift is eliminated.
@@ -1656,6 +1684,26 @@ Instruction *InstCombinerImpl::visitAShr(BinaryOperator &I) {
16561684
if (match(Op0, m_OneUse(m_NSWSub(m_Value(X), m_Value(Y)))))
16571685
return new SExtInst(Builder.CreateICmpSLT(X, Y), Ty);
16581686
}
1687+
1688+
// Special case: ashr nuw (mul (X, 3), 1) -> add nuw nsw (X, lshr(X, 1)
1689+
if (match(Op0, m_OneUse(m_NSWMul(m_Value(X), m_SpecificInt(3)))) &&
1690+
ShAmt == 1) {
1691+
Value *Shift;
1692+
if (auto *Op0Bin = cast<BinaryOperator>(Op0)) {
1693+
if (Op0Bin->hasNoUnsignedWrap())
1694+
// We can use lshr if the mul is nuw and nsw
1695+
Shift =
1696+
Builder.CreateLShr(X, ConstantInt::get(Ty, 1), "", I.isExact());
1697+
else
1698+
Shift =
1699+
Builder.CreateAShr(X, ConstantInt::get(Ty, 1), "", I.isExact());
1700+
1701+
auto *NewAdd = BinaryOperator::CreateNSWAdd(X, Shift);
1702+
NewAdd->setHasNoUnsignedWrap(Op0Bin->hasNoUnsignedWrap());
1703+
1704+
return NewAdd;
1705+
}
1706+
}
16591707
}
16601708

16611709
const SimplifyQuery Q = SQ.getWithInstruction(&I);

llvm/test/Transforms/InstCombine/lshr.ll

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -362,9 +362,9 @@ define <3 x i14> @mul_splat_fold_vec(<3 x i14> %x) {
362362

363363
define i32 @mul_times_3_div_2 (i32 %x) {
364364
; CHECK-LABEL: @mul_times_3_div_2(
365-
; CHECK-NEXT: [[TMP1:%.*]] = mul nuw nsw i32 [[X:%.*]], 3
366-
; CHECK-NEXT: [[TMP2:%.*]] = lshr i32 [[TMP1]], 1
367-
; CHECK-NEXT: ret i32 [[TMP2]]
365+
; CHECK-NEXT: [[TMP2:%.*]] = lshr i32 [[TMP1:%.*]], 1
366+
; CHECK-NEXT: [[TMP3:%.*]] = add nuw nsw i32 [[TMP2]], [[TMP1]]
367+
; CHECK-NEXT: ret i32 [[TMP3]]
368368
;
369369
%2 = mul nsw nuw i32 %x, 3
370370
%3 = lshr i32 %2, 1
@@ -373,21 +373,20 @@ define i32 @mul_times_3_div_2 (i32 %x) {
373373

374374
define i32 @shl_add_lshr (i32 %x, i32 %c, i32 %y) {
375375
; CHECK-LABEL: @shl_add_lshr(
376-
; CHECK-NEXT: [[TMP1:%.*]] = shl nuw i32 [[X:%.*]], [[C:%.*]]
377-
; CHECK-NEXT: [[TMP2:%.*]] = add nuw nsw i32 [[TMP1]], [[Y:%.*]]
378-
; CHECK-NEXT: [[TMP3:%.*]] = lshr exact i32 [[TMP2]], [[C]]
379-
; CHECK-NEXT: ret i32 [[TMP3]]
376+
; CHECK-NEXT: [[TMP3:%.*]] = lshr exact i32 [[TMP2:%.*]], [[C:%.*]]
377+
; CHECK-NEXT: [[TMP4:%.*]] = add nuw nsw i32 [[TMP3]], [[X:%.*]]
378+
; CHECK-NEXT: ret i32 [[TMP4]]
380379
;
381380
%2 = shl nuw i32 %x, %c
382-
%3 = add nsw nuw i32 %2, %y
381+
%3 = add nuw nsw i32 %2, %y
383382
%4 = lshr exact i32 %3, %c
384383
ret i32 %4
385384
}
386385

387386
define i32 @ashr_mul_times_3_div_2 (i32 %0) {
388387
; CHECK-LABEL: @ashr_mul_times_3_div_2(
389-
; CHECK-NEXT: [[TMP2:%.*]] = mul nuw nsw i32 [[TMP0:%.*]], 3
390-
; CHECK-NEXT: [[TMP3:%.*]] = ashr i32 [[TMP2]], 1
388+
; CHECK-NEXT: [[TMP2:%.*]] = lshr i32 [[TMP0:%.*]], 1
389+
; CHECK-NEXT: [[TMP3:%.*]] = add nuw nsw i32 [[TMP2]], [[TMP0]]
391390
; CHECK-NEXT: ret i32 [[TMP3]]
392391
;
393392
%2 = mul nsw nuw i32 %0, 3
@@ -397,9 +396,9 @@ define i32 @mul_times_3_div_2 (i32 %x) {
397396

398397
define i32 @ashr_mul_times_3_div_2_exact (i32 %0) {
399398
; CHECK-LABEL: @ashr_mul_times_3_div_2_exact(
400-
; CHECK-NEXT: [[TMP2:%.*]] = mul nsw i32 [[TMP0:%.*]], 3
401-
; CHECK-NEXT: [[TMP3:%.*]] = ashr exact i32 [[TMP2]], 1
402-
; CHECK-NEXT: ret i32 [[TMP3]]
399+
; CHECK-NEXT: [[TMP3:%.*]] = ashr exact i32 [[TMP2:%.*]], 1
400+
; CHECK-NEXT: [[TMP4:%.*]] = add nsw i32 [[TMP3]], [[TMP2]]
401+
; CHECK-NEXT: ret i32 [[TMP4]]
403402
;
404403
%2 = mul nsw i32 %0, 3
405404
%3 = ashr exact i32 %2, 1

0 commit comments

Comments
 (0)