Skip to content

Commit 009f863

Browse files
committed
1 parent 6d48496 commit 009f863

File tree

4 files changed

+102
-33
lines changed

4 files changed

+102
-33
lines changed

llvm/lib/Analysis/InstructionSimplify.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1479,6 +1479,29 @@ static Value *simplifyLShrInst(Value *Op0, Value *Op1, bool IsExact,
14791479
if (Q.IIQ.UseInstrInfo && match(Op0, m_NUWShl(m_Value(X), m_Specific(Op1))))
14801480
return X;
14811481

1482+
// Look for a "splat" mul pattern - it replicates bits across each half
1483+
// of a value, so a right shift is just a mask of the low bits:
1484+
const APInt *MulC;
1485+
const APInt *ShAmt;
1486+
if (Q.IIQ.UseInstrInfo && match(Op0, m_NUWMul(m_Value(X), m_APInt(MulC))) &&
1487+
match(Op1, m_APInt(ShAmt))) {
1488+
unsigned ShAmtC = ShAmt->getZExtValue();
1489+
unsigned BitWidth = ShAmt->getBitWidth();
1490+
if (BitWidth > 2 && (*MulC - 1).isPowerOf2() &&
1491+
MulC->logBase2() == ShAmtC) {
1492+
// FIXME: This condition should be covered by the computeKnownBits, but
1493+
// for some reason it is not, so keep this in for now. This has no
1494+
// negative effects, but KnownBits should be able to infer a number of
1495+
// leading bits based on 2^N + 1 not wrapping, as that means 2^N must not
1496+
// wrap either, which means the top N bits of X must be 0.
1497+
if (ShAmtC * 2 == BitWidth)
1498+
return X;
1499+
const KnownBits XKnown = computeKnownBits(X, /* Depth */ 0, Q);
1500+
if (XKnown.countMaxActiveBits() <= ShAmtC)
1501+
return X;
1502+
}
1503+
}
1504+
14821505
// ((X << A) | Y) >> A -> X if effective width of Y is not larger than A.
14831506
// We can return X as we do in the above case since OR alters no bits in X.
14841507
// SimplifyDemandedBits in InstCombine can do more general optimization for
@@ -1523,6 +1546,22 @@ static Value *simplifyAShrInst(Value *Op0, Value *Op1, bool IsExact,
15231546
if (Q.IIQ.UseInstrInfo && match(Op0, m_NSWShl(m_Value(X), m_Specific(Op1))))
15241547
return X;
15251548

1549+
const APInt *MulC;
1550+
const APInt *ShAmt;
1551+
if (Q.IIQ.UseInstrInfo && match(Op0, m_NUWMul(m_Value(X), m_APInt(MulC))) &&
1552+
match(Op1, m_APInt(ShAmt)) &&
1553+
cast<OverflowingBinaryOperator>(Op0)->hasNoSignedWrap()) {
1554+
unsigned ShAmtC = ShAmt->getZExtValue();
1555+
unsigned BitWidth = ShAmt->getBitWidth();
1556+
if (BitWidth > 2 && (*MulC - 1).isPowerOf2() &&
1557+
MulC->logBase2() == ShAmtC &&
1558+
ShAmtC < BitWidth - 1) /* Minus 1 for the sign bit */ {
1559+
KnownBits KnownX = computeKnownBits(X, /* Depth */ 0, Q);
1560+
if (KnownX.countMaxActiveBits() <= ShAmtC)
1561+
return X;
1562+
}
1563+
}
1564+
15261565
// Arithmetic shifting an all-sign-bit value is a no-op.
15271566
unsigned NumSignBits = ComputeNumSignBits(Op0, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
15281567
if (NumSignBits == Op0->getType()->getScalarSizeInBits())

llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1456,41 +1456,42 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) {
14561456
}
14571457

14581458
const APInt *MulC;
1459-
if (match(Op0, m_NUWMul(m_Value(X), m_APInt(MulC)))) {
1459+
if (match(Op0, m_OneUse(m_NUWMul(m_Value(X), m_APInt(MulC))))) {
14601460
if (BitWidth > 2 && (*MulC - 1).isPowerOf2() &&
14611461
MulC->logBase2() == ShAmtC) {
1462-
// Look for a "splat" mul pattern - it replicates bits across each half
1463-
// of a value, so a right shift is just a mask of the low bits:
1464-
// lshr i[2N] (mul nuw X, (2^N)+1), N --> and iN X, (2^N)-1
1465-
if (ShAmtC * 2 == BitWidth)
1466-
return BinaryOperator::CreateAnd(X, ConstantInt::get(Ty, *MulC - 2));
14671462

14681463
// lshr (mul nuw (X, 2^N + 1)), N -> add nuw (X, lshr(X, N))
1469-
if (Op0->hasOneUse()) {
1470-
auto *NewAdd = BinaryOperator::CreateNUWAdd(
1471-
X, Builder.CreateLShr(X, ConstantInt::get(Ty, ShAmtC), "",
1472-
I.isExact()));
1473-
NewAdd->setHasNoSignedWrap(
1474-
cast<OverflowingBinaryOperator>(Op0)->hasNoSignedWrap());
1475-
return NewAdd;
1476-
}
1464+
auto *NewAdd = BinaryOperator::CreateNUWAdd(
1465+
X, Builder.CreateLShr(X, ConstantInt::get(Ty, ShAmtC), "",
1466+
I.isExact()));
1467+
NewAdd->setHasNoSignedWrap(
1468+
cast<OverflowingBinaryOperator>(Op0)->hasNoSignedWrap());
1469+
return NewAdd;
14771470
}
14781471

14791472
// The one-use check is not strictly necessary, but codegen may not be
14801473
// able to invert the transform and perf may suffer with an extra mul
14811474
// instruction.
1482-
if (Op0->hasOneUse()) {
1483-
APInt NewMulC = MulC->lshr(ShAmtC);
1484-
// if c is divisible by (1 << ShAmtC):
1485-
// lshr (mul nuw x, MulC), ShAmtC -> mul nuw nsw x, (MulC >> ShAmtC)
1486-
if (MulC->eq(NewMulC.shl(ShAmtC))) {
1487-
auto *NewMul =
1488-
BinaryOperator::CreateNUWMul(X, ConstantInt::get(Ty, NewMulC));
1489-
assert(ShAmtC != 0 &&
1490-
"lshr X, 0 should be handled by simplifyLShrInst.");
1491-
NewMul->setHasNoSignedWrap(true);
1492-
return NewMul;
1493-
}
1475+
APInt NewMulC = MulC->lshr(ShAmtC);
1476+
// if c is divisible by (1 << ShAmtC):
1477+
// lshr (mul nuw x, MulC), ShAmtC -> mul nuw nsw x, (MulC >> ShAmtC)
1478+
if (MulC->eq(NewMulC.shl(ShAmtC))) {
1479+
auto *NewMul =
1480+
BinaryOperator::CreateNUWMul(X, ConstantInt::get(Ty, NewMulC));
1481+
assert(ShAmtC != 0 &&
1482+
"lshr X, 0 should be handled by simplifyLShrInst.");
1483+
NewMul->setHasNoSignedWrap(true);
1484+
return NewMul;
1485+
}
1486+
}
1487+
1488+
// lshr (mul nsw (X, 2^N + 1)), N -> add nsw (X, lshr(X, N))
1489+
if (match(Op0, m_OneUse(m_NSWMul(m_Value(X), m_APInt(MulC))))) {
1490+
if (BitWidth > 2 && (*MulC - 1).isPowerOf2() &&
1491+
MulC->logBase2() == ShAmtC) {
1492+
return BinaryOperator::CreateNSWAdd(
1493+
X, Builder.CreateLShr(X, ConstantInt::get(Ty, ShAmtC), "",
1494+
I.isExact()));
14941495
}
14951496
}
14961497

llvm/test/Transforms/InstCombine/ashr-lshr.ll

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -862,4 +862,26 @@ define i32 @ashr_mul_times_5_div_4_exact_2(i32 %x) {
862862
ret i32 %ashr
863863
}
864864

865+
define i32 @mul_splat_fold_known_active_bits(i32 %x) {
866+
; CHECK-LABEL: @mul_splat_fold_known_active_bits(
867+
; CHECK-NEXT: [[XX:%.*]] = and i32 [[X:%.*]], 360
868+
; CHECK-NEXT: ret i32 [[XX]]
869+
;
870+
%xx = and i32 %x, 360
871+
%m = mul nuw i32 %xx, 65537
872+
%t = ashr i32 %m, 16
873+
ret i32 %t
874+
}
875+
876+
define i32 @mul_splat_fold_no_known_active_bits(i32 %x) {
877+
; CHECK-LABEL: @mul_splat_fold_no_known_active_bits(
878+
; CHECK-NEXT: [[TMP1:%.*]] = ashr i32 [[X:%.*]], 16
879+
; CHECK-NEXT: [[T:%.*]] = add nsw i32 [[TMP1]], [[X]]
880+
; CHECK-NEXT: ret i32 [[T]]
881+
;
882+
%m = mul nsw i32 %x, 65537
883+
%t = ashr i32 %m, 16
884+
ret i32 %t
885+
}
886+
865887
declare void @use(i32)

llvm/test/Transforms/InstCombine/lshr.ll

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -348,22 +348,31 @@ define <2 x i32> @narrow_lshr_constant(<2 x i8> %x, <2 x i8> %y) {
348348

349349
define i32 @mul_splat_fold(i32 %x) {
350350
; CHECK-LABEL: @mul_splat_fold(
351-
; CHECK-NEXT: [[T:%.*]] = and i32 [[X:%.*]], 65535
352-
; CHECK-NEXT: ret i32 [[T]]
351+
; CHECK-NEXT: ret i32 [[X:%.*]]
353352
;
354353
%m = mul nuw i32 %x, 65537
355354
%t = lshr i32 %m, 16
356355
ret i32 %t
357356
}
358357

358+
define i32 @mul_splat_fold_known_zeros(i32 %x) {
359+
; CHECK-LABEL: @mul_splat_fold_known_zeros(
360+
; CHECK-NEXT: [[XX:%.*]] = and i32 [[X:%.*]], 360
361+
; CHECK-NEXT: ret i32 [[XX]]
362+
;
363+
%xx = and i32 %x, 360
364+
%m = mul nuw i32 %xx, 65537
365+
%t = lshr i32 %m, 16
366+
ret i32 %t
367+
}
368+
359369
; Vector type, extra use, weird types are all ok.
360370

361371
define <3 x i14> @mul_splat_fold_vec(<3 x i14> %x) {
362372
; CHECK-LABEL: @mul_splat_fold_vec(
363373
; CHECK-NEXT: [[M:%.*]] = mul nuw <3 x i14> [[X:%.*]], <i14 129, i14 129, i14 129>
364374
; CHECK-NEXT: call void @usevec(<3 x i14> [[M]])
365-
; CHECK-NEXT: [[T:%.*]] = and <3 x i14> [[X]], <i14 127, i14 127, i14 127>
366-
; CHECK-NEXT: ret <3 x i14> [[T]]
375+
; CHECK-NEXT: ret <3 x i14> [[X]]
367376
;
368377
%m = mul nuw <3 x i14> %x, <i14 129, i14 129, i14 129>
369378
call void @usevec(<3 x i14> %m)
@@ -628,8 +637,6 @@ define i32 @mul_splat_fold_wrong_lshr_const(i32 %x) {
628637
ret i32 %t
629638
}
630639

631-
; Negative test (but simplifies into a different transform)
632-
633640
define i32 @mul_splat_fold_no_nuw(i32 %x) {
634641
; CHECK-LABEL: @mul_splat_fold_no_nuw(
635642
; CHECK-NEXT: [[TMP1:%.*]] = lshr i32 [[X:%.*]], 16
@@ -641,7 +648,7 @@ define i32 @mul_splat_fold_no_nuw(i32 %x) {
641648
ret i32 %t
642649
}
643650

644-
; Negative test
651+
; Negative test
645652

646653
define i32 @mul_splat_fold_no_flags(i32 %x) {
647654
; CHECK-LABEL: @mul_splat_fold_no_flags(

0 commit comments

Comments
 (0)