Skip to content

[InstCombine] Drop poison-generating/UB-implying param attrs after changing operands #115988

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions llvm/include/llvm/IR/InstrTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -1744,6 +1744,9 @@ class CallBase : public Instruction {
paramHasAttr(ArgNo, Attribute::DereferenceableOrNull);
}

/// Drop parameter attributes that may cause this instruction to cause UB.
void dropPoisonGeneratingAndUBImplyingParamAttrs(unsigned ArgNo);

/// Determine if there are is an inalloca argument. Only the last argument can
/// have the inalloca attribute.
bool hasInAllocaArgument() const {
Expand Down
8 changes: 8 additions & 0 deletions llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,14 @@ class LLVM_LIBRARY_VISIBILITY InstCombiner {
return &I;
}

/// Replace operand of a call-like instruction and add old operand to the
/// worklist. Also drop poison generating and UB implying parameter
/// attributes.
Instruction *replaceArgOperand(CallBase &I, unsigned OpNum, Value *V) {
I.dropPoisonGeneratingAndUBImplyingParamAttrs(OpNum);
return replaceOperand(I, OpNum, V);
}

/// Replace use and add the previously used value to the worklist.
void replaceUse(Use &U, Value *NewValue) {
Value *OldOp = U;
Expand Down
11 changes: 11 additions & 0 deletions llvm/lib/IR/Instructions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Twine.h"
#include "llvm/IR/AttributeMask.h"
#include "llvm/IR/Attributes.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Constant.h"
Expand Down Expand Up @@ -330,6 +331,16 @@ unsigned CallBase::getNumSubclassExtraOperandsDynamic() const {
return cast<CallBrInst>(this)->getNumIndirectDests() + 1;
}

void CallBase::dropPoisonGeneratingAndUBImplyingParamAttrs(unsigned ArgNo) {
AttributeMask AM = AttributeFuncs::getUBImplyingAttributes();
// TODO: Add a helper AttributeFuncs::getPoisonGeneratingAttributes
AM.addAttribute(Attribute::NoFPClass);
AM.addAttribute(Attribute::Range);
AM.addAttribute(Attribute::Alignment);
AM.addAttribute(Attribute::NonNull);
removeParamAttrs(ArgNo, AM);
}

bool CallBase::isIndirectCall() const {
const Value *V = getCalledOperand();
if (isa<Function>(V) || isa<Constant>(V))
Expand Down
111 changes: 56 additions & 55 deletions llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ Instruction *InstCombinerImpl::simplifyMaskedStore(IntrinsicInst &II) {
APInt PoisonElts(DemandedElts.getBitWidth(), 0);
if (Value *V = SimplifyDemandedVectorElts(II.getOperand(0), DemandedElts,
PoisonElts))
return replaceOperand(II, 0, V);
return replaceArgOperand(II, 0, V);

return nullptr;
}
Expand Down Expand Up @@ -430,10 +430,10 @@ Instruction *InstCombinerImpl::simplifyMaskedScatter(IntrinsicInst &II) {
APInt PoisonElts(DemandedElts.getBitWidth(), 0);
if (Value *V = SimplifyDemandedVectorElts(II.getOperand(0), DemandedElts,
PoisonElts))
return replaceOperand(II, 0, V);
return replaceArgOperand(II, 0, V);
if (Value *V = SimplifyDemandedVectorElts(II.getOperand(1), DemandedElts,
PoisonElts))
return replaceOperand(II, 1, V);
return replaceArgOperand(II, 1, V);

return nullptr;
}
Expand Down Expand Up @@ -513,11 +513,11 @@ static Instruction *foldCttzCtlz(IntrinsicInst &II, InstCombinerImpl &IC) {
if (IsTZ) {
// cttz(-x) -> cttz(x)
if (match(Op0, m_Neg(m_Value(X))))
return IC.replaceOperand(II, 0, X);
return IC.replaceArgOperand(II, 0, X);

// cttz(-x & x) -> cttz(x)
if (match(Op0, m_c_And(m_Neg(m_Value(X)), m_Deferred(X))))
return IC.replaceOperand(II, 0, X);
return IC.replaceArgOperand(II, 0, X);

// cttz(sext(x)) -> cttz(zext(x))
if (match(Op0, m_OneUse(m_SExt(m_Value(X))))) {
Expand All @@ -541,10 +541,10 @@ static Instruction *foldCttzCtlz(IntrinsicInst &II, InstCombinerImpl &IC) {
Value *Y;
SelectPatternFlavor SPF = matchSelectPattern(Op0, X, Y).Flavor;
if (SPF == SPF_ABS || SPF == SPF_NABS)
return IC.replaceOperand(II, 0, X);
return IC.replaceArgOperand(II, 0, X);

if (match(Op0, m_Intrinsic<Intrinsic::abs>(m_Value(X))))
return IC.replaceOperand(II, 0, X);
return IC.replaceArgOperand(II, 0, X);

// cttz(shl(%const, %val), 1) --> add(cttz(%const, 1), %val)
if (match(Op0, m_Shl(m_ImmConstant(C), m_Value(X))) &&
Expand Down Expand Up @@ -636,13 +636,13 @@ static Instruction *foldCtpop(IntrinsicInst &II, InstCombinerImpl &IC) {
// ctpop(bitreverse(x)) -> ctpop(x)
// ctpop(bswap(x)) -> ctpop(x)
if (match(Op0, m_BitReverse(m_Value(X))) || match(Op0, m_BSwap(m_Value(X))))
return IC.replaceOperand(II, 0, X);
return IC.replaceArgOperand(II, 0, X);

// ctpop(rot(x)) -> ctpop(x)
if ((match(Op0, m_FShl(m_Value(X), m_Value(Y), m_Value())) ||
match(Op0, m_FShr(m_Value(X), m_Value(Y), m_Value()))) &&
X == Y)
return IC.replaceOperand(II, 0, X);
return IC.replaceArgOperand(II, 0, X);

// ctpop(x | -x) -> bitwidth - cttz(x, false)
if (Op0->hasOneUse() &&
Expand Down Expand Up @@ -814,6 +814,15 @@ static CallInst *canonicalizeConstantArg0ToArg1(CallInst &Call) {
if (isa<Constant>(Arg0) && !isa<Constant>(Arg1)) {
Call.setArgOperand(0, Arg1);
Call.setArgOperand(1, Arg0);
auto CallAttr = Call.getAttributes();
auto LHSAttr = CallAttr.getParamAttrs(0);
auto RHSAttr = CallAttr.getParamAttrs(1);
LLVMContext &Ctx = Call.getContext();
Call.setAttributes(
CallAttr.removeAttributesAtIndex(Ctx, 0)
.removeAttributesAtIndex(Ctx, 1)
.addParamAttributes(Ctx, 0, AttrBuilder(Ctx, RHSAttr))
.addParamAttributes(Ctx, 1, AttrBuilder(Ctx, LHSAttr)));
return &Call;
}
return nullptr;
Expand Down Expand Up @@ -929,13 +938,13 @@ Instruction *InstCombinerImpl::foldIntrinsicIsFPClass(IntrinsicInst &II) {
// is.fpclass (fneg x), mask -> is.fpclass x, (fneg mask)

II.setArgOperand(1, ConstantInt::get(Src1->getType(), fneg(Mask)));
return replaceOperand(II, 0, FNegSrc);
return replaceArgOperand(II, 0, FNegSrc);
}

Value *FAbsSrc;
if (match(Src0, m_FAbs(m_Value(FAbsSrc)))) {
II.setArgOperand(1, ConstantInt::get(Src1->getType(), inverse_fabs(Mask)));
return replaceOperand(II, 0, FAbsSrc);
return replaceArgOperand(II, 0, FAbsSrc);
}

if ((OrderedMask == fcInf || OrderedInvertedMask == fcInf) &&
Expand Down Expand Up @@ -1695,8 +1704,8 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {

if (II->isCommutative()) {
if (auto Pair = matchSymmetricPair(II->getOperand(0), II->getOperand(1))) {
replaceOperand(*II, 0, Pair->first);
replaceOperand(*II, 1, Pair->second);
replaceArgOperand(*II, 0, Pair->first);
replaceArgOperand(*II, 1, Pair->second);
return II;
}

Expand Down Expand Up @@ -1733,11 +1742,11 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
// TODO: Copy nsw if it was present on the neg?
Value *X;
if (match(IIOperand, m_Neg(m_Value(X))))
return replaceOperand(*II, 0, X);
return replaceArgOperand(*II, 0, X);
if (match(IIOperand, m_Select(m_Value(), m_Value(X), m_Neg(m_Deferred(X)))))
return replaceOperand(*II, 0, X);
return replaceArgOperand(*II, 0, X);
if (match(IIOperand, m_Select(m_Value(), m_Neg(m_Value(X)), m_Deferred(X))))
return replaceOperand(*II, 0, X);
return replaceArgOperand(*II, 0, X);

Value *Y;
// abs(a * abs(b)) -> abs(a * b)
Expand All @@ -1747,7 +1756,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
bool NSW =
cast<Instruction>(IIOperand)->hasNoSignedWrap() && IntMinIsPoison;
auto *XY = NSW ? Builder.CreateNSWMul(X, Y) : Builder.CreateMul(X, Y);
return replaceOperand(*II, 0, XY);
return replaceArgOperand(*II, 0, XY);
}

if (std::optional<bool> Known =
Expand Down Expand Up @@ -2122,7 +2131,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
match(II->getArgOperand(0), m_FAbs(m_Value(X))) ||
match(II->getArgOperand(0),
m_Intrinsic<Intrinsic::copysign>(m_Value(X), m_Value())))
return replaceOperand(*II, 0, X);
return replaceArgOperand(*II, 0, X);
}
}
break;
Expand Down Expand Up @@ -2152,7 +2161,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
if (!ModuloC)
return nullptr;
if (ModuloC != ShAmtC)
return replaceOperand(*II, 2, ModuloC);
return replaceArgOperand(*II, 2, ModuloC);

assert(match(ConstantFoldCompareInstOperands(ICmpInst::ICMP_UGT, WidthC,
ShAmtC, DL),
Expand Down Expand Up @@ -2234,8 +2243,8 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
// TODO: If InnerMask == Op1, we could copy attributes from inner
// callsite -> outer callsite.
Value *NewMask = Builder.CreateAnd(II->getArgOperand(1), InnerMask);
replaceOperand(CI, 0, InnerPtr);
replaceOperand(CI, 1, NewMask);
replaceArgOperand(CI, 0, InnerPtr);
replaceArgOperand(CI, 1, NewMask);
Changed = true;
}

Expand Down Expand Up @@ -2520,8 +2529,8 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
Value *A, *B;
if (match(II->getArgOperand(0), m_FNeg(m_Value(A))) &&
match(II->getArgOperand(1), m_FNeg(m_Value(B)))) {
replaceOperand(*II, 0, A);
replaceOperand(*II, 1, B);
replaceArgOperand(*II, 0, A);
replaceArgOperand(*II, 1, B);
return II;
}

Expand Down Expand Up @@ -2556,8 +2565,8 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
if (ElementCount::isKnownGT(NegatedCount, OtherCount) &&
ElementCount::isKnownLT(OtherCount, RetCount)) {
Value *InverseOtherOp = Builder.CreateFNeg(OtherOp);
replaceOperand(*II, NegatedOpArg, OpNotNeg);
replaceOperand(*II, OtherOpArg, InverseOtherOp);
replaceArgOperand(*II, NegatedOpArg, OpNotNeg);
replaceArgOperand(*II, OtherOpArg, InverseOtherOp);
return II;
}
// (-A) * B -> -(A * B), if it is cheaper to negate the result
Expand Down Expand Up @@ -2589,16 +2598,16 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
Value *Src2 = II->getArgOperand(2);
Value *X, *Y;
if (match(Src0, m_FNeg(m_Value(X))) && match(Src1, m_FNeg(m_Value(Y)))) {
replaceOperand(*II, 0, X);
replaceOperand(*II, 1, Y);
replaceArgOperand(*II, 0, X);
replaceArgOperand(*II, 1, Y);
return II;
}

// fma fabs(x), fabs(x), z -> fma x, x, z
if (match(Src0, m_FAbs(m_Value(X))) &&
match(Src1, m_FAbs(m_Specific(X)))) {
replaceOperand(*II, 0, X);
replaceOperand(*II, 1, X);
replaceArgOperand(*II, 0, X);
replaceArgOperand(*II, 1, X);
return II;
}

Expand Down Expand Up @@ -2645,7 +2654,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
// copysign Mag, (copysign ?, X) --> copysign Mag, X
Value *X;
if (match(Sign, m_Intrinsic<Intrinsic::copysign>(m_Value(), m_Value(X))))
return replaceOperand(*II, 1, X);
return replaceArgOperand(*II, 1, X);

// Clear sign-bit of constant magnitude:
// copysign -MagC, X --> copysign MagC, X
Expand All @@ -2654,14 +2663,15 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
if (match(Mag, m_APFloat(MagC)) && MagC->isNegative()) {
APFloat PosMagC = *MagC;
PosMagC.clearSign();
return replaceOperand(*II, 0, ConstantFP::get(Mag->getType(), PosMagC));
return replaceArgOperand(*II, 0,
ConstantFP::get(Mag->getType(), PosMagC));
}

// Peek through changes of magnitude's sign-bit. This call rewrites those:
// copysign (fabs X), Sign --> copysign X, Sign
// copysign (fneg X), Sign --> copysign X, Sign
if (match(Mag, m_FAbs(m_Value(X))) || match(Mag, m_FNeg(m_Value(X))))
return replaceOperand(*II, 0, X);
return replaceArgOperand(*II, 0, X);

break;
}
Expand Down Expand Up @@ -2689,10 +2699,10 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
}
// fabs (select Cond, -FVal, FVal) --> fabs FVal
if (match(TVal, m_FNeg(m_Specific(FVal))))
return replaceOperand(*II, 0, FVal);
return replaceArgOperand(*II, 0, FVal);
// fabs (select Cond, TVal, -TVal) --> fabs TVal
if (match(FVal, m_FNeg(m_Specific(TVal))))
return replaceOperand(*II, 0, TVal);
return replaceArgOperand(*II, 0, TVal);
}

Value *Magnitude, *Sign;
Expand Down Expand Up @@ -2731,7 +2741,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
// cos(-x) --> cos(x)
// cos(fabs(x)) --> cos(x)
// cos(copysign(x, y)) --> cos(x)
return replaceOperand(*II, 0, X);
return replaceArgOperand(*II, 0, X);
}
break;
}
Expand Down Expand Up @@ -2774,8 +2784,9 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
// width.
Value *NewExp = Builder.CreateAdd(InnerExp, Exp);
II->setArgOperand(1, NewExp);
II->dropPoisonGeneratingAndUBImplyingParamAttrs(1);
II->setFastMathFlags(InnerFlags); // Or the inner flags.
return replaceOperand(*II, 0, InnerSrc);
return replaceArgOperand(*II, 0, InnerSrc);
}
}

Expand Down Expand Up @@ -3461,11 +3472,8 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
Value *Arg = II->getArgOperand(0);
Value *Vect;

if (Value *NewOp =
simplifyReductionOperand(Arg, /*CanReorderLanes=*/true)) {
replaceUse(II->getOperandUse(0), NewOp);
return II;
}
if (Value *NewOp = simplifyReductionOperand(Arg, /*CanReorderLanes=*/true))
return replaceArgOperand(*II, 0, NewOp);

if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) {
if (auto *FTy = dyn_cast<FixedVectorType>(Vect->getType()))
Expand Down Expand Up @@ -3501,8 +3509,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {

if (Value *NewOp =
simplifyReductionOperand(Arg, /*CanReorderLanes=*/true)) {
replaceUse(II->getOperandUse(0), NewOp);
return II;
return replaceArgOperand(*II, 0, NewOp);
}

if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) {
Expand Down Expand Up @@ -3535,10 +3542,8 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
Value *Vect;

if (Value *NewOp =
simplifyReductionOperand(Arg, /*CanReorderLanes=*/true)) {
replaceUse(II->getOperandUse(0), NewOp);
return II;
}
simplifyReductionOperand(Arg, /*CanReorderLanes=*/true))
return replaceArgOperand(*II, 0, NewOp);

if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) {
if (auto *VTy = dyn_cast<VectorType>(Vect->getType()))
Expand Down Expand Up @@ -3566,8 +3571,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {

if (Value *NewOp =
simplifyReductionOperand(Arg, /*CanReorderLanes=*/true)) {
replaceUse(II->getOperandUse(0), NewOp);
return II;
return replaceArgOperand(*II, 0, NewOp);
}

if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) {
Expand Down Expand Up @@ -3597,8 +3601,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {

if (Value *NewOp =
simplifyReductionOperand(Arg, /*CanReorderLanes=*/true)) {
replaceUse(II->getOperandUse(0), NewOp);
return II;
return replaceArgOperand(*II, 0, NewOp);
}

if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) {
Expand Down Expand Up @@ -3639,8 +3642,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {

if (Value *NewOp =
simplifyReductionOperand(Arg, /*CanReorderLanes=*/true)) {
replaceUse(II->getOperandUse(0), NewOp);
return II;
return replaceArgOperand(*II, 0, NewOp);
}

if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) {
Expand Down Expand Up @@ -3674,8 +3676,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
: 0;
Value *Arg = II->getArgOperand(ArgIdx);
if (Value *NewOp = simplifyReductionOperand(Arg, CanReorderLanes)) {
replaceUse(II->getOperandUse(ArgIdx), NewOp);
return nullptr;
return replaceArgOperand(*II, ArgIdx, NewOp);
}
break;
}
Expand Down
Loading
Loading