Skip to content

Commit 03e7862

Browse files
authored
[ValueTracking] Move getFlippedStrictnessPredicateAndConstant into ValueTracking. NFC. (#122064)
Needed by #121958.
1 parent 9fc152d commit 03e7862

File tree

5 files changed

+87
-90
lines changed

5 files changed

+87
-90
lines changed

llvm/include/llvm/Analysis/ValueTracking.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1102,6 +1102,13 @@ bool mustExecuteUBIfPoisonOnPathTo(Instruction *Root,
11021102
Instruction *OnPathTo,
11031103
DominatorTree *DT);
11041104

1105+
/// Convert an integer comparison with a constant RHS into an equivalent
1106+
/// form with the strictness flipped predicate. Return the new predicate and
1107+
/// corresponding constant RHS if possible. Otherwise return std::nullopt.
1108+
/// E.g., (icmp sgt X, 0) -> (icmp sle X, 1).
1109+
std::optional<std::pair<CmpPredicate, Constant *>>
1110+
getFlippedStrictnessPredicateAndConstant(CmpPredicate Pred, Constant *C);
1111+
11051112
/// Specific patterns of select instructions we can match.
11061113
enum SelectPatternFlavor {
11071114
SPF_UNKNOWN = 0,

llvm/include/llvm/Transforms/InstCombine/InstCombiner.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -184,12 +184,6 @@ class LLVM_LIBRARY_VISIBILITY InstCombiner {
184184
return ConstantExpr::getSub(C, ConstantInt::get(C->getType(), 1));
185185
}
186186

187-
std::optional<std::pair<
188-
CmpPredicate,
189-
Constant *>> static getFlippedStrictnessPredicateAndConstant(CmpPredicate
190-
Pred,
191-
Constant *C);
192-
193187
static bool shouldAvoidAbsorbingNotIntoSelect(const SelectInst &SI) {
194188
// a ? b : false and a ? true : b are the canonical form of logical and/or.
195189
// This includes !a ? b : false and !a ? true : b. Absorbing the not into

llvm/lib/Analysis/ValueTracking.cpp

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8641,6 +8641,80 @@ SelectPatternResult llvm::getSelectPattern(CmpInst::Predicate Pred,
86418641
}
86428642
}
86438643

8644+
std::optional<std::pair<CmpPredicate, Constant *>>
8645+
llvm::getFlippedStrictnessPredicateAndConstant(CmpPredicate Pred, Constant *C) {
8646+
assert(ICmpInst::isRelational(Pred) && ICmpInst::isIntPredicate(Pred) &&
8647+
"Only for relational integer predicates.");
8648+
if (isa<UndefValue>(C))
8649+
return std::nullopt;
8650+
8651+
Type *Type = C->getType();
8652+
bool IsSigned = ICmpInst::isSigned(Pred);
8653+
8654+
CmpInst::Predicate UnsignedPred = ICmpInst::getUnsignedPredicate(Pred);
8655+
bool WillIncrement =
8656+
UnsignedPred == ICmpInst::ICMP_ULE || UnsignedPred == ICmpInst::ICMP_UGT;
8657+
8658+
// Check if the constant operand can be safely incremented/decremented
8659+
// without overflowing/underflowing.
8660+
auto ConstantIsOk = [WillIncrement, IsSigned](ConstantInt *C) {
8661+
return WillIncrement ? !C->isMaxValue(IsSigned) : !C->isMinValue(IsSigned);
8662+
};
8663+
8664+
Constant *SafeReplacementConstant = nullptr;
8665+
if (auto *CI = dyn_cast<ConstantInt>(C)) {
8666+
// Bail out if the constant can't be safely incremented/decremented.
8667+
if (!ConstantIsOk(CI))
8668+
return std::nullopt;
8669+
} else if (auto *FVTy = dyn_cast<FixedVectorType>(Type)) {
8670+
unsigned NumElts = FVTy->getNumElements();
8671+
for (unsigned i = 0; i != NumElts; ++i) {
8672+
Constant *Elt = C->getAggregateElement(i);
8673+
if (!Elt)
8674+
return std::nullopt;
8675+
8676+
if (isa<UndefValue>(Elt))
8677+
continue;
8678+
8679+
// Bail out if we can't determine if this constant is min/max or if we
8680+
// know that this constant is min/max.
8681+
auto *CI = dyn_cast<ConstantInt>(Elt);
8682+
if (!CI || !ConstantIsOk(CI))
8683+
return std::nullopt;
8684+
8685+
if (!SafeReplacementConstant)
8686+
SafeReplacementConstant = CI;
8687+
}
8688+
} else if (isa<VectorType>(C->getType())) {
8689+
// Handle scalable splat
8690+
Value *SplatC = C->getSplatValue();
8691+
auto *CI = dyn_cast_or_null<ConstantInt>(SplatC);
8692+
// Bail out if the constant can't be safely incremented/decremented.
8693+
if (!CI || !ConstantIsOk(CI))
8694+
return std::nullopt;
8695+
} else {
8696+
// ConstantExpr?
8697+
return std::nullopt;
8698+
}
8699+
8700+
// It may not be safe to change a compare predicate in the presence of
8701+
// undefined elements, so replace those elements with the first safe constant
8702+
// that we found.
8703+
// TODO: in case of poison, it is safe; let's replace undefs only.
8704+
if (C->containsUndefOrPoisonElement()) {
8705+
assert(SafeReplacementConstant && "Replacement constant not set");
8706+
C = Constant::replaceUndefsWith(C, SafeReplacementConstant);
8707+
}
8708+
8709+
CmpInst::Predicate NewPred = CmpInst::getFlippedStrictnessPredicate(Pred);
8710+
8711+
// Increment or decrement the constant.
8712+
Constant *OneOrNegOne = ConstantInt::get(Type, WillIncrement ? 1 : -1, true);
8713+
Constant *NewC = ConstantExpr::getAdd(C, OneOrNegOne);
8714+
8715+
return std::make_pair(NewPred, NewC);
8716+
}
8717+
86448718
static SelectPatternResult matchSelectPattern(CmpInst::Predicate Pred,
86458719
FastMathFlags FMF,
86468720
Value *CmpLHS, Value *CmpRHS,

llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp

Lines changed: 4 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -2485,9 +2485,8 @@ Instruction *InstCombinerImpl::foldICmpShlConstant(ICmpInst &Cmp,
24852485
// icmp ule i64 (shl X, 32), 8589934592 ->
24862486
// icmp ule i32 (trunc X, i32), 2 ->
24872487
// icmp ult i32 (trunc X, i32), 3
2488-
if (auto FlippedStrictness =
2489-
InstCombiner::getFlippedStrictnessPredicateAndConstant(
2490-
Pred, ConstantInt::get(ShType->getContext(), C))) {
2488+
if (auto FlippedStrictness = getFlippedStrictnessPredicateAndConstant(
2489+
Pred, ConstantInt::get(ShType->getContext(), C))) {
24912490
CmpPred = FlippedStrictness->first;
24922491
RHSC = cast<ConstantInt>(FlippedStrictness->second)->getValue();
24932492
}
@@ -3280,8 +3279,7 @@ bool InstCombinerImpl::matchThreeWayIntCompare(SelectInst *SI, Value *&LHS,
32803279
if (PredB == ICmpInst::ICMP_SGT && isa<Constant>(RHS2)) {
32813280
// x sgt C-1 <--> x sge C <--> not(x slt C)
32823281
auto FlippedStrictness =
3283-
InstCombiner::getFlippedStrictnessPredicateAndConstant(
3284-
PredB, cast<Constant>(RHS2));
3282+
getFlippedStrictnessPredicateAndConstant(PredB, cast<Constant>(RHS2));
32853283
if (!FlippedStrictness)
32863284
return false;
32873285
assert(FlippedStrictness->first == ICmpInst::ICMP_SGE &&
@@ -6908,79 +6906,6 @@ Instruction *InstCombinerImpl::foldICmpUsingBoolRange(ICmpInst &I) {
69086906
return nullptr;
69096907
}
69106908

6911-
std::optional<std::pair<CmpPredicate, Constant *>>
6912-
InstCombiner::getFlippedStrictnessPredicateAndConstant(CmpPredicate Pred,
6913-
Constant *C) {
6914-
assert(ICmpInst::isRelational(Pred) && ICmpInst::isIntPredicate(Pred) &&
6915-
"Only for relational integer predicates.");
6916-
6917-
Type *Type = C->getType();
6918-
bool IsSigned = ICmpInst::isSigned(Pred);
6919-
6920-
CmpInst::Predicate UnsignedPred = ICmpInst::getUnsignedPredicate(Pred);
6921-
bool WillIncrement =
6922-
UnsignedPred == ICmpInst::ICMP_ULE || UnsignedPred == ICmpInst::ICMP_UGT;
6923-
6924-
// Check if the constant operand can be safely incremented/decremented
6925-
// without overflowing/underflowing.
6926-
auto ConstantIsOk = [WillIncrement, IsSigned](ConstantInt *C) {
6927-
return WillIncrement ? !C->isMaxValue(IsSigned) : !C->isMinValue(IsSigned);
6928-
};
6929-
6930-
Constant *SafeReplacementConstant = nullptr;
6931-
if (auto *CI = dyn_cast<ConstantInt>(C)) {
6932-
// Bail out if the constant can't be safely incremented/decremented.
6933-
if (!ConstantIsOk(CI))
6934-
return std::nullopt;
6935-
} else if (auto *FVTy = dyn_cast<FixedVectorType>(Type)) {
6936-
unsigned NumElts = FVTy->getNumElements();
6937-
for (unsigned i = 0; i != NumElts; ++i) {
6938-
Constant *Elt = C->getAggregateElement(i);
6939-
if (!Elt)
6940-
return std::nullopt;
6941-
6942-
if (isa<UndefValue>(Elt))
6943-
continue;
6944-
6945-
// Bail out if we can't determine if this constant is min/max or if we
6946-
// know that this constant is min/max.
6947-
auto *CI = dyn_cast<ConstantInt>(Elt);
6948-
if (!CI || !ConstantIsOk(CI))
6949-
return std::nullopt;
6950-
6951-
if (!SafeReplacementConstant)
6952-
SafeReplacementConstant = CI;
6953-
}
6954-
} else if (isa<VectorType>(C->getType())) {
6955-
// Handle scalable splat
6956-
Value *SplatC = C->getSplatValue();
6957-
auto *CI = dyn_cast_or_null<ConstantInt>(SplatC);
6958-
// Bail out if the constant can't be safely incremented/decremented.
6959-
if (!CI || !ConstantIsOk(CI))
6960-
return std::nullopt;
6961-
} else {
6962-
// ConstantExpr?
6963-
return std::nullopt;
6964-
}
6965-
6966-
// It may not be safe to change a compare predicate in the presence of
6967-
// undefined elements, so replace those elements with the first safe constant
6968-
// that we found.
6969-
// TODO: in case of poison, it is safe; let's replace undefs only.
6970-
if (C->containsUndefOrPoisonElement()) {
6971-
assert(SafeReplacementConstant && "Replacement constant not set");
6972-
C = Constant::replaceUndefsWith(C, SafeReplacementConstant);
6973-
}
6974-
6975-
CmpInst::Predicate NewPred = CmpInst::getFlippedStrictnessPredicate(Pred);
6976-
6977-
// Increment or decrement the constant.
6978-
Constant *OneOrNegOne = ConstantInt::get(Type, WillIncrement ? 1 : -1, true);
6979-
Constant *NewC = ConstantExpr::getAdd(C, OneOrNegOne);
6980-
6981-
return std::make_pair(NewPred, NewC);
6982-
}
6983-
69846909
/// If we have an icmp le or icmp ge instruction with a constant operand, turn
69856910
/// it into the appropriate icmp lt or icmp gt instruction. This transform
69866911
/// allows them to be folded in visitICmpInst.
@@ -6996,8 +6921,7 @@ static ICmpInst *canonicalizeCmpWithConstant(ICmpInst &I) {
69966921
if (!Op1C)
69976922
return nullptr;
69986923

6999-
auto FlippedStrictness =
7000-
InstCombiner::getFlippedStrictnessPredicateAndConstant(Pred, Op1C);
6924+
auto FlippedStrictness = getFlippedStrictnessPredicateAndConstant(Pred, Op1C);
70016925
if (!FlippedStrictness)
70026926
return nullptr;
70036927

llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1689,8 +1689,7 @@ tryToReuseConstantFromSelectInComparison(SelectInst &Sel, ICmpInst &Cmp,
16891689
return nullptr;
16901690

16911691
// Check the constant we'd have with flipped-strictness predicate.
1692-
auto FlippedStrictness =
1693-
InstCombiner::getFlippedStrictnessPredicateAndConstant(Pred, C0);
1692+
auto FlippedStrictness = getFlippedStrictnessPredicateAndConstant(Pred, C0);
16941693
if (!FlippedStrictness)
16951694
return nullptr;
16961695

@@ -1970,8 +1969,7 @@ static Value *foldSelectWithConstOpToBinOp(ICmpInst *Cmp, Value *TrueVal,
19701969
Value *RHS;
19711970
SelectPatternFlavor SPF;
19721971
const DataLayout &DL = BOp->getDataLayout();
1973-
auto Flipped =
1974-
InstCombiner::getFlippedStrictnessPredicateAndConstant(Predicate, C1);
1972+
auto Flipped = getFlippedStrictnessPredicateAndConstant(Predicate, C1);
19751973

19761974
if (C3 == ConstantFoldBinaryOpOperands(Opcode, C1, C2, DL)) {
19771975
SPF = getSelectPattern(Predicate).Flavor;

0 commit comments

Comments
 (0)