Skip to content

Commit 087a533

Browse files
committed
[VPlan] Change parent of VPReductionRecipe to VPRecipeWithIRFlags. NFC
This patch change the parent of the VPReductionRecipe from VPSingleDefRecipe to VPRecipeWithIRFlags and also print/get/control flags by the VPRecipeWithIRFlags. This will remove the dependency of the underlying instruction. This patch also add a new function `setFastMathFlags()` to the VPRecipeWithIRFlags because the entire reduction chain may contains multiple instructions. And the underlying instruction may not contains the corresponding flags for this reduction.
1 parent 9a0e4d7 commit 087a533

File tree

3 files changed

+36
-21
lines changed

3 files changed

+36
-21
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9799,9 +9799,9 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
97999799
if (CM.blockNeedsPredicationForAnyReason(BB))
98009800
CondOp = RecipeBuilder.getBlockInMask(BB);
98019801

9802-
auto *RedRecipe = new VPReductionRecipe(
9803-
RdxDesc, CurrentLinkI, PreviousLink, VecOp, CondOp,
9804-
CM.useOrderedReductions(RdxDesc), CurrentLinkI->getDebugLoc());
9802+
auto *RedRecipe =
9803+
new VPReductionRecipe(RdxDesc, CurrentLinkI, PreviousLink, VecOp,
9804+
CondOp, CM.useOrderedReductions(RdxDesc));
98059805
// Append the recipe to the end of the VPBasicBlock because we need to
98069806
// ensure that it comes after all of it's inputs, including CondOp.
98079807
// Delete CurrentLink as it will be invalid if its operand is replaced

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -713,6 +713,8 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
713713
R->getVPDefID() == VPRecipeBase::VPWidenGEPSC ||
714714
R->getVPDefID() == VPRecipeBase::VPWidenCastSC ||
715715
R->getVPDefID() == VPRecipeBase::VPWidenIntrinsicSC ||
716+
R->getVPDefID() == VPRecipeBase::VPReductionSC ||
717+
R->getVPDefID() == VPRecipeBase::VPReductionEVLSC ||
716718
R->getVPDefID() == VPRecipeBase::VPReplicateSC ||
717719
R->getVPDefID() == VPRecipeBase::VPReverseVectorPointerSC ||
718720
R->getVPDefID() == VPRecipeBase::VPVectorPointerSC;
@@ -788,6 +790,12 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
788790
}
789791
}
790792

793+
/// Set fast-math flags for this recipe.
794+
void setFastMathFlags(FastMathFlags FMFs) {
795+
OpType = OperationType::FPMathOp;
796+
this->FMFs = FMFs;
797+
}
798+
791799
CmpInst::Predicate getPredicate() const {
792800
assert(OpType == OperationType::Cmp &&
793801
"recipe doesn't have a compare predicate");
@@ -2286,7 +2294,7 @@ class VPInterleaveRecipe : public VPRecipeBase {
22862294
/// A recipe to represent inloop reduction operations, performing a reduction on
22872295
/// a vector operand into a scalar value, and adding the result to a chain.
22882296
/// The Operands are {ChainOp, VecOp, [Condition]}.
2289-
class VPReductionRecipe : public VPSingleDefRecipe {
2297+
class VPReductionRecipe : public VPRecipeWithIRFlags {
22902298
/// The recurrence decriptor for the reduction in question.
22912299
const RecurrenceDescriptor &RdxDesc;
22922300
bool IsOrdered;
@@ -2296,29 +2304,32 @@ class VPReductionRecipe : public VPSingleDefRecipe {
22962304
protected:
22972305
VPReductionRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
22982306
Instruction *I, ArrayRef<VPValue *> Operands,
2299-
VPValue *CondOp, bool IsOrdered, DebugLoc DL)
2300-
: VPSingleDefRecipe(SC, Operands, I, DL), RdxDesc(R),
2307+
VPValue *CondOp, bool IsOrdered)
2308+
: VPRecipeWithIRFlags(SC, Operands, *I), RdxDesc(R),
23012309
IsOrdered(IsOrdered) {
23022310
if (CondOp) {
23032311
IsConditional = true;
23042312
addOperand(CondOp);
23052313
}
2314+
// The inloop reduction may across multiple scalar instruction and the
2315+
// underlying instruction may not contains the corresponding flags. Set the
2316+
// flags explicit from the redurrence descriptor.
2317+
setFastMathFlags(R.getFastMathFlags());
23062318
}
23072319

23082320
public:
23092321
VPReductionRecipe(const RecurrenceDescriptor &R, Instruction *I,
23102322
VPValue *ChainOp, VPValue *VecOp, VPValue *CondOp,
2311-
bool IsOrdered, DebugLoc DL = {})
2323+
bool IsOrdered)
23122324
: VPReductionRecipe(VPDef::VPReductionSC, R, I,
23132325
ArrayRef<VPValue *>({ChainOp, VecOp}), CondOp,
2314-
IsOrdered, DL) {}
2326+
IsOrdered) {}
23152327

23162328
~VPReductionRecipe() override = default;
23172329

23182330
VPReductionRecipe *clone() override {
23192331
return new VPReductionRecipe(RdxDesc, getUnderlyingInstr(), getChainOp(),
2320-
getVecOp(), getCondOp(), IsOrdered,
2321-
getDebugLoc());
2332+
getVecOp(), getCondOp(), IsOrdered);
23222333
}
23232334

23242335
static inline bool classof(const VPRecipeBase *R) {
@@ -2373,7 +2384,7 @@ class VPReductionEVLRecipe : public VPReductionRecipe {
23732384
VPDef::VPReductionEVLSC, R.getRecurrenceDescriptor(),
23742385
cast_or_null<Instruction>(R.getUnderlyingValue()),
23752386
ArrayRef<VPValue *>({R.getChainOp(), R.getVecOp(), &EVL}), CondOp,
2376-
R.isOrdered(), R.getDebugLoc()) {}
2387+
R.isOrdered()) {}
23772388

23782389
~VPReductionEVLRecipe() override = default;
23792390

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2224,7 +2224,8 @@ void VPReductionRecipe::execute(VPTransformState &State) {
22242224
RecurKind Kind = RdxDesc.getRecurrenceKind();
22252225
// Propagate the fast-math flags carried by the underlying instruction.
22262226
IRBuilderBase::FastMathFlagGuard FMFGuard(State.Builder);
2227-
State.Builder.setFastMathFlags(RdxDesc.getFastMathFlags());
2227+
if (hasFastMathFlags())
2228+
State.Builder.setFastMathFlags(getFastMathFlags());
22282229
State.setDebugLocFrom(getDebugLoc());
22292230
Value *NewVecOp = State.get(getVecOp());
22302231
if (VPValue *Cond = getCondOp()) {
@@ -2275,7 +2276,8 @@ void VPReductionEVLRecipe::execute(VPTransformState &State) {
22752276
// Propagate the fast-math flags carried by the underlying instruction.
22762277
IRBuilderBase::FastMathFlagGuard FMFGuard(Builder);
22772278
const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor();
2278-
Builder.setFastMathFlags(RdxDesc.getFastMathFlags());
2279+
if (hasFastMathFlags())
2280+
Builder.setFastMathFlags(getFastMathFlags());
22792281

22802282
RecurKind Kind = RdxDesc.getRecurrenceKind();
22812283
Value *Prev = State.get(getChainOp(), /*IsScalar*/ true);
@@ -2312,6 +2314,8 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
23122314
Type *ElementTy = Ctx.Types.inferScalarType(this);
23132315
auto *VectorTy = cast<VectorType>(toVectorTy(ElementTy, VF));
23142316
unsigned Opcode = RdxDesc.getOpcode();
2317+
FastMathFlags FMFs =
2318+
hasFastMathFlags() ? getFastMathFlags() : FastMathFlags();
23152319

23162320
// TODO: Support any-of and in-loop reductions.
23172321
assert(
@@ -2331,12 +2335,12 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
23312335
Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, Ctx.CostKind);
23322336
if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RdxKind)) {
23332337
Intrinsic::ID Id = getMinMaxReductionIntrinsicOp(RdxKind);
2334-
return Cost + Ctx.TTI.getMinMaxReductionCost(
2335-
Id, VectorTy, RdxDesc.getFastMathFlags(), Ctx.CostKind);
2338+
return Cost +
2339+
Ctx.TTI.getMinMaxReductionCost(Id, VectorTy, FMFs, Ctx.CostKind);
23362340
}
23372341

2338-
return Cost + Ctx.TTI.getArithmeticReductionCost(
2339-
Opcode, VectorTy, RdxDesc.getFastMathFlags(), Ctx.CostKind);
2342+
return Cost + Ctx.TTI.getArithmeticReductionCost(Opcode, VectorTy, FMFs,
2343+
Ctx.CostKind);
23402344
}
23412345

23422346
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
@@ -2347,8 +2351,8 @@ void VPReductionRecipe::print(raw_ostream &O, const Twine &Indent,
23472351
O << " = ";
23482352
getChainOp()->printAsOperand(O, SlotTracker);
23492353
O << " +";
2350-
if (isa<FPMathOperator>(getUnderlyingInstr()))
2351-
O << getUnderlyingInstr()->getFastMathFlags();
2354+
if (isa_and_nonnull<FPMathOperator>(getUnderlyingValue()))
2355+
printFlags(O);
23522356
O << " reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " (";
23532357
getVecOp()->printAsOperand(O, SlotTracker);
23542358
if (isConditional()) {
@@ -2369,8 +2373,8 @@ void VPReductionEVLRecipe::print(raw_ostream &O, const Twine &Indent,
23692373
O << " = ";
23702374
getChainOp()->printAsOperand(O, SlotTracker);
23712375
O << " +";
2372-
if (isa<FPMathOperator>(getUnderlyingInstr()))
2373-
O << getUnderlyingInstr()->getFastMathFlags();
2376+
if (isa_and_nonnull<FPMathOperator>(getUnderlyingValue()))
2377+
printFlags(O);
23742378
O << " vp.reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " (";
23752379
getVecOp()->printAsOperand(O, SlotTracker);
23762380
O << ", ";

0 commit comments

Comments
 (0)