Skip to content

Commit 7e73b7a

Browse files
ElvisWang123pawosm-arm
authored andcommitted
[VPlan] Make VPReductionRecipe a VPRecipeWithIRFlags. NFC (llvm#130881)
This patch change the parent of the VPReductionRecipe from VPSingleDefRecipe to VPRecipeWithIRFlags and also print/get/drop/control flags by the VPRecipeWithIRFlags. This will remove the dependency of the underlying instruction. This patch also add a new function `setFastMathFlags()` to the VPRecipeWithIRFlags because the entire reduction chain may contains multiple instructions. And the underlying instruction may not contains the corresponding flags for this reduction. Split from llvm#113903.
1 parent 7bd096b commit 7e73b7a

File tree

3 files changed

+40
-21
lines changed

3 files changed

+40
-21
lines changed

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1059,6 +1059,8 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
10591059
R->getVPDefID() == VPRecipeBase::VPWidenGEPSC ||
10601060
R->getVPDefID() == VPRecipeBase::VPWidenCastSC ||
10611061
R->getVPDefID() == VPRecipeBase::VPWidenIntrinsicSC ||
1062+
R->getVPDefID() == VPRecipeBase::VPReductionSC ||
1063+
R->getVPDefID() == VPRecipeBase::VPReductionEVLSC ||
10621064
R->getVPDefID() == VPRecipeBase::VPReplicateSC ||
10631065
R->getVPDefID() == VPRecipeBase::VPReverseVectorPointerSC ||
10641066
R->getVPDefID() == VPRecipeBase::VPVectorPointerSC;
@@ -2656,7 +2658,7 @@ class VPInterleaveRecipe : public VPRecipeBase {
26562658
/// A recipe to represent inloop reduction operations, performing a reduction on
26572659
/// a vector operand into a scalar value, and adding the result to a chain.
26582660
/// The Operands are {ChainOp, VecOp, [Condition]}.
2659-
class VPReductionRecipe : public VPSingleDefRecipe {
2661+
class VPReductionRecipe : public VPRecipeWithIRFlags {
26602662
/// The recurrence decriptor for the reduction in question.
26612663
const RecurrenceDescriptor &RdxDesc;
26622664
bool IsOrdered;
@@ -2667,12 +2669,17 @@ class VPReductionRecipe : public VPSingleDefRecipe {
26672669
VPReductionRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
26682670
Instruction *I, ArrayRef<VPValue *> Operands,
26692671
VPValue *CondOp, bool IsOrdered, DebugLoc DL)
2670-
: VPSingleDefRecipe(SC, Operands, I, DL), RdxDesc(R),
2671-
IsOrdered(IsOrdered) {
2672+
: VPRecipeWithIRFlags(SC, Operands,
2673+
isa_and_nonnull<FPMathOperator>(I)
2674+
? R.getFastMathFlags()
2675+
: FastMathFlags(),
2676+
DL),
2677+
RdxDesc(R), IsOrdered(IsOrdered) {
26722678
if (CondOp) {
26732679
IsConditional = true;
26742680
addOperand(CondOp);
26752681
}
2682+
setUnderlyingValue(I);
26762683
}
26772684

26782685
public:
@@ -2738,12 +2745,13 @@ class VPReductionRecipe : public VPSingleDefRecipe {
27382745
/// The Operands are {ChainOp, VecOp, EVL, [Condition]}.
27392746
class VPReductionEVLRecipe : public VPReductionRecipe {
27402747
public:
2741-
VPReductionEVLRecipe(VPReductionRecipe &R, VPValue &EVL, VPValue *CondOp)
2748+
VPReductionEVLRecipe(VPReductionRecipe &R, VPValue &EVL, VPValue *CondOp,
2749+
DebugLoc DL = {})
27422750
: VPReductionRecipe(
27432751
VPDef::VPReductionEVLSC, R.getRecurrenceDescriptor(),
27442752
cast_or_null<Instruction>(R.getUnderlyingValue()),
27452753
ArrayRef<VPValue *>({R.getChainOp(), R.getVecOp(), &EVL}), CondOp,
2746-
R.isOrdered(), R.getDebugLoc()) {}
2754+
R.isOrdered(), DL) {}
27472755

27482756
~VPReductionEVLRecipe() override = default;
27492757

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2236,7 +2236,7 @@ void VPReductionRecipe::execute(VPTransformState &State) {
22362236
"In-loop AnyOf reductions aren't currently supported");
22372237
// Propagate the fast-math flags carried by the underlying instruction.
22382238
IRBuilderBase::FastMathFlagGuard FMFGuard(State.Builder);
2239-
State.Builder.setFastMathFlags(RdxDesc.getFastMathFlags());
2239+
State.Builder.setFastMathFlags(getFastMathFlags());
22402240
State.setDebugLocFrom(getDebugLoc());
22412241
Value *NewVecOp = State.get(getVecOp());
22422242
if (VPValue *Cond = getCondOp()) {
@@ -2283,7 +2283,7 @@ void VPReductionEVLRecipe::execute(VPTransformState &State) {
22832283
// Propagate the fast-math flags carried by the underlying instruction.
22842284
IRBuilderBase::FastMathFlagGuard FMFGuard(Builder);
22852285
const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor();
2286-
Builder.setFastMathFlags(RdxDesc.getFastMathFlags());
2286+
Builder.setFastMathFlags(getFastMathFlags());
22872287

22882288
RecurKind Kind = RdxDesc.getRecurrenceKind();
22892289
Value *Prev = State.get(getChainOp(), /*IsScalar*/ true);
@@ -2320,6 +2320,7 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
23202320
Type *ElementTy = Ctx.Types.inferScalarType(this);
23212321
auto *VectorTy = cast<VectorType>(toVectorTy(ElementTy, VF));
23222322
unsigned Opcode = RdxDesc.getOpcode();
2323+
FastMathFlags FMFs = getFastMathFlags();
23232324

23242325
// TODO: Support any-of and in-loop reductions.
23252326
assert(
@@ -2339,12 +2340,12 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
23392340
Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, Ctx.CostKind);
23402341
if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RdxKind)) {
23412342
Intrinsic::ID Id = getMinMaxReductionIntrinsicOp(RdxKind);
2342-
return Cost + Ctx.TTI.getMinMaxReductionCost(
2343-
Id, VectorTy, RdxDesc.getFastMathFlags(), Ctx.CostKind);
2343+
return Cost +
2344+
Ctx.TTI.getMinMaxReductionCost(Id, VectorTy, FMFs, Ctx.CostKind);
23442345
}
23452346

2346-
return Cost + Ctx.TTI.getArithmeticReductionCost(
2347-
Opcode, VectorTy, RdxDesc.getFastMathFlags(), Ctx.CostKind);
2347+
return Cost + Ctx.TTI.getArithmeticReductionCost(Opcode, VectorTy, FMFs,
2348+
Ctx.CostKind);
23482349
}
23492350

23502351
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
@@ -2355,8 +2356,7 @@ void VPReductionRecipe::print(raw_ostream &O, const Twine &Indent,
23552356
O << " = ";
23562357
getChainOp()->printAsOperand(O, SlotTracker);
23572358
O << " +";
2358-
if (isa<FPMathOperator>(getUnderlyingInstr()))
2359-
O << getUnderlyingInstr()->getFastMathFlags();
2359+
printFlags(O);
23602360
O << " reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " (";
23612361
getVecOp()->printAsOperand(O, SlotTracker);
23622362
if (isConditional()) {
@@ -2377,8 +2377,7 @@ void VPReductionEVLRecipe::print(raw_ostream &O, const Twine &Indent,
23772377
O << " = ";
23782378
getChainOp()->printAsOperand(O, SlotTracker);
23792379
O << " +";
2380-
if (isa<FPMathOperator>(getUnderlyingInstr()))
2381-
O << getUnderlyingInstr()->getFastMathFlags();
2380+
printFlags(O);
23822381
O << " vp.reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " (";
23832382
getVecOp()->printAsOperand(O, SlotTracker);
23842383
O << ", ";

llvm/unittests/Transforms/Vectorize/VPlanTest.cpp

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1120,29 +1120,35 @@ TEST_F(VPRecipeTest, MayHaveSideEffectsAndMayReadWriteMemory) {
11201120
}
11211121

11221122
{
1123+
auto *Add = BinaryOperator::CreateAdd(PoisonValue::get(Int32),
1124+
PoisonValue::get(Int32));
11231125
VPValue *ChainOp = Plan.getOrAddLiveIn(ConstantInt::get(Int32, 1));
11241126
VPValue *VecOp = Plan.getOrAddLiveIn(ConstantInt::get(Int32, 2));
11251127
VPValue *CondOp = Plan.getOrAddLiveIn(ConstantInt::get(Int32, 3));
1126-
VPReductionRecipe Recipe(RecurrenceDescriptor(), nullptr, ChainOp, CondOp,
1128+
VPReductionRecipe Recipe(RecurrenceDescriptor(), Add, ChainOp, CondOp,
11271129
VecOp, false);
11281130
EXPECT_FALSE(Recipe.mayHaveSideEffects());
11291131
EXPECT_FALSE(Recipe.mayReadFromMemory());
11301132
EXPECT_FALSE(Recipe.mayWriteToMemory());
11311133
EXPECT_FALSE(Recipe.mayReadOrWriteMemory());
1134+
delete Add;
11321135
}
11331136

11341137
{
1138+
auto *Add = BinaryOperator::CreateAdd(PoisonValue::get(Int32),
1139+
PoisonValue::get(Int32));
11351140
VPValue *ChainOp = Plan.getOrAddLiveIn(ConstantInt::get(Int32, 1));
11361141
VPValue *VecOp = Plan.getOrAddLiveIn(ConstantInt::get(Int32, 2));
11371142
VPValue *CondOp = Plan.getOrAddLiveIn(ConstantInt::get(Int32, 3));
1138-
VPReductionRecipe Recipe(RecurrenceDescriptor(), nullptr, ChainOp, CondOp,
1143+
VPReductionRecipe Recipe(RecurrenceDescriptor(), Add, ChainOp, CondOp,
11391144
VecOp, false);
11401145
VPValue *EVL = Plan.getOrAddLiveIn(ConstantInt::get(Int32, 4));
11411146
VPReductionEVLRecipe EVLRecipe(Recipe, *EVL, CondOp);
11421147
EXPECT_FALSE(EVLRecipe.mayHaveSideEffects());
11431148
EXPECT_FALSE(EVLRecipe.mayReadFromMemory());
11441149
EXPECT_FALSE(EVLRecipe.mayWriteToMemory());
11451150
EXPECT_FALSE(EVLRecipe.mayReadOrWriteMemory());
1151+
delete Add;
11461152
}
11471153

11481154
{
@@ -1484,28 +1490,34 @@ TEST_F(VPRecipeTest, dumpRecipeUnnamedVPValuesNotInPlanOrBlock) {
14841490

14851491
TEST_F(VPRecipeTest, CastVPReductionRecipeToVPUser) {
14861492
IntegerType *Int32 = IntegerType::get(C, 32);
1493+
auto *Add = BinaryOperator::CreateAdd(PoisonValue::get(Int32),
1494+
PoisonValue::get(Int32));
14871495
VPValue *ChainOp = getPlan().getOrAddLiveIn(ConstantInt::get(Int32, 1));
14881496
VPValue *VecOp = getPlan().getOrAddLiveIn(ConstantInt::get(Int32, 2));
14891497
VPValue *CondOp = getPlan().getOrAddLiveIn(ConstantInt::get(Int32, 3));
1490-
VPReductionRecipe Recipe(RecurrenceDescriptor(), nullptr, ChainOp, CondOp,
1491-
VecOp, false);
1498+
VPReductionRecipe Recipe(RecurrenceDescriptor(), Add, ChainOp, CondOp, VecOp,
1499+
false);
14921500
EXPECT_TRUE(isa<VPUser>(&Recipe));
14931501
VPRecipeBase *BaseR = &Recipe;
14941502
EXPECT_TRUE(isa<VPUser>(BaseR));
1503+
delete Add;
14951504
}
14961505

14971506
TEST_F(VPRecipeTest, CastVPReductionEVLRecipeToVPUser) {
14981507
IntegerType *Int32 = IntegerType::get(C, 32);
1508+
auto *Add = BinaryOperator::CreateAdd(PoisonValue::get(Int32),
1509+
PoisonValue::get(Int32));
14991510
VPValue *ChainOp = getPlan().getOrAddLiveIn(ConstantInt::get(Int32, 1));
15001511
VPValue *VecOp = getPlan().getOrAddLiveIn(ConstantInt::get(Int32, 2));
15011512
VPValue *CondOp = getPlan().getOrAddLiveIn(ConstantInt::get(Int32, 3));
1502-
VPReductionRecipe Recipe(RecurrenceDescriptor(), nullptr, ChainOp, CondOp,
1503-
VecOp, false);
1513+
VPReductionRecipe Recipe(RecurrenceDescriptor(), Add, ChainOp, CondOp, VecOp,
1514+
false);
15041515
VPValue *EVL = getPlan().getOrAddLiveIn(ConstantInt::get(Int32, 0));
15051516
VPReductionEVLRecipe EVLRecipe(Recipe, *EVL, CondOp);
15061517
EXPECT_TRUE(isa<VPUser>(&EVLRecipe));
15071518
VPRecipeBase *BaseR = &EVLRecipe;
15081519
EXPECT_TRUE(isa<VPUser>(BaseR));
1520+
delete Add;
15091521
}
15101522
} // namespace
15111523

0 commit comments

Comments
 (0)