Skip to content

[CmpInstAnalysis] Return decomposed bit test as struct (NFC) #109819

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 25, 2024

Conversation

nikic
Copy link
Contributor

@nikic nikic commented Sep 24, 2024

decomposeBitTestICmp() currently returns the result via two out parameters plus an in-place modification of Pred. This changes it to return an optional struct instead.

The motivation here is twofold. First, I'd like to extend this code to handle cases where the comparison is against a value other than zero, which would mean yet another out parameter. Second, while doing that I was badly bitten by the in-place modification, so I'd like to get rid of it.

decomposeBitTestICmp() currently returns the result via two out
parameters plus an in-place modification of Pred. This changes it
to return an optional struct instead.

The motivation here is twofold. First, I'd like to extend this code
to handle cases where the comparison is against a value other than
zero, which would mean yet another out parameter. Second, while
doing that I was badly bitten by the in-place modification, so I'd
like to get rid of it.
@nikic nikic requested review from dtcxzyw and goldsteinn September 24, 2024 15:51
@llvmbot llvmbot added llvm:analysis Includes value tracking, cost tables and constant folding llvm:transforms labels Sep 24, 2024
@llvmbot
Copy link
Member

llvmbot commented Sep 24, 2024

@llvm/pr-subscribers-llvm-transforms

@llvm/pr-subscribers-llvm-analysis

Author: Nikita Popov (nikic)

Changes

decomposeBitTestICmp() currently returns the result via two out parameters plus an in-place modification of Pred. This changes it to return an optional struct instead.

The motivation here is twofold. First, I'd like to extend this code to handle cases where the comparison is against a value other than zero, which would mean yet another out parameter. Second, while doing that I was badly bitten by the in-place modification, so I'd like to get rid of it.


Full diff: https://github.com/llvm/llvm-project/pull/109819.diff

7 Files Affected:

  • (modified) llvm/include/llvm/Analysis/CmpInstAnalysis.h (+13-6)
  • (modified) llvm/lib/Analysis/CmpInstAnalysis.cpp (+35-32)
  • (modified) llvm/lib/Analysis/InstructionSimplify.cpp (+4-6)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp (+13-7)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp (+4-5)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp (+12-8)
  • (modified) llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp (+10-4)
diff --git a/llvm/include/llvm/Analysis/CmpInstAnalysis.h b/llvm/include/llvm/Analysis/CmpInstAnalysis.h
index 1d07a0c22887bb..406dacd930605e 100644
--- a/llvm/include/llvm/Analysis/CmpInstAnalysis.h
+++ b/llvm/include/llvm/Analysis/CmpInstAnalysis.h
@@ -14,6 +14,7 @@
 #ifndef LLVM_ANALYSIS_CMPINSTANALYSIS_H
 #define LLVM_ANALYSIS_CMPINSTANALYSIS_H
 
+#include "llvm/ADT/APInt.h"
 #include "llvm/IR/InstrTypes.h"
 
 namespace llvm {
@@ -91,12 +92,18 @@ namespace llvm {
   Constant *getPredForFCmpCode(unsigned Code, Type *OpTy,
                                CmpInst::Predicate &Pred);
 
-  /// Decompose an icmp into the form ((X & Mask) pred 0) if possible. The
-  /// returned predicate is either == or !=. Returns false if decomposition
-  /// fails.
-  bool decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate &Pred,
-                            Value *&X, APInt &Mask,
-                            bool LookThroughTrunc = true);
+  /// Represents the operation icmp (X & Mask) pred 0, where pred can only be
+  /// eq or ne.
+  struct DecomposedBitTest {
+    Value *X;
+    CmpInst::Predicate Pred;
+    APInt Mask;
+  };
+
+  /// Decompose an icmp into the form ((X & Mask) pred 0) if possible.
+  std::optional<DecomposedBitTest>
+  decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
+                       bool LookThroughTrunc = true);
 
 } // end namespace llvm
 
diff --git a/llvm/lib/Analysis/CmpInstAnalysis.cpp b/llvm/lib/Analysis/CmpInstAnalysis.cpp
index a1fa7857764d98..36d7aa510545af 100644
--- a/llvm/lib/Analysis/CmpInstAnalysis.cpp
+++ b/llvm/lib/Analysis/CmpInstAnalysis.cpp
@@ -73,81 +73,84 @@ Constant *llvm::getPredForFCmpCode(unsigned Code, Type *OpTy,
   return nullptr;
 }
 
-bool llvm::decomposeBitTestICmp(Value *LHS, Value *RHS,
-                                CmpInst::Predicate &Pred,
-                                Value *&X, APInt &Mask, bool LookThruTrunc) {
+std::optional<DecomposedBitTest>
+llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
+                           bool LookThruTrunc) {
   using namespace PatternMatch;
 
   const APInt *C;
   if (!match(RHS, m_APIntAllowPoison(C)))
-    return false;
+    return std::nullopt;
 
+  DecomposedBitTest Result;
   switch (Pred) {
   default:
-    return false;
+    return std::nullopt;
   case ICmpInst::ICMP_SLT:
     // X < 0 is equivalent to (X & SignMask) != 0.
     if (!C->isZero())
-      return false;
-    Mask = APInt::getSignMask(C->getBitWidth());
-    Pred = ICmpInst::ICMP_NE;
+      return std::nullopt;
+    Result.Mask = APInt::getSignMask(C->getBitWidth());
+    Result.Pred = ICmpInst::ICMP_NE;
     break;
   case ICmpInst::ICMP_SLE:
     // X <= -1 is equivalent to (X & SignMask) != 0.
     if (!C->isAllOnes())
-      return false;
-    Mask = APInt::getSignMask(C->getBitWidth());
-    Pred = ICmpInst::ICMP_NE;
+      return std::nullopt;
+    Result.Mask = APInt::getSignMask(C->getBitWidth());
+    Result.Pred = ICmpInst::ICMP_NE;
     break;
   case ICmpInst::ICMP_SGT:
     // X > -1 is equivalent to (X & SignMask) == 0.
     if (!C->isAllOnes())
-      return false;
-    Mask = APInt::getSignMask(C->getBitWidth());
-    Pred = ICmpInst::ICMP_EQ;
+      return std::nullopt;
+    Result.Mask = APInt::getSignMask(C->getBitWidth());
+    Result.Pred = ICmpInst::ICMP_EQ;
     break;
   case ICmpInst::ICMP_SGE:
     // X >= 0 is equivalent to (X & SignMask) == 0.
     if (!C->isZero())
-      return false;
-    Mask = APInt::getSignMask(C->getBitWidth());
-    Pred = ICmpInst::ICMP_EQ;
+      return std::nullopt;
+    Result.Mask = APInt::getSignMask(C->getBitWidth());
+    Result.Pred = ICmpInst::ICMP_EQ;
     break;
   case ICmpInst::ICMP_ULT:
     // X <u 2^n is equivalent to (X & ~(2^n-1)) == 0.
     if (!C->isPowerOf2())
-      return false;
-    Mask = -*C;
-    Pred = ICmpInst::ICMP_EQ;
+      return std::nullopt;
+    Result.Mask = -*C;
+    Result.Pred = ICmpInst::ICMP_EQ;
     break;
   case ICmpInst::ICMP_ULE:
     // X <=u 2^n-1 is equivalent to (X & ~(2^n-1)) == 0.
     if (!(*C + 1).isPowerOf2())
-      return false;
-    Mask = ~*C;
-    Pred = ICmpInst::ICMP_EQ;
+      return std::nullopt;
+    Result.Mask = ~*C;
+    Result.Pred = ICmpInst::ICMP_EQ;
     break;
   case ICmpInst::ICMP_UGT:
     // X >u 2^n-1 is equivalent to (X & ~(2^n-1)) != 0.
     if (!(*C + 1).isPowerOf2())
-      return false;
-    Mask = ~*C;
-    Pred = ICmpInst::ICMP_NE;
+      return std::nullopt;
+    Result.Mask = ~*C;
+    Result.Pred = ICmpInst::ICMP_NE;
     break;
   case ICmpInst::ICMP_UGE:
     // X >=u 2^n is equivalent to (X & ~(2^n-1)) != 0.
     if (!C->isPowerOf2())
-      return false;
-    Mask = -*C;
-    Pred = ICmpInst::ICMP_NE;
+      return std::nullopt;
+    Result.Mask = -*C;
+    Result.Pred = ICmpInst::ICMP_NE;
     break;
   }
 
+  Value *X;
   if (LookThruTrunc && match(LHS, m_Trunc(m_Value(X)))) {
-    Mask = Mask.zext(X->getType()->getScalarSizeInBits());
+    Result.X = X;
+    Result.Mask = Result.Mask.zext(X->getType()->getScalarSizeInBits());
   } else {
-    X = LHS;
+    Result.X = LHS;
   }
 
-  return true;
+  return Result;
 }
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index 32a9f1ab34fb3f..90f05d43a2b147 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -4624,13 +4624,11 @@ static Value *simplifyCmpSelOfMaxMin(Value *CmpLHS, Value *CmpRHS,
 static Value *simplifySelectWithFakeICmpEq(Value *CmpLHS, Value *CmpRHS,
                                            ICmpInst::Predicate Pred,
                                            Value *TrueVal, Value *FalseVal) {
-  Value *X;
-  APInt Mask;
-  if (!decomposeBitTestICmp(CmpLHS, CmpRHS, Pred, X, Mask))
-    return nullptr;
+  if (auto Res = decomposeBitTestICmp(CmpLHS, CmpRHS, Pred))
+    return simplifySelectBitTest(TrueVal, FalseVal, Res->X, &Res->Mask,
+                                 Res->Pred == ICmpInst::ICMP_EQ);
 
-  return simplifySelectBitTest(TrueVal, FalseVal, X, &Mask,
-                               Pred == ICmpInst::ICMP_EQ);
+  return nullptr;
 }
 
 /// Try to simplify a select instruction when its condition operand is an
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 80d3adedfc89f3..2c2d24d392a938 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -181,11 +181,13 @@ static unsigned conjugateICmpMask(unsigned Mask) {
 // Adapts the external decomposeBitTestICmp for local use.
 static bool decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate &Pred,
                                  Value *&X, Value *&Y, Value *&Z) {
-  APInt Mask;
-  if (!llvm::decomposeBitTestICmp(LHS, RHS, Pred, X, Mask))
+  auto Res = llvm::decomposeBitTestICmp(LHS, RHS, Pred);
+  if (!Res)
     return false;
 
-  Y = ConstantInt::get(X->getType(), Mask);
+  Pred = Res->Pred;
+  X = Res->X;
+  Y = ConstantInt::get(X->getType(), Res->Mask);
   Z = ConstantInt::get(X->getType(), 0);
   return true;
 }
@@ -870,11 +872,15 @@ static Value *foldSignedTruncationCheck(ICmpInst *ICmp0, ICmpInst *ICmp1,
                            APInt &UnsetBitsMask) -> bool {
     CmpInst::Predicate Pred = ICmp->getPredicate();
     // Can it be decomposed into  icmp eq (X & Mask), 0  ?
-    if (llvm::decomposeBitTestICmp(ICmp->getOperand(0), ICmp->getOperand(1),
-                                   Pred, X, UnsetBitsMask,
-                                   /*LookThroughTrunc=*/false) &&
-        Pred == ICmpInst::ICMP_EQ)
+    if (auto Res =
+            llvm::decomposeBitTestICmp(ICmp->getOperand(0), ICmp->getOperand(1),
+                                       Pred, /*LookThroughTrunc=*/false);
+        Res && Res->Pred == ICmpInst::ICMP_EQ) {
+      X = Res->X;
+      UnsetBitsMask = Res->Mask;
       return true;
+    }
+
     // Is it  icmp eq (X & Mask), 0  already?
     const APInt *Mask;
     if (match(ICmp, m_ICmp(Pred, m_And(m_Value(X), m_APInt(Mask)), m_Zero())) &&
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 698abbb34c18c3..b1215bb4d83b0f 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -5905,11 +5905,10 @@ Instruction *InstCombinerImpl::foldICmpWithTrunc(ICmpInst &ICmp) {
   // This matches patterns corresponding to tests of the signbit as well as:
   // (trunc X) u< C --> (X & -C) == 0 (are all masked-high-bits clear?)
   // (trunc X) u> C --> (X & ~C) != 0 (are any masked-high-bits set?)
-  APInt Mask;
-  if (decomposeBitTestICmp(Op0, Op1, Pred, X, Mask, true /* WithTrunc */)) {
-    Value *And = Builder.CreateAnd(X, Mask);
-    Constant *Zero = ConstantInt::getNullValue(X->getType());
-    return new ICmpInst(Pred, And, Zero);
+  if (auto Res = decomposeBitTestICmp(Op0, Op1, Pred, /*WithTrunc=*/true)) {
+    Value *And = Builder.CreateAnd(Res->X, Res->Mask);
+    Constant *Zero = ConstantInt::getNullValue(Res->X->getType());
+    return new ICmpInst(Res->Pred, And, Zero);
   }
 
   unsigned SrcBits = X->getType()->getScalarSizeInBits();
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 7476db9ee38f45..3dbe95897d6356 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -145,12 +145,15 @@ static Value *foldSelectICmpAnd(SelectInst &Sel, ICmpInst *Cmp,
       return nullptr;
 
     AndMask = *AndRHS;
-  } else if (decomposeBitTestICmp(Cmp->getOperand(0), Cmp->getOperand(1),
-                                  Pred, V, AndMask)) {
-    assert(ICmpInst::isEquality(Pred) && "Not equality test?");
-    if (!AndMask.isPowerOf2())
+  } else if (auto Res = decomposeBitTestICmp(Cmp->getOperand(0),
+                                             Cmp->getOperand(1), Pred)) {
+    assert(ICmpInst::isEquality(Res->Pred) && "Not equality test?");
+    if (!Res->Mask.isPowerOf2())
       return nullptr;
 
+    V = Res->X;
+    AndMask = Res->Mask;
+    Pred = Res->Pred;
     CreateAnd = true;
   } else {
     return nullptr;
@@ -747,12 +750,13 @@ static Value *foldSelectICmpAndBinOp(const ICmpInst *IC, Value *TrueVal,
 
     C1Log = C1->logBase2();
   } else {
-    APInt C1;
-    if (!decomposeBitTestICmp(CmpLHS, CmpRHS, Pred, CmpLHS, C1) ||
-        !C1.isPowerOf2())
+    auto Res = decomposeBitTestICmp(CmpLHS, CmpRHS, Pred);
+    if (!Res || !Res->Mask.isPowerOf2())
       return nullptr;
 
-    C1Log = C1.logBase2();
+    CmpLHS = Res->X;
+    Pred = Res->Pred;
+    C1Log = Res->Mask.logBase2();
     NeedAnd = true;
   }
 
diff --git a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
index 578d087e470e1e..e3c3984ccb5156 100644
--- a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
@@ -2465,10 +2465,16 @@ static bool detectShiftUntilBitTestIdiom(Loop *CurLoop, Value *&BaseX,
   };
   auto MatchDecomposableConstantBitMask = [&]() {
     APInt Mask;
-    return llvm::decomposeBitTestICmp(CmpLHS, CmpRHS, Pred, CurrX, Mask) &&
-           ICmpInst::isEquality(Pred) && Mask.isPowerOf2() &&
-           (BitMask = ConstantInt::get(CurrX->getType(), Mask)) &&
-           (BitPos = ConstantInt::get(CurrX->getType(), Mask.logBase2()));
+    auto Res = llvm::decomposeBitTestICmp(CmpLHS, CmpRHS, Pred);
+    if (Res && Res->Mask.isPowerOf2()) {
+      assert(ICmpInst::isEquality(Res->Pred));
+      Pred = Res->Pred;
+      CurrX = Res->X;
+      BitMask = ConstantInt::get(CurrX->getType(), Res->Mask);
+      BitPos = ConstantInt::get(CurrX->getType(), Res->Mask.logBase2());
+      return true;
+    }
+    return false;
   };
 
   if (!MatchVariableBitMask() && !MatchConstantBitMask() &&

(BitPos = ConstantInt::get(CurrX->getType(), Mask.logBase2()));
auto Res = llvm::decomposeBitTestICmp(CmpLHS, CmpRHS, Pred);
if (Res && Res->Mask.isPowerOf2()) {
assert(ICmpInst::isEquality(Res->Pred));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did this become an assert?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

decomposeBitTestICmp() is required to always return an equality predicate.

Pred == ICmpInst::ICMP_EQ)
if (auto Res =
llvm::decomposeBitTestICmp(ICmp->getOperand(0), ICmp->getOperand(1),
Pred, /*LookThroughTrunc=*/false);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imo just yank the auto Res = llvm:::dec.... out of the if.

Copy link
Contributor

@goldsteinn goldsteinn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Member

@dtcxzyw dtcxzyw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG

@@ -2465,10 +2465,16 @@ static bool detectShiftUntilBitTestIdiom(Loop *CurLoop, Value *&BaseX,
};
auto MatchDecomposableConstantBitMask = [&]() {
APInt Mask;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is unused now.

@nikic nikic merged commit b8d1bae into llvm:main Sep 25, 2024
6 of 8 checks passed
@nikic nikic deleted the decompose-bit-struct branch September 25, 2024 08:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
llvm:analysis Includes value tracking, cost tables and constant folding llvm:transforms
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants