Skip to content

[VPlan] Make VPReductionRecipe a VPRecipeWithIRFlags. NFC #130881

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions llvm/lib/Transforms/Vectorize/VPlan.h
Original file line number Diff line number Diff line change
Expand Up @@ -711,6 +711,8 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
R->getVPDefID() == VPRecipeBase::VPWidenGEPSC ||
R->getVPDefID() == VPRecipeBase::VPWidenCastSC ||
R->getVPDefID() == VPRecipeBase::VPWidenIntrinsicSC ||
R->getVPDefID() == VPRecipeBase::VPReductionSC ||
R->getVPDefID() == VPRecipeBase::VPReductionEVLSC ||
R->getVPDefID() == VPRecipeBase::VPReplicateSC ||
R->getVPDefID() == VPRecipeBase::VPReverseVectorPointerSC ||
R->getVPDefID() == VPRecipeBase::VPVectorPointerSC;
Expand Down Expand Up @@ -2236,7 +2238,7 @@ class VPInterleaveRecipe : public VPRecipeBase {
/// A recipe to represent inloop reduction operations, performing a reduction on
/// a vector operand into a scalar value, and adding the result to a chain.
/// The Operands are {ChainOp, VecOp, [Condition]}.
class VPReductionRecipe : public VPSingleDefRecipe {
class VPReductionRecipe : public VPRecipeWithIRFlags {
/// The recurrence decriptor for the reduction in question.
const RecurrenceDescriptor &RdxDesc;
bool IsOrdered;
Expand All @@ -2247,12 +2249,17 @@ class VPReductionRecipe : public VPSingleDefRecipe {
VPReductionRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
Instruction *I, ArrayRef<VPValue *> Operands,
VPValue *CondOp, bool IsOrdered, DebugLoc DL)
: VPSingleDefRecipe(SC, Operands, I, DL), RdxDesc(R),
IsOrdered(IsOrdered) {
: VPRecipeWithIRFlags(SC, Operands,
isa_and_nonnull<FPMathOperator>(I)
? R.getFastMathFlags()
: FastMathFlags(),
DL),
RdxDesc(R), IsOrdered(IsOrdered) {
if (CondOp) {
IsConditional = true;
addOperand(CondOp);
}
setUnderlyingValue(I);
}

public:
Expand Down Expand Up @@ -2318,12 +2325,13 @@ class VPReductionRecipe : public VPSingleDefRecipe {
/// The Operands are {ChainOp, VecOp, EVL, [Condition]}.
class VPReductionEVLRecipe : public VPReductionRecipe {
public:
VPReductionEVLRecipe(VPReductionRecipe &R, VPValue &EVL, VPValue *CondOp)
VPReductionEVLRecipe(VPReductionRecipe &R, VPValue &EVL, VPValue *CondOp,
DebugLoc DL = {})
: VPReductionRecipe(
VPDef::VPReductionEVLSC, R.getRecurrenceDescriptor(),
cast_or_null<Instruction>(R.getUnderlyingValue()),
ArrayRef<VPValue *>({R.getChainOp(), R.getVecOp(), &EVL}), CondOp,
R.isOrdered(), R.getDebugLoc()) {}
R.isOrdered(), DL) {}

~VPReductionEVLRecipe() override = default;

Expand Down
19 changes: 9 additions & 10 deletions llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2290,7 +2290,7 @@ void VPReductionRecipe::execute(VPTransformState &State) {
"In-loop AnyOf reductions aren't currently supported");
// Propagate the fast-math flags carried by the underlying instruction.
IRBuilderBase::FastMathFlagGuard FMFGuard(State.Builder);
State.Builder.setFastMathFlags(RdxDesc.getFastMathFlags());
State.Builder.setFastMathFlags(getFastMathFlags());
State.setDebugLocFrom(getDebugLoc());
Value *NewVecOp = State.get(getVecOp());
if (VPValue *Cond = getCondOp()) {
Expand Down Expand Up @@ -2337,7 +2337,7 @@ void VPReductionEVLRecipe::execute(VPTransformState &State) {
// Propagate the fast-math flags carried by the underlying instruction.
IRBuilderBase::FastMathFlagGuard FMFGuard(Builder);
const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor();
Builder.setFastMathFlags(RdxDesc.getFastMathFlags());
Builder.setFastMathFlags(getFastMathFlags());

RecurKind Kind = RdxDesc.getRecurrenceKind();
Value *Prev = State.get(getChainOp(), /*IsScalar*/ true);
Expand Down Expand Up @@ -2374,6 +2374,7 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
Type *ElementTy = Ctx.Types.inferScalarType(this);
auto *VectorTy = cast<VectorType>(toVectorTy(ElementTy, VF));
unsigned Opcode = RdxDesc.getOpcode();
FastMathFlags FMFs = getFastMathFlags();

// TODO: Support any-of and in-loop reductions.
assert(
Expand All @@ -2393,12 +2394,12 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, Ctx.CostKind);
if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RdxKind)) {
Intrinsic::ID Id = getMinMaxReductionIntrinsicOp(RdxKind);
return Cost + Ctx.TTI.getMinMaxReductionCost(
Id, VectorTy, RdxDesc.getFastMathFlags(), Ctx.CostKind);
return Cost +
Ctx.TTI.getMinMaxReductionCost(Id, VectorTy, FMFs, Ctx.CostKind);
}

return Cost + Ctx.TTI.getArithmeticReductionCost(
Opcode, VectorTy, RdxDesc.getFastMathFlags(), Ctx.CostKind);
return Cost + Ctx.TTI.getArithmeticReductionCost(Opcode, VectorTy, FMFs,
Ctx.CostKind);
}

#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
Expand All @@ -2409,8 +2410,7 @@ void VPReductionRecipe::print(raw_ostream &O, const Twine &Indent,
O << " = ";
getChainOp()->printAsOperand(O, SlotTracker);
O << " +";
if (isa<FPMathOperator>(getUnderlyingInstr()))
O << getUnderlyingInstr()->getFastMathFlags();
printFlags(O);
O << " reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " (";
getVecOp()->printAsOperand(O, SlotTracker);
if (isConditional()) {
Expand All @@ -2431,8 +2431,7 @@ void VPReductionEVLRecipe::print(raw_ostream &O, const Twine &Indent,
O << " = ";
getChainOp()->printAsOperand(O, SlotTracker);
O << " +";
if (isa<FPMathOperator>(getUnderlyingInstr()))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not major but if we add back the isa check then we remove the diff for the integer reduction test

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add it back the prevent integer reduction test changes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How can we end up with integer reductions with FMFs set?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From what I understand an integer RecurrenceDescriptor has all fast math flags set:

// Start with all flags set because we will intersect this with the reduction
// flags from all the reduction operations.
FastMathFlags FMF = FastMathFlags::getFast();

So currently today an integer VPReductionRecipe will use a builder with all fast math flags set:

  IRBuilderBase::FastMathFlagGuard FMFGuard(State.Builder);
  State.Builder.setFastMathFlags(RdxDesc.getFastMathFlags());

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, but we shouldn't further propagate incorrect flags to VPReductionRecipe. We should probably check if it is a FP reduction on construction, and if it isn't set empty (fast-math) flags.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense. I think @ElvisWang123 had already done something similar in a previous version of this PR like

     if (isa<FPMathOperator>(I))
       setFastMathFlags(R.getFastMathFlags());

in the constructor. Should that be done as a part of this PR or in a follow up?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be done straight away, I added a suggested edit

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed checks, thanks!

O << getUnderlyingInstr()->getFastMathFlags();
printFlags(O);
O << " vp.reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " (";
getVecOp()->printAsOperand(O, SlotTracker);
O << ", ";
Expand Down
24 changes: 18 additions & 6 deletions llvm/unittests/Transforms/Vectorize/VPlanTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1165,29 +1165,35 @@ TEST_F(VPRecipeTest, MayHaveSideEffectsAndMayReadWriteMemory) {
}

{
auto *Add = BinaryOperator::CreateAdd(PoisonValue::get(Int32),
PoisonValue::get(Int32));
VPValue *ChainOp = Plan.getOrAddLiveIn(ConstantInt::get(Int32, 1));
VPValue *VecOp = Plan.getOrAddLiveIn(ConstantInt::get(Int32, 2));
VPValue *CondOp = Plan.getOrAddLiveIn(ConstantInt::get(Int32, 3));
VPReductionRecipe Recipe(RecurrenceDescriptor(), nullptr, ChainOp, CondOp,
VPReductionRecipe Recipe(RecurrenceDescriptor(), Add, ChainOp, CondOp,
VecOp, false);
EXPECT_FALSE(Recipe.mayHaveSideEffects());
EXPECT_FALSE(Recipe.mayReadFromMemory());
EXPECT_FALSE(Recipe.mayWriteToMemory());
EXPECT_FALSE(Recipe.mayReadOrWriteMemory());
delete Add;
}

{
auto *Add = BinaryOperator::CreateAdd(PoisonValue::get(Int32),
PoisonValue::get(Int32));
VPValue *ChainOp = Plan.getOrAddLiveIn(ConstantInt::get(Int32, 1));
VPValue *VecOp = Plan.getOrAddLiveIn(ConstantInt::get(Int32, 2));
VPValue *CondOp = Plan.getOrAddLiveIn(ConstantInt::get(Int32, 3));
VPReductionRecipe Recipe(RecurrenceDescriptor(), nullptr, ChainOp, CondOp,
VPReductionRecipe Recipe(RecurrenceDescriptor(), Add, ChainOp, CondOp,
VecOp, false);
VPValue *EVL = Plan.getOrAddLiveIn(ConstantInt::get(Int32, 4));
VPReductionEVLRecipe EVLRecipe(Recipe, *EVL, CondOp);
EXPECT_FALSE(EVLRecipe.mayHaveSideEffects());
EXPECT_FALSE(EVLRecipe.mayReadFromMemory());
EXPECT_FALSE(EVLRecipe.mayWriteToMemory());
EXPECT_FALSE(EVLRecipe.mayReadOrWriteMemory());
delete Add;
}

{
Expand Down Expand Up @@ -1529,28 +1535,34 @@ TEST_F(VPRecipeTest, dumpRecipeUnnamedVPValuesNotInPlanOrBlock) {

TEST_F(VPRecipeTest, CastVPReductionRecipeToVPUser) {
IntegerType *Int32 = IntegerType::get(C, 32);
auto *Add = BinaryOperator::CreateAdd(PoisonValue::get(Int32),
PoisonValue::get(Int32));
VPValue *ChainOp = getPlan().getOrAddLiveIn(ConstantInt::get(Int32, 1));
VPValue *VecOp = getPlan().getOrAddLiveIn(ConstantInt::get(Int32, 2));
VPValue *CondOp = getPlan().getOrAddLiveIn(ConstantInt::get(Int32, 3));
VPReductionRecipe Recipe(RecurrenceDescriptor(), nullptr, ChainOp, CondOp,
VecOp, false);
VPReductionRecipe Recipe(RecurrenceDescriptor(), Add, ChainOp, CondOp, VecOp,
false);
EXPECT_TRUE(isa<VPUser>(&Recipe));
VPRecipeBase *BaseR = &Recipe;
EXPECT_TRUE(isa<VPUser>(BaseR));
delete Add;
}

TEST_F(VPRecipeTest, CastVPReductionEVLRecipeToVPUser) {
IntegerType *Int32 = IntegerType::get(C, 32);
auto *Add = BinaryOperator::CreateAdd(PoisonValue::get(Int32),
PoisonValue::get(Int32));
VPValue *ChainOp = getPlan().getOrAddLiveIn(ConstantInt::get(Int32, 1));
VPValue *VecOp = getPlan().getOrAddLiveIn(ConstantInt::get(Int32, 2));
VPValue *CondOp = getPlan().getOrAddLiveIn(ConstantInt::get(Int32, 3));
VPReductionRecipe Recipe(RecurrenceDescriptor(), nullptr, ChainOp, CondOp,
VecOp, false);
VPReductionRecipe Recipe(RecurrenceDescriptor(), Add, ChainOp, CondOp, VecOp,
false);
VPValue *EVL = getPlan().getOrAddLiveIn(ConstantInt::get(Int32, 0));
VPReductionEVLRecipe EVLRecipe(Recipe, *EVL, CondOp);
EXPECT_TRUE(isa<VPUser>(&EVLRecipe));
VPRecipeBase *BaseR = &EVLRecipe;
EXPECT_TRUE(isa<VPUser>(BaseR));
delete Add;
}
} // namespace

Expand Down
Loading