Skip to content

[CmpInstAnalysis] Decompose icmp eq (and x, C) C2 #136367

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 24, 2025
Merged

Conversation

jrbyrnes
Copy link
Contributor

This type of decomposition is used in multiple places already. Adding it to CmpInstAnalysis reduces code duplication.

Change-Id: I1dd786a4652ccd2e486db6903e16e58ffa1a7959
@jrbyrnes jrbyrnes requested review from arsenm, andjo403 and dtcxzyw April 18, 2025 20:35
@jrbyrnes jrbyrnes requested a review from nikic as a code owner April 18, 2025 20:36
@llvmbot llvmbot added llvm:instcombine Covers the InstCombine, InstSimplify and AggressiveInstCombine passes llvm:analysis Includes value tracking, cost tables and constant folding llvm:transforms labels Apr 18, 2025
@llvmbot
Copy link
Member

llvmbot commented Apr 18, 2025

@llvm/pr-subscribers-llvm-transforms

Author: Jeffrey Byrnes (jrbyrnes)

Changes

This type of decomposition is used in multiple places already. Adding it to CmpInstAnalysis reduces code duplication.


Full diff: https://github.com/llvm/llvm-project/pull/136367.diff

5 Files Affected:

  • (modified) llvm/include/llvm/Analysis/CmpInstAnalysis.h (+9-5)
  • (modified) llvm/lib/Analysis/CmpInstAnalysis.cpp (+26-8)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp (+4-10)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp (+15-17)
  • (modified) llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp (+5-9)
diff --git a/llvm/include/llvm/Analysis/CmpInstAnalysis.h b/llvm/include/llvm/Analysis/CmpInstAnalysis.h
index aeda58ac7535d..a796a38feff88 100644
--- a/llvm/include/llvm/Analysis/CmpInstAnalysis.h
+++ b/llvm/include/llvm/Analysis/CmpInstAnalysis.h
@@ -95,24 +95,28 @@ namespace llvm {
   /// Represents the operation icmp (X & Mask) pred C, where pred can only be
   /// eq or ne.
   struct DecomposedBitTest {
-    Value *X;
+    Value *X = nullptr;
     CmpInst::Predicate Pred;
     APInt Mask;
     APInt C;
   };
 
   /// Decompose an icmp into the form ((X & Mask) pred C) if possible.
-  /// Unless \p AllowNonZeroC is true, C will always be 0.
+  /// Unless \p AllowNonZeroC is true, C will always be 0. If \p
+  /// DecomposeBitMask is specified, then, for equality predicates, this will
+  /// decompose bitmasking (e.g. implemented via `and`).
   std::optional<DecomposedBitTest>
   decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
-                       bool LookThroughTrunc = true,
-                       bool AllowNonZeroC = false);
+                       bool LookThroughTrunc = true, bool AllowNonZeroC = false,
+                       bool DecomposeBitMask = false);
 
   /// Decompose an icmp into the form ((X & Mask) pred C) if
   /// possible. Unless \p AllowNonZeroC is true, C will always be 0.
+  /// If \p DecomposeBitMask is specified, then, for equality predicates, this
+  /// will decompose bitmasking (e.g. implemented via `and`).
   std::optional<DecomposedBitTest>
   decomposeBitTest(Value *Cond, bool LookThroughTrunc = true,
-                   bool AllowNonZeroC = false);
+                   bool AllowNonZeroC = false, bool DecomposeBitMask = false);
 
 } // end namespace llvm
 
diff --git a/llvm/lib/Analysis/CmpInstAnalysis.cpp b/llvm/lib/Analysis/CmpInstAnalysis.cpp
index 5c0d1dd1c74b0..6714088097127 100644
--- a/llvm/lib/Analysis/CmpInstAnalysis.cpp
+++ b/llvm/lib/Analysis/CmpInstAnalysis.cpp
@@ -75,11 +75,13 @@ Constant *llvm::getPredForFCmpCode(unsigned Code, Type *OpTy,
 
 std::optional<DecomposedBitTest>
 llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
-                           bool LookThruTrunc, bool AllowNonZeroC) {
+                           bool LookThruTrunc, bool AllowNonZeroC,
+                           bool DecomposeBitMask) {
   using namespace PatternMatch;
 
   const APInt *OrigC;
-  if (!ICmpInst::isRelational(Pred) || !match(RHS, m_APIntAllowPoison(OrigC)))
+  if ((ICmpInst::isEquality(Pred) && !DecomposeBitMask) ||
+      !match(RHS, m_APIntAllowPoison(OrigC)))
     return std::nullopt;
 
   bool Inverted = false;
@@ -97,9 +99,10 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
   }
 
   DecomposedBitTest Result;
+
   switch (Pred) {
   default:
-    llvm_unreachable("Unexpected predicate");
+    return std::nullopt;
   case ICmpInst::ICMP_SLT: {
     // X < 0 is equivalent to (X & SignMask) != 0.
     if (C.isZero()) {
@@ -128,7 +131,7 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
 
     return std::nullopt;
   }
-  case ICmpInst::ICMP_ULT:
+  case ICmpInst::ICMP_ULT: {
     // X <u 2^n is equivalent to (X & ~(2^n-1)) == 0.
     if (C.isPowerOf2()) {
       Result.Mask = -C;
@@ -147,6 +150,19 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
 
     return std::nullopt;
   }
+  case ICmpInst::ICMP_EQ:
+  case ICmpInst::ICMP_NE: {
+    assert(DecomposeBitMask);
+    const APInt *AndC;
+    Value *AndVal;
+    if (match(LHS, m_And(m_Value(AndVal), m_APIntAllowPoison(AndC)))) {
+      Result = {AndVal /*X*/, Pred /*Pred*/, *AndC /*Mask*/, *OrigC /*C*/};
+      break;
+    }
+
+    return std::nullopt;
+  }
+  }
 
   if (!AllowNonZeroC && !Result.C.isZero())
     return std::nullopt;
@@ -159,15 +175,17 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
     Result.X = X;
     Result.Mask = Result.Mask.zext(X->getType()->getScalarSizeInBits());
     Result.C = Result.C.zext(X->getType()->getScalarSizeInBits());
-  } else {
+  } else if (!Result.X) {
     Result.X = LHS;
   }
 
   return Result;
 }
 
-std::optional<DecomposedBitTest>
-llvm::decomposeBitTest(Value *Cond, bool LookThruTrunc, bool AllowNonZeroC) {
+std::optional<DecomposedBitTest> llvm::decomposeBitTest(Value *Cond,
+                                                        bool LookThruTrunc,
+                                                        bool AllowNonZeroC,
+                                                        bool DecomposeBitMask) {
   using namespace PatternMatch;
   if (auto *ICmp = dyn_cast<ICmpInst>(Cond)) {
     // Don't allow pointers. Splat vectors are fine.
@@ -175,7 +193,7 @@ llvm::decomposeBitTest(Value *Cond, bool LookThruTrunc, bool AllowNonZeroC) {
       return std::nullopt;
     return decomposeBitTestICmp(ICmp->getOperand(0), ICmp->getOperand(1),
                                 ICmp->getPredicate(), LookThruTrunc,
-                                AllowNonZeroC);
+                                AllowNonZeroC, DecomposeBitMask);
   }
   Value *X;
   if (Cond->getType()->isIntOrIntVectorTy(1) &&
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 19bf81137aab7..60461a6e8b338 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -875,22 +875,16 @@ static Value *foldSignedTruncationCheck(ICmpInst *ICmp0, ICmpInst *ICmp1,
                            APInt &UnsetBitsMask) -> bool {
     CmpPredicate Pred = ICmp->getPredicate();
     // Can it be decomposed into  icmp eq (X & Mask), 0  ?
-    auto Res =
-        llvm::decomposeBitTestICmp(ICmp->getOperand(0), ICmp->getOperand(1),
-                                   Pred, /*LookThroughTrunc=*/false);
+    auto Res = llvm::decomposeBitTestICmp(
+        ICmp->getOperand(0), ICmp->getOperand(1), Pred,
+        /*LookThroughTrunc=*/true, /*AllowNonZeroC=*/false,
+        /*DecomposeBitMask=*/true);
     if (Res && Res->Pred == ICmpInst::ICMP_EQ) {
       X = Res->X;
       UnsetBitsMask = Res->Mask;
       return true;
     }
 
-    // Is it  icmp eq (X & Mask), 0  already?
-    const APInt *Mask;
-    if (match(ICmp, m_ICmp(Pred, m_And(m_Value(X), m_APInt(Mask)), m_Zero())) &&
-        Pred == ICmpInst::ICMP_EQ) {
-      UnsetBitsMask = *Mask;
-      return true;
-    }
     return false;
   };
 
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 4bba2f406b4c1..8cdb00fb44a48 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -3730,31 +3730,29 @@ static Value *foldSelectBitTest(SelectInst &Sel, Value *CondVal, Value *TrueVal,
   Value *CmpLHS, *CmpRHS;
 
   if (match(CondVal, m_ICmp(Pred, m_Value(CmpLHS), m_Value(CmpRHS)))) {
-    if (ICmpInst::isEquality(Pred)) {
-      if (!match(CmpRHS, m_Zero()))
-        return nullptr;
+    auto Res = decomposeBitTestICmp(
+        CmpLHS, CmpRHS, Pred, /*LookThroughTrunc=*/true,
+        /*AllowNonZeroC=*/false, /*DecomposeBitMask=*/true);
 
-      V = CmpLHS;
-      const APInt *AndRHS;
-      if (!match(CmpLHS, m_And(m_Value(), m_Power2(AndRHS))))
-        return nullptr;
+    if (!Res)
+      return nullptr;
 
-      AndMask = *AndRHS;
-    } else if (auto Res = decomposeBitTestICmp(CmpLHS, CmpRHS, Pred)) {
-      assert(ICmpInst::isEquality(Res->Pred) && "Not equality test?");
-      AndMask = Res->Mask;
+    V = CmpLHS;
+    AndMask = Res->Mask;
+
+    if (!ICmpInst::isEquality(Pred)) {
       V = Res->X;
       KnownBits Known =
           computeKnownBits(V, /*Depth=*/0, SQ.getWithInstruction(&Sel));
       AndMask &= Known.getMaxValue();
-      if (!AndMask.isPowerOf2())
-        return nullptr;
-
-      Pred = Res->Pred;
       CreateAnd = true;
-    } else {
-      return nullptr;
     }
+
+    Pred = Res->Pred;
+
+    if (!AndMask.isPowerOf2())
+      return nullptr;
+
   } else if (auto *Trunc = dyn_cast<TruncInst>(CondVal)) {
     V = Trunc->getOperand(0);
     AndMask = APInt(V->getType()->getScalarSizeInBits(), 1);
diff --git a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
index eacc67cd5a475..ced1557ec3060 100644
--- a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
@@ -2761,14 +2761,11 @@ static bool detectShiftUntilBitTestIdiom(Loop *CurLoop, Value *&BaseX,
                              m_LoopInvariant(m_Shl(m_One(), m_Value(BitPos)),
                                              CurLoop))));
   };
-  auto MatchConstantBitMask = [&]() {
-    return ICmpInst::isEquality(Pred) && match(CmpRHS, m_Zero()) &&
-           match(CmpLHS, m_And(m_Value(CurrX),
-                               m_CombineAnd(m_Value(BitMask), m_Power2()))) &&
-           (BitPos = ConstantExpr::getExactLogBase2(cast<Constant>(BitMask)));
-  };
+
   auto MatchDecomposableConstantBitMask = [&]() {
-    auto Res = llvm::decomposeBitTestICmp(CmpLHS, CmpRHS, Pred);
+    auto Res = llvm::decomposeBitTestICmp(
+        CmpLHS, CmpRHS, Pred, /*LookThroughTrunc=*/true,
+        /*AllowNonZeroC=*/false, /*DecomposeBitMask=*/true);
     if (Res && Res->Mask.isPowerOf2()) {
       assert(ICmpInst::isEquality(Res->Pred));
       Pred = Res->Pred;
@@ -2780,8 +2777,7 @@ static bool detectShiftUntilBitTestIdiom(Loop *CurLoop, Value *&BaseX,
     return false;
   };
 
-  if (!MatchVariableBitMask() && !MatchConstantBitMask() &&
-      !MatchDecomposableConstantBitMask()) {
+  if (!MatchVariableBitMask() && !MatchDecomposableConstantBitMask()) {
     LLVM_DEBUG(dbgs() << DEBUG_TYPE " Bad backedge comparison.\n");
     return false;
   }

@llvmbot
Copy link
Member

llvmbot commented Apr 18, 2025

@llvm/pr-subscribers-llvm-analysis

Author: Jeffrey Byrnes (jrbyrnes)

Changes

This type of decomposition is used in multiple places already. Adding it to CmpInstAnalysis reduces code duplication.


Full diff: https://github.com/llvm/llvm-project/pull/136367.diff

5 Files Affected:

  • (modified) llvm/include/llvm/Analysis/CmpInstAnalysis.h (+9-5)
  • (modified) llvm/lib/Analysis/CmpInstAnalysis.cpp (+26-8)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp (+4-10)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp (+15-17)
  • (modified) llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp (+5-9)
diff --git a/llvm/include/llvm/Analysis/CmpInstAnalysis.h b/llvm/include/llvm/Analysis/CmpInstAnalysis.h
index aeda58ac7535d..a796a38feff88 100644
--- a/llvm/include/llvm/Analysis/CmpInstAnalysis.h
+++ b/llvm/include/llvm/Analysis/CmpInstAnalysis.h
@@ -95,24 +95,28 @@ namespace llvm {
   /// Represents the operation icmp (X & Mask) pred C, where pred can only be
   /// eq or ne.
   struct DecomposedBitTest {
-    Value *X;
+    Value *X = nullptr;
     CmpInst::Predicate Pred;
     APInt Mask;
     APInt C;
   };
 
   /// Decompose an icmp into the form ((X & Mask) pred C) if possible.
-  /// Unless \p AllowNonZeroC is true, C will always be 0.
+  /// Unless \p AllowNonZeroC is true, C will always be 0. If \p
+  /// DecomposeBitMask is specified, then, for equality predicates, this will
+  /// decompose bitmasking (e.g. implemented via `and`).
   std::optional<DecomposedBitTest>
   decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
-                       bool LookThroughTrunc = true,
-                       bool AllowNonZeroC = false);
+                       bool LookThroughTrunc = true, bool AllowNonZeroC = false,
+                       bool DecomposeBitMask = false);
 
   /// Decompose an icmp into the form ((X & Mask) pred C) if
   /// possible. Unless \p AllowNonZeroC is true, C will always be 0.
+  /// If \p DecomposeBitMask is specified, then, for equality predicates, this
+  /// will decompose bitmasking (e.g. implemented via `and`).
   std::optional<DecomposedBitTest>
   decomposeBitTest(Value *Cond, bool LookThroughTrunc = true,
-                   bool AllowNonZeroC = false);
+                   bool AllowNonZeroC = false, bool DecomposeBitMask = false);
 
 } // end namespace llvm
 
diff --git a/llvm/lib/Analysis/CmpInstAnalysis.cpp b/llvm/lib/Analysis/CmpInstAnalysis.cpp
index 5c0d1dd1c74b0..6714088097127 100644
--- a/llvm/lib/Analysis/CmpInstAnalysis.cpp
+++ b/llvm/lib/Analysis/CmpInstAnalysis.cpp
@@ -75,11 +75,13 @@ Constant *llvm::getPredForFCmpCode(unsigned Code, Type *OpTy,
 
 std::optional<DecomposedBitTest>
 llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
-                           bool LookThruTrunc, bool AllowNonZeroC) {
+                           bool LookThruTrunc, bool AllowNonZeroC,
+                           bool DecomposeBitMask) {
   using namespace PatternMatch;
 
   const APInt *OrigC;
-  if (!ICmpInst::isRelational(Pred) || !match(RHS, m_APIntAllowPoison(OrigC)))
+  if ((ICmpInst::isEquality(Pred) && !DecomposeBitMask) ||
+      !match(RHS, m_APIntAllowPoison(OrigC)))
     return std::nullopt;
 
   bool Inverted = false;
@@ -97,9 +99,10 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
   }
 
   DecomposedBitTest Result;
+
   switch (Pred) {
   default:
-    llvm_unreachable("Unexpected predicate");
+    return std::nullopt;
   case ICmpInst::ICMP_SLT: {
     // X < 0 is equivalent to (X & SignMask) != 0.
     if (C.isZero()) {
@@ -128,7 +131,7 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
 
     return std::nullopt;
   }
-  case ICmpInst::ICMP_ULT:
+  case ICmpInst::ICMP_ULT: {
     // X <u 2^n is equivalent to (X & ~(2^n-1)) == 0.
     if (C.isPowerOf2()) {
       Result.Mask = -C;
@@ -147,6 +150,19 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
 
     return std::nullopt;
   }
+  case ICmpInst::ICMP_EQ:
+  case ICmpInst::ICMP_NE: {
+    assert(DecomposeBitMask);
+    const APInt *AndC;
+    Value *AndVal;
+    if (match(LHS, m_And(m_Value(AndVal), m_APIntAllowPoison(AndC)))) {
+      Result = {AndVal /*X*/, Pred /*Pred*/, *AndC /*Mask*/, *OrigC /*C*/};
+      break;
+    }
+
+    return std::nullopt;
+  }
+  }
 
   if (!AllowNonZeroC && !Result.C.isZero())
     return std::nullopt;
@@ -159,15 +175,17 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
     Result.X = X;
     Result.Mask = Result.Mask.zext(X->getType()->getScalarSizeInBits());
     Result.C = Result.C.zext(X->getType()->getScalarSizeInBits());
-  } else {
+  } else if (!Result.X) {
     Result.X = LHS;
   }
 
   return Result;
 }
 
-std::optional<DecomposedBitTest>
-llvm::decomposeBitTest(Value *Cond, bool LookThruTrunc, bool AllowNonZeroC) {
+std::optional<DecomposedBitTest> llvm::decomposeBitTest(Value *Cond,
+                                                        bool LookThruTrunc,
+                                                        bool AllowNonZeroC,
+                                                        bool DecomposeBitMask) {
   using namespace PatternMatch;
   if (auto *ICmp = dyn_cast<ICmpInst>(Cond)) {
     // Don't allow pointers. Splat vectors are fine.
@@ -175,7 +193,7 @@ llvm::decomposeBitTest(Value *Cond, bool LookThruTrunc, bool AllowNonZeroC) {
       return std::nullopt;
     return decomposeBitTestICmp(ICmp->getOperand(0), ICmp->getOperand(1),
                                 ICmp->getPredicate(), LookThruTrunc,
-                                AllowNonZeroC);
+                                AllowNonZeroC, DecomposeBitMask);
   }
   Value *X;
   if (Cond->getType()->isIntOrIntVectorTy(1) &&
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 19bf81137aab7..60461a6e8b338 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -875,22 +875,16 @@ static Value *foldSignedTruncationCheck(ICmpInst *ICmp0, ICmpInst *ICmp1,
                            APInt &UnsetBitsMask) -> bool {
     CmpPredicate Pred = ICmp->getPredicate();
     // Can it be decomposed into  icmp eq (X & Mask), 0  ?
-    auto Res =
-        llvm::decomposeBitTestICmp(ICmp->getOperand(0), ICmp->getOperand(1),
-                                   Pred, /*LookThroughTrunc=*/false);
+    auto Res = llvm::decomposeBitTestICmp(
+        ICmp->getOperand(0), ICmp->getOperand(1), Pred,
+        /*LookThroughTrunc=*/true, /*AllowNonZeroC=*/false,
+        /*DecomposeBitMask=*/true);
     if (Res && Res->Pred == ICmpInst::ICMP_EQ) {
       X = Res->X;
       UnsetBitsMask = Res->Mask;
       return true;
     }
 
-    // Is it  icmp eq (X & Mask), 0  already?
-    const APInt *Mask;
-    if (match(ICmp, m_ICmp(Pred, m_And(m_Value(X), m_APInt(Mask)), m_Zero())) &&
-        Pred == ICmpInst::ICMP_EQ) {
-      UnsetBitsMask = *Mask;
-      return true;
-    }
     return false;
   };
 
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 4bba2f406b4c1..8cdb00fb44a48 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -3730,31 +3730,29 @@ static Value *foldSelectBitTest(SelectInst &Sel, Value *CondVal, Value *TrueVal,
   Value *CmpLHS, *CmpRHS;
 
   if (match(CondVal, m_ICmp(Pred, m_Value(CmpLHS), m_Value(CmpRHS)))) {
-    if (ICmpInst::isEquality(Pred)) {
-      if (!match(CmpRHS, m_Zero()))
-        return nullptr;
+    auto Res = decomposeBitTestICmp(
+        CmpLHS, CmpRHS, Pred, /*LookThroughTrunc=*/true,
+        /*AllowNonZeroC=*/false, /*DecomposeBitMask=*/true);
 
-      V = CmpLHS;
-      const APInt *AndRHS;
-      if (!match(CmpLHS, m_And(m_Value(), m_Power2(AndRHS))))
-        return nullptr;
+    if (!Res)
+      return nullptr;
 
-      AndMask = *AndRHS;
-    } else if (auto Res = decomposeBitTestICmp(CmpLHS, CmpRHS, Pred)) {
-      assert(ICmpInst::isEquality(Res->Pred) && "Not equality test?");
-      AndMask = Res->Mask;
+    V = CmpLHS;
+    AndMask = Res->Mask;
+
+    if (!ICmpInst::isEquality(Pred)) {
       V = Res->X;
       KnownBits Known =
           computeKnownBits(V, /*Depth=*/0, SQ.getWithInstruction(&Sel));
       AndMask &= Known.getMaxValue();
-      if (!AndMask.isPowerOf2())
-        return nullptr;
-
-      Pred = Res->Pred;
       CreateAnd = true;
-    } else {
-      return nullptr;
     }
+
+    Pred = Res->Pred;
+
+    if (!AndMask.isPowerOf2())
+      return nullptr;
+
   } else if (auto *Trunc = dyn_cast<TruncInst>(CondVal)) {
     V = Trunc->getOperand(0);
     AndMask = APInt(V->getType()->getScalarSizeInBits(), 1);
diff --git a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
index eacc67cd5a475..ced1557ec3060 100644
--- a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
@@ -2761,14 +2761,11 @@ static bool detectShiftUntilBitTestIdiom(Loop *CurLoop, Value *&BaseX,
                              m_LoopInvariant(m_Shl(m_One(), m_Value(BitPos)),
                                              CurLoop))));
   };
-  auto MatchConstantBitMask = [&]() {
-    return ICmpInst::isEquality(Pred) && match(CmpRHS, m_Zero()) &&
-           match(CmpLHS, m_And(m_Value(CurrX),
-                               m_CombineAnd(m_Value(BitMask), m_Power2()))) &&
-           (BitPos = ConstantExpr::getExactLogBase2(cast<Constant>(BitMask)));
-  };
+
   auto MatchDecomposableConstantBitMask = [&]() {
-    auto Res = llvm::decomposeBitTestICmp(CmpLHS, CmpRHS, Pred);
+    auto Res = llvm::decomposeBitTestICmp(
+        CmpLHS, CmpRHS, Pred, /*LookThroughTrunc=*/true,
+        /*AllowNonZeroC=*/false, /*DecomposeBitMask=*/true);
     if (Res && Res->Mask.isPowerOf2()) {
       assert(ICmpInst::isEquality(Res->Pred));
       Pred = Res->Pred;
@@ -2780,8 +2777,7 @@ static bool detectShiftUntilBitTestIdiom(Loop *CurLoop, Value *&BaseX,
     return false;
   };
 
-  if (!MatchVariableBitMask() && !MatchConstantBitMask() &&
-      !MatchDecomposableConstantBitMask()) {
+  if (!MatchVariableBitMask() && !MatchDecomposableConstantBitMask()) {
     LLVM_DEBUG(dbgs() << DEBUG_TYPE " Bad backedge comparison.\n");
     return false;
   }

@jrbyrnes
Copy link
Contributor Author

It's NFC for the current users.

Copy link
Member

@dtcxzyw dtcxzyw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Please wait for additional approval from other reviewers.
BTW this patch should not block #135274. See my previous comment #135274 (comment).

Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Historically the purpose of decomposeBitTestICmp() was to decompose plain icmps into and+icmp form, but I don't see a reason not to support the and+icmp form directly for cases where the distinction does not matter.

Change-Id: I5355b103c34cb5d49868f44e3f8c10b95764a531
@jrbyrnes
Copy link
Contributor Author

BTW this patch should not block #135274. See my previous comment #135274 (comment).

Thanks -- it seems this patch is close to landing, so I'll keep #135274 as it is for now . But if this hasn't landed in the next day or so I will decouple them.

@jrbyrnes
Copy link
Contributor Author

Needs additional approval from @nikic

Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

Pred, /*LookThroughTrunc=*/false);
auto Res = llvm::decomposeBitTestICmp(
ICmp->getOperand(0), ICmp->getOperand(1), Pred,
/*LookThroughTrunc=*/true, /*AllowNonZeroC=*/false,
Copy link
Contributor

@andjo403 andjo403 Apr 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/*LookThroughTrunc=*/true, /*AllowNonZeroC=*/false,
/*LookThroughTrunc=*/false, /*AllowNonZeroC=*/false,

do not think that this shall start to look through trunc as that case have special handling below

Change-Id: Id1749b08b8e4ae360aa885458e9b45acbd017a39
Change-Id: If2699d1886a3da5f137b9c3b96d92f36fab69d79
@jrbyrnes jrbyrnes merged commit 1636f4a into llvm:main Apr 24, 2025
6 of 10 checks passed
@jrbyrnes
Copy link
Contributor Author

Thanks all

IanWood1 pushed a commit to IanWood1/llvm-project that referenced this pull request May 6, 2025
This type of decomposition is used in multiple places already. Adding it
to `CmpInstAnalysis` reduces code duplication.
IanWood1 pushed a commit to IanWood1/llvm-project that referenced this pull request May 6, 2025
This type of decomposition is used in multiple places already. Adding it
to `CmpInstAnalysis` reduces code duplication.
IanWood1 pushed a commit to IanWood1/llvm-project that referenced this pull request May 6, 2025
This type of decomposition is used in multiple places already. Adding it
to `CmpInstAnalysis` reduces code duplication.
Ankur-0429 pushed a commit to Ankur-0429/llvm-project that referenced this pull request May 9, 2025
This type of decomposition is used in multiple places already. Adding it
to `CmpInstAnalysis` reduces code duplication.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
llvm:analysis Includes value tracking, cost tables and constant folding llvm:instcombine Covers the InstCombine, InstSimplify and AggressiveInstCombine passes llvm:transforms
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants