Skip to content

[ValueTracking] Move getFlippedStrictnessPredicateAndConstant into ValueTracking. NFC. #122064

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 8, 2025

Conversation

dtcxzyw
Copy link
Member

@dtcxzyw dtcxzyw commented Jan 8, 2025

Needed by #121958.

@dtcxzyw dtcxzyw requested a review from nikic as a code owner January 8, 2025 07:23
@llvmbot llvmbot added llvm:instcombine Covers the InstCombine, InstSimplify and AggressiveInstCombine passes llvm:analysis Includes value tracking, cost tables and constant folding llvm:transforms labels Jan 8, 2025
@llvmbot
Copy link
Member

llvmbot commented Jan 8, 2025

@llvm/pr-subscribers-llvm-transforms

@llvm/pr-subscribers-llvm-analysis

Author: Yingwei Zheng (dtcxzyw)

Changes

Needed by #121958.


Full diff: https://github.com/llvm/llvm-project/pull/122064.diff

5 Files Affected:

  • (modified) llvm/include/llvm/Analysis/ValueTracking.h (+7)
  • (modified) llvm/include/llvm/Transforms/InstCombine/InstCombiner.h (-6)
  • (modified) llvm/lib/Analysis/ValueTracking.cpp (+74)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp (+4-80)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp (+2-4)
diff --git a/llvm/include/llvm/Analysis/ValueTracking.h b/llvm/include/llvm/Analysis/ValueTracking.h
index 8aa024a72afc88..b4918c2d1e8a18 100644
--- a/llvm/include/llvm/Analysis/ValueTracking.h
+++ b/llvm/include/llvm/Analysis/ValueTracking.h
@@ -1102,6 +1102,13 @@ bool mustExecuteUBIfPoisonOnPathTo(Instruction *Root,
                                    Instruction *OnPathTo,
                                    DominatorTree *DT);
 
+/// Convert an integer comparison with a constant RHS into an equivalent
+/// form with the strictness flipped predicate. Return the new predicate and
+/// corresponding constant RHS if possible. Otherwise return std::nullopt.
+/// E.g., (icmp sgt X, 0) -> (icmp sle X, 1).
+std::optional<std::pair<CmpPredicate, Constant *>>
+getFlippedStrictnessPredicateAndConstant(CmpPredicate Pred, Constant *C);
+
 /// Specific patterns of select instructions we can match.
 enum SelectPatternFlavor {
   SPF_UNKNOWN = 0,
diff --git a/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h b/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
index 71592058e34563..fa6b60cba15aaf 100644
--- a/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
+++ b/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
@@ -184,12 +184,6 @@ class LLVM_LIBRARY_VISIBILITY InstCombiner {
     return ConstantExpr::getSub(C, ConstantInt::get(C->getType(), 1));
   }
 
-  std::optional<std::pair<
-      CmpPredicate,
-      Constant *>> static getFlippedStrictnessPredicateAndConstant(CmpPredicate
-                                                                       Pred,
-                                                                   Constant *C);
-
   static bool shouldAvoidAbsorbingNotIntoSelect(const SelectInst &SI) {
     // a ? b : false and a ? true : b are the canonical form of logical and/or.
     // This includes !a ? b : false and !a ? true : b. Absorbing the not into
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 2f6e869ae7b735..0eb43dd581acc6 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -8641,6 +8641,80 @@ SelectPatternResult llvm::getSelectPattern(CmpInst::Predicate Pred,
   }
 }
 
+std::optional<std::pair<CmpPredicate, Constant *>>
+llvm::getFlippedStrictnessPredicateAndConstant(CmpPredicate Pred, Constant *C) {
+  assert(ICmpInst::isRelational(Pred) && ICmpInst::isIntPredicate(Pred) &&
+         "Only for relational integer predicates.");
+  if (isa<UndefValue>(C))
+    return std::nullopt;
+
+  Type *Type = C->getType();
+  bool IsSigned = ICmpInst::isSigned(Pred);
+
+  CmpInst::Predicate UnsignedPred = ICmpInst::getUnsignedPredicate(Pred);
+  bool WillIncrement =
+      UnsignedPred == ICmpInst::ICMP_ULE || UnsignedPred == ICmpInst::ICMP_UGT;
+
+  // Check if the constant operand can be safely incremented/decremented
+  // without overflowing/underflowing.
+  auto ConstantIsOk = [WillIncrement, IsSigned](ConstantInt *C) {
+    return WillIncrement ? !C->isMaxValue(IsSigned) : !C->isMinValue(IsSigned);
+  };
+
+  Constant *SafeReplacementConstant = nullptr;
+  if (auto *CI = dyn_cast<ConstantInt>(C)) {
+    // Bail out if the constant can't be safely incremented/decremented.
+    if (!ConstantIsOk(CI))
+      return std::nullopt;
+  } else if (auto *FVTy = dyn_cast<FixedVectorType>(Type)) {
+    unsigned NumElts = FVTy->getNumElements();
+    for (unsigned i = 0; i != NumElts; ++i) {
+      Constant *Elt = C->getAggregateElement(i);
+      if (!Elt)
+        return std::nullopt;
+
+      if (isa<UndefValue>(Elt))
+        continue;
+
+      // Bail out if we can't determine if this constant is min/max or if we
+      // know that this constant is min/max.
+      auto *CI = dyn_cast<ConstantInt>(Elt);
+      if (!CI || !ConstantIsOk(CI))
+        return std::nullopt;
+
+      if (!SafeReplacementConstant)
+        SafeReplacementConstant = CI;
+    }
+  } else if (isa<VectorType>(C->getType())) {
+    // Handle scalable splat
+    Value *SplatC = C->getSplatValue();
+    auto *CI = dyn_cast_or_null<ConstantInt>(SplatC);
+    // Bail out if the constant can't be safely incremented/decremented.
+    if (!CI || !ConstantIsOk(CI))
+      return std::nullopt;
+  } else {
+    // ConstantExpr?
+    return std::nullopt;
+  }
+
+  // It may not be safe to change a compare predicate in the presence of
+  // undefined elements, so replace those elements with the first safe constant
+  // that we found.
+  // TODO: in case of poison, it is safe; let's replace undefs only.
+  if (C->containsUndefOrPoisonElement()) {
+    assert(SafeReplacementConstant && "Replacement constant not set");
+    C = Constant::replaceUndefsWith(C, SafeReplacementConstant);
+  }
+
+  CmpInst::Predicate NewPred = CmpInst::getFlippedStrictnessPredicate(Pred);
+
+  // Increment or decrement the constant.
+  Constant *OneOrNegOne = ConstantInt::get(Type, WillIncrement ? 1 : -1, true);
+  Constant *NewC = ConstantExpr::getAdd(C, OneOrNegOne);
+
+  return std::make_pair(NewPred, NewC);
+}
+
 static SelectPatternResult matchSelectPattern(CmpInst::Predicate Pred,
                                               FastMathFlags FMF,
                                               Value *CmpLHS, Value *CmpRHS,
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 8b23583c510637..c2d659035877ed 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -2485,9 +2485,8 @@ Instruction *InstCombinerImpl::foldICmpShlConstant(ICmpInst &Cmp,
       // icmp ule i64 (shl X, 32), 8589934592 ->
       // icmp ule i32 (trunc X, i32), 2 ->
       // icmp ult i32 (trunc X, i32), 3
-      if (auto FlippedStrictness =
-              InstCombiner::getFlippedStrictnessPredicateAndConstant(
-                  Pred, ConstantInt::get(ShType->getContext(), C))) {
+      if (auto FlippedStrictness = getFlippedStrictnessPredicateAndConstant(
+              Pred, ConstantInt::get(ShType->getContext(), C))) {
         CmpPred = FlippedStrictness->first;
         RHSC = cast<ConstantInt>(FlippedStrictness->second)->getValue();
       }
@@ -3280,8 +3279,7 @@ bool InstCombinerImpl::matchThreeWayIntCompare(SelectInst *SI, Value *&LHS,
   if (PredB == ICmpInst::ICMP_SGT && isa<Constant>(RHS2)) {
     // x sgt C-1  <-->  x sge C  <-->  not(x slt C)
     auto FlippedStrictness =
-        InstCombiner::getFlippedStrictnessPredicateAndConstant(
-            PredB, cast<Constant>(RHS2));
+        getFlippedStrictnessPredicateAndConstant(PredB, cast<Constant>(RHS2));
     if (!FlippedStrictness)
       return false;
     assert(FlippedStrictness->first == ICmpInst::ICMP_SGE &&
@@ -6908,79 +6906,6 @@ Instruction *InstCombinerImpl::foldICmpUsingBoolRange(ICmpInst &I) {
   return nullptr;
 }
 
-std::optional<std::pair<CmpPredicate, Constant *>>
-InstCombiner::getFlippedStrictnessPredicateAndConstant(CmpPredicate Pred,
-                                                       Constant *C) {
-  assert(ICmpInst::isRelational(Pred) && ICmpInst::isIntPredicate(Pred) &&
-         "Only for relational integer predicates.");
-
-  Type *Type = C->getType();
-  bool IsSigned = ICmpInst::isSigned(Pred);
-
-  CmpInst::Predicate UnsignedPred = ICmpInst::getUnsignedPredicate(Pred);
-  bool WillIncrement =
-      UnsignedPred == ICmpInst::ICMP_ULE || UnsignedPred == ICmpInst::ICMP_UGT;
-
-  // Check if the constant operand can be safely incremented/decremented
-  // without overflowing/underflowing.
-  auto ConstantIsOk = [WillIncrement, IsSigned](ConstantInt *C) {
-    return WillIncrement ? !C->isMaxValue(IsSigned) : !C->isMinValue(IsSigned);
-  };
-
-  Constant *SafeReplacementConstant = nullptr;
-  if (auto *CI = dyn_cast<ConstantInt>(C)) {
-    // Bail out if the constant can't be safely incremented/decremented.
-    if (!ConstantIsOk(CI))
-      return std::nullopt;
-  } else if (auto *FVTy = dyn_cast<FixedVectorType>(Type)) {
-    unsigned NumElts = FVTy->getNumElements();
-    for (unsigned i = 0; i != NumElts; ++i) {
-      Constant *Elt = C->getAggregateElement(i);
-      if (!Elt)
-        return std::nullopt;
-
-      if (isa<UndefValue>(Elt))
-        continue;
-
-      // Bail out if we can't determine if this constant is min/max or if we
-      // know that this constant is min/max.
-      auto *CI = dyn_cast<ConstantInt>(Elt);
-      if (!CI || !ConstantIsOk(CI))
-        return std::nullopt;
-
-      if (!SafeReplacementConstant)
-        SafeReplacementConstant = CI;
-    }
-  } else if (isa<VectorType>(C->getType())) {
-    // Handle scalable splat
-    Value *SplatC = C->getSplatValue();
-    auto *CI = dyn_cast_or_null<ConstantInt>(SplatC);
-    // Bail out if the constant can't be safely incremented/decremented.
-    if (!CI || !ConstantIsOk(CI))
-      return std::nullopt;
-  } else {
-    // ConstantExpr?
-    return std::nullopt;
-  }
-
-  // It may not be safe to change a compare predicate in the presence of
-  // undefined elements, so replace those elements with the first safe constant
-  // that we found.
-  // TODO: in case of poison, it is safe; let's replace undefs only.
-  if (C->containsUndefOrPoisonElement()) {
-    assert(SafeReplacementConstant && "Replacement constant not set");
-    C = Constant::replaceUndefsWith(C, SafeReplacementConstant);
-  }
-
-  CmpInst::Predicate NewPred = CmpInst::getFlippedStrictnessPredicate(Pred);
-
-  // Increment or decrement the constant.
-  Constant *OneOrNegOne = ConstantInt::get(Type, WillIncrement ? 1 : -1, true);
-  Constant *NewC = ConstantExpr::getAdd(C, OneOrNegOne);
-
-  return std::make_pair(NewPred, NewC);
-}
-
 /// If we have an icmp le or icmp ge instruction with a constant operand, turn
 /// it into the appropriate icmp lt or icmp gt instruction. This transform
 /// allows them to be folded in visitICmpInst.
@@ -6996,8 +6921,7 @@ static ICmpInst *canonicalizeCmpWithConstant(ICmpInst &I) {
   if (!Op1C)
     return nullptr;
 
-  auto FlippedStrictness =
-      InstCombiner::getFlippedStrictnessPredicateAndConstant(Pred, Op1C);
+  auto FlippedStrictness = getFlippedStrictnessPredicateAndConstant(Pred, Op1C);
   if (!FlippedStrictness)
     return nullptr;
 
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 7fd91c72a2fb0e..eca518aa640700 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -1689,8 +1689,7 @@ tryToReuseConstantFromSelectInComparison(SelectInst &Sel, ICmpInst &Cmp,
     return nullptr;
 
   // Check the constant we'd have with flipped-strictness predicate.
-  auto FlippedStrictness =
-      InstCombiner::getFlippedStrictnessPredicateAndConstant(Pred, C0);
+  auto FlippedStrictness = getFlippedStrictnessPredicateAndConstant(Pred, C0);
   if (!FlippedStrictness)
     return nullptr;
 
@@ -1970,8 +1969,7 @@ static Value *foldSelectWithConstOpToBinOp(ICmpInst *Cmp, Value *TrueVal,
   Value *RHS;
   SelectPatternFlavor SPF;
   const DataLayout &DL = BOp->getDataLayout();
-  auto Flipped =
-      InstCombiner::getFlippedStrictnessPredicateAndConstant(Predicate, C1);
+  auto Flipped = getFlippedStrictnessPredicateAndConstant(Predicate, C1);
 
   if (C3 == ConstantFoldBinaryOpOperands(Opcode, C1, C2, DL)) {
     SPF = getSelectPattern(Predicate).Flavor;

Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@dtcxzyw dtcxzyw merged commit 03e7862 into llvm:main Jan 8, 2025
12 checks passed
@dtcxzyw dtcxzyw deleted the move-flip-strictness branch January 8, 2025 12:02
@llvm-ci
Copy link
Collaborator

llvm-ci commented Jan 8, 2025

LLVM Buildbot has detected a new failure on builder openmp-offload-libc-amdgpu-runtime running on omp-vega20-1 while building llvm at step 6 "test-openmp".

Full details are available at: https://lab.llvm.org/buildbot/#/builders/73/builds/11355

Here is the relevant piece of the build log for the reference
Step 6 (test-openmp) failure: test (failure)
******************** TEST 'libomp :: tasking/issue-94260-2.c' FAILED ********************
Exit Code: -11

Command Output (stdout):
--
# RUN: at line 1
/home/ompworker/bbot/openmp-offload-libc-amdgpu-runtime/llvm.build/./bin/clang -fopenmp   -I /home/ompworker/bbot/openmp-offload-libc-amdgpu-runtime/llvm.build/runtimes/runtimes-bins/openmp/runtime/src -I /home/ompworker/bbot/openmp-offload-libc-amdgpu-runtime/llvm.src/openmp/runtime/test -L /home/ompworker/bbot/openmp-offload-libc-amdgpu-runtime/llvm.build/runtimes/runtimes-bins/openmp/runtime/src  -fno-omit-frame-pointer -I /home/ompworker/bbot/openmp-offload-libc-amdgpu-runtime/llvm.src/openmp/runtime/test/ompt /home/ompworker/bbot/openmp-offload-libc-amdgpu-runtime/llvm.src/openmp/runtime/test/tasking/issue-94260-2.c -o /home/ompworker/bbot/openmp-offload-libc-amdgpu-runtime/llvm.build/runtimes/runtimes-bins/openmp/runtime/test/tasking/Output/issue-94260-2.c.tmp -lm -latomic && /home/ompworker/bbot/openmp-offload-libc-amdgpu-runtime/llvm.build/runtimes/runtimes-bins/openmp/runtime/test/tasking/Output/issue-94260-2.c.tmp
# executed command: /home/ompworker/bbot/openmp-offload-libc-amdgpu-runtime/llvm.build/./bin/clang -fopenmp -I /home/ompworker/bbot/openmp-offload-libc-amdgpu-runtime/llvm.build/runtimes/runtimes-bins/openmp/runtime/src -I /home/ompworker/bbot/openmp-offload-libc-amdgpu-runtime/llvm.src/openmp/runtime/test -L /home/ompworker/bbot/openmp-offload-libc-amdgpu-runtime/llvm.build/runtimes/runtimes-bins/openmp/runtime/src -fno-omit-frame-pointer -I /home/ompworker/bbot/openmp-offload-libc-amdgpu-runtime/llvm.src/openmp/runtime/test/ompt /home/ompworker/bbot/openmp-offload-libc-amdgpu-runtime/llvm.src/openmp/runtime/test/tasking/issue-94260-2.c -o /home/ompworker/bbot/openmp-offload-libc-amdgpu-runtime/llvm.build/runtimes/runtimes-bins/openmp/runtime/test/tasking/Output/issue-94260-2.c.tmp -lm -latomic
# executed command: /home/ompworker/bbot/openmp-offload-libc-amdgpu-runtime/llvm.build/runtimes/runtimes-bins/openmp/runtime/test/tasking/Output/issue-94260-2.c.tmp
# note: command had no output on stdout or stderr
# error: command failed with exit status: -11

--

********************


Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
llvm:analysis Includes value tracking, cost tables and constant folding llvm:instcombine Covers the InstCombine, InstSimplify and AggressiveInstCombine passes llvm:transforms
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants