Skip to content

Commit ed19620

Browse files
authored
[VPlan] Make VPReductionRecipe a VPRecipeWithIRFlags. NFC (llvm#130881)
This patch change the parent of the VPReductionRecipe from VPSingleDefRecipe to VPRecipeWithIRFlags and also print/get/drop/control flags by the VPRecipeWithIRFlags. This will remove the dependency of the underlying instruction. This patch also add a new function `setFastMathFlags()` to the VPRecipeWithIRFlags because the entire reduction chain may contains multiple instructions. And the underlying instruction may not contains the corresponding flags for this reduction. Split from llvm#113903.
1 parent 00cad3e commit ed19620

File tree

3 files changed

+40
-21
lines changed

3 files changed

+40
-21
lines changed

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -711,6 +711,8 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
711711
R->getVPDefID() == VPRecipeBase::VPWidenGEPSC ||
712712
R->getVPDefID() == VPRecipeBase::VPWidenCastSC ||
713713
R->getVPDefID() == VPRecipeBase::VPWidenIntrinsicSC ||
714+
R->getVPDefID() == VPRecipeBase::VPReductionSC ||
715+
R->getVPDefID() == VPRecipeBase::VPReductionEVLSC ||
714716
R->getVPDefID() == VPRecipeBase::VPReplicateSC ||
715717
R->getVPDefID() == VPRecipeBase::VPReverseVectorPointerSC ||
716718
R->getVPDefID() == VPRecipeBase::VPVectorPointerSC;
@@ -2236,7 +2238,7 @@ class VPInterleaveRecipe : public VPRecipeBase {
22362238
/// A recipe to represent inloop reduction operations, performing a reduction on
22372239
/// a vector operand into a scalar value, and adding the result to a chain.
22382240
/// The Operands are {ChainOp, VecOp, [Condition]}.
2239-
class VPReductionRecipe : public VPSingleDefRecipe {
2241+
class VPReductionRecipe : public VPRecipeWithIRFlags {
22402242
/// The recurrence decriptor for the reduction in question.
22412243
const RecurrenceDescriptor &RdxDesc;
22422244
bool IsOrdered;
@@ -2247,12 +2249,17 @@ class VPReductionRecipe : public VPSingleDefRecipe {
22472249
VPReductionRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
22482250
Instruction *I, ArrayRef<VPValue *> Operands,
22492251
VPValue *CondOp, bool IsOrdered, DebugLoc DL)
2250-
: VPSingleDefRecipe(SC, Operands, I, DL), RdxDesc(R),
2251-
IsOrdered(IsOrdered) {
2252+
: VPRecipeWithIRFlags(SC, Operands,
2253+
isa_and_nonnull<FPMathOperator>(I)
2254+
? R.getFastMathFlags()
2255+
: FastMathFlags(),
2256+
DL),
2257+
RdxDesc(R), IsOrdered(IsOrdered) {
22522258
if (CondOp) {
22532259
IsConditional = true;
22542260
addOperand(CondOp);
22552261
}
2262+
setUnderlyingValue(I);
22562263
}
22572264

22582265
public:
@@ -2318,12 +2325,13 @@ class VPReductionRecipe : public VPSingleDefRecipe {
23182325
/// The Operands are {ChainOp, VecOp, EVL, [Condition]}.
23192326
class VPReductionEVLRecipe : public VPReductionRecipe {
23202327
public:
2321-
VPReductionEVLRecipe(VPReductionRecipe &R, VPValue &EVL, VPValue *CondOp)
2328+
VPReductionEVLRecipe(VPReductionRecipe &R, VPValue &EVL, VPValue *CondOp,
2329+
DebugLoc DL = {})
23222330
: VPReductionRecipe(
23232331
VPDef::VPReductionEVLSC, R.getRecurrenceDescriptor(),
23242332
cast_or_null<Instruction>(R.getUnderlyingValue()),
23252333
ArrayRef<VPValue *>({R.getChainOp(), R.getVecOp(), &EVL}), CondOp,
2326-
R.isOrdered(), R.getDebugLoc()) {}
2334+
R.isOrdered(), DL) {}
23272335

23282336
~VPReductionEVLRecipe() override = default;
23292337

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2290,7 +2290,7 @@ void VPReductionRecipe::execute(VPTransformState &State) {
22902290
"In-loop AnyOf reductions aren't currently supported");
22912291
// Propagate the fast-math flags carried by the underlying instruction.
22922292
IRBuilderBase::FastMathFlagGuard FMFGuard(State.Builder);
2293-
State.Builder.setFastMathFlags(RdxDesc.getFastMathFlags());
2293+
State.Builder.setFastMathFlags(getFastMathFlags());
22942294
State.setDebugLocFrom(getDebugLoc());
22952295
Value *NewVecOp = State.get(getVecOp());
22962296
if (VPValue *Cond = getCondOp()) {
@@ -2337,7 +2337,7 @@ void VPReductionEVLRecipe::execute(VPTransformState &State) {
23372337
// Propagate the fast-math flags carried by the underlying instruction.
23382338
IRBuilderBase::FastMathFlagGuard FMFGuard(Builder);
23392339
const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor();
2340-
Builder.setFastMathFlags(RdxDesc.getFastMathFlags());
2340+
Builder.setFastMathFlags(getFastMathFlags());
23412341

23422342
RecurKind Kind = RdxDesc.getRecurrenceKind();
23432343
Value *Prev = State.get(getChainOp(), /*IsScalar*/ true);
@@ -2374,6 +2374,7 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
23742374
Type *ElementTy = Ctx.Types.inferScalarType(this);
23752375
auto *VectorTy = cast<VectorType>(toVectorTy(ElementTy, VF));
23762376
unsigned Opcode = RdxDesc.getOpcode();
2377+
FastMathFlags FMFs = getFastMathFlags();
23772378

23782379
// TODO: Support any-of and in-loop reductions.
23792380
assert(
@@ -2393,12 +2394,12 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
23932394
Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, Ctx.CostKind);
23942395
if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RdxKind)) {
23952396
Intrinsic::ID Id = getMinMaxReductionIntrinsicOp(RdxKind);
2396-
return Cost + Ctx.TTI.getMinMaxReductionCost(
2397-
Id, VectorTy, RdxDesc.getFastMathFlags(), Ctx.CostKind);
2397+
return Cost +
2398+
Ctx.TTI.getMinMaxReductionCost(Id, VectorTy, FMFs, Ctx.CostKind);
23982399
}
23992400

2400-
return Cost + Ctx.TTI.getArithmeticReductionCost(
2401-
Opcode, VectorTy, RdxDesc.getFastMathFlags(), Ctx.CostKind);
2401+
return Cost + Ctx.TTI.getArithmeticReductionCost(Opcode, VectorTy, FMFs,
2402+
Ctx.CostKind);
24022403
}
24032404

24042405
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
@@ -2409,8 +2410,7 @@ void VPReductionRecipe::print(raw_ostream &O, const Twine &Indent,
24092410
O << " = ";
24102411
getChainOp()->printAsOperand(O, SlotTracker);
24112412
O << " +";
2412-
if (isa<FPMathOperator>(getUnderlyingInstr()))
2413-
O << getUnderlyingInstr()->getFastMathFlags();
2413+
printFlags(O);
24142414
O << " reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " (";
24152415
getVecOp()->printAsOperand(O, SlotTracker);
24162416
if (isConditional()) {
@@ -2431,8 +2431,7 @@ void VPReductionEVLRecipe::print(raw_ostream &O, const Twine &Indent,
24312431
O << " = ";
24322432
getChainOp()->printAsOperand(O, SlotTracker);
24332433
O << " +";
2434-
if (isa<FPMathOperator>(getUnderlyingInstr()))
2435-
O << getUnderlyingInstr()->getFastMathFlags();
2434+
printFlags(O);
24362435
O << " vp.reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " (";
24372436
getVecOp()->printAsOperand(O, SlotTracker);
24382437
O << ", ";

llvm/unittests/Transforms/Vectorize/VPlanTest.cpp

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1165,29 +1165,35 @@ TEST_F(VPRecipeTest, MayHaveSideEffectsAndMayReadWriteMemory) {
11651165
}
11661166

11671167
{
1168+
auto *Add = BinaryOperator::CreateAdd(PoisonValue::get(Int32),
1169+
PoisonValue::get(Int32));
11681170
VPValue *ChainOp = Plan.getOrAddLiveIn(ConstantInt::get(Int32, 1));
11691171
VPValue *VecOp = Plan.getOrAddLiveIn(ConstantInt::get(Int32, 2));
11701172
VPValue *CondOp = Plan.getOrAddLiveIn(ConstantInt::get(Int32, 3));
1171-
VPReductionRecipe Recipe(RecurrenceDescriptor(), nullptr, ChainOp, CondOp,
1173+
VPReductionRecipe Recipe(RecurrenceDescriptor(), Add, ChainOp, CondOp,
11721174
VecOp, false);
11731175
EXPECT_FALSE(Recipe.mayHaveSideEffects());
11741176
EXPECT_FALSE(Recipe.mayReadFromMemory());
11751177
EXPECT_FALSE(Recipe.mayWriteToMemory());
11761178
EXPECT_FALSE(Recipe.mayReadOrWriteMemory());
1179+
delete Add;
11771180
}
11781181

11791182
{
1183+
auto *Add = BinaryOperator::CreateAdd(PoisonValue::get(Int32),
1184+
PoisonValue::get(Int32));
11801185
VPValue *ChainOp = Plan.getOrAddLiveIn(ConstantInt::get(Int32, 1));
11811186
VPValue *VecOp = Plan.getOrAddLiveIn(ConstantInt::get(Int32, 2));
11821187
VPValue *CondOp = Plan.getOrAddLiveIn(ConstantInt::get(Int32, 3));
1183-
VPReductionRecipe Recipe(RecurrenceDescriptor(), nullptr, ChainOp, CondOp,
1188+
VPReductionRecipe Recipe(RecurrenceDescriptor(), Add, ChainOp, CondOp,
11841189
VecOp, false);
11851190
VPValue *EVL = Plan.getOrAddLiveIn(ConstantInt::get(Int32, 4));
11861191
VPReductionEVLRecipe EVLRecipe(Recipe, *EVL, CondOp);
11871192
EXPECT_FALSE(EVLRecipe.mayHaveSideEffects());
11881193
EXPECT_FALSE(EVLRecipe.mayReadFromMemory());
11891194
EXPECT_FALSE(EVLRecipe.mayWriteToMemory());
11901195
EXPECT_FALSE(EVLRecipe.mayReadOrWriteMemory());
1196+
delete Add;
11911197
}
11921198

11931199
{
@@ -1529,28 +1535,34 @@ TEST_F(VPRecipeTest, dumpRecipeUnnamedVPValuesNotInPlanOrBlock) {
15291535

15301536
TEST_F(VPRecipeTest, CastVPReductionRecipeToVPUser) {
15311537
IntegerType *Int32 = IntegerType::get(C, 32);
1538+
auto *Add = BinaryOperator::CreateAdd(PoisonValue::get(Int32),
1539+
PoisonValue::get(Int32));
15321540
VPValue *ChainOp = getPlan().getOrAddLiveIn(ConstantInt::get(Int32, 1));
15331541
VPValue *VecOp = getPlan().getOrAddLiveIn(ConstantInt::get(Int32, 2));
15341542
VPValue *CondOp = getPlan().getOrAddLiveIn(ConstantInt::get(Int32, 3));
1535-
VPReductionRecipe Recipe(RecurrenceDescriptor(), nullptr, ChainOp, CondOp,
1536-
VecOp, false);
1543+
VPReductionRecipe Recipe(RecurrenceDescriptor(), Add, ChainOp, CondOp, VecOp,
1544+
false);
15371545
EXPECT_TRUE(isa<VPUser>(&Recipe));
15381546
VPRecipeBase *BaseR = &Recipe;
15391547
EXPECT_TRUE(isa<VPUser>(BaseR));
1548+
delete Add;
15401549
}
15411550

15421551
TEST_F(VPRecipeTest, CastVPReductionEVLRecipeToVPUser) {
15431552
IntegerType *Int32 = IntegerType::get(C, 32);
1553+
auto *Add = BinaryOperator::CreateAdd(PoisonValue::get(Int32),
1554+
PoisonValue::get(Int32));
15441555
VPValue *ChainOp = getPlan().getOrAddLiveIn(ConstantInt::get(Int32, 1));
15451556
VPValue *VecOp = getPlan().getOrAddLiveIn(ConstantInt::get(Int32, 2));
15461557
VPValue *CondOp = getPlan().getOrAddLiveIn(ConstantInt::get(Int32, 3));
1547-
VPReductionRecipe Recipe(RecurrenceDescriptor(), nullptr, ChainOp, CondOp,
1548-
VecOp, false);
1558+
VPReductionRecipe Recipe(RecurrenceDescriptor(), Add, ChainOp, CondOp, VecOp,
1559+
false);
15491560
VPValue *EVL = getPlan().getOrAddLiveIn(ConstantInt::get(Int32, 0));
15501561
VPReductionEVLRecipe EVLRecipe(Recipe, *EVL, CondOp);
15511562
EXPECT_TRUE(isa<VPUser>(&EVLRecipe));
15521563
VPRecipeBase *BaseR = &EVLRecipe;
15531564
EXPECT_TRUE(isa<VPUser>(BaseR));
1565+
delete Add;
15541566
}
15551567
} // namespace
15561568

0 commit comments

Comments
 (0)