Skip to content

Commit d8b5db1

Browse files
committed
[InstCombine] Fold X * (2^N + 1) >> N -> X + X >> N, or directly to X if X >> N is 0
Alive2 Proofs: https://alive2.llvm.org/ce/z/eSinJY https://alive2.llvm.org/ce/z/sweDgc https://alive2.llvm.org/ce/z/-2dXZi
1 parent 7172e34 commit d8b5db1

File tree

4 files changed

+142
-52
lines changed

4 files changed

+142
-52
lines changed

llvm/lib/Analysis/InstructionSimplify.cpp

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1479,6 +1479,30 @@ static Value *simplifyLShrInst(Value *Op0, Value *Op1, bool IsExact,
14791479
if (Q.IIQ.UseInstrInfo && match(Op0, m_NUWShl(m_Value(X), m_Specific(Op1))))
14801480
return X;
14811481

1482+
// Look for a "splat" mul pattern - it replicates bits across each half
1483+
// of a value, so a right shift is just a mask of the low bits:
1484+
// lshr i[2N] (mul nuw X, (2^N)+1), N --> and iN X, (2^N)-1
1485+
const APInt *MulC;
1486+
const APInt *ShAmt;
1487+
if (Q.IIQ.UseInstrInfo && match(Op0, m_NUWMul(m_Value(X), m_APInt(MulC))) &&
1488+
match(Op1, m_APInt(ShAmt))) {
1489+
unsigned ShAmtC = ShAmt->getZExtValue();
1490+
unsigned BitWidth = ShAmt->getBitWidth();
1491+
if (BitWidth > 2 && (*MulC - 1).isPowerOf2() &&
1492+
MulC->logBase2() == ShAmtC) {
1493+
// FIXME: This condition should be covered by the computeKnownBits, but
1494+
// for some reason it is not, so keep this in for now. This has no
1495+
// negative affects, but KnownBits should be able to infer a number of
1496+
// leading bits based on 2^N + 1 not wrapping, as that means 2^N must not
1497+
// wrap either, which means the top N bits of X must be 0.
1498+
if (ShAmtC * 2 == BitWidth)
1499+
return X;
1500+
const KnownBits XKnown = computeKnownBits(X, /* Depth */ 0, Q);
1501+
if (XKnown.countMaxActiveBits() <= ShAmtC)
1502+
return X;
1503+
}
1504+
}
1505+
14821506
// ((X << A) | Y) >> A -> X if effective width of Y is not larger than A.
14831507
// We can return X as we do in the above case since OR alters no bits in X.
14841508
// SimplifyDemandedBits in InstCombine can do more general optimization for
@@ -1518,8 +1542,24 @@ static Value *simplifyAShrInst(Value *Op0, Value *Op1, bool IsExact,
15181542
match(Op0, m_Shl(m_AllOnes(), m_Specific(Op1))))
15191543
return Constant::getAllOnesValue(Op0->getType());
15201544

1521-
// (X << A) >> A -> X
1545+
const APInt *MulC;
1546+
const APInt *ShAmt;
15221547
Value *X;
1548+
if (Q.IIQ.UseInstrInfo && match(Op0, m_NUWMul(m_Value(X), m_APInt(MulC))) &&
1549+
match(Op1, m_APInt(ShAmt)) &&
1550+
cast<OverflowingBinaryOperator>(Op0)->hasNoSignedWrap()) {
1551+
unsigned ShAmtC = ShAmt->getZExtValue();
1552+
unsigned BitWidth = ShAmt->getBitWidth();
1553+
if (BitWidth > 2 && (*MulC - 1).isPowerOf2() &&
1554+
MulC->logBase2() == ShAmtC &&
1555+
(ShAmtC < BitWidth - 1)) /* Minus 1 for the sign bit */ {
1556+
KnownBits KnownX = computeKnownBits(X, /* Depth */ 0, Q);
1557+
if (KnownX.countMaxActiveBits() <= ShAmtC)
1558+
return X;
1559+
}
1560+
}
1561+
1562+
// (X << A) >> A -> X
15231563
if (Q.IIQ.UseInstrInfo && match(Op0, m_NSWShl(m_Value(X), m_Specific(Op1))))
15241564
return X;
15251565

llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp

Lines changed: 47 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1456,30 +1456,42 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) {
14561456
}
14571457

14581458
const APInt *MulC;
1459-
if (match(Op0, m_NUWMul(m_Value(X), m_APInt(MulC)))) {
1460-
// Look for a "splat" mul pattern - it replicates bits across each half of
1461-
// a value, so a right shift is just a mask of the low bits:
1462-
// lshr i[2N] (mul nuw X, (2^N)+1), N --> and iN X, (2^N)-1
1463-
// TODO: Generalize to allow more than just half-width shifts?
1464-
if (BitWidth > 2 && ShAmtC * 2 == BitWidth && (*MulC - 1).isPowerOf2() &&
1465-
MulC->logBase2() == ShAmtC)
1466-
return BinaryOperator::CreateAnd(X, ConstantInt::get(Ty, *MulC - 2));
1459+
if (match(Op0, m_OneUse(m_NUWMul(m_Value(X), m_APInt(MulC))))) {
1460+
if (BitWidth > 2 && (*MulC - 1).isPowerOf2() &&
1461+
MulC->logBase2() == ShAmtC) {
1462+
1463+
// lshr (mul nuw (X, 2^N + 1)), N -> add nuw (X, lshr(X, N))
1464+
auto *NewAdd = BinaryOperator::CreateNUWAdd(
1465+
X, Builder.CreateLShr(X, ConstantInt::get(Ty, ShAmtC), "",
1466+
I.isExact()));
1467+
NewAdd->setHasNoSignedWrap(
1468+
cast<OverflowingBinaryOperator>(Op0)->hasNoSignedWrap());
1469+
return NewAdd;
1470+
}
14671471

14681472
// The one-use check is not strictly necessary, but codegen may not be
14691473
// able to invert the transform and perf may suffer with an extra mul
14701474
// instruction.
1471-
if (Op0->hasOneUse()) {
1472-
APInt NewMulC = MulC->lshr(ShAmtC);
1473-
// if c is divisible by (1 << ShAmtC):
1474-
// lshr (mul nuw x, MulC), ShAmtC -> mul nuw nsw x, (MulC >> ShAmtC)
1475-
if (MulC->eq(NewMulC.shl(ShAmtC))) {
1476-
auto *NewMul =
1477-
BinaryOperator::CreateNUWMul(X, ConstantInt::get(Ty, NewMulC));
1478-
assert(ShAmtC != 0 &&
1479-
"lshr X, 0 should be handled by simplifyLShrInst.");
1480-
NewMul->setHasNoSignedWrap(true);
1481-
return NewMul;
1482-
}
1475+
APInt NewMulC = MulC->lshr(ShAmtC);
1476+
// if c is divisible by (1 << ShAmtC):
1477+
// lshr (mul nuw x, MulC), ShAmtC -> mul nuw nsw x, (MulC >> ShAmtC)
1478+
if (MulC->eq(NewMulC.shl(ShAmtC))) {
1479+
auto *NewMul =
1480+
BinaryOperator::CreateNUWMul(X, ConstantInt::get(Ty, NewMulC));
1481+
assert(ShAmtC != 0 &&
1482+
"lshr X, 0 should be handled by simplifyLShrInst.");
1483+
NewMul->setHasNoSignedWrap(true);
1484+
return NewMul;
1485+
}
1486+
}
1487+
1488+
// lshr (mul nsw (X, 2^N + 1)), N -> add nsw (X, lshr(X, N))
1489+
if (match(Op0, m_OneUse(m_NSWMul(m_Value(X), m_APInt(MulC))))) {
1490+
if (BitWidth > 2 && (*MulC - 1).isPowerOf2() &&
1491+
MulC->logBase2() == ShAmtC) {
1492+
return BinaryOperator::CreateNSWAdd(
1493+
X, Builder.CreateLShr(X, ConstantInt::get(Ty, ShAmtC), "",
1494+
I.isExact()));
14831495
}
14841496
}
14851497

@@ -1686,6 +1698,21 @@ Instruction *InstCombinerImpl::visitAShr(BinaryOperator &I) {
16861698
if (match(Op0, m_OneUse(m_NSWSub(m_Value(X), m_Value(Y)))))
16871699
return new SExtInst(Builder.CreateICmpSLT(X, Y), Ty);
16881700
}
1701+
1702+
const APInt *MulC;
1703+
if (match(Op0, m_OneUse(m_NSWMul(m_Value(X), m_APInt(MulC)))) &&
1704+
(BitWidth > 2 && (*MulC - 1).isPowerOf2() &&
1705+
MulC->logBase2() == ShAmt &&
1706+
(ShAmt < BitWidth - 1))) /* Minus 1 for the sign bit */ {
1707+
1708+
// ashr (mul nsw (X, 2^N + 1)), N -> add nsw (X, ashr(X, N))
1709+
auto *NewAdd = BinaryOperator::CreateNSWAdd(
1710+
X,
1711+
Builder.CreateAShr(X, ConstantInt::get(Ty, ShAmt), "", I.isExact()));
1712+
NewAdd->setHasNoUnsignedWrap(
1713+
cast<OverflowingBinaryOperator>(Op0)->hasNoUnsignedWrap());
1714+
return NewAdd;
1715+
}
16891716
}
16901717

16911718
const SimplifyQuery Q = SQ.getWithInstruction(&I);

llvm/test/Transforms/InstCombine/ashr-lshr.ll

Lines changed: 49 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -607,8 +607,8 @@ define <2 x i8> @ashr_known_pos_exact_vec(<2 x i8> %x, <2 x i8> %y) {
607607

608608
define i32 @lshr_mul_times_3_div_2(i32 %0) {
609609
; CHECK-LABEL: @lshr_mul_times_3_div_2(
610-
; CHECK-NEXT: [[MUL:%.*]] = mul nuw nsw i32 [[TMP0:%.*]], 3
611-
; CHECK-NEXT: [[LSHR:%.*]] = lshr i32 [[MUL]], 1
610+
; CHECK-NEXT: [[TMP2:%.*]] = lshr i32 [[TMP0:%.*]], 1
611+
; CHECK-NEXT: [[LSHR:%.*]] = add nuw nsw i32 [[TMP2]], [[TMP0]]
612612
; CHECK-NEXT: ret i32 [[LSHR]]
613613
;
614614
%mul = mul nsw nuw i32 %0, 3
@@ -618,8 +618,8 @@ define i32 @lshr_mul_times_3_div_2(i32 %0) {
618618

619619
define i32 @lshr_mul_times_3_div_2_exact(i32 %x) {
620620
; CHECK-LABEL: @lshr_mul_times_3_div_2_exact(
621-
; CHECK-NEXT: [[MUL:%.*]] = mul nsw i32 [[X:%.*]], 3
622-
; CHECK-NEXT: [[LSHR:%.*]] = lshr exact i32 [[MUL]], 1
621+
; CHECK-NEXT: [[TMP1:%.*]] = lshr exact i32 [[X:%.*]], 1
622+
; CHECK-NEXT: [[LSHR:%.*]] = add nsw i32 [[TMP1]], [[X]]
623623
; CHECK-NEXT: ret i32 [[LSHR]]
624624
;
625625
%mul = mul nsw i32 %x, 3
@@ -657,8 +657,8 @@ define i32 @mul_times_3_div_2_multiuse_lshr(i32 %x) {
657657

658658
define i32 @lshr_mul_times_3_div_2_exact_2(i32 %x) {
659659
; CHECK-LABEL: @lshr_mul_times_3_div_2_exact_2(
660-
; CHECK-NEXT: [[MUL:%.*]] = mul nuw i32 [[X:%.*]], 3
661-
; CHECK-NEXT: [[LSHR:%.*]] = lshr exact i32 [[MUL]], 1
660+
; CHECK-NEXT: [[TMP1:%.*]] = lshr exact i32 [[X:%.*]], 1
661+
; CHECK-NEXT: [[LSHR:%.*]] = add nuw i32 [[TMP1]], [[X]]
662662
; CHECK-NEXT: ret i32 [[LSHR]]
663663
;
664664
%mul = mul nuw i32 %x, 3
@@ -668,8 +668,8 @@ define i32 @lshr_mul_times_3_div_2_exact_2(i32 %x) {
668668

669669
define i32 @lshr_mul_times_5_div_4(i32 %0) {
670670
; CHECK-LABEL: @lshr_mul_times_5_div_4(
671-
; CHECK-NEXT: [[MUL:%.*]] = mul nuw nsw i32 [[TMP0:%.*]], 5
672-
; CHECK-NEXT: [[LSHR:%.*]] = lshr i32 [[MUL]], 2
671+
; CHECK-NEXT: [[TMP2:%.*]] = lshr i32 [[TMP0:%.*]], 2
672+
; CHECK-NEXT: [[LSHR:%.*]] = add nuw nsw i32 [[TMP2]], [[TMP0]]
673673
; CHECK-NEXT: ret i32 [[LSHR]]
674674
;
675675
%mul = mul nsw nuw i32 %0, 5
@@ -679,8 +679,8 @@ define i32 @lshr_mul_times_5_div_4(i32 %0) {
679679

680680
define i32 @lshr_mul_times_5_div_4_exact(i32 %x) {
681681
; CHECK-LABEL: @lshr_mul_times_5_div_4_exact(
682-
; CHECK-NEXT: [[MUL:%.*]] = mul nsw i32 [[X:%.*]], 5
683-
; CHECK-NEXT: [[LSHR:%.*]] = lshr exact i32 [[MUL]], 2
682+
; CHECK-NEXT: [[TMP1:%.*]] = lshr exact i32 [[X:%.*]], 2
683+
; CHECK-NEXT: [[LSHR:%.*]] = add nsw i32 [[TMP1]], [[X]]
684684
; CHECK-NEXT: ret i32 [[LSHR]]
685685
;
686686
%mul = mul nsw i32 %x, 5
@@ -718,8 +718,8 @@ define i32 @mul_times_5_div_4_multiuse_lshr(i32 %x) {
718718

719719
define i32 @lshr_mul_times_5_div_4_exact_2(i32 %x) {
720720
; CHECK-LABEL: @lshr_mul_times_5_div_4_exact_2(
721-
; CHECK-NEXT: [[MUL:%.*]] = mul nuw i32 [[X:%.*]], 5
722-
; CHECK-NEXT: [[LSHR:%.*]] = lshr exact i32 [[MUL]], 2
721+
; CHECK-NEXT: [[TMP1:%.*]] = lshr exact i32 [[X:%.*]], 2
722+
; CHECK-NEXT: [[LSHR:%.*]] = add nuw i32 [[TMP1]], [[X]]
723723
; CHECK-NEXT: ret i32 [[LSHR]]
724724
;
725725
%mul = mul nuw i32 %x, 5
@@ -729,8 +729,8 @@ define i32 @lshr_mul_times_5_div_4_exact_2(i32 %x) {
729729

730730
define i32 @ashr_mul_times_3_div_2(i32 %0) {
731731
; CHECK-LABEL: @ashr_mul_times_3_div_2(
732-
; CHECK-NEXT: [[MUL:%.*]] = mul nuw nsw i32 [[TMP0:%.*]], 3
733-
; CHECK-NEXT: [[ASHR:%.*]] = ashr i32 [[MUL]], 1
732+
; CHECK-NEXT: [[TMP2:%.*]] = ashr i32 [[TMP0:%.*]], 1
733+
; CHECK-NEXT: [[ASHR:%.*]] = add nuw nsw i32 [[TMP2]], [[TMP0]]
734734
; CHECK-NEXT: ret i32 [[ASHR]]
735735
;
736736
%mul = mul nuw nsw i32 %0, 3
@@ -740,8 +740,8 @@ define i32 @ashr_mul_times_3_div_2(i32 %0) {
740740

741741
define i32 @ashr_mul_times_3_div_2_exact(i32 %x) {
742742
; CHECK-LABEL: @ashr_mul_times_3_div_2_exact(
743-
; CHECK-NEXT: [[MUL:%.*]] = mul nsw i32 [[X:%.*]], 3
744-
; CHECK-NEXT: [[ASHR:%.*]] = ashr exact i32 [[MUL]], 1
743+
; CHECK-NEXT: [[TMP1:%.*]] = ashr exact i32 [[X:%.*]], 1
744+
; CHECK-NEXT: [[ASHR:%.*]] = add nsw i32 [[TMP1]], [[X]]
745745
; CHECK-NEXT: ret i32 [[ASHR]]
746746
;
747747
%mul = mul nsw i32 %x, 3
@@ -792,8 +792,8 @@ define i32 @mul_times_3_div_2_multiuse_ashr(i32 %x) {
792792

793793
define i32 @ashr_mul_times_3_div_2_exact_2(i32 %x) {
794794
; CHECK-LABEL: @ashr_mul_times_3_div_2_exact_2(
795-
; CHECK-NEXT: [[MUL:%.*]] = mul nsw i32 [[X:%.*]], 3
796-
; CHECK-NEXT: [[ASHR:%.*]] = ashr exact i32 [[MUL]], 1
795+
; CHECK-NEXT: [[TMP1:%.*]] = ashr exact i32 [[X:%.*]], 1
796+
; CHECK-NEXT: [[ASHR:%.*]] = add nsw i32 [[TMP1]], [[X]]
797797
; CHECK-NEXT: ret i32 [[ASHR]]
798798
;
799799
%mul = mul nsw i32 %x, 3
@@ -803,8 +803,8 @@ define i32 @ashr_mul_times_3_div_2_exact_2(i32 %x) {
803803

804804
define i32 @ashr_mul_times_5_div_4(i32 %0) {
805805
; CHECK-LABEL: @ashr_mul_times_5_div_4(
806-
; CHECK-NEXT: [[MUL:%.*]] = mul nuw nsw i32 [[TMP0:%.*]], 5
807-
; CHECK-NEXT: [[ASHR:%.*]] = ashr i32 [[MUL]], 2
806+
; CHECK-NEXT: [[TMP2:%.*]] = ashr i32 [[TMP0:%.*]], 2
807+
; CHECK-NEXT: [[ASHR:%.*]] = add nuw nsw i32 [[TMP2]], [[TMP0]]
808808
; CHECK-NEXT: ret i32 [[ASHR]]
809809
;
810810
%mul = mul nuw nsw i32 %0, 5
@@ -814,8 +814,8 @@ define i32 @ashr_mul_times_5_div_4(i32 %0) {
814814

815815
define i32 @ashr_mul_times_5_div_4_exact(i32 %x) {
816816
; CHECK-LABEL: @ashr_mul_times_5_div_4_exact(
817-
; CHECK-NEXT: [[MUL:%.*]] = mul nsw i32 [[X:%.*]], 5
818-
; CHECK-NEXT: [[ASHR:%.*]] = ashr exact i32 [[MUL]], 2
817+
; CHECK-NEXT: [[TMP1:%.*]] = ashr exact i32 [[X:%.*]], 2
818+
; CHECK-NEXT: [[ASHR:%.*]] = add nsw i32 [[TMP1]], [[X]]
819819
; CHECK-NEXT: ret i32 [[ASHR]]
820820
;
821821
%mul = mul nsw i32 %x, 5
@@ -853,13 +853,38 @@ define i32 @mul_times_5_div_4_multiuse_ashr(i32 %x) {
853853

854854
define i32 @ashr_mul_times_5_div_4_exact_2(i32 %x) {
855855
; CHECK-LABEL: @ashr_mul_times_5_div_4_exact_2(
856-
; CHECK-NEXT: [[MUL:%.*]] = mul nsw i32 [[X:%.*]], 5
857-
; CHECK-NEXT: [[ASHR:%.*]] = ashr exact i32 [[MUL]], 2
856+
; CHECK-NEXT: [[TMP1:%.*]] = ashr exact i32 [[X:%.*]], 2
857+
; CHECK-NEXT: [[ASHR:%.*]] = add nsw i32 [[TMP1]], [[X]]
858858
; CHECK-NEXT: ret i32 [[ASHR]]
859859
;
860860
%mul = mul nsw i32 %x, 5
861861
%ashr = ashr exact i32 %mul, 2
862862
ret i32 %ashr
863863
}
864864

865+
define i32 @mul_splat_fold_known_active_bits(i32 %x) {
866+
; CHECK-LABEL: @mul_splat_fold_known_active_bits(
867+
; CHECK-NEXT: [[M:%.*]] = mul nuw i32 [[X:%.*]], 65537
868+
; CHECK-NEXT: [[T:%.*]] = ashr i32 [[M]], 16
869+
; CHECK-NEXT: ret i32 [[T]]
870+
;
871+
%xx = and i32 %x, 360
872+
%m = mul nuw i32 %x, 65537
873+
%t = ashr i32 %m, 16
874+
ret i32 %t
875+
}
876+
877+
; Negative test
878+
879+
define i32 @mul_splat_fold_no_known_active_bits(i32 %x) {
880+
; CHECK-LABEL: @mul_splat_fold_no_known_active_bits(
881+
; CHECK-NEXT: [[TMP1:%.*]] = ashr i32 [[X:%.*]], 16
882+
; CHECK-NEXT: [[T:%.*]] = add nsw i32 [[TMP1]], [[X]]
883+
; CHECK-NEXT: ret i32 [[T]]
884+
;
885+
%m = mul nsw i32 %x, 65537
886+
%t = ashr i32 %m, 16
887+
ret i32 %t
888+
}
889+
865890
declare void @use(i32)

llvm/test/Transforms/InstCombine/lshr.ll

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -348,8 +348,7 @@ define <2 x i32> @narrow_lshr_constant(<2 x i8> %x, <2 x i8> %y) {
348348

349349
define i32 @mul_splat_fold(i32 %x) {
350350
; CHECK-LABEL: @mul_splat_fold(
351-
; CHECK-NEXT: [[T:%.*]] = and i32 [[X:%.*]], 65535
352-
; CHECK-NEXT: ret i32 [[T]]
351+
; CHECK-NEXT: ret i32 [[X:%.*]]
353352
;
354353
%m = mul nuw i32 %x, 65537
355354
%t = lshr i32 %m, 16
@@ -362,8 +361,7 @@ define <3 x i14> @mul_splat_fold_vec(<3 x i14> %x) {
362361
; CHECK-LABEL: @mul_splat_fold_vec(
363362
; CHECK-NEXT: [[M:%.*]] = mul nuw <3 x i14> [[X:%.*]], <i14 129, i14 129, i14 129>
364363
; CHECK-NEXT: call void @usevec(<3 x i14> [[M]])
365-
; CHECK-NEXT: [[T:%.*]] = and <3 x i14> [[X]], <i14 127, i14 127, i14 127>
366-
; CHECK-NEXT: ret <3 x i14> [[T]]
364+
; CHECK-NEXT: ret <3 x i14> [[X]]
367365
;
368366
%m = mul nuw <3 x i14> %x, <i14 129, i14 129, i14 129>
369367
call void @usevec(<3 x i14> %m)
@@ -632,16 +630,16 @@ define i32 @mul_splat_fold_wrong_lshr_const(i32 %x) {
632630

633631
define i32 @mul_splat_fold_no_nuw(i32 %x) {
634632
; CHECK-LABEL: @mul_splat_fold_no_nuw(
635-
; CHECK-NEXT: [[M:%.*]] = mul nsw i32 [[X:%.*]], 65537
636-
; CHECK-NEXT: [[T:%.*]] = lshr i32 [[M]], 16
633+
; CHECK-NEXT: [[TMP1:%.*]] = lshr i32 [[X:%.*]], 16
634+
; CHECK-NEXT: [[T:%.*]] = add nsw i32 [[TMP1]], [[X]]
637635
; CHECK-NEXT: ret i32 [[T]]
638636
;
639637
%m = mul nsw i32 %x, 65537
640638
%t = lshr i32 %m, 16
641639
ret i32 %t
642640
}
643641

644-
; Negative test
642+
; Negative test
645643

646644
define i32 @mul_splat_fold_no_flags(i32 %x) {
647645
; CHECK-LABEL: @mul_splat_fold_no_flags(

0 commit comments

Comments
 (0)