Skip to content

Replace uses of ConstantExpr::getCompare. #91558

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 9, 2024

Conversation

efriedma-quic
Copy link
Collaborator

@efriedma-quic efriedma-quic commented May 9, 2024

Use ICmpInst::compare() where possible, ConstantFoldCompareInstOperands in other places. This only changes places where the either the fold is guaranteed to succeed, or the code doesn't use the resulting compare if we fail to fold.

Use ICmpInst::compare() where possible, ConstantFoldCompareInstruction
in other places.  This only changes places where the either the fold is
guaranteed to succeed, or the code doesn't use the resulting compare if
we fail to fold.
@efriedma-quic efriedma-quic requested a review from nikic as a code owner May 9, 2024 04:58
@llvmbot llvmbot added backend:AMDGPU backend:X86 llvm:ir llvm:analysis Includes value tracking, cost tables and constant folding llvm:transforms labels May 9, 2024
@llvmbot
Copy link
Member

llvmbot commented May 9, 2024

@llvm/pr-subscribers-backend-amdgpu
@llvm/pr-subscribers-llvm-transforms

@llvm/pr-subscribers-backend-x86

Author: Eli Friedman (efriedma-quic)

Changes

Use ICmpInst::compare() where possible, ConstantFoldCompareInstruction in other places. This only changes places where the either the fold is guaranteed to succeed, or the code doesn't use the resulting compare if we fail to fold.


Full diff: https://github.com/llvm/llvm-project/pull/91558.diff

14 Files Affected:

  • (modified) llvm/lib/Analysis/BranchProbabilityInfo.cpp (+3-2)
  • (modified) llvm/lib/Analysis/ConstantFolding.cpp (+4-4)
  • (modified) llvm/lib/Analysis/InlineCost.cpp (+5-7)
  • (modified) llvm/lib/Analysis/ScalarEvolution.cpp (+1-3)
  • (modified) llvm/lib/IR/Constants.cpp (+2-2)
  • (modified) llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp (+3-2)
  • (modified) llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp (+3-2)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp (+2-2)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp (+2-1)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp (+6-9)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp (+6-6)
  • (modified) llvm/lib/Transforms/InstCombine/InstructionCombining.cpp (+6-3)
  • (modified) llvm/lib/Transforms/Scalar/JumpThreading.cpp (+4-2)
  • (modified) llvm/lib/Transforms/Utils/SimplifyCFG.cpp (+4-4)
diff --git a/llvm/lib/Analysis/BranchProbabilityInfo.cpp b/llvm/lib/Analysis/BranchProbabilityInfo.cpp
index 36a2df6459132..0ef9ff836be29 100644
--- a/llvm/lib/Analysis/BranchProbabilityInfo.cpp
+++ b/llvm/lib/Analysis/BranchProbabilityInfo.cpp
@@ -22,6 +22,7 @@
 #include "llvm/IR/Attributes.h"
 #include "llvm/IR/BasicBlock.h"
 #include "llvm/IR/CFG.h"
+#include "llvm/IR/ConstantFold.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/Dominators.h"
 #include "llvm/IR/Function.h"
@@ -630,8 +631,8 @@ computeUnlikelySuccessors(const BasicBlock *BB, Loop *L,
       if (!CmpLHSConst)
         continue;
       // Now constant-evaluate the compare
-      Constant *Result = ConstantExpr::getCompare(CI->getPredicate(),
-                                                  CmpLHSConst, CmpConst, true);
+      Constant *Result = ConstantFoldCompareInstruction(CI->getPredicate(),
+                                                        CmpLHSConst, CmpConst);
       // If the result means we don't branch to the block then that block is
       // unlikely.
       if (Result &&
diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp
index 749374a3aa48a..046a769453808 100644
--- a/llvm/lib/Analysis/ConstantFolding.cpp
+++ b/llvm/lib/Analysis/ConstantFolding.cpp
@@ -1268,10 +1268,10 @@ Constant *llvm::ConstantFoldCompareInstOperands(
       Value *Stripped1 =
           Ops1->stripAndAccumulateInBoundsConstantOffsets(DL, Offset1);
       if (Stripped0 == Stripped1)
-        return ConstantExpr::getCompare(
-            ICmpInst::getSignedPredicate(Predicate),
-            ConstantInt::get(CE0->getContext(), Offset0),
-            ConstantInt::get(CE0->getContext(), Offset1));
+        return ConstantInt::getBool(
+            Ops0->getContext(),
+            ICmpInst::compare(Offset0, Offset1,
+                              ICmpInst::getSignedPredicate(Predicate)));
     }
   } else if (isa<ConstantExpr>(Ops1)) {
     // If RHS is a constant expression, but the left side isn't, swap the
diff --git a/llvm/lib/Analysis/InlineCost.cpp b/llvm/lib/Analysis/InlineCost.cpp
index c75460f44c1d9..a531064e304d0 100644
--- a/llvm/lib/Analysis/InlineCost.cpp
+++ b/llvm/lib/Analysis/InlineCost.cpp
@@ -2046,13 +2046,11 @@ bool CallAnalyzer::visitCmpInst(CmpInst &I) {
     if (RHSBase && LHSBase == RHSBase) {
       // We have common bases, fold the icmp to a constant based on the
       // offsets.
-      Constant *CLHS = ConstantInt::get(LHS->getContext(), LHSOffset);
-      Constant *CRHS = ConstantInt::get(RHS->getContext(), RHSOffset);
-      if (Constant *C = ConstantExpr::getICmp(I.getPredicate(), CLHS, CRHS)) {
-        SimplifiedValues[&I] = C;
-        ++NumConstantPtrCmps;
-        return true;
-      }
+      SimplifiedValues[&I] = ConstantInt::getBool(
+          I.getType(),
+          ICmpInst::compare(LHSOffset, RHSOffset, I.getPredicate()));
+      ++NumConstantPtrCmps;
+      return true;
     }
   }
 
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 93f885c5d5ad8..7dc5aa084f3c3 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -10615,9 +10615,7 @@ bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred,
   if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
     // Check for both operands constant.
     if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
-      if (ConstantExpr::getICmp(Pred,
-                                LHSC->getValue(),
-                                RHSC->getValue())->isNullValue())
+      if (!ICmpInst::compare(LHSC->getAPInt(), RHSC->getAPInt(), Pred))
         return TrivialCase(false);
       return TrivialCase(true);
     }
diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp
index 5268eccf70144..db442c54125a7 100644
--- a/llvm/lib/IR/Constants.cpp
+++ b/llvm/lib/IR/Constants.cpp
@@ -315,8 +315,8 @@ bool Constant::isElementWiseEqual(Value *Y) const {
   Type *IntTy = VectorType::getInteger(VTy);
   Constant *C0 = ConstantExpr::getBitCast(const_cast<Constant *>(this), IntTy);
   Constant *C1 = ConstantExpr::getBitCast(cast<Constant>(Y), IntTy);
-  Constant *CmpEq = ConstantExpr::getICmp(ICmpInst::ICMP_EQ, C0, C1);
-  return isa<PoisonValue>(CmpEq) || match(CmpEq, m_One());
+  Constant *CmpEq = ConstantFoldCompareInstruction(ICmpInst::ICMP_EQ, C0, C1);
+  return CmpEq && (isa<PoisonValue>(CmpEq) || match(CmpEq, m_One()));
 }
 
 static bool
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp b/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
index 5b7fa13f2e835..5cb2de20f0957 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
@@ -854,8 +854,9 @@ GCNTTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
 
     if (auto *CSrc0 = dyn_cast<Constant>(Src0)) {
       if (auto *CSrc1 = dyn_cast<Constant>(Src1)) {
-        Constant *CCmp = ConstantExpr::getCompare(CCVal, CSrc0, CSrc1);
-        if (CCmp->isNullValue()) {
+        Constant *CCmp = ConstantFoldCompareInstruction(
+            (ICmpInst::Predicate)CCVal, CSrc0, CSrc1);
+        if (CCmp && CCmp->isNullValue()) {
           return IC.replaceInstUsesWith(
               II, IC.Builder.CreateSExt(CCmp, II.getType()));
         }
diff --git a/llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp b/llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp
index e46fc034cc269..4ecc29317d6b2 100644
--- a/llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp
+++ b/llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp
@@ -29,8 +29,9 @@ using namespace llvm;
 static Constant *getNegativeIsTrueBoolVec(Constant *V) {
   VectorType *IntTy = VectorType::getInteger(cast<VectorType>(V->getType()));
   V = ConstantExpr::getBitCast(V, IntTy);
-  V = ConstantExpr::getICmp(CmpInst::ICMP_SGT, Constant::getNullValue(IntTy),
-                            V);
+  V = ConstantFoldCompareInstruction(CmpInst::ICMP_SGT,
+                                     Constant::getNullValue(IntTy), V);
+  assert(V && "Vector must be foldable");
   return V;
 }
 
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index ed9a89b14efcc..fe7716f26a5d7 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -2504,8 +2504,8 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) {
           match(C1, m_Power2())) {
         Constant *Log2C1 = ConstantExpr::getExactLogBase2(C1);
         Constant *Cmp =
-            ConstantExpr::getCompare(ICmpInst::ICMP_ULT, Log2C3, C2);
-        if (Cmp->isZeroValue()) {
+            ConstantFoldCompareInstruction(ICmpInst::ICMP_ULT, Log2C3, C2);
+        if (Cmp && Cmp->isZeroValue()) {
           // iff C1,C3 is pow2 and Log2(C3) >= C2:
           // ((C1 >> X) << C2) & C3 -> X == (cttz(C1)+C2-cttz(C3)) ? C3 : 0
           Constant *ShlC = ConstantExpr::getAdd(C2, Log2C1);
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index d7433ad3599f9..4255864331ecd 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -1982,7 +1982,8 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
       if (ModuloC != ShAmtC)
         return replaceOperand(*II, 2, ModuloC);
 
-      assert(match(ConstantExpr::getICmp(ICmpInst::ICMP_UGT, WidthC, ShAmtC),
+      assert(match(ConstantFoldCompareInstruction(ICmpInst::ICMP_UGT, WidthC,
+                                                  ShAmtC),
                    m_One()) &&
              "Shift amount expected to be modulo bitwidth");
 
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 7092fb5e509bb..e1a3194a1beb7 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -3176,15 +3176,12 @@ Instruction *InstCombinerImpl::foldICmpSelectConstant(ICmpInst &Cmp,
                               C3GreaterThan)) {
     assert(C1LessThan && C2Equal && C3GreaterThan);
 
-    bool TrueWhenLessThan =
-        ConstantExpr::getCompare(Cmp.getPredicate(), C1LessThan, C)
-            ->isAllOnesValue();
-    bool TrueWhenEqual =
-        ConstantExpr::getCompare(Cmp.getPredicate(), C2Equal, C)
-            ->isAllOnesValue();
-    bool TrueWhenGreaterThan =
-        ConstantExpr::getCompare(Cmp.getPredicate(), C3GreaterThan, C)
-            ->isAllOnesValue();
+    bool TrueWhenLessThan = ICmpInst::compare(
+        C1LessThan->getValue(), C->getValue(), Cmp.getPredicate());
+    bool TrueWhenEqual = ICmpInst::compare(C2Equal->getValue(), C->getValue(),
+                                           Cmp.getPredicate());
+    bool TrueWhenGreaterThan = ICmpInst::compare(
+        C3GreaterThan->getValue(), C->getValue(), Cmp.getPredicate());
 
     // This generates the new instruction that will replace the original Cmp
     // Instruction. Instead of enumerating the various combinations when
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 8818369e79452..29462e2214242 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -1493,14 +1493,14 @@ static Value *canonicalizeClampLike(SelectInst &Sel0, ICmpInst &Cmp0,
     std::swap(ThresholdLowIncl, ThresholdHighExcl);
 
   // The fold has a precondition 1: C2 s>= ThresholdLow
-  auto *Precond1 = ConstantExpr::getICmp(ICmpInst::Predicate::ICMP_SGE, C2,
-                                         ThresholdLowIncl);
-  if (!match(Precond1, m_One()))
+  auto *Precond1 = ConstantFoldCompareInstruction(ICmpInst::Predicate::ICMP_SGE,
+                                                  C2, ThresholdLowIncl);
+  if (!Precond1 || !match(Precond1, m_One()))
     return nullptr;
   // The fold has a precondition 2: C2 s<= ThresholdHigh
-  auto *Precond2 = ConstantExpr::getICmp(ICmpInst::Predicate::ICMP_SLE, C2,
-                                         ThresholdHighExcl);
-  if (!match(Precond2, m_One()))
+  auto *Precond2 = ConstantFoldCompareInstruction(ICmpInst::Predicate::ICMP_SLE,
+                                                  C2, ThresholdHighExcl);
+  if (!Precond2 || !match(Precond2, m_One()))
     return nullptr;
 
   // If we are matching from a truncated input, we need to sext the
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index b6f8b24f43b8c..affa2fdcbb897 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -808,9 +808,12 @@ Instruction *InstCombinerImpl::tryFoldInstWithCtpopWithNot(Instruction *I) {
   Constant *BitWidthC = ConstantInt::get(Ty, Ty->getScalarSizeInBits());
   // Need extra check for icmp. Note if this check is true, it generally means
   // the icmp will simplify to true/false.
-  if (Opc == Instruction::ICmp && !cast<ICmpInst>(I)->isEquality() &&
-      !ConstantExpr::getICmp(ICmpInst::ICMP_UGT, C, BitWidthC)->isZeroValue())
-    return nullptr;
+  if (Opc == Instruction::ICmp && !cast<ICmpInst>(I)->isEquality()) {
+    Constant *Cmp =
+        ConstantFoldCompareInstruction(ICmpInst::ICMP_UGT, C, BitWidthC);
+    if (Cmp && !Cmp->isZeroValue())
+      return nullptr;
+  }
 
   // Check we can invert `(not x)` for free.
   bool Consumes = false;
diff --git a/llvm/lib/Transforms/Scalar/JumpThreading.cpp b/llvm/lib/Transforms/Scalar/JumpThreading.cpp
index 08d82fa66da30..ac90ad8c08e61 100644
--- a/llvm/lib/Transforms/Scalar/JumpThreading.cpp
+++ b/llvm/lib/Transforms/Scalar/JumpThreading.cpp
@@ -37,6 +37,7 @@
 #include "llvm/IR/BasicBlock.h"
 #include "llvm/IR/CFG.h"
 #include "llvm/IR/Constant.h"
+#include "llvm/IR/ConstantFold.h"
 #include "llvm/IR/ConstantRange.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/DataLayout.h"
@@ -868,7 +869,7 @@ bool JumpThreadingPass::computeValueKnownInPredecessorsImpl(
 
       for (const auto &LHSVal : LHSVals) {
         Constant *V = LHSVal.first;
-        Constant *Folded = ConstantExpr::getCompare(Pred, V, CmpConst);
+        Constant *Folded = ConstantFoldCompareInstruction(Pred, V, CmpConst);
         if (Constant *KC = getKnownConstant(Folded, WantInteger))
           Result.emplace_back(KC, LHSVal.second);
       }
@@ -1538,7 +1539,8 @@ Constant *JumpThreadingPass::evaluateOnPredecessorEdge(BasicBlock *BB,
       Constant *Op1 =
           evaluateOnPredecessorEdge(BB, PredPredBB, CondCmp->getOperand(1));
       if (Op0 && Op1) {
-        return ConstantExpr::getCompare(CondCmp->getPredicate(), Op0, Op1);
+        return ConstantFoldCompareInstruction(CondCmp->getPredicate(), Op0,
+                                              Op1);
       }
     }
     return nullptr;
diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index 5a44a11ecfd2c..4efd39c930c99 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -6581,16 +6581,16 @@ static void reuseTableCompare(
   Constant *FalseConst = ConstantInt::getFalse(RangeCmp->getType());
 
   // Check if the compare with the default value is constant true or false.
-  Constant *DefaultConst = ConstantExpr::getICmp(CmpInst->getPredicate(),
-                                                 DefaultValue, CmpOp1, true);
+  Constant *DefaultConst = ConstantFoldCompareInstruction(
+      CmpInst->getPredicate(), DefaultValue, CmpOp1);
   if (DefaultConst != TrueConst && DefaultConst != FalseConst)
     return;
 
   // Check if the compare with the case values is distinct from the default
   // compare result.
   for (auto ValuePair : Values) {
-    Constant *CaseConst = ConstantExpr::getICmp(CmpInst->getPredicate(),
-                                                ValuePair.second, CmpOp1, true);
+    Constant *CaseConst = ConstantFoldCompareInstruction(
+        CmpInst->getPredicate(), ValuePair.second, CmpOp1);
     if (!CaseConst || CaseConst == DefaultConst ||
         (CaseConst != TrueConst && CaseConst != FalseConst))
       return;

@llvmbot
Copy link
Member

llvmbot commented May 9, 2024

@llvm/pr-subscribers-llvm-analysis

Author: Eli Friedman (efriedma-quic)

Changes

Use ICmpInst::compare() where possible, ConstantFoldCompareInstruction in other places. This only changes places where the either the fold is guaranteed to succeed, or the code doesn't use the resulting compare if we fail to fold.


Full diff: https://github.com/llvm/llvm-project/pull/91558.diff

14 Files Affected:

  • (modified) llvm/lib/Analysis/BranchProbabilityInfo.cpp (+3-2)
  • (modified) llvm/lib/Analysis/ConstantFolding.cpp (+4-4)
  • (modified) llvm/lib/Analysis/InlineCost.cpp (+5-7)
  • (modified) llvm/lib/Analysis/ScalarEvolution.cpp (+1-3)
  • (modified) llvm/lib/IR/Constants.cpp (+2-2)
  • (modified) llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp (+3-2)
  • (modified) llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp (+3-2)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp (+2-2)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp (+2-1)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp (+6-9)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp (+6-6)
  • (modified) llvm/lib/Transforms/InstCombine/InstructionCombining.cpp (+6-3)
  • (modified) llvm/lib/Transforms/Scalar/JumpThreading.cpp (+4-2)
  • (modified) llvm/lib/Transforms/Utils/SimplifyCFG.cpp (+4-4)
diff --git a/llvm/lib/Analysis/BranchProbabilityInfo.cpp b/llvm/lib/Analysis/BranchProbabilityInfo.cpp
index 36a2df6459132..0ef9ff836be29 100644
--- a/llvm/lib/Analysis/BranchProbabilityInfo.cpp
+++ b/llvm/lib/Analysis/BranchProbabilityInfo.cpp
@@ -22,6 +22,7 @@
 #include "llvm/IR/Attributes.h"
 #include "llvm/IR/BasicBlock.h"
 #include "llvm/IR/CFG.h"
+#include "llvm/IR/ConstantFold.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/Dominators.h"
 #include "llvm/IR/Function.h"
@@ -630,8 +631,8 @@ computeUnlikelySuccessors(const BasicBlock *BB, Loop *L,
       if (!CmpLHSConst)
         continue;
       // Now constant-evaluate the compare
-      Constant *Result = ConstantExpr::getCompare(CI->getPredicate(),
-                                                  CmpLHSConst, CmpConst, true);
+      Constant *Result = ConstantFoldCompareInstruction(CI->getPredicate(),
+                                                        CmpLHSConst, CmpConst);
       // If the result means we don't branch to the block then that block is
       // unlikely.
       if (Result &&
diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp
index 749374a3aa48a..046a769453808 100644
--- a/llvm/lib/Analysis/ConstantFolding.cpp
+++ b/llvm/lib/Analysis/ConstantFolding.cpp
@@ -1268,10 +1268,10 @@ Constant *llvm::ConstantFoldCompareInstOperands(
       Value *Stripped1 =
           Ops1->stripAndAccumulateInBoundsConstantOffsets(DL, Offset1);
       if (Stripped0 == Stripped1)
-        return ConstantExpr::getCompare(
-            ICmpInst::getSignedPredicate(Predicate),
-            ConstantInt::get(CE0->getContext(), Offset0),
-            ConstantInt::get(CE0->getContext(), Offset1));
+        return ConstantInt::getBool(
+            Ops0->getContext(),
+            ICmpInst::compare(Offset0, Offset1,
+                              ICmpInst::getSignedPredicate(Predicate)));
     }
   } else if (isa<ConstantExpr>(Ops1)) {
     // If RHS is a constant expression, but the left side isn't, swap the
diff --git a/llvm/lib/Analysis/InlineCost.cpp b/llvm/lib/Analysis/InlineCost.cpp
index c75460f44c1d9..a531064e304d0 100644
--- a/llvm/lib/Analysis/InlineCost.cpp
+++ b/llvm/lib/Analysis/InlineCost.cpp
@@ -2046,13 +2046,11 @@ bool CallAnalyzer::visitCmpInst(CmpInst &I) {
     if (RHSBase && LHSBase == RHSBase) {
       // We have common bases, fold the icmp to a constant based on the
       // offsets.
-      Constant *CLHS = ConstantInt::get(LHS->getContext(), LHSOffset);
-      Constant *CRHS = ConstantInt::get(RHS->getContext(), RHSOffset);
-      if (Constant *C = ConstantExpr::getICmp(I.getPredicate(), CLHS, CRHS)) {
-        SimplifiedValues[&I] = C;
-        ++NumConstantPtrCmps;
-        return true;
-      }
+      SimplifiedValues[&I] = ConstantInt::getBool(
+          I.getType(),
+          ICmpInst::compare(LHSOffset, RHSOffset, I.getPredicate()));
+      ++NumConstantPtrCmps;
+      return true;
     }
   }
 
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 93f885c5d5ad8..7dc5aa084f3c3 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -10615,9 +10615,7 @@ bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred,
   if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
     // Check for both operands constant.
     if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
-      if (ConstantExpr::getICmp(Pred,
-                                LHSC->getValue(),
-                                RHSC->getValue())->isNullValue())
+      if (!ICmpInst::compare(LHSC->getAPInt(), RHSC->getAPInt(), Pred))
         return TrivialCase(false);
       return TrivialCase(true);
     }
diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp
index 5268eccf70144..db442c54125a7 100644
--- a/llvm/lib/IR/Constants.cpp
+++ b/llvm/lib/IR/Constants.cpp
@@ -315,8 +315,8 @@ bool Constant::isElementWiseEqual(Value *Y) const {
   Type *IntTy = VectorType::getInteger(VTy);
   Constant *C0 = ConstantExpr::getBitCast(const_cast<Constant *>(this), IntTy);
   Constant *C1 = ConstantExpr::getBitCast(cast<Constant>(Y), IntTy);
-  Constant *CmpEq = ConstantExpr::getICmp(ICmpInst::ICMP_EQ, C0, C1);
-  return isa<PoisonValue>(CmpEq) || match(CmpEq, m_One());
+  Constant *CmpEq = ConstantFoldCompareInstruction(ICmpInst::ICMP_EQ, C0, C1);
+  return CmpEq && (isa<PoisonValue>(CmpEq) || match(CmpEq, m_One()));
 }
 
 static bool
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp b/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
index 5b7fa13f2e835..5cb2de20f0957 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
@@ -854,8 +854,9 @@ GCNTTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
 
     if (auto *CSrc0 = dyn_cast<Constant>(Src0)) {
       if (auto *CSrc1 = dyn_cast<Constant>(Src1)) {
-        Constant *CCmp = ConstantExpr::getCompare(CCVal, CSrc0, CSrc1);
-        if (CCmp->isNullValue()) {
+        Constant *CCmp = ConstantFoldCompareInstruction(
+            (ICmpInst::Predicate)CCVal, CSrc0, CSrc1);
+        if (CCmp && CCmp->isNullValue()) {
           return IC.replaceInstUsesWith(
               II, IC.Builder.CreateSExt(CCmp, II.getType()));
         }
diff --git a/llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp b/llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp
index e46fc034cc269..4ecc29317d6b2 100644
--- a/llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp
+++ b/llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp
@@ -29,8 +29,9 @@ using namespace llvm;
 static Constant *getNegativeIsTrueBoolVec(Constant *V) {
   VectorType *IntTy = VectorType::getInteger(cast<VectorType>(V->getType()));
   V = ConstantExpr::getBitCast(V, IntTy);
-  V = ConstantExpr::getICmp(CmpInst::ICMP_SGT, Constant::getNullValue(IntTy),
-                            V);
+  V = ConstantFoldCompareInstruction(CmpInst::ICMP_SGT,
+                                     Constant::getNullValue(IntTy), V);
+  assert(V && "Vector must be foldable");
   return V;
 }
 
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index ed9a89b14efcc..fe7716f26a5d7 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -2504,8 +2504,8 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) {
           match(C1, m_Power2())) {
         Constant *Log2C1 = ConstantExpr::getExactLogBase2(C1);
         Constant *Cmp =
-            ConstantExpr::getCompare(ICmpInst::ICMP_ULT, Log2C3, C2);
-        if (Cmp->isZeroValue()) {
+            ConstantFoldCompareInstruction(ICmpInst::ICMP_ULT, Log2C3, C2);
+        if (Cmp && Cmp->isZeroValue()) {
           // iff C1,C3 is pow2 and Log2(C3) >= C2:
           // ((C1 >> X) << C2) & C3 -> X == (cttz(C1)+C2-cttz(C3)) ? C3 : 0
           Constant *ShlC = ConstantExpr::getAdd(C2, Log2C1);
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index d7433ad3599f9..4255864331ecd 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -1982,7 +1982,8 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
       if (ModuloC != ShAmtC)
         return replaceOperand(*II, 2, ModuloC);
 
-      assert(match(ConstantExpr::getICmp(ICmpInst::ICMP_UGT, WidthC, ShAmtC),
+      assert(match(ConstantFoldCompareInstruction(ICmpInst::ICMP_UGT, WidthC,
+                                                  ShAmtC),
                    m_One()) &&
              "Shift amount expected to be modulo bitwidth");
 
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 7092fb5e509bb..e1a3194a1beb7 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -3176,15 +3176,12 @@ Instruction *InstCombinerImpl::foldICmpSelectConstant(ICmpInst &Cmp,
                               C3GreaterThan)) {
     assert(C1LessThan && C2Equal && C3GreaterThan);
 
-    bool TrueWhenLessThan =
-        ConstantExpr::getCompare(Cmp.getPredicate(), C1LessThan, C)
-            ->isAllOnesValue();
-    bool TrueWhenEqual =
-        ConstantExpr::getCompare(Cmp.getPredicate(), C2Equal, C)
-            ->isAllOnesValue();
-    bool TrueWhenGreaterThan =
-        ConstantExpr::getCompare(Cmp.getPredicate(), C3GreaterThan, C)
-            ->isAllOnesValue();
+    bool TrueWhenLessThan = ICmpInst::compare(
+        C1LessThan->getValue(), C->getValue(), Cmp.getPredicate());
+    bool TrueWhenEqual = ICmpInst::compare(C2Equal->getValue(), C->getValue(),
+                                           Cmp.getPredicate());
+    bool TrueWhenGreaterThan = ICmpInst::compare(
+        C3GreaterThan->getValue(), C->getValue(), Cmp.getPredicate());
 
     // This generates the new instruction that will replace the original Cmp
     // Instruction. Instead of enumerating the various combinations when
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 8818369e79452..29462e2214242 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -1493,14 +1493,14 @@ static Value *canonicalizeClampLike(SelectInst &Sel0, ICmpInst &Cmp0,
     std::swap(ThresholdLowIncl, ThresholdHighExcl);
 
   // The fold has a precondition 1: C2 s>= ThresholdLow
-  auto *Precond1 = ConstantExpr::getICmp(ICmpInst::Predicate::ICMP_SGE, C2,
-                                         ThresholdLowIncl);
-  if (!match(Precond1, m_One()))
+  auto *Precond1 = ConstantFoldCompareInstruction(ICmpInst::Predicate::ICMP_SGE,
+                                                  C2, ThresholdLowIncl);
+  if (!Precond1 || !match(Precond1, m_One()))
     return nullptr;
   // The fold has a precondition 2: C2 s<= ThresholdHigh
-  auto *Precond2 = ConstantExpr::getICmp(ICmpInst::Predicate::ICMP_SLE, C2,
-                                         ThresholdHighExcl);
-  if (!match(Precond2, m_One()))
+  auto *Precond2 = ConstantFoldCompareInstruction(ICmpInst::Predicate::ICMP_SLE,
+                                                  C2, ThresholdHighExcl);
+  if (!Precond2 || !match(Precond2, m_One()))
     return nullptr;
 
   // If we are matching from a truncated input, we need to sext the
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index b6f8b24f43b8c..affa2fdcbb897 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -808,9 +808,12 @@ Instruction *InstCombinerImpl::tryFoldInstWithCtpopWithNot(Instruction *I) {
   Constant *BitWidthC = ConstantInt::get(Ty, Ty->getScalarSizeInBits());
   // Need extra check for icmp. Note if this check is true, it generally means
   // the icmp will simplify to true/false.
-  if (Opc == Instruction::ICmp && !cast<ICmpInst>(I)->isEquality() &&
-      !ConstantExpr::getICmp(ICmpInst::ICMP_UGT, C, BitWidthC)->isZeroValue())
-    return nullptr;
+  if (Opc == Instruction::ICmp && !cast<ICmpInst>(I)->isEquality()) {
+    Constant *Cmp =
+        ConstantFoldCompareInstruction(ICmpInst::ICMP_UGT, C, BitWidthC);
+    if (Cmp && !Cmp->isZeroValue())
+      return nullptr;
+  }
 
   // Check we can invert `(not x)` for free.
   bool Consumes = false;
diff --git a/llvm/lib/Transforms/Scalar/JumpThreading.cpp b/llvm/lib/Transforms/Scalar/JumpThreading.cpp
index 08d82fa66da30..ac90ad8c08e61 100644
--- a/llvm/lib/Transforms/Scalar/JumpThreading.cpp
+++ b/llvm/lib/Transforms/Scalar/JumpThreading.cpp
@@ -37,6 +37,7 @@
 #include "llvm/IR/BasicBlock.h"
 #include "llvm/IR/CFG.h"
 #include "llvm/IR/Constant.h"
+#include "llvm/IR/ConstantFold.h"
 #include "llvm/IR/ConstantRange.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/DataLayout.h"
@@ -868,7 +869,7 @@ bool JumpThreadingPass::computeValueKnownInPredecessorsImpl(
 
       for (const auto &LHSVal : LHSVals) {
         Constant *V = LHSVal.first;
-        Constant *Folded = ConstantExpr::getCompare(Pred, V, CmpConst);
+        Constant *Folded = ConstantFoldCompareInstruction(Pred, V, CmpConst);
         if (Constant *KC = getKnownConstant(Folded, WantInteger))
           Result.emplace_back(KC, LHSVal.second);
       }
@@ -1538,7 +1539,8 @@ Constant *JumpThreadingPass::evaluateOnPredecessorEdge(BasicBlock *BB,
       Constant *Op1 =
           evaluateOnPredecessorEdge(BB, PredPredBB, CondCmp->getOperand(1));
       if (Op0 && Op1) {
-        return ConstantExpr::getCompare(CondCmp->getPredicate(), Op0, Op1);
+        return ConstantFoldCompareInstruction(CondCmp->getPredicate(), Op0,
+                                              Op1);
       }
     }
     return nullptr;
diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index 5a44a11ecfd2c..4efd39c930c99 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -6581,16 +6581,16 @@ static void reuseTableCompare(
   Constant *FalseConst = ConstantInt::getFalse(RangeCmp->getType());
 
   // Check if the compare with the default value is constant true or false.
-  Constant *DefaultConst = ConstantExpr::getICmp(CmpInst->getPredicate(),
-                                                 DefaultValue, CmpOp1, true);
+  Constant *DefaultConst = ConstantFoldCompareInstruction(
+      CmpInst->getPredicate(), DefaultValue, CmpOp1);
   if (DefaultConst != TrueConst && DefaultConst != FalseConst)
     return;
 
   // Check if the compare with the case values is distinct from the default
   // compare result.
   for (auto ValuePair : Values) {
-    Constant *CaseConst = ConstantExpr::getICmp(CmpInst->getPredicate(),
-                                                ValuePair.second, CmpOp1, true);
+    Constant *CaseConst = ConstantFoldCompareInstruction(
+        CmpInst->getPredicate(), ValuePair.second, CmpOp1);
     if (!CaseConst || CaseConst == DefaultConst ||
         (CaseConst != TrueConst && CaseConst != FalseConst))
       return;

Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is basically fine, but please use ConstantFoldCompareInstOperands instead of ConstantFoldCompareInstruction. The latter should essentially be considered a private API.

(The only reason it isn't is that IRBuilder ConstantFolder needs it, which is defined in a header.)

@efriedma-quic efriedma-quic changed the title [NFC] Replace uses of ConstantExpr::getCompare. Replace uses of ConstantExpr::getCompare. May 9, 2024
@efriedma-quic
Copy link
Collaborator Author

Made suggested change. I didn't change the use in llvm/lib/IR/Constants.cpp; I think that would introduce a circular dependency.

Note that this is not NFC anymore; folding using the datalayout actually unblocks transforms in a few places.

Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@efriedma-quic efriedma-quic merged commit f893dcc into llvm:main May 9, 2024
@efriedma-quic efriedma-quic deleted the constantexpr-cmp branch May 9, 2024 23:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:AMDGPU backend:X86 llvm:analysis Includes value tracking, cost tables and constant folding llvm:ir llvm:transforms
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants