-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[ValueTracking] Move getFlippedStrictnessPredicateAndConstant
into ValueTracking. NFC.
#122064
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…ValueTracking. NFC.
@llvm/pr-subscribers-llvm-transforms @llvm/pr-subscribers-llvm-analysis Author: Yingwei Zheng (dtcxzyw) ChangesNeeded by #121958. Full diff: https://github.com/llvm/llvm-project/pull/122064.diff 5 Files Affected:
diff --git a/llvm/include/llvm/Analysis/ValueTracking.h b/llvm/include/llvm/Analysis/ValueTracking.h
index 8aa024a72afc88..b4918c2d1e8a18 100644
--- a/llvm/include/llvm/Analysis/ValueTracking.h
+++ b/llvm/include/llvm/Analysis/ValueTracking.h
@@ -1102,6 +1102,13 @@ bool mustExecuteUBIfPoisonOnPathTo(Instruction *Root,
Instruction *OnPathTo,
DominatorTree *DT);
+/// Convert an integer comparison with a constant RHS into an equivalent
+/// form with the strictness flipped predicate. Return the new predicate and
+/// corresponding constant RHS if possible. Otherwise return std::nullopt.
+/// E.g., (icmp sgt X, 0) -> (icmp sle X, 1).
+std::optional<std::pair<CmpPredicate, Constant *>>
+getFlippedStrictnessPredicateAndConstant(CmpPredicate Pred, Constant *C);
+
/// Specific patterns of select instructions we can match.
enum SelectPatternFlavor {
SPF_UNKNOWN = 0,
diff --git a/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h b/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
index 71592058e34563..fa6b60cba15aaf 100644
--- a/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
+++ b/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
@@ -184,12 +184,6 @@ class LLVM_LIBRARY_VISIBILITY InstCombiner {
return ConstantExpr::getSub(C, ConstantInt::get(C->getType(), 1));
}
- std::optional<std::pair<
- CmpPredicate,
- Constant *>> static getFlippedStrictnessPredicateAndConstant(CmpPredicate
- Pred,
- Constant *C);
-
static bool shouldAvoidAbsorbingNotIntoSelect(const SelectInst &SI) {
// a ? b : false and a ? true : b are the canonical form of logical and/or.
// This includes !a ? b : false and !a ? true : b. Absorbing the not into
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 2f6e869ae7b735..0eb43dd581acc6 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -8641,6 +8641,80 @@ SelectPatternResult llvm::getSelectPattern(CmpInst::Predicate Pred,
}
}
+std::optional<std::pair<CmpPredicate, Constant *>>
+llvm::getFlippedStrictnessPredicateAndConstant(CmpPredicate Pred, Constant *C) {
+ assert(ICmpInst::isRelational(Pred) && ICmpInst::isIntPredicate(Pred) &&
+ "Only for relational integer predicates.");
+ if (isa<UndefValue>(C))
+ return std::nullopt;
+
+ Type *Type = C->getType();
+ bool IsSigned = ICmpInst::isSigned(Pred);
+
+ CmpInst::Predicate UnsignedPred = ICmpInst::getUnsignedPredicate(Pred);
+ bool WillIncrement =
+ UnsignedPred == ICmpInst::ICMP_ULE || UnsignedPred == ICmpInst::ICMP_UGT;
+
+ // Check if the constant operand can be safely incremented/decremented
+ // without overflowing/underflowing.
+ auto ConstantIsOk = [WillIncrement, IsSigned](ConstantInt *C) {
+ return WillIncrement ? !C->isMaxValue(IsSigned) : !C->isMinValue(IsSigned);
+ };
+
+ Constant *SafeReplacementConstant = nullptr;
+ if (auto *CI = dyn_cast<ConstantInt>(C)) {
+ // Bail out if the constant can't be safely incremented/decremented.
+ if (!ConstantIsOk(CI))
+ return std::nullopt;
+ } else if (auto *FVTy = dyn_cast<FixedVectorType>(Type)) {
+ unsigned NumElts = FVTy->getNumElements();
+ for (unsigned i = 0; i != NumElts; ++i) {
+ Constant *Elt = C->getAggregateElement(i);
+ if (!Elt)
+ return std::nullopt;
+
+ if (isa<UndefValue>(Elt))
+ continue;
+
+ // Bail out if we can't determine if this constant is min/max or if we
+ // know that this constant is min/max.
+ auto *CI = dyn_cast<ConstantInt>(Elt);
+ if (!CI || !ConstantIsOk(CI))
+ return std::nullopt;
+
+ if (!SafeReplacementConstant)
+ SafeReplacementConstant = CI;
+ }
+ } else if (isa<VectorType>(C->getType())) {
+ // Handle scalable splat
+ Value *SplatC = C->getSplatValue();
+ auto *CI = dyn_cast_or_null<ConstantInt>(SplatC);
+ // Bail out if the constant can't be safely incremented/decremented.
+ if (!CI || !ConstantIsOk(CI))
+ return std::nullopt;
+ } else {
+ // ConstantExpr?
+ return std::nullopt;
+ }
+
+ // It may not be safe to change a compare predicate in the presence of
+ // undefined elements, so replace those elements with the first safe constant
+ // that we found.
+ // TODO: in case of poison, it is safe; let's replace undefs only.
+ if (C->containsUndefOrPoisonElement()) {
+ assert(SafeReplacementConstant && "Replacement constant not set");
+ C = Constant::replaceUndefsWith(C, SafeReplacementConstant);
+ }
+
+ CmpInst::Predicate NewPred = CmpInst::getFlippedStrictnessPredicate(Pred);
+
+ // Increment or decrement the constant.
+ Constant *OneOrNegOne = ConstantInt::get(Type, WillIncrement ? 1 : -1, true);
+ Constant *NewC = ConstantExpr::getAdd(C, OneOrNegOne);
+
+ return std::make_pair(NewPred, NewC);
+}
+
static SelectPatternResult matchSelectPattern(CmpInst::Predicate Pred,
FastMathFlags FMF,
Value *CmpLHS, Value *CmpRHS,
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 8b23583c510637..c2d659035877ed 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -2485,9 +2485,8 @@ Instruction *InstCombinerImpl::foldICmpShlConstant(ICmpInst &Cmp,
// icmp ule i64 (shl X, 32), 8589934592 ->
// icmp ule i32 (trunc X, i32), 2 ->
// icmp ult i32 (trunc X, i32), 3
- if (auto FlippedStrictness =
- InstCombiner::getFlippedStrictnessPredicateAndConstant(
- Pred, ConstantInt::get(ShType->getContext(), C))) {
+ if (auto FlippedStrictness = getFlippedStrictnessPredicateAndConstant(
+ Pred, ConstantInt::get(ShType->getContext(), C))) {
CmpPred = FlippedStrictness->first;
RHSC = cast<ConstantInt>(FlippedStrictness->second)->getValue();
}
@@ -3280,8 +3279,7 @@ bool InstCombinerImpl::matchThreeWayIntCompare(SelectInst *SI, Value *&LHS,
if (PredB == ICmpInst::ICMP_SGT && isa<Constant>(RHS2)) {
// x sgt C-1 <--> x sge C <--> not(x slt C)
auto FlippedStrictness =
- InstCombiner::getFlippedStrictnessPredicateAndConstant(
- PredB, cast<Constant>(RHS2));
+ getFlippedStrictnessPredicateAndConstant(PredB, cast<Constant>(RHS2));
if (!FlippedStrictness)
return false;
assert(FlippedStrictness->first == ICmpInst::ICMP_SGE &&
@@ -6908,79 +6906,6 @@ Instruction *InstCombinerImpl::foldICmpUsingBoolRange(ICmpInst &I) {
return nullptr;
}
-std::optional<std::pair<CmpPredicate, Constant *>>
-InstCombiner::getFlippedStrictnessPredicateAndConstant(CmpPredicate Pred,
- Constant *C) {
- assert(ICmpInst::isRelational(Pred) && ICmpInst::isIntPredicate(Pred) &&
- "Only for relational integer predicates.");
-
- Type *Type = C->getType();
- bool IsSigned = ICmpInst::isSigned(Pred);
-
- CmpInst::Predicate UnsignedPred = ICmpInst::getUnsignedPredicate(Pred);
- bool WillIncrement =
- UnsignedPred == ICmpInst::ICMP_ULE || UnsignedPred == ICmpInst::ICMP_UGT;
-
- // Check if the constant operand can be safely incremented/decremented
- // without overflowing/underflowing.
- auto ConstantIsOk = [WillIncrement, IsSigned](ConstantInt *C) {
- return WillIncrement ? !C->isMaxValue(IsSigned) : !C->isMinValue(IsSigned);
- };
-
- Constant *SafeReplacementConstant = nullptr;
- if (auto *CI = dyn_cast<ConstantInt>(C)) {
- // Bail out if the constant can't be safely incremented/decremented.
- if (!ConstantIsOk(CI))
- return std::nullopt;
- } else if (auto *FVTy = dyn_cast<FixedVectorType>(Type)) {
- unsigned NumElts = FVTy->getNumElements();
- for (unsigned i = 0; i != NumElts; ++i) {
- Constant *Elt = C->getAggregateElement(i);
- if (!Elt)
- return std::nullopt;
-
- if (isa<UndefValue>(Elt))
- continue;
-
- // Bail out if we can't determine if this constant is min/max or if we
- // know that this constant is min/max.
- auto *CI = dyn_cast<ConstantInt>(Elt);
- if (!CI || !ConstantIsOk(CI))
- return std::nullopt;
-
- if (!SafeReplacementConstant)
- SafeReplacementConstant = CI;
- }
- } else if (isa<VectorType>(C->getType())) {
- // Handle scalable splat
- Value *SplatC = C->getSplatValue();
- auto *CI = dyn_cast_or_null<ConstantInt>(SplatC);
- // Bail out if the constant can't be safely incremented/decremented.
- if (!CI || !ConstantIsOk(CI))
- return std::nullopt;
- } else {
- // ConstantExpr?
- return std::nullopt;
- }
-
- // It may not be safe to change a compare predicate in the presence of
- // undefined elements, so replace those elements with the first safe constant
- // that we found.
- // TODO: in case of poison, it is safe; let's replace undefs only.
- if (C->containsUndefOrPoisonElement()) {
- assert(SafeReplacementConstant && "Replacement constant not set");
- C = Constant::replaceUndefsWith(C, SafeReplacementConstant);
- }
-
- CmpInst::Predicate NewPred = CmpInst::getFlippedStrictnessPredicate(Pred);
-
- // Increment or decrement the constant.
- Constant *OneOrNegOne = ConstantInt::get(Type, WillIncrement ? 1 : -1, true);
- Constant *NewC = ConstantExpr::getAdd(C, OneOrNegOne);
-
- return std::make_pair(NewPred, NewC);
-}
-
/// If we have an icmp le or icmp ge instruction with a constant operand, turn
/// it into the appropriate icmp lt or icmp gt instruction. This transform
/// allows them to be folded in visitICmpInst.
@@ -6996,8 +6921,7 @@ static ICmpInst *canonicalizeCmpWithConstant(ICmpInst &I) {
if (!Op1C)
return nullptr;
- auto FlippedStrictness =
- InstCombiner::getFlippedStrictnessPredicateAndConstant(Pred, Op1C);
+ auto FlippedStrictness = getFlippedStrictnessPredicateAndConstant(Pred, Op1C);
if (!FlippedStrictness)
return nullptr;
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 7fd91c72a2fb0e..eca518aa640700 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -1689,8 +1689,7 @@ tryToReuseConstantFromSelectInComparison(SelectInst &Sel, ICmpInst &Cmp,
return nullptr;
// Check the constant we'd have with flipped-strictness predicate.
- auto FlippedStrictness =
- InstCombiner::getFlippedStrictnessPredicateAndConstant(Pred, C0);
+ auto FlippedStrictness = getFlippedStrictnessPredicateAndConstant(Pred, C0);
if (!FlippedStrictness)
return nullptr;
@@ -1970,8 +1969,7 @@ static Value *foldSelectWithConstOpToBinOp(ICmpInst *Cmp, Value *TrueVal,
Value *RHS;
SelectPatternFlavor SPF;
const DataLayout &DL = BOp->getDataLayout();
- auto Flipped =
- InstCombiner::getFlippedStrictnessPredicateAndConstant(Predicate, C1);
+ auto Flipped = getFlippedStrictnessPredicateAndConstant(Predicate, C1);
if (C3 == ConstantFoldBinaryOpOperands(Opcode, C1, C2, DL)) {
SPF = getSelectPattern(Predicate).Flavor;
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/73/builds/11355 Here is the relevant piece of the build log for the reference
|
Needed by #121958.