Skip to content

Commit ad13c64

Browse files
committed
[VPlan] Change parent of VPReductionRecipe to VPRecipeWithIRFlags. NFC
This patch change the parent of the VPReductionRecipe from VPSingleDefRecipe to VPRecipeWithIRFlags and also print/get/control flags by the VPRecipeWithIRFlags. This will remove the dependency of the underlying instruction. This patch also add a new function `setFastMathFlags()` to the VPRecipeWithIRFlags because the entire reduction chain may contains multiple instructions. And the underlying instruction may not contains the corresponding flags for this reduction.
1 parent c6198a2 commit ad13c64

File tree

3 files changed

+36
-21
lines changed

3 files changed

+36
-21
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9799,9 +9799,9 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
97999799
if (CM.blockNeedsPredicationForAnyReason(BB))
98009800
CondOp = RecipeBuilder.getBlockInMask(BB);
98019801

9802-
auto *RedRecipe = new VPReductionRecipe(
9803-
RdxDesc, CurrentLinkI, PreviousLink, VecOp, CondOp,
9804-
CM.useOrderedReductions(RdxDesc), CurrentLinkI->getDebugLoc());
9802+
auto *RedRecipe =
9803+
new VPReductionRecipe(RdxDesc, CurrentLinkI, PreviousLink, VecOp,
9804+
CondOp, CM.useOrderedReductions(RdxDesc));
98059805
// Append the recipe to the end of the VPBasicBlock because we need to
98069806
// ensure that it comes after all of it's inputs, including CondOp.
98079807
// Delete CurrentLink as it will be invalid if its operand is replaced

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -714,6 +714,8 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
714714
R->getVPDefID() == VPRecipeBase::VPWidenGEPSC ||
715715
R->getVPDefID() == VPRecipeBase::VPWidenCastSC ||
716716
R->getVPDefID() == VPRecipeBase::VPWidenIntrinsicSC ||
717+
R->getVPDefID() == VPRecipeBase::VPReductionSC ||
718+
R->getVPDefID() == VPRecipeBase::VPReductionEVLSC ||
717719
R->getVPDefID() == VPRecipeBase::VPReplicateSC ||
718720
R->getVPDefID() == VPRecipeBase::VPReverseVectorPointerSC ||
719721
R->getVPDefID() == VPRecipeBase::VPVectorPointerSC;
@@ -789,6 +791,12 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
789791
}
790792
}
791793

794+
/// Set fast-math flags for this recipe.
795+
void setFastMathFlags(FastMathFlags FMFs) {
796+
OpType = OperationType::FPMathOp;
797+
this->FMFs = FMFs;
798+
}
799+
792800
CmpInst::Predicate getPredicate() const {
793801
assert(OpType == OperationType::Cmp &&
794802
"recipe doesn't have a compare predicate");
@@ -2287,7 +2295,7 @@ class VPInterleaveRecipe : public VPRecipeBase {
22872295
/// A recipe to represent inloop reduction operations, performing a reduction on
22882296
/// a vector operand into a scalar value, and adding the result to a chain.
22892297
/// The Operands are {ChainOp, VecOp, [Condition]}.
2290-
class VPReductionRecipe : public VPSingleDefRecipe {
2298+
class VPReductionRecipe : public VPRecipeWithIRFlags {
22912299
/// The recurrence decriptor for the reduction in question.
22922300
const RecurrenceDescriptor &RdxDesc;
22932301
bool IsOrdered;
@@ -2297,29 +2305,32 @@ class VPReductionRecipe : public VPSingleDefRecipe {
22972305
protected:
22982306
VPReductionRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
22992307
Instruction *I, ArrayRef<VPValue *> Operands,
2300-
VPValue *CondOp, bool IsOrdered, DebugLoc DL)
2301-
: VPSingleDefRecipe(SC, Operands, I, DL), RdxDesc(R),
2308+
VPValue *CondOp, bool IsOrdered)
2309+
: VPRecipeWithIRFlags(SC, Operands, *I), RdxDesc(R),
23022310
IsOrdered(IsOrdered) {
23032311
if (CondOp) {
23042312
IsConditional = true;
23052313
addOperand(CondOp);
23062314
}
2315+
// The inloop reduction may across multiple scalar instruction and the
2316+
// underlying instruction may not contains the corresponding flags. Set the
2317+
// flags explicit from the redurrence descriptor.
2318+
setFastMathFlags(R.getFastMathFlags());
23072319
}
23082320

23092321
public:
23102322
VPReductionRecipe(const RecurrenceDescriptor &R, Instruction *I,
23112323
VPValue *ChainOp, VPValue *VecOp, VPValue *CondOp,
2312-
bool IsOrdered, DebugLoc DL = {})
2324+
bool IsOrdered)
23132325
: VPReductionRecipe(VPDef::VPReductionSC, R, I,
23142326
ArrayRef<VPValue *>({ChainOp, VecOp}), CondOp,
2315-
IsOrdered, DL) {}
2327+
IsOrdered) {}
23162328

23172329
~VPReductionRecipe() override = default;
23182330

23192331
VPReductionRecipe *clone() override {
23202332
return new VPReductionRecipe(RdxDesc, getUnderlyingInstr(), getChainOp(),
2321-
getVecOp(), getCondOp(), IsOrdered,
2322-
getDebugLoc());
2333+
getVecOp(), getCondOp(), IsOrdered);
23232334
}
23242335

23252336
static inline bool classof(const VPRecipeBase *R) {
@@ -2374,7 +2385,7 @@ class VPReductionEVLRecipe : public VPReductionRecipe {
23742385
VPDef::VPReductionEVLSC, R.getRecurrenceDescriptor(),
23752386
cast_or_null<Instruction>(R.getUnderlyingValue()),
23762387
ArrayRef<VPValue *>({R.getChainOp(), R.getVecOp(), &EVL}), CondOp,
2377-
R.isOrdered(), R.getDebugLoc()) {}
2388+
R.isOrdered()) {}
23782389

23792390
~VPReductionEVLRecipe() override = default;
23802391

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2226,7 +2226,8 @@ void VPReductionRecipe::execute(VPTransformState &State) {
22262226
RecurKind Kind = RdxDesc.getRecurrenceKind();
22272227
// Propagate the fast-math flags carried by the underlying instruction.
22282228
IRBuilderBase::FastMathFlagGuard FMFGuard(State.Builder);
2229-
State.Builder.setFastMathFlags(RdxDesc.getFastMathFlags());
2229+
if (hasFastMathFlags())
2230+
State.Builder.setFastMathFlags(getFastMathFlags());
22302231
State.setDebugLocFrom(getDebugLoc());
22312232
Value *NewVecOp = State.get(getVecOp());
22322233
if (VPValue *Cond = getCondOp()) {
@@ -2277,7 +2278,8 @@ void VPReductionEVLRecipe::execute(VPTransformState &State) {
22772278
// Propagate the fast-math flags carried by the underlying instruction.
22782279
IRBuilderBase::FastMathFlagGuard FMFGuard(Builder);
22792280
const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor();
2280-
Builder.setFastMathFlags(RdxDesc.getFastMathFlags());
2281+
if (hasFastMathFlags())
2282+
Builder.setFastMathFlags(getFastMathFlags());
22812283

22822284
RecurKind Kind = RdxDesc.getRecurrenceKind();
22832285
Value *Prev = State.get(getChainOp(), /*IsScalar*/ true);
@@ -2314,6 +2316,8 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
23142316
Type *ElementTy = Ctx.Types.inferScalarType(this);
23152317
auto *VectorTy = cast<VectorType>(toVectorTy(ElementTy, VF));
23162318
unsigned Opcode = RdxDesc.getOpcode();
2319+
FastMathFlags FMFs =
2320+
hasFastMathFlags() ? getFastMathFlags() : FastMathFlags();
23172321

23182322
// TODO: Support any-of and in-loop reductions.
23192323
assert(
@@ -2333,12 +2337,12 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
23332337
Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, Ctx.CostKind);
23342338
if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RdxKind)) {
23352339
Intrinsic::ID Id = getMinMaxReductionIntrinsicOp(RdxKind);
2336-
return Cost + Ctx.TTI.getMinMaxReductionCost(
2337-
Id, VectorTy, RdxDesc.getFastMathFlags(), Ctx.CostKind);
2340+
return Cost +
2341+
Ctx.TTI.getMinMaxReductionCost(Id, VectorTy, FMFs, Ctx.CostKind);
23382342
}
23392343

2340-
return Cost + Ctx.TTI.getArithmeticReductionCost(
2341-
Opcode, VectorTy, RdxDesc.getFastMathFlags(), Ctx.CostKind);
2344+
return Cost + Ctx.TTI.getArithmeticReductionCost(Opcode, VectorTy, FMFs,
2345+
Ctx.CostKind);
23422346
}
23432347

23442348
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
@@ -2349,8 +2353,8 @@ void VPReductionRecipe::print(raw_ostream &O, const Twine &Indent,
23492353
O << " = ";
23502354
getChainOp()->printAsOperand(O, SlotTracker);
23512355
O << " +";
2352-
if (isa<FPMathOperator>(getUnderlyingInstr()))
2353-
O << getUnderlyingInstr()->getFastMathFlags();
2356+
if (isa_and_nonnull<FPMathOperator>(getUnderlyingValue()))
2357+
printFlags(O);
23542358
O << " reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " (";
23552359
getVecOp()->printAsOperand(O, SlotTracker);
23562360
if (isConditional()) {
@@ -2371,8 +2375,8 @@ void VPReductionEVLRecipe::print(raw_ostream &O, const Twine &Indent,
23712375
O << " = ";
23722376
getChainOp()->printAsOperand(O, SlotTracker);
23732377
O << " +";
2374-
if (isa<FPMathOperator>(getUnderlyingInstr()))
2375-
O << getUnderlyingInstr()->getFastMathFlags();
2378+
if (isa_and_nonnull<FPMathOperator>(getUnderlyingValue()))
2379+
printFlags(O);
23762380
O << " vp.reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " (";
23772381
getVecOp()->printAsOperand(O, SlotTracker);
23782382
O << ", ";

0 commit comments

Comments
 (0)