Skip to content

Commit 1baa385

Browse files
authored
[IR][PatternMatch] Only accept poison in getSplatValue() (#89159)
In #88217 a large set of matchers was changed to only accept poison values in splats, but not undef values. This is because we now use poison for non-demanded vector elements, and allowing undef can cause correctness issues. This patch covers the remaining matchers by changing the AllowUndef parameter of getSplatValue() to AllowPoison instead. We also carry out corresponding renames in matchers. As a followup, we may want to change the default for things like m_APInt to m_APIntAllowPoison (as this is much less risky when only allowing poison), but this change doesn't do that. There is one caveat here: We have a single place (X86FixupVectorConstants) which does require handling of vector splats with undefs. This is because this works on backend constant pool entries, which currently still use undef instead of poison for non-demanded elements (because SDAG as a whole does not have an explicit poison representation). As it's just the single use, I've open-coded a getSplatValueAllowUndef() helper there, to discourage use in any other places.
1 parent 7ec342b commit 1baa385

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+479
-413
lines changed

llvm/include/llvm/IR/Constant.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -146,9 +146,9 @@ class Constant : public User {
146146
Constant *getAggregateElement(Constant *Elt) const;
147147

148148
/// If all elements of the vector constant have the same value, return that
149-
/// value. Otherwise, return nullptr. Ignore undefined elements by setting
150-
/// AllowUndefs to true.
151-
Constant *getSplatValue(bool AllowUndefs = false) const;
149+
/// value. Otherwise, return nullptr. Ignore poison elements by setting
150+
/// AllowPoison to true.
151+
Constant *getSplatValue(bool AllowPoison = false) const;
152152

153153
/// If C is a constant integer then return its value, otherwise C must be a
154154
/// vector of constant integers, all equal, and the common value is returned.

llvm/include/llvm/IR/Constants.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -532,9 +532,9 @@ class ConstantVector final : public ConstantAggregate {
532532
}
533533

534534
/// If all elements of the vector constant have the same value, return that
535-
/// value. Otherwise, return nullptr. Ignore undefined elements by setting
536-
/// AllowUndefs to true.
537-
Constant *getSplatValue(bool AllowUndefs = false) const;
535+
/// value. Otherwise, return nullptr. Ignore poison elements by setting
536+
/// AllowPoison to true.
537+
Constant *getSplatValue(bool AllowPoison = false) const;
538538

539539
/// Methods for support type inquiry through isa, cast, and dyn_cast:
540540
static bool classof(const Value *V) {

llvm/include/llvm/IR/PatternMatch.h

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -243,10 +243,10 @@ inline match_combine_and<LTy, RTy> m_CombineAnd(const LTy &L, const RTy &R) {
243243

244244
struct apint_match {
245245
const APInt *&Res;
246-
bool AllowUndef;
246+
bool AllowPoison;
247247

248-
apint_match(const APInt *&Res, bool AllowUndef)
249-
: Res(Res), AllowUndef(AllowUndef) {}
248+
apint_match(const APInt *&Res, bool AllowPoison)
249+
: Res(Res), AllowPoison(AllowPoison) {}
250250

251251
template <typename ITy> bool match(ITy *V) {
252252
if (auto *CI = dyn_cast<ConstantInt>(V)) {
@@ -256,7 +256,7 @@ struct apint_match {
256256
if (V->getType()->isVectorTy())
257257
if (const auto *C = dyn_cast<Constant>(V))
258258
if (auto *CI =
259-
dyn_cast_or_null<ConstantInt>(C->getSplatValue(AllowUndef))) {
259+
dyn_cast_or_null<ConstantInt>(C->getSplatValue(AllowPoison))) {
260260
Res = &CI->getValue();
261261
return true;
262262
}
@@ -268,10 +268,10 @@ struct apint_match {
268268
// function for both apint/apfloat.
269269
struct apfloat_match {
270270
const APFloat *&Res;
271-
bool AllowUndef;
271+
bool AllowPoison;
272272

273-
apfloat_match(const APFloat *&Res, bool AllowUndef)
274-
: Res(Res), AllowUndef(AllowUndef) {}
273+
apfloat_match(const APFloat *&Res, bool AllowPoison)
274+
: Res(Res), AllowPoison(AllowPoison) {}
275275

276276
template <typename ITy> bool match(ITy *V) {
277277
if (auto *CI = dyn_cast<ConstantFP>(V)) {
@@ -281,7 +281,7 @@ struct apfloat_match {
281281
if (V->getType()->isVectorTy())
282282
if (const auto *C = dyn_cast<Constant>(V))
283283
if (auto *CI =
284-
dyn_cast_or_null<ConstantFP>(C->getSplatValue(AllowUndef))) {
284+
dyn_cast_or_null<ConstantFP>(C->getSplatValue(AllowPoison))) {
285285
Res = &CI->getValueAPF();
286286
return true;
287287
}
@@ -292,35 +292,35 @@ struct apfloat_match {
292292
/// Match a ConstantInt or splatted ConstantVector, binding the
293293
/// specified pointer to the contained APInt.
294294
inline apint_match m_APInt(const APInt *&Res) {
295-
// Forbid undefs by default to maintain previous behavior.
296-
return apint_match(Res, /* AllowUndef */ false);
295+
// Forbid poison by default to maintain previous behavior.
296+
return apint_match(Res, /* AllowPoison */ false);
297297
}
298298

299-
/// Match APInt while allowing undefs in splat vector constants.
300-
inline apint_match m_APIntAllowUndef(const APInt *&Res) {
301-
return apint_match(Res, /* AllowUndef */ true);
299+
/// Match APInt while allowing poison in splat vector constants.
300+
inline apint_match m_APIntAllowPoison(const APInt *&Res) {
301+
return apint_match(Res, /* AllowPoison */ true);
302302
}
303303

304-
/// Match APInt while forbidding undefs in splat vector constants.
305-
inline apint_match m_APIntForbidUndef(const APInt *&Res) {
306-
return apint_match(Res, /* AllowUndef */ false);
304+
/// Match APInt while forbidding poison in splat vector constants.
305+
inline apint_match m_APIntForbidPoison(const APInt *&Res) {
306+
return apint_match(Res, /* AllowPoison */ false);
307307
}
308308

309309
/// Match a ConstantFP or splatted ConstantVector, binding the
310310
/// specified pointer to the contained APFloat.
311311
inline apfloat_match m_APFloat(const APFloat *&Res) {
312312
// Forbid undefs by default to maintain previous behavior.
313-
return apfloat_match(Res, /* AllowUndef */ false);
313+
return apfloat_match(Res, /* AllowPoison */ false);
314314
}
315315

316-
/// Match APFloat while allowing undefs in splat vector constants.
317-
inline apfloat_match m_APFloatAllowUndef(const APFloat *&Res) {
318-
return apfloat_match(Res, /* AllowUndef */ true);
316+
/// Match APFloat while allowing poison in splat vector constants.
317+
inline apfloat_match m_APFloatAllowPoison(const APFloat *&Res) {
318+
return apfloat_match(Res, /* AllowPoison */ true);
319319
}
320320

321-
/// Match APFloat while forbidding undefs in splat vector constants.
322-
inline apfloat_match m_APFloatForbidUndef(const APFloat *&Res) {
323-
return apfloat_match(Res, /* AllowUndef */ false);
321+
/// Match APFloat while forbidding poison in splat vector constants.
322+
inline apfloat_match m_APFloatForbidPoison(const APFloat *&Res) {
323+
return apfloat_match(Res, /* AllowPoison */ false);
324324
}
325325

326326
template <int64_t Val> struct constantint_match {
@@ -418,7 +418,7 @@ template <typename Predicate> struct api_pred_ty : public Predicate {
418418

419419
/// This helper class is used to match scalar and vector constants that
420420
/// satisfy a specified predicate, and bind them to an APFloat.
421-
/// Undefs are allowed in splat vector constants.
421+
/// Poison is allowed in splat vector constants.
422422
template <typename Predicate> struct apf_pred_ty : public Predicate {
423423
const APFloat *&Res;
424424

@@ -433,7 +433,7 @@ template <typename Predicate> struct apf_pred_ty : public Predicate {
433433
if (V->getType()->isVectorTy())
434434
if (const auto *C = dyn_cast<Constant>(V))
435435
if (auto *CI = dyn_cast_or_null<ConstantFP>(
436-
C->getSplatValue(/* AllowUndef */ true)))
436+
C->getSplatValue(/* AllowPoison */ true)))
437437
if (this->isValue(CI->getValue())) {
438438
Res = &CI->getValue();
439439
return true;
@@ -883,7 +883,7 @@ struct bind_const_intval_ty {
883883

884884
/// Match a specified integer value or vector of all elements of that
885885
/// value.
886-
template <bool AllowUndefs> struct specific_intval {
886+
template <bool AllowPoison> struct specific_intval {
887887
const APInt &Val;
888888

889889
specific_intval(const APInt &V) : Val(V) {}
@@ -892,13 +892,13 @@ template <bool AllowUndefs> struct specific_intval {
892892
const auto *CI = dyn_cast<ConstantInt>(V);
893893
if (!CI && V->getType()->isVectorTy())
894894
if (const auto *C = dyn_cast<Constant>(V))
895-
CI = dyn_cast_or_null<ConstantInt>(C->getSplatValue(AllowUndefs));
895+
CI = dyn_cast_or_null<ConstantInt>(C->getSplatValue(AllowPoison));
896896

897897
return CI && APInt::isSameValue(CI->getValue(), Val);
898898
}
899899
};
900900

901-
template <bool AllowUndefs> struct specific_intval64 {
901+
template <bool AllowPoison> struct specific_intval64 {
902902
uint64_t Val;
903903

904904
specific_intval64(uint64_t V) : Val(V) {}
@@ -907,7 +907,7 @@ template <bool AllowUndefs> struct specific_intval64 {
907907
const auto *CI = dyn_cast<ConstantInt>(V);
908908
if (!CI && V->getType()->isVectorTy())
909909
if (const auto *C = dyn_cast<Constant>(V))
910-
CI = dyn_cast_or_null<ConstantInt>(C->getSplatValue(AllowUndefs));
910+
CI = dyn_cast_or_null<ConstantInt>(C->getSplatValue(AllowPoison));
911911

912912
return CI && CI->getValue() == Val;
913913
}
@@ -923,11 +923,11 @@ inline specific_intval64<false> m_SpecificInt(uint64_t V) {
923923
return specific_intval64<false>(V);
924924
}
925925

926-
inline specific_intval<true> m_SpecificIntAllowUndef(const APInt &V) {
926+
inline specific_intval<true> m_SpecificIntAllowPoison(const APInt &V) {
927927
return specific_intval<true>(V);
928928
}
929929

930-
inline specific_intval64<true> m_SpecificIntAllowUndef(uint64_t V) {
930+
inline specific_intval64<true> m_SpecificIntAllowPoison(uint64_t V) {
931931
return specific_intval64<true>(V);
932932
}
933933

@@ -1699,9 +1699,9 @@ struct m_SpecificMask {
16991699
bool match(ArrayRef<int> Mask) { return MaskRef == Mask; }
17001700
};
17011701

1702-
struct m_SplatOrUndefMask {
1702+
struct m_SplatOrPoisonMask {
17031703
int &SplatIndex;
1704-
m_SplatOrUndefMask(int &SplatIndex) : SplatIndex(SplatIndex) {}
1704+
m_SplatOrPoisonMask(int &SplatIndex) : SplatIndex(SplatIndex) {}
17051705
bool match(ArrayRef<int> Mask) {
17061706
const auto *First = find_if(Mask, [](int Elem) { return Elem != -1; });
17071707
if (First == Mask.end())

llvm/lib/Analysis/CmpInstAnalysis.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ bool llvm::decomposeBitTestICmp(Value *LHS, Value *RHS,
7979
using namespace PatternMatch;
8080

8181
const APInt *C;
82-
if (!match(RHS, m_APIntAllowUndef(C)))
82+
if (!match(RHS, m_APIntAllowPoison(C)))
8383
return false;
8484

8585
switch (Pred) {

llvm/lib/Analysis/InstructionSimplify.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3023,7 +3023,7 @@ static Value *simplifyICmpWithConstant(CmpInst::Predicate Pred, Value *LHS,
30233023

30243024
Value *X;
30253025
const APInt *C;
3026-
if (!match(RHS, m_APIntAllowUndef(C)))
3026+
if (!match(RHS, m_APIntAllowPoison(C)))
30273027
return nullptr;
30283028

30293029
// Sign-bit checks can be optimized to true/false after unsigned
@@ -3056,9 +3056,9 @@ static Value *simplifyICmpWithConstant(CmpInst::Predicate Pred, Value *LHS,
30563056
// (mul nuw/nsw X, MulC) == C --> false (if C is not a multiple of MulC)
30573057
const APInt *MulC;
30583058
if (IIQ.UseInstrInfo && ICmpInst::isEquality(Pred) &&
3059-
((match(LHS, m_NUWMul(m_Value(), m_APIntAllowUndef(MulC))) &&
3059+
((match(LHS, m_NUWMul(m_Value(), m_APIntAllowPoison(MulC))) &&
30603060
*MulC != 0 && C->urem(*MulC) != 0) ||
3061-
(match(LHS, m_NSWMul(m_Value(), m_APIntAllowUndef(MulC))) &&
3061+
(match(LHS, m_NSWMul(m_Value(), m_APIntAllowPoison(MulC))) &&
30623062
*MulC != 0 && C->srem(*MulC) != 0)))
30633063
return ConstantInt::get(ITy, Pred == ICmpInst::ICMP_NE);
30643064

@@ -3203,7 +3203,7 @@ static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred,
32033203

32043204
// (sub C, X) == X, C is odd --> false
32053205
// (sub C, X) != X, C is odd --> true
3206-
if (match(LBO, m_Sub(m_APIntAllowUndef(C), m_Specific(RHS))) &&
3206+
if (match(LBO, m_Sub(m_APIntAllowPoison(C), m_Specific(RHS))) &&
32073207
(*C & 1) == 1 && ICmpInst::isEquality(Pred))
32083208
return (Pred == ICmpInst::ICMP_EQ) ? getFalse(ITy) : getTrue(ITy);
32093209

@@ -3354,7 +3354,7 @@ static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS,
33543354
// (C2 << X) != C --> true
33553355
const APInt *C;
33563356
if (match(LHS, m_Shl(m_Power2(), m_Value())) &&
3357-
match(RHS, m_APIntAllowUndef(C)) && !C->isPowerOf2()) {
3357+
match(RHS, m_APIntAllowPoison(C)) && !C->isPowerOf2()) {
33583358
// C2 << X can equal zero in some circumstances.
33593359
// This simplification might be unsafe if C is zero.
33603360
//
@@ -4105,7 +4105,7 @@ static Value *simplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS,
41054105
}
41064106

41074107
const APFloat *C = nullptr;
4108-
match(RHS, m_APFloatAllowUndef(C));
4108+
match(RHS, m_APFloatAllowPoison(C));
41094109
std::optional<KnownFPClass> FullKnownClassLHS;
41104110

41114111
// Lazily compute the possible classes for LHS. Avoid computing it twice if
@@ -6459,7 +6459,7 @@ Value *llvm::simplifyBinaryIntrinsic(Intrinsic::ID IID, Type *ReturnType,
64596459
ReturnType, MinMaxIntrinsic::getSaturationPoint(IID, BitWidth));
64606460

64616461
const APInt *C;
6462-
if (match(Op1, m_APIntAllowUndef(C))) {
6462+
if (match(Op1, m_APIntAllowPoison(C))) {
64636463
// Clamp to limit value. For example:
64646464
// umax(i8 %x, i8 255) --> 255
64656465
if (*C == MinMaxIntrinsic::getSaturationPoint(IID, BitWidth))

llvm/lib/Analysis/ValueTracking.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4116,7 +4116,7 @@ std::pair<Value *, FPClassTest> llvm::fcmpToClassTest(FCmpInst::Predicate Pred,
41164116
Value *LHS, Value *RHS,
41174117
bool LookThroughSrc) {
41184118
const APFloat *ConstRHS;
4119-
if (!match(RHS, m_APFloatAllowUndef(ConstRHS)))
4119+
if (!match(RHS, m_APFloatAllowPoison(ConstRHS)))
41204120
return {nullptr, fcAllFlags};
41214121

41224122
return fcmpToClassTest(Pred, F, LHS, ConstRHS, LookThroughSrc);
@@ -4517,7 +4517,7 @@ std::tuple<Value *, FPClassTest, FPClassTest>
45174517
llvm::fcmpImpliesClass(CmpInst::Predicate Pred, const Function &F, Value *LHS,
45184518
Value *RHS, bool LookThroughSrc) {
45194519
const APFloat *ConstRHS;
4520-
if (!match(RHS, m_APFloatAllowUndef(ConstRHS)))
4520+
if (!match(RHS, m_APFloatAllowPoison(ConstRHS)))
45214521
return {nullptr, fcAllFlags, fcAllFlags};
45224522

45234523
// TODO: Just call computeKnownFPClass for RHS to handle non-constants.

llvm/lib/IR/Constants.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1696,14 +1696,14 @@ void ConstantVector::destroyConstantImpl() {
16961696
getType()->getContext().pImpl->VectorConstants.remove(this);
16971697
}
16981698

1699-
Constant *Constant::getSplatValue(bool AllowUndefs) const {
1699+
Constant *Constant::getSplatValue(bool AllowPoison) const {
17001700
assert(this->getType()->isVectorTy() && "Only valid for vectors!");
17011701
if (isa<ConstantAggregateZero>(this))
17021702
return getNullValue(cast<VectorType>(getType())->getElementType());
17031703
if (const ConstantDataVector *CV = dyn_cast<ConstantDataVector>(this))
17041704
return CV->getSplatValue();
17051705
if (const ConstantVector *CV = dyn_cast<ConstantVector>(this))
1706-
return CV->getSplatValue(AllowUndefs);
1706+
return CV->getSplatValue(AllowPoison);
17071707

17081708
// Check if this is a constant expression splat of the form returned by
17091709
// ConstantVector::getSplat()
@@ -1728,7 +1728,7 @@ Constant *Constant::getSplatValue(bool AllowUndefs) const {
17281728
return nullptr;
17291729
}
17301730

1731-
Constant *ConstantVector::getSplatValue(bool AllowUndefs) const {
1731+
Constant *ConstantVector::getSplatValue(bool AllowPoison) const {
17321732
// Check out first element.
17331733
Constant *Elt = getOperand(0);
17341734
// Then make sure all remaining elements point to the same value.
@@ -1738,15 +1738,15 @@ Constant *ConstantVector::getSplatValue(bool AllowUndefs) const {
17381738
continue;
17391739

17401740
// Strict mode: any mismatch is not a splat.
1741-
if (!AllowUndefs)
1741+
if (!AllowPoison)
17421742
return nullptr;
17431743

1744-
// Allow undefs mode: ignore undefined elements.
1745-
if (isa<UndefValue>(OpC))
1744+
// Allow poison mode: ignore poison elements.
1745+
if (isa<PoisonValue>(OpC))
17461746
continue;
17471747

17481748
// If we do not have a defined element yet, use the current operand.
1749-
if (isa<UndefValue>(Elt))
1749+
if (isa<PoisonValue>(Elt))
17501750
Elt = OpC;
17511751

17521752
if (OpC != Elt)

llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -906,8 +906,8 @@ bool AMDGPULibCalls::fold_pow(FPMathOperator *FPOp, IRBuilder<> &B,
906906

907907
const APFloat *CF = nullptr;
908908
const APInt *CINT = nullptr;
909-
if (!match(opr1, m_APFloatAllowUndef(CF)))
910-
match(opr1, m_APIntAllowUndef(CINT));
909+
if (!match(opr1, m_APFloatAllowPoison(CF)))
910+
match(opr1, m_APIntAllowPoison(CINT));
911911

912912
// 0x1111111 means that we don't do anything for this call.
913913
int ci_opr1 = (CINT ? (int)CINT->getSExtValue() : 0x1111111);
@@ -1039,7 +1039,7 @@ bool AMDGPULibCalls::fold_pow(FPMathOperator *FPOp, IRBuilder<> &B,
10391039
Constant *cnval = nullptr;
10401040
if (getVecSize(FInfo) == 1) {
10411041
CF = nullptr;
1042-
match(opr0, m_APFloatAllowUndef(CF));
1042+
match(opr0, m_APFloatAllowPoison(CF));
10431043

10441044
if (CF) {
10451045
double V = (getArgType(FInfo) == AMDGPULibFunc::F32)

llvm/lib/Target/X86/X86FixupVectorConstants.cpp

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,23 @@ FunctionPass *llvm::createX86FixupVectorConstants() {
6464
return new X86FixupVectorConstantsPass();
6565
}
6666

67+
/// Normally, we only allow poison in vector splats. However, as this is part
68+
/// of the backend, and working with the DAG representation, which currently
69+
/// only natively represents undef values, we need to accept undefs here.
70+
static Constant *getSplatValueAllowUndef(const ConstantVector *C) {
71+
Constant *Res = nullptr;
72+
for (Value *Op : C->operands()) {
73+
Constant *OpC = cast<Constant>(Op);
74+
if (isa<UndefValue>(OpC))
75+
continue;
76+
if (!Res)
77+
Res = OpC;
78+
else if (Res != OpC)
79+
return nullptr;
80+
}
81+
return Res;
82+
}
83+
6784
// Attempt to extract the full width of bits data from the constant.
6885
static std::optional<APInt> extractConstantBits(const Constant *C) {
6986
unsigned NumBits = C->getType()->getPrimitiveSizeInBits();
@@ -78,7 +95,7 @@ static std::optional<APInt> extractConstantBits(const Constant *C) {
7895
return CFP->getValue().bitcastToAPInt();
7996

8097
if (auto *CV = dyn_cast<ConstantVector>(C)) {
81-
if (auto *CVSplat = CV->getSplatValue(/*AllowUndefs*/ true)) {
98+
if (auto *CVSplat = getSplatValueAllowUndef(CV)) {
8299
if (std::optional<APInt> Bits = extractConstantBits(CVSplat)) {
83100
assert((NumBits % Bits->getBitWidth()) == 0 && "Illegal splat");
84101
return APInt::getSplat(NumBits, *Bits);

0 commit comments

Comments
 (0)