Skip to content

Commit 8acd710

Browse files
jrbyrnestomtor
authored andcommitted
[InstCombine] Combine or-disjoint (and->mul), (and->mul) to and->mul (llvm#136013)
The canonical pattern for bitmasked mul is currently ``` %val = and %x, %bitMask // where %bitMask is some constant %cmp = icmp eq %val, 0 %sel = select %cmp, 0, %C // where %C is some constant = C' * %bitMask ``` In certain cases, where we are combining multiple of these bitmasked muls with common factors, we are able to optimize into and->mul (see llvm#135274 ) This optimization lends itself to further optimizations. This PR addresses one of such optimizations. In cases where we have `or-disjoint ( mul(and (X, C1), D) , mul (and (X, C2), D))` we can combine into `mul( and (X, (C1 + C2)), D) ` provided C1 and C2 are disjoint. Generalized proof: https://alive2.llvm.org/ce/z/MQYMui
1 parent 8a0a86d commit 8acd710

File tree

2 files changed

+190
-56
lines changed

2 files changed

+190
-56
lines changed

llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp

Lines changed: 87 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -3592,6 +3592,73 @@ static Value *foldOrOfInversions(BinaryOperator &I,
35923592
return nullptr;
35933593
}
35943594

3595+
// A decomposition of ((X & Mask) * Factor). The NUW / NSW bools
3596+
// track these properities for preservation. Note that we can decompose
3597+
// equivalent select form of this expression (e.g. (!(X & Mask) ? 0 : Mask *
3598+
// Factor))
3599+
struct DecomposedBitMaskMul {
3600+
Value *X;
3601+
APInt Factor;
3602+
APInt Mask;
3603+
bool NUW;
3604+
bool NSW;
3605+
};
3606+
3607+
static std::optional<DecomposedBitMaskMul> matchBitmaskMul(Value *V) {
3608+
Instruction *Op = dyn_cast<Instruction>(V);
3609+
if (!Op)
3610+
return std::nullopt;
3611+
3612+
// Decompose (A & N) * C) into BitMaskMul
3613+
Value *Original = nullptr;
3614+
const APInt *Mask = nullptr;
3615+
const APInt *MulConst = nullptr;
3616+
if (match(Op, m_Mul(m_And(m_Value(Original), m_APInt(Mask)),
3617+
m_APInt(MulConst)))) {
3618+
if (MulConst->isZero() || Mask->isZero())
3619+
return std::nullopt;
3620+
3621+
return std::optional<DecomposedBitMaskMul>(
3622+
{Original, *MulConst, *Mask,
3623+
cast<BinaryOperator>(Op)->hasNoUnsignedWrap(),
3624+
cast<BinaryOperator>(Op)->hasNoSignedWrap()});
3625+
}
3626+
3627+
Value *Cond = nullptr;
3628+
const APInt *EqZero = nullptr, *NeZero = nullptr;
3629+
3630+
// Decompose ((A & N) ? 0 : N * C) into BitMaskMul
3631+
if (match(Op, m_Select(m_Value(Cond), m_APInt(EqZero), m_APInt(NeZero)))) {
3632+
auto ICmpDecompose =
3633+
decomposeBitTest(Cond, /*LookThruTrunc=*/true,
3634+
/*AllowNonZeroC=*/false, /*DecomposeBitMask=*/true);
3635+
if (!ICmpDecompose.has_value())
3636+
return std::nullopt;
3637+
3638+
assert(ICmpInst::isEquality(ICmpDecompose->Pred) &&
3639+
ICmpDecompose->C.isZero());
3640+
3641+
if (ICmpDecompose->Pred == ICmpInst::ICMP_NE)
3642+
std::swap(EqZero, NeZero);
3643+
3644+
if (!EqZero->isZero() || NeZero->isZero())
3645+
return std::nullopt;
3646+
3647+
if (!ICmpDecompose->Mask.isPowerOf2() || ICmpDecompose->Mask.isZero() ||
3648+
NeZero->getBitWidth() != ICmpDecompose->Mask.getBitWidth())
3649+
return std::nullopt;
3650+
3651+
if (!NeZero->urem(ICmpDecompose->Mask).isZero())
3652+
return std::nullopt;
3653+
3654+
return std::optional<DecomposedBitMaskMul>(
3655+
{ICmpDecompose->X, NeZero->udiv(ICmpDecompose->Mask),
3656+
ICmpDecompose->Mask, /*NUW=*/false, /*NSW=*/false});
3657+
}
3658+
3659+
return std::nullopt;
3660+
}
3661+
35953662
// FIXME: We use commutative matchers (m_c_*) for some, but not all, matches
35963663
// here. We should standardize that construct where it is needed or choose some
35973664
// other way to ensure that commutated variants of patterns are not missed.
@@ -3674,49 +3741,26 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
36743741
/*NSW=*/true, /*NUW=*/true))
36753742
return R;
36763743

3677-
Value *Cond0 = nullptr, *Cond1 = nullptr;
3678-
const APInt *Op0Eq = nullptr, *Op0Ne = nullptr;
3679-
const APInt *Op1Eq = nullptr, *Op1Ne = nullptr;
3680-
3681-
// (!(A & N) ? 0 : N * C) + (!(A & M) ? 0 : M * C) -> A & (N + M) * C
3682-
if (match(I.getOperand(0),
3683-
m_Select(m_Value(Cond0), m_APInt(Op0Eq), m_APInt(Op0Ne))) &&
3684-
match(I.getOperand(1),
3685-
m_Select(m_Value(Cond1), m_APInt(Op1Eq), m_APInt(Op1Ne)))) {
3686-
3687-
auto LHSDecompose =
3688-
decomposeBitTest(Cond0, /*LookThruTrunc=*/true,
3689-
/*AllowNonZeroC=*/false, /*DecomposeAnd=*/true);
3690-
auto RHSDecompose =
3691-
decomposeBitTest(Cond1, /*LookThruTrunc=*/true,
3692-
/*AllowNonZeroC=*/false, /*DecomposeAnd=*/true);
3693-
3694-
if (LHSDecompose && RHSDecompose && LHSDecompose->X == RHSDecompose->X &&
3695-
RHSDecompose->Mask.isPowerOf2() && LHSDecompose->Mask.isPowerOf2() &&
3696-
LHSDecompose->Mask != RHSDecompose->Mask &&
3697-
LHSDecompose->Mask.getBitWidth() == Op0Ne->getBitWidth() &&
3698-
RHSDecompose->Mask.getBitWidth() == Op1Ne->getBitWidth()) {
3699-
assert(Op0Ne->getBitWidth() == Op1Ne->getBitWidth());
3700-
assert(ICmpInst::isEquality(LHSDecompose->Pred));
3701-
if (LHSDecompose->Pred == ICmpInst::ICMP_NE)
3702-
std::swap(Op0Eq, Op0Ne);
3703-
if (RHSDecompose->Pred == ICmpInst::ICMP_NE)
3704-
std::swap(Op1Eq, Op1Ne);
3705-
3706-
if (!Op0Ne->isZero() && !Op1Ne->isZero() && Op0Eq->isZero() &&
3707-
Op1Eq->isZero() && Op0Ne->urem(LHSDecompose->Mask).isZero() &&
3708-
Op1Ne->urem(RHSDecompose->Mask).isZero() &&
3709-
Op0Ne->udiv(LHSDecompose->Mask) ==
3710-
Op1Ne->udiv(RHSDecompose->Mask)) {
3711-
auto NewAnd = Builder.CreateAnd(
3712-
LHSDecompose->X,
3713-
ConstantInt::get(LHSDecompose->X->getType(),
3714-
(LHSDecompose->Mask + RHSDecompose->Mask)));
3715-
3716-
return BinaryOperator::CreateMul(
3717-
NewAnd, ConstantInt::get(NewAnd->getType(),
3718-
Op0Ne->udiv(LHSDecompose->Mask)));
3719-
}
3744+
// (A & N) * C + (A & M) * C -> (A & (N + M)) & C
3745+
// This also accepts the equivalent select form of (A & N) * C
3746+
// expressions i.e. !(A & N) ? 0 : N * C)
3747+
auto Decomp1 = matchBitmaskMul(I.getOperand(1));
3748+
if (Decomp1) {
3749+
auto Decomp0 = matchBitmaskMul(I.getOperand(0));
3750+
if (Decomp0 && Decomp0->X == Decomp1->X &&
3751+
(Decomp0->Mask & Decomp1->Mask).isZero() &&
3752+
Decomp0->Factor == Decomp1->Factor) {
3753+
3754+
Value *NewAnd = Builder.CreateAnd(
3755+
Decomp0->X, ConstantInt::get(Decomp0->X->getType(),
3756+
(Decomp0->Mask + Decomp1->Mask)));
3757+
3758+
auto *Combined = BinaryOperator::CreateMul(
3759+
NewAnd, ConstantInt::get(NewAnd->getType(), Decomp1->Factor));
3760+
3761+
Combined->setHasNoUnsignedWrap(Decomp0->NUW && Decomp1->NUW);
3762+
Combined->setHasNoSignedWrap(Decomp0->NSW && Decomp1->NSW);
3763+
return Combined;
37203764
}
37213765
}
37223766
}

llvm/test/Transforms/InstCombine/or-bitmask.ll

Lines changed: 103 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,9 @@ define i32 @add_select_cmp_and2(i32 %in) {
3636

3737
define i32 @add_select_cmp_and3(i32 %in) {
3838
; CHECK-LABEL: @add_select_cmp_and3(
39-
; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IN:%.*]], 3
40-
; CHECK-NEXT: [[TEMP:%.*]] = mul nuw nsw i32 [[TMP1]], 72
41-
; CHECK-NEXT: [[BITOP2:%.*]] = and i32 [[IN]], 4
42-
; CHECK-NEXT: [[CMP2:%.*]] = icmp eq i32 [[BITOP2]], 0
43-
; CHECK-NEXT: [[SEL2:%.*]] = select i1 [[CMP2]], i32 0, i32 288
44-
; CHECK-NEXT: [[OUT:%.*]] = or disjoint i32 [[TEMP]], [[SEL2]]
45-
; CHECK-NEXT: ret i32 [[OUT]]
39+
; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IN:%.*]], 7
40+
; CHECK-NEXT: [[TEMP1:%.*]] = mul nuw nsw i32 [[TMP1]], 72
41+
; CHECK-NEXT: ret i32 [[TEMP1]]
4642
;
4743
%bitop0 = and i32 %in, 1
4844
%cmp0 = icmp eq i32 %bitop0, 0
@@ -60,12 +56,9 @@ define i32 @add_select_cmp_and3(i32 %in) {
6056

6157
define i32 @add_select_cmp_and4(i32 %in) {
6258
; CHECK-LABEL: @add_select_cmp_and4(
63-
; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IN:%.*]], 3
64-
; CHECK-NEXT: [[OUT:%.*]] = mul nuw nsw i32 [[TMP1]], 72
65-
; CHECK-NEXT: [[TMP2:%.*]] = and i32 [[IN]], 12
66-
; CHECK-NEXT: [[TEMP3:%.*]] = mul nuw nsw i32 [[TMP2]], 72
67-
; CHECK-NEXT: [[OUT1:%.*]] = or disjoint i32 [[OUT]], [[TEMP3]]
68-
; CHECK-NEXT: ret i32 [[OUT1]]
59+
; CHECK-NEXT: [[TMP2:%.*]] = and i32 [[IN:%.*]], 15
60+
; CHECK-NEXT: [[TEMP2:%.*]] = mul nuw nsw i32 [[TMP2]], 72
61+
; CHECK-NEXT: ret i32 [[TEMP2]]
6962
;
7063
%bitop0 = and i32 %in, 1
7164
%cmp0 = icmp eq i32 %bitop0, 0
@@ -361,6 +354,103 @@ define i64 @mask_select_types_1(i64 %in) {
361354
ret i64 %out
362355
}
363356

357+
define i32 @add_select_cmp_mixed1(i32 %in) {
358+
; CHECK-LABEL: @add_select_cmp_mixed1(
359+
; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IN:%.*]], 3
360+
; CHECK-NEXT: [[OUT:%.*]] = mul nuw nsw i32 [[TMP1]], 72
361+
; CHECK-NEXT: ret i32 [[OUT]]
362+
;
363+
%mask = and i32 %in, 1
364+
%sel0 = mul i32 %mask, 72
365+
%bitop1 = and i32 %in, 2
366+
%cmp1 = icmp eq i32 %bitop1, 0
367+
%sel1 = select i1 %cmp1, i32 0, i32 144
368+
%out = or disjoint i32 %sel0, %sel1
369+
ret i32 %out
370+
}
371+
372+
define i32 @add_select_cmp_mixed2(i32 %in) {
373+
; CHECK-LABEL: @add_select_cmp_mixed2(
374+
; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IN:%.*]], 3
375+
; CHECK-NEXT: [[OUT:%.*]] = mul nuw nsw i32 [[TMP1]], 72
376+
; CHECK-NEXT: ret i32 [[OUT]]
377+
;
378+
%bitop0 = and i32 %in, 1
379+
%cmp0 = icmp eq i32 %bitop0, 0
380+
%mask = and i32 %in, 2
381+
%sel0 = select i1 %cmp0, i32 0, i32 72
382+
%sel1 = mul i32 %mask, 72
383+
%out = or disjoint i32 %sel0, %sel1
384+
ret i32 %out
385+
}
386+
387+
define i32 @add_select_cmp_and_mul(i32 %in) {
388+
; CHECK-LABEL: @add_select_cmp_and_mul(
389+
; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IN:%.*]], 3
390+
; CHECK-NEXT: [[OUT:%.*]] = mul nuw nsw i32 [[TMP1]], 72
391+
; CHECK-NEXT: ret i32 [[OUT]]
392+
;
393+
%mask0 = and i32 %in, 1
394+
%sel0 = mul i32 %mask0, 72
395+
%mask1 = and i32 %in, 2
396+
%sel1 = mul i32 %mask1, 72
397+
%out = or disjoint i32 %sel0, %sel1
398+
ret i32 %out
399+
}
400+
401+
define i32 @add_select_cmp_mixed2_mismatch(i32 %in) {
402+
; CHECK-LABEL: @add_select_cmp_mixed2_mismatch(
403+
; CHECK-NEXT: [[BITOP0:%.*]] = and i32 [[IN:%.*]], 1
404+
; CHECK-NEXT: [[CMP0:%.*]] = icmp eq i32 [[BITOP0]], 0
405+
; CHECK-NEXT: [[MASK:%.*]] = and i32 [[IN]], 2
406+
; CHECK-NEXT: [[SEL0:%.*]] = select i1 [[CMP0]], i32 0, i32 73
407+
; CHECK-NEXT: [[SEL1:%.*]] = mul nuw nsw i32 [[MASK]], 72
408+
; CHECK-NEXT: [[OUT:%.*]] = or disjoint i32 [[SEL0]], [[SEL1]]
409+
; CHECK-NEXT: ret i32 [[OUT]]
410+
;
411+
%bitop0 = and i32 %in, 1
412+
%cmp0 = icmp eq i32 %bitop0, 0
413+
%mask = and i32 %in, 2
414+
%sel0 = select i1 %cmp0, i32 0, i32 73
415+
%sel1 = mul i32 %mask, 72
416+
%out = or disjoint i32 %sel0, %sel1
417+
ret i32 %out
418+
}
419+
420+
define i32 @add_select_cmp_and_mul_mismatch(i32 %in) {
421+
; CHECK-LABEL: @add_select_cmp_and_mul_mismatch(
422+
; CHECK-NEXT: [[TMP1:%.*]] = trunc i32 [[IN:%.*]] to i1
423+
; CHECK-NEXT: [[SEL0:%.*]] = select i1 [[TMP1]], i32 73, i32 0
424+
; CHECK-NEXT: [[MASK1:%.*]] = and i32 [[IN]], 2
425+
; CHECK-NEXT: [[SEL1:%.*]] = mul nuw nsw i32 [[MASK1]], 72
426+
; CHECK-NEXT: [[OUT:%.*]] = or disjoint i32 [[SEL0]], [[SEL1]]
427+
; CHECK-NEXT: ret i32 [[OUT]]
428+
;
429+
%mask0 = and i32 %in, 1
430+
%sel0 = mul i32 %mask0, 73
431+
%mask1 = and i32 %in, 2
432+
%sel1 = mul i32 %mask1, 72
433+
%out = or disjoint i32 %sel0, %sel1
434+
ret i32 %out
435+
}
436+
437+
define i32 @and_mul_non_disjoint(i32 %in) {
438+
; CHECK-LABEL: @and_mul_non_disjoint(
439+
; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IN:%.*]], 2
440+
; CHECK-NEXT: [[OUT:%.*]] = mul nuw nsw i32 [[TMP1]], 72
441+
; CHECK-NEXT: [[MASK1:%.*]] = and i32 [[IN]], 4
442+
; CHECK-NEXT: [[SEL1:%.*]] = mul nuw nsw i32 [[MASK1]], 72
443+
; CHECK-NEXT: [[OUT1:%.*]] = or i32 [[OUT]], [[SEL1]]
444+
; CHECK-NEXT: ret i32 [[OUT1]]
445+
;
446+
%mask0 = and i32 %in, 2
447+
%sel0 = mul i32 %mask0, 72
448+
%mask1 = and i32 %in, 4
449+
%sel1 = mul i32 %mask1, 72
450+
%out = or i32 %sel0, %sel1
451+
ret i32 %out
452+
}
453+
364454
;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
365455
; CONSTSPLAT: {{.*}}
366456
; CONSTVEC: {{.*}}

0 commit comments

Comments
 (0)