Skip to content

Commit e9996d1

Browse files
committed
[InstCombine] Extend bitmask mul combine to handle independent operands
Change-Id: Ife1a010d2ae6df40549a6c73f7b893948befa3be
1 parent cc17f68 commit e9996d1

File tree

2 files changed

+123
-19
lines changed

2 files changed

+123
-19
lines changed

llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp

Lines changed: 73 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3602,6 +3602,11 @@ struct DecomposedBitMaskMul {
36023602
APInt Mask;
36033603
bool NUW;
36043604
bool NSW;
3605+
3606+
bool isCombineableWith(DecomposedBitMaskMul Other) {
3607+
return X == Other.X && (Mask & Other.Mask).isZero() &&
3608+
Factor == Other.Factor;
3609+
}
36053610
};
36063611

36073612
static std::optional<DecomposedBitMaskMul> matchBitmaskMul(Value *V) {
@@ -3659,6 +3664,34 @@ static std::optional<DecomposedBitMaskMul> matchBitmaskMul(Value *V) {
36593664
return std::nullopt;
36603665
}
36613666

3667+
using CombinedBitmaskMul =
3668+
std::pair<std::optional<DecomposedBitMaskMul>, Value *>;
3669+
3670+
static CombinedBitmaskMul matchCombinedBitmaskMul(Value *V) {
3671+
auto DecompBitMaskMul = matchBitmaskMul(V);
3672+
if (DecompBitMaskMul)
3673+
return {DecompBitMaskMul, nullptr};
3674+
3675+
// Otherwise, check the operands of V for bitmaskmul pattern
3676+
auto BOp = dyn_cast<BinaryOperator>(V);
3677+
if (!BOp)
3678+
return {std::nullopt, nullptr};
3679+
3680+
auto Disj = dyn_cast<PossiblyDisjointInst>(BOp);
3681+
if (!Disj || !Disj->isDisjoint())
3682+
return {std::nullopt, nullptr};
3683+
3684+
auto DecompBitMaskMul0 = matchBitmaskMul(BOp->getOperand(0));
3685+
if (DecompBitMaskMul0)
3686+
return {DecompBitMaskMul0, BOp->getOperand(1)};
3687+
3688+
auto DecompBitMaskMul1 = matchBitmaskMul(BOp->getOperand(1));
3689+
if (DecompBitMaskMul1)
3690+
return {DecompBitMaskMul1, BOp->getOperand(0)};
3691+
3692+
return {std::nullopt, nullptr};
3693+
}
3694+
36623695
// FIXME: We use commutative matchers (m_c_*) for some, but not all, matches
36633696
// here. We should standardize that construct where it is needed or choose some
36643697
// other way to ensure that commutated variants of patterns are not missed.
@@ -3741,25 +3774,46 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
37413774
/*NSW=*/true, /*NUW=*/true))
37423775
return R;
37433776

3744-
// (A & N) * C + (A & M) * C -> (A & (N + M)) & C
3745-
// This also accepts the equivalent select form of (A & N) * C
3746-
// expressions i.e. !(A & N) ? 0 : N * C)
3747-
auto Decomp1 = matchBitmaskMul(I.getOperand(1));
3748-
if (Decomp1) {
3749-
auto Decomp0 = matchBitmaskMul(I.getOperand(0));
3750-
if (Decomp0 && Decomp0->X == Decomp1->X &&
3751-
(Decomp0->Mask & Decomp1->Mask).isZero() &&
3752-
Decomp0->Factor == Decomp1->Factor) {
3753-
3754-
Value *NewAnd = Builder.CreateAnd(
3755-
Decomp0->X, ConstantInt::get(Decomp0->X->getType(),
3756-
(Decomp0->Mask + Decomp1->Mask)));
3757-
3758-
auto *Combined = BinaryOperator::CreateMul(
3759-
NewAnd, ConstantInt::get(NewAnd->getType(), Decomp1->Factor));
3760-
3761-
Combined->setHasNoUnsignedWrap(Decomp0->NUW && Decomp1->NUW);
3762-
Combined->setHasNoSignedWrap(Decomp0->NSW && Decomp1->NSW);
3777+
// (!(A & N) ? 0 : N * C) + (!(A & M) ? 0 : M * C) -> A & (N + M) * C
3778+
// This also accepts the equivalent mul form of (A & N) ? 0 : N * C)
3779+
// expressions i.e. (A & N) * C
3780+
CombinedBitmaskMul Decomp1 = matchCombinedBitmaskMul(I.getOperand(1));
3781+
auto BMDecomp1 = Decomp1.first;
3782+
3783+
if (BMDecomp1) {
3784+
CombinedBitmaskMul Decomp0 = matchCombinedBitmaskMul(I.getOperand(0));
3785+
auto BMDecomp0 = Decomp0.first;
3786+
3787+
if (BMDecomp0 && BMDecomp0->isCombineableWith(*BMDecomp1)) {
3788+
auto NewAnd = Builder.CreateAnd(
3789+
BMDecomp0->X,
3790+
ConstantInt::get(BMDecomp0->X->getType(),
3791+
(BMDecomp0->Mask + BMDecomp1->Mask)));
3792+
3793+
BinaryOperator *Combined = cast<BinaryOperator>(Builder.CreateMul(
3794+
NewAnd, ConstantInt::get(NewAnd->getType(), BMDecomp1->Factor)));
3795+
3796+
Combined->setHasNoUnsignedWrap(BMDecomp0->NUW && BMDecomp1->NUW);
3797+
Combined->setHasNoSignedWrap(BMDecomp0->NSW && BMDecomp1->NSW);
3798+
3799+
// If our tree has indepdent or-disjoint operands, bring them in.
3800+
auto OtherOp0 = Decomp0.second;
3801+
auto OtherOp1 = Decomp1.second;
3802+
3803+
if (OtherOp0 || OtherOp1) {
3804+
Value *OtherOp;
3805+
if (OtherOp0 && OtherOp1) {
3806+
OtherOp = Builder.CreateOr(OtherOp0, OtherOp1);
3807+
cast<PossiblyDisjointInst>(OtherOp)->setIsDisjoint(true);
3808+
} else {
3809+
OtherOp = OtherOp0 ? OtherOp0 : OtherOp1;
3810+
}
3811+
Combined = cast<BinaryOperator>(Builder.CreateOr(Combined, OtherOp));
3812+
cast<PossiblyDisjointInst>(Combined)->setIsDisjoint(true);
3813+
}
3814+
3815+
// Caller expects detached instruction
3816+
Combined->removeFromParent();
37633817
return Combined;
37643818
}
37653819
}

llvm/test/Transforms/InstCombine/or-bitmask.ll

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,56 @@ define i32 @and_mul_non_disjoint(i32 %in) {
451451
ret i32 %out
452452
}
453453

454+
define i32 @unrelated_ops(i32 %in, i32 %in2) {
455+
; CHECK-LABEL: @unrelated_ops(
456+
; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IN:%.*]], 15
457+
; CHECK-NEXT: [[TMP2:%.*]] = mul nuw nsw i32 [[TMP1]], 72
458+
; CHECK-NEXT: [[OUT:%.*]] = or disjoint i32 [[TMP2]], [[IN2:%.*]]
459+
; CHECK-NEXT: ret i32 [[OUT]]
460+
;
461+
%1 = and i32 %in, 3
462+
%temp = mul nuw nsw i32 %1, 72
463+
%2 = and i32 %in, 12
464+
%temp2 = mul nuw nsw i32 %2, 72
465+
%temp3 = or disjoint i32 %in2, %temp2
466+
%out = or disjoint i32 %temp, %temp3
467+
ret i32 %out
468+
}
469+
470+
define i32 @unrelated_ops1(i32 %in, i32 %in2) {
471+
; CHECK-LABEL: @unrelated_ops1(
472+
; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IN:%.*]], 15
473+
; CHECK-NEXT: [[TMP2:%.*]] = mul nuw nsw i32 [[TMP1]], 72
474+
; CHECK-NEXT: [[OUT:%.*]] = or disjoint i32 [[TMP2]], [[IN2:%.*]]
475+
; CHECK-NEXT: ret i32 [[OUT]]
476+
;
477+
%1 = and i32 %in, 3
478+
%temp = mul nuw nsw i32 %1, 72
479+
%2 = and i32 %in, 12
480+
%temp2 = mul nuw nsw i32 %2, 72
481+
%temp3 = or disjoint i32 %in2, %temp
482+
%out = or disjoint i32 %temp3, %temp2
483+
ret i32 %out
484+
}
485+
486+
define i32 @unrelated_ops2(i32 %in, i32 %in2, i32 %in3) {
487+
; CHECK-LABEL: @unrelated_ops2(
488+
; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IN:%.*]], 15
489+
; CHECK-NEXT: [[TMP2:%.*]] = mul nuw nsw i32 [[TMP1]], 72
490+
; CHECK-NEXT: [[TMP3:%.*]] = or disjoint i32 [[IN3:%.*]], [[IN2:%.*]]
491+
; CHECK-NEXT: [[OUT:%.*]] = or disjoint i32 [[TMP2]], [[TMP3]]
492+
; CHECK-NEXT: ret i32 [[OUT]]
493+
;
494+
%1 = and i32 %in, 3
495+
%temp = mul nuw nsw i32 %1, 72
496+
%temp3 = or disjoint i32 %temp, %in3
497+
%2 = and i32 %in, 12
498+
%temp2 = mul nuw nsw i32 %2, 72
499+
%temp4 = or disjoint i32 %in2, %temp2
500+
%out = or disjoint i32 %temp3, %temp4
501+
ret i32 %out
502+
}
503+
454504
;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
455505
; CONSTSPLAT: {{.*}}
456506
; CONSTVEC: {{.*}}

0 commit comments

Comments
 (0)