Skip to content

[IR][PatternMatch] Only accept poison in getSplatValue() #89159

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions llvm/include/llvm/IR/Constant.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,9 @@ class Constant : public User {
Constant *getAggregateElement(Constant *Elt) const;

/// If all elements of the vector constant have the same value, return that
/// value. Otherwise, return nullptr. Ignore undefined elements by setting
/// AllowUndefs to true.
Constant *getSplatValue(bool AllowUndefs = false) const;
/// value. Otherwise, return nullptr. Ignore poison elements by setting
/// AllowPoison to true.
Constant *getSplatValue(bool AllowPoison = false) const;

/// If C is a constant integer then return its value, otherwise C must be a
/// vector of constant integers, all equal, and the common value is returned.
Expand Down
6 changes: 3 additions & 3 deletions llvm/include/llvm/IR/Constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -532,9 +532,9 @@ class ConstantVector final : public ConstantAggregate {
}

/// If all elements of the vector constant have the same value, return that
/// value. Otherwise, return nullptr. Ignore undefined elements by setting
/// AllowUndefs to true.
Constant *getSplatValue(bool AllowUndefs = false) const;
/// value. Otherwise, return nullptr. Ignore poison elements by setting
/// AllowPoison to true.
Constant *getSplatValue(bool AllowPoison = false) const;

/// Methods for support type inquiry through isa, cast, and dyn_cast:
static bool classof(const Value *V) {
Expand Down
66 changes: 33 additions & 33 deletions llvm/include/llvm/IR/PatternMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -243,10 +243,10 @@ inline match_combine_and<LTy, RTy> m_CombineAnd(const LTy &L, const RTy &R) {

struct apint_match {
const APInt *&Res;
bool AllowUndef;
bool AllowPoison;

apint_match(const APInt *&Res, bool AllowUndef)
: Res(Res), AllowUndef(AllowUndef) {}
apint_match(const APInt *&Res, bool AllowPoison)
: Res(Res), AllowPoison(AllowPoison) {}

template <typename ITy> bool match(ITy *V) {
if (auto *CI = dyn_cast<ConstantInt>(V)) {
Expand All @@ -256,7 +256,7 @@ struct apint_match {
if (V->getType()->isVectorTy())
if (const auto *C = dyn_cast<Constant>(V))
if (auto *CI =
dyn_cast_or_null<ConstantInt>(C->getSplatValue(AllowUndef))) {
dyn_cast_or_null<ConstantInt>(C->getSplatValue(AllowPoison))) {
Res = &CI->getValue();
return true;
}
Expand All @@ -268,10 +268,10 @@ struct apint_match {
// function for both apint/apfloat.
struct apfloat_match {
const APFloat *&Res;
bool AllowUndef;
bool AllowPoison;

apfloat_match(const APFloat *&Res, bool AllowUndef)
: Res(Res), AllowUndef(AllowUndef) {}
apfloat_match(const APFloat *&Res, bool AllowPoison)
: Res(Res), AllowPoison(AllowPoison) {}

template <typename ITy> bool match(ITy *V) {
if (auto *CI = dyn_cast<ConstantFP>(V)) {
Expand All @@ -281,7 +281,7 @@ struct apfloat_match {
if (V->getType()->isVectorTy())
if (const auto *C = dyn_cast<Constant>(V))
if (auto *CI =
dyn_cast_or_null<ConstantFP>(C->getSplatValue(AllowUndef))) {
dyn_cast_or_null<ConstantFP>(C->getSplatValue(AllowPoison))) {
Res = &CI->getValueAPF();
return true;
}
Expand All @@ -292,35 +292,35 @@ struct apfloat_match {
/// Match a ConstantInt or splatted ConstantVector, binding the
/// specified pointer to the contained APInt.
inline apint_match m_APInt(const APInt *&Res) {
// Forbid undefs by default to maintain previous behavior.
return apint_match(Res, /* AllowUndef */ false);
// Forbid poison by default to maintain previous behavior.
return apint_match(Res, /* AllowPoison */ false);
}

/// Match APInt while allowing undefs in splat vector constants.
inline apint_match m_APIntAllowUndef(const APInt *&Res) {
return apint_match(Res, /* AllowUndef */ true);
/// Match APInt while allowing poison in splat vector constants.
inline apint_match m_APIntAllowPoison(const APInt *&Res) {
return apint_match(Res, /* AllowPoison */ true);
}

/// Match APInt while forbidding undefs in splat vector constants.
inline apint_match m_APIntForbidUndef(const APInt *&Res) {
return apint_match(Res, /* AllowUndef */ false);
/// Match APInt while forbidding poison in splat vector constants.
inline apint_match m_APIntForbidPoison(const APInt *&Res) {
return apint_match(Res, /* AllowPoison */ false);
}

/// Match a ConstantFP or splatted ConstantVector, binding the
/// specified pointer to the contained APFloat.
inline apfloat_match m_APFloat(const APFloat *&Res) {
// Forbid undefs by default to maintain previous behavior.
return apfloat_match(Res, /* AllowUndef */ false);
return apfloat_match(Res, /* AllowPoison */ false);
}

/// Match APFloat while allowing undefs in splat vector constants.
inline apfloat_match m_APFloatAllowUndef(const APFloat *&Res) {
return apfloat_match(Res, /* AllowUndef */ true);
/// Match APFloat while allowing poison in splat vector constants.
inline apfloat_match m_APFloatAllowPoison(const APFloat *&Res) {
return apfloat_match(Res, /* AllowPoison */ true);
}

/// Match APFloat while forbidding undefs in splat vector constants.
inline apfloat_match m_APFloatForbidUndef(const APFloat *&Res) {
return apfloat_match(Res, /* AllowUndef */ false);
/// Match APFloat while forbidding poison in splat vector constants.
inline apfloat_match m_APFloatForbidPoison(const APFloat *&Res) {
return apfloat_match(Res, /* AllowPoison */ false);
}

template <int64_t Val> struct constantint_match {
Expand Down Expand Up @@ -418,7 +418,7 @@ template <typename Predicate> struct api_pred_ty : public Predicate {

/// This helper class is used to match scalar and vector constants that
/// satisfy a specified predicate, and bind them to an APFloat.
/// Undefs are allowed in splat vector constants.
/// Poison is allowed in splat vector constants.
template <typename Predicate> struct apf_pred_ty : public Predicate {
const APFloat *&Res;

Expand All @@ -433,7 +433,7 @@ template <typename Predicate> struct apf_pred_ty : public Predicate {
if (V->getType()->isVectorTy())
if (const auto *C = dyn_cast<Constant>(V))
if (auto *CI = dyn_cast_or_null<ConstantFP>(
C->getSplatValue(/* AllowUndef */ true)))
C->getSplatValue(/* AllowPoison */ true)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we default the corresponding api_pred_ty to true?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so, but not as part of this change.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've been looking into this. From a very quick inspection of all relevant matchers, I think only the usage in foldSelectICmpAndBinOp is problematic if poison is allowed.

However, our test coverage seems to be pretty bad -- allowing poison in api_pred_ty only results in test changes for a single transform :( At the same time, I don't really want to spend time adding poison vector tests for dozens of affected transforms...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The usage in foldSelectICmpAndBinOp is fine after all. I submitted #89188 to get the discussion going on this, not sure whether our current test coverage will be a blocker for this or not.

if (this->isValue(CI->getValue())) {
Res = &CI->getValue();
return true;
Expand Down Expand Up @@ -883,7 +883,7 @@ struct bind_const_intval_ty {

/// Match a specified integer value or vector of all elements of that
/// value.
template <bool AllowUndefs> struct specific_intval {
template <bool AllowPoison> struct specific_intval {
const APInt &Val;

specific_intval(const APInt &V) : Val(V) {}
Expand All @@ -892,13 +892,13 @@ template <bool AllowUndefs> struct specific_intval {
const auto *CI = dyn_cast<ConstantInt>(V);
if (!CI && V->getType()->isVectorTy())
if (const auto *C = dyn_cast<Constant>(V))
CI = dyn_cast_or_null<ConstantInt>(C->getSplatValue(AllowUndefs));
CI = dyn_cast_or_null<ConstantInt>(C->getSplatValue(AllowPoison));

return CI && APInt::isSameValue(CI->getValue(), Val);
}
};

template <bool AllowUndefs> struct specific_intval64 {
template <bool AllowPoison> struct specific_intval64 {
uint64_t Val;

specific_intval64(uint64_t V) : Val(V) {}
Expand All @@ -907,7 +907,7 @@ template <bool AllowUndefs> struct specific_intval64 {
const auto *CI = dyn_cast<ConstantInt>(V);
if (!CI && V->getType()->isVectorTy())
if (const auto *C = dyn_cast<Constant>(V))
CI = dyn_cast_or_null<ConstantInt>(C->getSplatValue(AllowUndefs));
CI = dyn_cast_or_null<ConstantInt>(C->getSplatValue(AllowPoison));

return CI && CI->getValue() == Val;
}
Expand All @@ -923,11 +923,11 @@ inline specific_intval64<false> m_SpecificInt(uint64_t V) {
return specific_intval64<false>(V);
}

inline specific_intval<true> m_SpecificIntAllowUndef(const APInt &V) {
inline specific_intval<true> m_SpecificIntAllowPoison(const APInt &V) {
return specific_intval<true>(V);
}

inline specific_intval64<true> m_SpecificIntAllowUndef(uint64_t V) {
inline specific_intval64<true> m_SpecificIntAllowPoison(uint64_t V) {
return specific_intval64<true>(V);
}

Expand Down Expand Up @@ -1699,9 +1699,9 @@ struct m_SpecificMask {
bool match(ArrayRef<int> Mask) { return MaskRef == Mask; }
};

struct m_SplatOrUndefMask {
struct m_SplatOrPoisonMask {
int &SplatIndex;
m_SplatOrUndefMask(int &SplatIndex) : SplatIndex(SplatIndex) {}
m_SplatOrPoisonMask(int &SplatIndex) : SplatIndex(SplatIndex) {}
bool match(ArrayRef<int> Mask) {
const auto *First = find_if(Mask, [](int Elem) { return Elem != -1; });
if (First == Mask.end())
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Analysis/CmpInstAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ bool llvm::decomposeBitTestICmp(Value *LHS, Value *RHS,
using namespace PatternMatch;

const APInt *C;
if (!match(RHS, m_APIntAllowUndef(C)))
if (!match(RHS, m_APIntAllowPoison(C)))
return false;

switch (Pred) {
Expand Down
14 changes: 7 additions & 7 deletions llvm/lib/Analysis/InstructionSimplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3023,7 +3023,7 @@ static Value *simplifyICmpWithConstant(CmpInst::Predicate Pred, Value *LHS,

Value *X;
const APInt *C;
if (!match(RHS, m_APIntAllowUndef(C)))
if (!match(RHS, m_APIntAllowPoison(C)))
return nullptr;

// Sign-bit checks can be optimized to true/false after unsigned
Expand Down Expand Up @@ -3056,9 +3056,9 @@ static Value *simplifyICmpWithConstant(CmpInst::Predicate Pred, Value *LHS,
// (mul nuw/nsw X, MulC) == C --> false (if C is not a multiple of MulC)
const APInt *MulC;
if (IIQ.UseInstrInfo && ICmpInst::isEquality(Pred) &&
((match(LHS, m_NUWMul(m_Value(), m_APIntAllowUndef(MulC))) &&
((match(LHS, m_NUWMul(m_Value(), m_APIntAllowPoison(MulC))) &&
*MulC != 0 && C->urem(*MulC) != 0) ||
(match(LHS, m_NSWMul(m_Value(), m_APIntAllowUndef(MulC))) &&
(match(LHS, m_NSWMul(m_Value(), m_APIntAllowPoison(MulC))) &&
*MulC != 0 && C->srem(*MulC) != 0)))
return ConstantInt::get(ITy, Pred == ICmpInst::ICMP_NE);

Expand Down Expand Up @@ -3203,7 +3203,7 @@ static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred,

// (sub C, X) == X, C is odd --> false
// (sub C, X) != X, C is odd --> true
if (match(LBO, m_Sub(m_APIntAllowUndef(C), m_Specific(RHS))) &&
if (match(LBO, m_Sub(m_APIntAllowPoison(C), m_Specific(RHS))) &&
(*C & 1) == 1 && ICmpInst::isEquality(Pred))
return (Pred == ICmpInst::ICMP_EQ) ? getFalse(ITy) : getTrue(ITy);

Expand Down Expand Up @@ -3354,7 +3354,7 @@ static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS,
// (C2 << X) != C --> true
const APInt *C;
if (match(LHS, m_Shl(m_Power2(), m_Value())) &&
match(RHS, m_APIntAllowUndef(C)) && !C->isPowerOf2()) {
match(RHS, m_APIntAllowPoison(C)) && !C->isPowerOf2()) {
// C2 << X can equal zero in some circumstances.
// This simplification might be unsafe if C is zero.
//
Expand Down Expand Up @@ -4105,7 +4105,7 @@ static Value *simplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS,
}

const APFloat *C = nullptr;
match(RHS, m_APFloatAllowUndef(C));
match(RHS, m_APFloatAllowPoison(C));
std::optional<KnownFPClass> FullKnownClassLHS;

// Lazily compute the possible classes for LHS. Avoid computing it twice if
Expand Down Expand Up @@ -6459,7 +6459,7 @@ Value *llvm::simplifyBinaryIntrinsic(Intrinsic::ID IID, Type *ReturnType,
ReturnType, MinMaxIntrinsic::getSaturationPoint(IID, BitWidth));

const APInt *C;
if (match(Op1, m_APIntAllowUndef(C))) {
if (match(Op1, m_APIntAllowPoison(C))) {
// Clamp to limit value. For example:
// umax(i8 %x, i8 255) --> 255
if (*C == MinMaxIntrinsic::getSaturationPoint(IID, BitWidth))
Expand Down
4 changes: 2 additions & 2 deletions llvm/lib/Analysis/ValueTracking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4116,7 +4116,7 @@ std::pair<Value *, FPClassTest> llvm::fcmpToClassTest(FCmpInst::Predicate Pred,
Value *LHS, Value *RHS,
bool LookThroughSrc) {
const APFloat *ConstRHS;
if (!match(RHS, m_APFloatAllowUndef(ConstRHS)))
if (!match(RHS, m_APFloatAllowPoison(ConstRHS)))
return {nullptr, fcAllFlags};

return fcmpToClassTest(Pred, F, LHS, ConstRHS, LookThroughSrc);
Expand Down Expand Up @@ -4517,7 +4517,7 @@ std::tuple<Value *, FPClassTest, FPClassTest>
llvm::fcmpImpliesClass(CmpInst::Predicate Pred, const Function &F, Value *LHS,
Value *RHS, bool LookThroughSrc) {
const APFloat *ConstRHS;
if (!match(RHS, m_APFloatAllowUndef(ConstRHS)))
if (!match(RHS, m_APFloatAllowPoison(ConstRHS)))
return {nullptr, fcAllFlags, fcAllFlags};

// TODO: Just call computeKnownFPClass for RHS to handle non-constants.
Expand Down
14 changes: 7 additions & 7 deletions llvm/lib/IR/Constants.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1696,14 +1696,14 @@ void ConstantVector::destroyConstantImpl() {
getType()->getContext().pImpl->VectorConstants.remove(this);
}

Constant *Constant::getSplatValue(bool AllowUndefs) const {
Constant *Constant::getSplatValue(bool AllowPoison) const {
assert(this->getType()->isVectorTy() && "Only valid for vectors!");
if (isa<ConstantAggregateZero>(this))
return getNullValue(cast<VectorType>(getType())->getElementType());
if (const ConstantDataVector *CV = dyn_cast<ConstantDataVector>(this))
return CV->getSplatValue();
if (const ConstantVector *CV = dyn_cast<ConstantVector>(this))
return CV->getSplatValue(AllowUndefs);
return CV->getSplatValue(AllowPoison);

// Check if this is a constant expression splat of the form returned by
// ConstantVector::getSplat()
Expand All @@ -1728,7 +1728,7 @@ Constant *Constant::getSplatValue(bool AllowUndefs) const {
return nullptr;
}

Constant *ConstantVector::getSplatValue(bool AllowUndefs) const {
Constant *ConstantVector::getSplatValue(bool AllowPoison) const {
// Check out first element.
Constant *Elt = getOperand(0);
// Then make sure all remaining elements point to the same value.
Expand All @@ -1738,15 +1738,15 @@ Constant *ConstantVector::getSplatValue(bool AllowUndefs) const {
continue;

// Strict mode: any mismatch is not a splat.
if (!AllowUndefs)
if (!AllowPoison)
return nullptr;

// Allow undefs mode: ignore undefined elements.
if (isa<UndefValue>(OpC))
// Allow poison mode: ignore poison elements.
if (isa<PoisonValue>(OpC))
continue;

// If we do not have a defined element yet, use the current operand.
if (isa<UndefValue>(Elt))
if (isa<PoisonValue>(Elt))
Elt = OpC;

if (OpC != Elt)
Expand Down
6 changes: 3 additions & 3 deletions llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -906,8 +906,8 @@ bool AMDGPULibCalls::fold_pow(FPMathOperator *FPOp, IRBuilder<> &B,

const APFloat *CF = nullptr;
const APInt *CINT = nullptr;
if (!match(opr1, m_APFloatAllowUndef(CF)))
match(opr1, m_APIntAllowUndef(CINT));
if (!match(opr1, m_APFloatAllowPoison(CF)))
match(opr1, m_APIntAllowPoison(CINT));

// 0x1111111 means that we don't do anything for this call.
int ci_opr1 = (CINT ? (int)CINT->getSExtValue() : 0x1111111);
Expand Down Expand Up @@ -1039,7 +1039,7 @@ bool AMDGPULibCalls::fold_pow(FPMathOperator *FPOp, IRBuilder<> &B,
Constant *cnval = nullptr;
if (getVecSize(FInfo) == 1) {
CF = nullptr;
match(opr0, m_APFloatAllowUndef(CF));
match(opr0, m_APFloatAllowPoison(CF));

if (CF) {
double V = (getArgType(FInfo) == AMDGPULibFunc::F32)
Expand Down
19 changes: 18 additions & 1 deletion llvm/lib/Target/X86/X86FixupVectorConstants.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,23 @@ FunctionPass *llvm::createX86FixupVectorConstants() {
return new X86FixupVectorConstantsPass();
}

/// Normally, we only allow poison in vector splats. However, as this is part
/// of the backend, and working with the DAG representation, which currently
/// only natively represents undef values, we need to accept undefs here.
static Constant *getSplatValueAllowUndef(const ConstantVector *C) {
Constant *Res = nullptr;
for (Value *Op : C->operands()) {
Constant *OpC = cast<Constant>(Op);
if (isa<UndefValue>(OpC))
continue;
if (!Res)
Res = OpC;
else if (Res != OpC)
return nullptr;
}
return Res;
}

// Attempt to extract the full width of bits data from the constant.
static std::optional<APInt> extractConstantBits(const Constant *C) {
unsigned NumBits = C->getType()->getPrimitiveSizeInBits();
Expand All @@ -78,7 +95,7 @@ static std::optional<APInt> extractConstantBits(const Constant *C) {
return CFP->getValue().bitcastToAPInt();

if (auto *CV = dyn_cast<ConstantVector>(C)) {
if (auto *CVSplat = CV->getSplatValue(/*AllowUndefs*/ true)) {
if (auto *CVSplat = getSplatValueAllowUndef(CV)) {
if (std::optional<APInt> Bits = extractConstantBits(CVSplat)) {
assert((NumBits % Bits->getBitWidth()) == 0 && "Illegal splat");
return APInt::getSplat(NumBits, *Bits);
Expand Down
Loading