Skip to content

[InstCombine] Extend bitmask mul combine to handle independent operands #142503

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

jrbyrnes
Copy link
Contributor

@jrbyrnes jrbyrnes commented Jun 2, 2025

This extends #136013 to capture cases where the combineable bitmask muls are nested under multiple or-disjoints.

This PR is meant for commits starting at 8c403c9

op1 = or-disjoint mul(and (X, C1), D) , reg1
op2 = or-disjoint mul(and (X, C2), D) , reg2
out = or-disjoint op1, op2

->

temp1 = or-disjoint reg1, reg2
out = or-disjoint mul(and (X, (C1 + C2)), D), temp1

Case1: https://alive2.llvm.org/ce/z/dHApyV
Case2: https://alive2.llvm.org/ce/z/Jz-Nag
Case3: https://alive2.llvm.org/ce/z/3xBnEV

@jrbyrnes jrbyrnes requested review from arsenm and dtcxzyw June 2, 2025 23:29
@jrbyrnes jrbyrnes requested a review from nikic as a code owner June 2, 2025 23:29
@llvmbot llvmbot added llvm:instcombine Covers the InstCombine, InstSimplify and AggressiveInstCombine passes llvm:transforms labels Jun 2, 2025
@llvmbot
Copy link
Member

llvmbot commented Jun 2, 2025

@llvm/pr-subscribers-llvm-transforms

Author: Jeffrey Byrnes (jrbyrnes)

Changes

This extends #136013 to capture cases where the combineable bitmask muls are nested under multiple or-disjoints.

This PR is meant for commits starting at 8c403c9

op1 = or-disjoint mul(and (X, C1), D) , reg1
op2 = or-disjoint mul(and (X, C2), D) , reg2
out = or-disjoint op1, op2

->

temp1 = or-disjoint reg1, reg2
out = or-disjoint mul(and (X, (C1 + C2)), D), temp1

Case1: https://alive2.llvm.org/ce/z/dHApyV
Case2: https://alive2.llvm.org/ce/z/Jz-Nag
Case3: https://alive2.llvm.org/ce/z/3xBnEV


Full diff: https://github.com/llvm/llvm-project/pull/142503.diff

2 Files Affected:

  • (modified) llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp (+145-42)
  • (modified) llvm/test/Transforms/InstCombine/or-bitmask.ll (+136-13)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 59b46ebdb72e2..231b980506744 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -3593,6 +3593,111 @@ static Value *foldOrOfInversions(BinaryOperator &I,
   return nullptr;
 }
 
+// A decomposition of ((A & N) ? 0 : N * C) . Where X = A, Factor = C, Mask = N.
+// The NUW / NSW bools
+// Note that we can decompose equivalent forms of this expression (e.g. ((A & N)
+// * C))
+struct DecomposedBitMaskMul {
+  Value *X;
+  APInt Factor;
+  APInt Mask;
+  bool NUW;
+  bool NSW;
+
+  bool isCombineableWith(DecomposedBitMaskMul Other) {
+    return X == Other.X && (Mask & Other.Mask).isZero() &&
+           Factor == Other.Factor;
+  }
+};
+
+static std::optional<DecomposedBitMaskMul> matchBitmaskMul(Value *V) {
+  Instruction *Op = dyn_cast<Instruction>(V);
+  if (!Op)
+    return std::nullopt;
+
+  Value *MulOp = nullptr;
+  const APInt *MulConst = nullptr;
+
+  // Decompose (A & N) * C) into BitMaskMul
+  if (match(Op, m_Mul(m_Value(MulOp), m_APInt(MulConst)))) {
+    Value *Original = nullptr;
+    const APInt *Mask = nullptr;
+    if (MulConst->isZero())
+      return std::nullopt;
+
+    if (match(MulOp, m_And(m_Value(Original), m_APInt(Mask)))) {
+      if (Mask->isZero())
+        return std::nullopt;
+      return std::optional<DecomposedBitMaskMul>(
+          {Original, *MulConst, *Mask,
+           cast<BinaryOperator>(Op)->hasNoUnsignedWrap(),
+           cast<BinaryOperator>(Op)->hasNoSignedWrap()});
+    }
+    return std::nullopt;
+  }
+
+  Value *Cond = nullptr;
+  const APInt *EqZero = nullptr, *NeZero = nullptr;
+
+  // Decompose ((A & N) ? 0 : N * C) into BitMaskMul
+  if (match(Op, m_Select(m_Value(Cond), m_APInt(EqZero), m_APInt(NeZero)))) {
+    auto ICmpDecompose =
+        decomposeBitTest(Cond, /*LookThruTrunc=*/true,
+                         /*AllowNonZeroC=*/false, /*DecomposeBitMask=*/true);
+    if (!ICmpDecompose.has_value())
+      return std::nullopt;
+
+    if (ICmpDecompose->Pred == ICmpInst::ICMP_NE)
+      std::swap(EqZero, NeZero);
+
+    if (!EqZero->isZero() || NeZero->isZero())
+      return std::nullopt;
+
+    if (!ICmpInst::isEquality(ICmpDecompose->Pred) ||
+        !ICmpDecompose->C.isZero() || !ICmpDecompose->Mask.isPowerOf2() ||
+        ICmpDecompose->Mask.isZero() ||
+        NeZero->getBitWidth() != ICmpDecompose->Mask.getBitWidth())
+      return std::nullopt;
+
+    if (!NeZero->urem(ICmpDecompose->Mask).isZero())
+      return std::nullopt;
+
+    return std::optional<DecomposedBitMaskMul>(
+        {ICmpDecompose->X, NeZero->udiv(ICmpDecompose->Mask),
+         ICmpDecompose->Mask, /*NUW=*/false, /*NSW=*/false});
+  }
+
+  return std::nullopt;
+}
+
+using CombinedBitmaskMul =
+    std::pair<std::optional<DecomposedBitMaskMul>, Value *>;
+
+static CombinedBitmaskMul matchCombinedBitmaskMul(Value *V) {
+  auto DecompBitMaskMul = matchBitmaskMul(V);
+  if (DecompBitMaskMul)
+    return {DecompBitMaskMul, nullptr};
+
+  // Otherwise, check the operands of V for bitmaskmul pattern
+  auto BOp = dyn_cast<BinaryOperator>(V);
+  if (!BOp)
+    return {std::nullopt, nullptr};
+
+  auto Disj = dyn_cast<PossiblyDisjointInst>(BOp);
+  if (!Disj || !Disj->isDisjoint())
+    return {std::nullopt, nullptr};
+
+  auto DecompBitMaskMul0 = matchBitmaskMul(BOp->getOperand(0));
+  if (DecompBitMaskMul0)
+    return {DecompBitMaskMul0, BOp->getOperand(1)};
+
+  auto DecompBitMaskMul1 = matchBitmaskMul(BOp->getOperand(1));
+  if (DecompBitMaskMul1)
+    return {DecompBitMaskMul1, BOp->getOperand(0)};
+
+  return {std::nullopt, nullptr};
+}
+
 // FIXME: We use commutative matchers (m_c_*) for some, but not all, matches
 // here. We should standardize that construct where it is needed or choose some
 // other way to ensure that commutated variants of patterns are not missed.
@@ -3675,49 +3780,47 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
                                    /*NSW=*/true, /*NUW=*/true))
       return R;
 
-    Value *Cond0 = nullptr, *Cond1 = nullptr;
-    const APInt *Op0Eq = nullptr, *Op0Ne = nullptr;
-    const APInt *Op1Eq = nullptr, *Op1Ne = nullptr;
-
-    //  (!(A & N) ? 0 : N * C) + (!(A & M) ? 0 : M * C) -> A & (N + M) * C
-    if (match(I.getOperand(0),
-              m_Select(m_Value(Cond0), m_APInt(Op0Eq), m_APInt(Op0Ne))) &&
-        match(I.getOperand(1),
-              m_Select(m_Value(Cond1), m_APInt(Op1Eq), m_APInt(Op1Ne)))) {
-
-      auto LHSDecompose =
-          decomposeBitTest(Cond0, /*LookThruTrunc=*/true,
-                           /*AllowNonZeroC=*/false, /*DecomposeAnd=*/true);
-      auto RHSDecompose =
-          decomposeBitTest(Cond1, /*LookThruTrunc=*/true,
-                           /*AllowNonZeroC=*/false, /*DecomposeAnd=*/true);
-
-      if (LHSDecompose && RHSDecompose && LHSDecompose->X == RHSDecompose->X &&
-          RHSDecompose->Mask.isPowerOf2() && LHSDecompose->Mask.isPowerOf2() &&
-          LHSDecompose->Mask != RHSDecompose->Mask &&
-          LHSDecompose->Mask.getBitWidth() == Op0Ne->getBitWidth() &&
-          RHSDecompose->Mask.getBitWidth() == Op1Ne->getBitWidth()) {
-        assert(Op0Ne->getBitWidth() == Op1Ne->getBitWidth());
-        assert(ICmpInst::isEquality(LHSDecompose->Pred));
-        if (LHSDecompose->Pred == ICmpInst::ICMP_NE)
-          std::swap(Op0Eq, Op0Ne);
-        if (RHSDecompose->Pred == ICmpInst::ICMP_NE)
-          std::swap(Op1Eq, Op1Ne);
-
-        if (!Op0Ne->isZero() && !Op1Ne->isZero() && Op0Eq->isZero() &&
-            Op1Eq->isZero() && Op0Ne->urem(LHSDecompose->Mask).isZero() &&
-            Op1Ne->urem(RHSDecompose->Mask).isZero() &&
-            Op0Ne->udiv(LHSDecompose->Mask) ==
-                Op1Ne->udiv(RHSDecompose->Mask)) {
-          auto NewAnd = Builder.CreateAnd(
-              LHSDecompose->X,
-              ConstantInt::get(LHSDecompose->X->getType(),
-                               (LHSDecompose->Mask + RHSDecompose->Mask)));
-
-          return BinaryOperator::CreateMul(
-              NewAnd, ConstantInt::get(NewAnd->getType(),
-                                       Op0Ne->udiv(LHSDecompose->Mask)));
+    // (!(A & N) ? 0 : N * C) + (!(A & M) ? 0 : M * C) -> A & (N + M) * C
+    // This also accepts the equivalent mul form of (A & N) ? 0 : N * C)
+    // expressions i.e. (A & N) * C
+    CombinedBitmaskMul Decomp1 = matchCombinedBitmaskMul(I.getOperand(1));
+    auto BMDecomp1 = Decomp1.first;
+
+    if (BMDecomp1) {
+      CombinedBitmaskMul Decomp0 = matchCombinedBitmaskMul(I.getOperand(0));
+      auto BMDecomp0 = Decomp0.first;
+
+      if (BMDecomp0 && BMDecomp0->isCombineableWith(*BMDecomp1)) {
+        auto NewAnd = Builder.CreateAnd(
+            BMDecomp0->X,
+            ConstantInt::get(BMDecomp0->X->getType(),
+                             (BMDecomp0->Mask + BMDecomp1->Mask)));
+
+        BinaryOperator *Combined = cast<BinaryOperator>(Builder.CreateMul(
+            NewAnd, ConstantInt::get(NewAnd->getType(), BMDecomp1->Factor)));
+
+        Combined->setHasNoUnsignedWrap(BMDecomp0->NUW && BMDecomp1->NUW);
+        Combined->setHasNoSignedWrap(BMDecomp0->NSW && BMDecomp1->NSW);
+
+        // If our tree has indepdent or-disjoint operands, bring them in.
+        auto OtherOp0 = Decomp0.second;
+        auto OtherOp1 = Decomp1.second;
+
+        if (OtherOp0 || OtherOp1) {
+          Value *OtherOp;
+          if (OtherOp0 && OtherOp1) {
+            OtherOp = Builder.CreateOr(OtherOp0, OtherOp1);
+            cast<PossiblyDisjointInst>(OtherOp)->setIsDisjoint(true);
+          } else {
+            OtherOp = OtherOp0 ? OtherOp0 : OtherOp1;
+          }
+          Combined = cast<BinaryOperator>(Builder.CreateOr(Combined, OtherOp));
+          cast<PossiblyDisjointInst>(Combined)->setIsDisjoint(true);
         }
+
+        // Caller expects detached instruction
+        Combined->removeFromParent();
+        return Combined;
       }
     }
   }
diff --git a/llvm/test/Transforms/InstCombine/or-bitmask.ll b/llvm/test/Transforms/InstCombine/or-bitmask.ll
index 3b482dc1794db..8a2e328fa95cb 100644
--- a/llvm/test/Transforms/InstCombine/or-bitmask.ll
+++ b/llvm/test/Transforms/InstCombine/or-bitmask.ll
@@ -36,13 +36,9 @@ define i32 @add_select_cmp_and2(i32 %in) {
 
 define i32 @add_select_cmp_and3(i32 %in) {
 ; CHECK-LABEL: @add_select_cmp_and3(
-; CHECK-NEXT:    [[TMP1:%.*]] = and i32 [[IN:%.*]], 3
-; CHECK-NEXT:    [[TEMP:%.*]] = mul nuw nsw i32 [[TMP1]], 72
-; CHECK-NEXT:    [[BITOP2:%.*]] = and i32 [[IN]], 4
-; CHECK-NEXT:    [[CMP2:%.*]] = icmp eq i32 [[BITOP2]], 0
-; CHECK-NEXT:    [[SEL2:%.*]] = select i1 [[CMP2]], i32 0, i32 288
-; CHECK-NEXT:    [[OUT:%.*]] = or disjoint i32 [[TEMP]], [[SEL2]]
-; CHECK-NEXT:    ret i32 [[OUT]]
+; CHECK-NEXT:    [[TMP1:%.*]] = and i32 [[IN:%.*]], 7
+; CHECK-NEXT:    [[TEMP1:%.*]] = mul nuw nsw i32 [[TMP1]], 72
+; CHECK-NEXT:    ret i32 [[TEMP1]]
 ;
   %bitop0 = and i32 %in, 1
   %cmp0 = icmp eq i32 %bitop0, 0
@@ -60,12 +56,9 @@ define i32 @add_select_cmp_and3(i32 %in) {
 
 define i32 @add_select_cmp_and4(i32 %in) {
 ; CHECK-LABEL: @add_select_cmp_and4(
-; CHECK-NEXT:    [[TMP1:%.*]] = and i32 [[IN:%.*]], 3
-; CHECK-NEXT:    [[OUT:%.*]] = mul nuw nsw i32 [[TMP1]], 72
-; CHECK-NEXT:    [[TMP2:%.*]] = and i32 [[IN]], 12
-; CHECK-NEXT:    [[TEMP3:%.*]] = mul nuw nsw i32 [[TMP2]], 72
-; CHECK-NEXT:    [[OUT1:%.*]] = or disjoint i32 [[OUT]], [[TEMP3]]
-; CHECK-NEXT:    ret i32 [[OUT1]]
+; CHECK-NEXT:    [[TMP2:%.*]] = and i32 [[IN:%.*]], 15
+; CHECK-NEXT:    [[TEMP2:%.*]] = mul nuw nsw i32 [[TMP2]], 72
+; CHECK-NEXT:    ret i32 [[TEMP2]]
 ;
   %bitop0 = and i32 %in, 1
   %cmp0 = icmp eq i32 %bitop0, 0
@@ -361,6 +354,136 @@ define i64 @mask_select_types_1(i64 %in) {
   ret i64 %out
 }
 
+define i32 @add_select_cmp_mixed1(i32 %in) {
+; CHECK-LABEL: @add_select_cmp_mixed1(
+; CHECK-NEXT:    [[TMP1:%.*]] = and i32 [[IN:%.*]], 3
+; CHECK-NEXT:    [[OUT:%.*]] = mul nuw nsw i32 [[TMP1]], 72
+; CHECK-NEXT:    ret i32 [[OUT]]
+;
+  %mask = and i32 %in, 1
+  %sel0 = mul i32 %mask, 72
+  %bitop1 = and i32 %in, 2
+  %cmp1 = icmp eq i32 %bitop1, 0
+  %sel1 = select i1 %cmp1, i32 0, i32 144
+  %out = or disjoint i32 %sel0, %sel1
+  ret i32 %out
+}
+
+define i32 @add_select_cmp_mixed2(i32 %in) {
+; CHECK-LABEL: @add_select_cmp_mixed2(
+; CHECK-NEXT:    [[TMP1:%.*]] = and i32 [[IN:%.*]], 3
+; CHECK-NEXT:    [[OUT:%.*]] = mul nuw nsw i32 [[TMP1]], 72
+; CHECK-NEXT:    ret i32 [[OUT]]
+;
+  %bitop0 = and i32 %in, 1
+  %cmp0 = icmp eq i32 %bitop0, 0
+  %mask = and i32 %in, 2
+  %sel0 = select i1 %cmp0, i32 0, i32 72
+  %sel1 = mul i32 %mask, 72
+  %out = or disjoint i32 %sel0, %sel1
+  ret i32 %out
+}
+
+define i32 @add_select_cmp_and_mul(i32 %in) {
+; CHECK-LABEL: @add_select_cmp_and_mul(
+; CHECK-NEXT:    [[TMP1:%.*]] = and i32 [[IN:%.*]], 3
+; CHECK-NEXT:    [[OUT:%.*]] = mul nuw nsw i32 [[TMP1]], 72
+; CHECK-NEXT:    ret i32 [[OUT]]
+;
+  %mask0 = and i32 %in, 1
+  %sel0 = mul i32 %mask0, 72
+  %mask1 = and i32 %in, 2
+  %sel1 = mul i32 %mask1, 72
+  %out = or disjoint i32 %sel0, %sel1
+  ret i32 %out
+}
+
+define i32 @add_select_cmp_mixed2_mismatch(i32 %in) {
+; CHECK-LABEL: @add_select_cmp_mixed2_mismatch(
+; CHECK-NEXT:    [[BITOP0:%.*]] = and i32 [[IN:%.*]], 1
+; CHECK-NEXT:    [[CMP0:%.*]] = icmp eq i32 [[BITOP0]], 0
+; CHECK-NEXT:    [[MASK:%.*]] = and i32 [[IN]], 2
+; CHECK-NEXT:    [[SEL0:%.*]] = select i1 [[CMP0]], i32 0, i32 73
+; CHECK-NEXT:    [[SEL1:%.*]] = mul nuw nsw i32 [[MASK]], 72
+; CHECK-NEXT:    [[OUT:%.*]] = or disjoint i32 [[SEL0]], [[SEL1]]
+; CHECK-NEXT:    ret i32 [[OUT]]
+;
+  %bitop0 = and i32 %in, 1
+  %cmp0 = icmp eq i32 %bitop0, 0
+  %mask = and i32 %in, 2
+  %sel0 = select i1 %cmp0, i32 0, i32 73
+  %sel1 = mul i32 %mask, 72
+  %out = or disjoint i32 %sel0, %sel1
+  ret i32 %out
+}
+
+define i32 @add_select_cmp_and_mul_mismatch(i32 %in) {
+; CHECK-LABEL: @add_select_cmp_and_mul_mismatch(
+; CHECK-NEXT:    [[TMP1:%.*]] = trunc i32 [[IN:%.*]] to i1
+; CHECK-NEXT:    [[SEL0:%.*]] = select i1 [[TMP1]], i32 73, i32 0
+; CHECK-NEXT:    [[MASK1:%.*]] = and i32 [[IN]], 2
+; CHECK-NEXT:    [[SEL1:%.*]] = mul nuw nsw i32 [[MASK1]], 72
+; CHECK-NEXT:    [[OUT:%.*]] = or disjoint i32 [[SEL0]], [[SEL1]]
+; CHECK-NEXT:    ret i32 [[OUT]]
+;
+  %mask0 = and i32 %in, 1
+  %sel0 = mul i32 %mask0, 73
+  %mask1 = and i32 %in, 2
+  %sel1 = mul i32 %mask1, 72
+  %out = or disjoint i32 %sel0, %sel1
+  ret i32 %out
+}
+
+define i32 @unrelated_ops(i32 %in, i32 %in2) {
+; CHECK-LABEL: @unrelated_ops(
+; CHECK-NEXT:    [[TMP1:%.*]] = and i32 [[IN:%.*]], 15
+; CHECK-NEXT:    [[TMP2:%.*]] = mul nuw nsw i32 [[TMP1]], 72
+; CHECK-NEXT:    [[OUT:%.*]] = or disjoint i32 [[TMP2]], [[IN2:%.*]]
+; CHECK-NEXT:    ret i32 [[OUT]]
+;
+  %1 = and i32 %in, 3
+  %temp = mul nuw nsw i32 %1, 72
+  %2 = and i32 %in, 12
+  %temp2 = mul nuw nsw i32 %2, 72
+  %temp3 = or disjoint i32 %in2, %temp2
+  %out = or disjoint i32 %temp, %temp3
+  ret i32 %out
+}
+
+define i32 @unrelated_ops1(i32 %in, i32 %in2) {
+; CHECK-LABEL: @unrelated_ops1(
+; CHECK-NEXT:    [[TMP1:%.*]] = and i32 [[IN:%.*]], 15
+; CHECK-NEXT:    [[TMP2:%.*]] = mul nuw nsw i32 [[TMP1]], 72
+; CHECK-NEXT:    [[OUT:%.*]] = or disjoint i32 [[TMP2]], [[IN2:%.*]]
+; CHECK-NEXT:    ret i32 [[OUT]]
+;
+  %1 = and i32 %in, 3
+  %temp = mul nuw nsw i32 %1, 72
+  %2 = and i32 %in, 12
+  %temp2 = mul nuw nsw i32 %2, 72
+  %temp3 = or disjoint i32 %in2, %temp
+  %out = or disjoint i32 %temp3, %temp2
+  ret i32 %out
+}
+
+define i32 @unrelated_ops2(i32 %in, i32 %in2, i32 %in3) {
+; CHECK-LABEL: @unrelated_ops2(
+; CHECK-NEXT:    [[TMP1:%.*]] = and i32 [[IN:%.*]], 15
+; CHECK-NEXT:    [[TMP2:%.*]] = mul nuw nsw i32 [[TMP1]], 72
+; CHECK-NEXT:    [[TMP3:%.*]] = or disjoint i32 [[IN3:%.*]], [[IN2:%.*]]
+; CHECK-NEXT:    [[OUT:%.*]] = or disjoint i32 [[TMP2]], [[TMP3]]
+; CHECK-NEXT:    ret i32 [[OUT]]
+;
+  %1 = and i32 %in, 3
+  %temp = mul nuw nsw i32 %1, 72
+  %temp3 = or disjoint i32 %temp, %in3
+  %2 = and i32 %in, 12
+  %temp2 = mul nuw nsw i32 %2, 72
+  %temp4 = or disjoint i32 %in2, %temp2
+  %out = or disjoint i32 %temp3, %temp4
+  ret i32 %out
+}
+
 ;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
 ; CONSTSPLAT: {{.*}}
 ; CONSTVEC: {{.*}}

Change-Id: Ife1a010d2ae6df40549a6c73f7b893948befa3be
@jrbyrnes jrbyrnes force-pushed the ICAndSelExtendCase1 branch from 8c403c9 to e9996d1 Compare June 12, 2025 15:47
@jrbyrnes
Copy link
Contributor Author

Rebase for #135274

Change-Id: I879acdf0b17a7110286c6c375410300611c468eb
@jrbyrnes
Copy link
Contributor Author

Ping -- any concerns?

@dtcxzyw
Copy link
Member

dtcxzyw commented Jun 19, 2025

op1 = or-disjoint mul(and (X, C1), D) , reg1
op2 = or-disjoint mul(and (X, C2), D) , reg2
out = or-disjoint op1, op2
->
temp1 = or-disjoint reg1, reg2
out = or-disjoint mul(and (X, (C1 + C2)), D), temp1

Can we perform the reassociation first in ReassociatePass or InstCombinePass to get

or-disjoint (mul(and (X, C1), D), mul(and (X, C2), D))

Then this pattern should be handled by #136013.

Change-Id: Ib86e8ed347ef60948c3e4cb44c5fab1c3667afc6
@jrbyrnes
Copy link
Contributor Author

jrbyrnes commented Jun 23, 2025

Okay -- the latest version reassociates chains with independent operands into a more combineable pattern (rather than doing the combine itself).

We should only be doing this reassociation if we can find a bitmask mul operand on both sides, otherwise we may end up infinitely reassociating. Since we require the DecomposedBitMaskMul analysis, I've decided to do the reassocation in InstCombineAndOrXor where the analysis already exists.

The implementation does seem a little simpler, but we now rely on future iterations of InstCombine to perform the optimization. Moreover, we still need to do the analysis to check for the conditions of reassociation. Thus, it is somewhat unclear if this change is a net positive.

@dtcxzyw
Copy link
Member

dtcxzyw commented Jun 24, 2025

Moreover, we still need to do the analysis to check for the conditions of reassociation. Thus, it is somewhat unclear if this change is a net positive.

I don't think it is necessary to check the specific pattern. As these two muls share the same constant factor, it is always profitable to reassociate the tree. ReassociatePass performs the reassociate for add, but not for or disjoint: https://godbolt.org/z/TMcYTdd5T It may be better to handle this pattern in ReassociatePass.

Change-Id: I950ee32ec053430fd51c7fd52645fe52e9e6ecff
@jrbyrnes
Copy link
Contributor Author

jrbyrnes commented Jun 25, 2025

Moreover, we still need to do the analysis to check for the conditions of reassociation. Thus, it is somewhat unclear if this change is a net positive.

I don't think it is necessary to check the specific pattern. As these two muls share the same constant factor, it is always profitable to reassociate the tree. ReassociatePass performs the reassociate for add, but not for or disjoint: https://godbolt.org/z/TMcYTdd5T It may be better to handle this pattern in ReassociatePass.

Sorry, did not see your comment originally.

I think we can extend the reassociation pass to handle cases like https://godbolt.org/z/7vWePe1xK , but I don't think this will be sufficient for my usecase.

The problem is #133139 (comment)

In other words:

%masked = and i32 %val, CONSTANT
%cond = icmp eq i32 %masked, 0
%out = select i1 %cond, i32 0, i32 CONSTANT * FACTOR

is the canonical form of

%masked = and %val, CONSTANT
%out = mul i32 %masked, FACTOR

The premise of the work starting at #136013 is that if we are combining multiple of these forms under an or-disjoint, it is preferable to have a single and->mul. This PR and subsequent PRs (including the PR I'm currently commenting on) have transformed the select form into and->mul given conditions which make the transform unarguably beneficial, however, it is still very much possible that we have the select form.

Thus, to be sufficient for my usecase, it seems that reassociation must be able to handle the case where one or more of the leaf forms are and->icmp->select (instead of and->mul). Presently, this is not done even for the add case https://godbolt.org/z/W1YxhrWa9 .

Potentially, we could make this extension to the add case, then extend the or-disjoint to support all forms. However, if possible I would prefer to bring in something along the lines of what this PR is currently doing, potentially to be deleted after we improve reassociation -- this work has already been stalled for some time, and going the reassociation route seems like it would open a whole new body of work.

Change-Id: I875b9fac4749b3f391efce47f8d3b9e2004de8c2
@jrbyrnes
Copy link
Contributor Author

Thus, to be sufficient for my usecase, it seems that reassociation must be able to handle the case where one or more of the leaf forms are and->icmp->select (instead of and->mul)

I see that I did not make this explicitly clear in the description, nor did I cover it in any tests. Sorry about that.

@nikic
Copy link
Contributor

nikic commented Jun 25, 2025

The problem is #133139 (comment)

In other words:

%masked = and i32 %val, CONSTANT
%cond = icmp eq i32 %masked, 0
%out = select i1 %cond, i32 0, i32 CONSTANT * FACTOR

is the canonical form of

%masked = and %val, CONSTANT
%out = mul i32 %masked, FACTOR

The premise of the work starting at #136013 is that if we are combining multiple of these forms under an or-disjoint, it is preferable to have a single and->mul. This PR and subsequent PRs (including the PR I'm currently commenting on) have transformed the select form into and->mul given conditions which make the transform unarguably beneficial, however, it is still very much possible that we have the select form.

I wonder whether we should reconsider that decision. The transform is reversible, so if necessary we can have an undo transform from and+mul to and+icmp+select in the backend.


On this PR, I have two notes:

  • This is something I missed in the previous iterations, but it's even more important now: We need one-use checks to ensure we're not emitting more instructions if some of the old ones have extra uses.
  • I don't really like the implementation approach here. We do something similar already in reassociateBooleanAndOr and I think should follow that general pattern. Which would basically be something like this:
if (Value *Res = tryFold(Op0, Op1))
  return Res;
if (match(Op0, m_OneUse(m_DisjointOr(m_Value(X), m_Value(Y)))) {
  if (Value *Res = tryFold(X, Op1))
    return Builder.CreateDisjointOr(Res, Y);
  if (Value *Res = tryFold(Y, Op1);
    return Builder.CreateDisjointOr(Res, X);
}
if (match(Op1, m_OneUse(m_DisjointOr(m_Value(X), m_Value(Y)))) {
  if (Value *Res = tryFold(X, Op0))
    return Builder.CreateDisjointOr(Res, Y);
  if (Value *Res = tryFold(Y, Op0);
    return Builder.CreateDisjointOr(Res, X);
}

The nice thing about this is that if we have more folds where we think it is worthwhile to perform them via reassociation, it's just a matter of adding them to tryFold(), we don't need to invent special handling for each one of them.

@jrbyrnes
Copy link
Contributor Author

jrbyrnes commented Jun 26, 2025

  • I don't really like the implementation approach here. We do something similar already in reassociateBooleanAndOr and I think should follow that general pattern.

Looking into the details -- turning out to be a larger refactor than I had thought. How do you feel about providing the functionality using the current implementation, then refactoring into reassociateBooleanAndOr as an extension?

EDIT: posted what I have -- can revert / split-up as needed.

Change-Id: Ie86d0e58f7fdb2c0489d3dee3a41ef3911f9477b
Copy link

github-actions bot commented Jun 26, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

Change-Id: I2c418b8e5bf7fed050ee77515a73fa4368a1ea7d
@dtcxzyw
Copy link
Member

dtcxzyw commented Jun 27, 2025

Looking into the details -- turning out to be a larger refactor than I had thought. How do you feel about providing the functionality using the current implementation, then refactoring into reassociateBooleanAndOr as an extension?

You don't need to refactor reassociateBooleanAndOr. It also handles boolean logical and/or so it is not suitable for non-boolean or instructions. I guess nikic meant to add a new helper like reassociateDisjointOr.

@nikic
Copy link
Contributor

nikic commented Jun 27, 2025

Yes, I just meant the general code pattern, not to reuse that specific helper.

Change-Id: I172be21fe78361f4520a893d6c97c422accbf13f
@jrbyrnes
Copy link
Contributor Author

Thanks for the clarification.

The latest introduces reassociateDisjointOr . This method takes the operands of an or-disjoint and checks all permutations of child trees when one or both of the operands are themselves or-disjoints.

These permutations are checked for foldDisjointOr patterns -- this currently just checks foldBitmaskMul but can easily be extended.

I've tried to follow the implementation and caller pattern for reassociateBooleanAndOr but I found the structure to be quite confusing. In particular, I found it strange that the caller was responsible for deconstructing a tree, and the method was responsible for permuting it -- this created some strange contracts between caller/method, especially in the case where both operands were or-disjoints.

Thus, reassociateDisjointOr is responsible for both deconstructing the tree and permuting it, and the caller is just responsible for providing the operands of or-disjoint.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
llvm:instcombine Covers the InstCombine, InstSimplify and AggressiveInstCombine passes llvm:transforms
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants