Skip to content

Commit 447cad5

Browse files
lukel97pawosm-arm
authored andcommitted
[VPlan] Only store RecurKind + FastMathFlags in VPReductionRecipe. NFCI (llvm#131300)
VPReductionRecipes take a RecurrenceDescriptor, but only use the RecurKind and FastMathFlags in it when executing. This patch makes the recipe more lightweight by stripping it to only take the latter two. The motiviation for this is to simplify an upcoming patch to support in-loop AnyOf reductions. For an in-loop AnyOf reduction we want to create an Or reduction, and by using RecurKind we can create an arbitrary reduction without needing a full RecurrenceDescriptor.
1 parent 8dd520e commit 447cad5

File tree

5 files changed

+52
-56
lines changed

5 files changed

+52
-56
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9598,8 +9598,12 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
95989598
if (CM.blockNeedsPredicationForAnyReason(BB))
95999599
CondOp = RecipeBuilder.getBlockInMask(BB);
96009600

9601+
// Non-FP RdxDescs will have all fast math flags set, so clear them.
9602+
FastMathFlags FMFs = isa<FPMathOperator>(CurrentLinkI)
9603+
? RdxDesc.getFastMathFlags()
9604+
: FastMathFlags();
96019605
auto *RedRecipe = new VPReductionRecipe(
9602-
RdxDesc, CurrentLinkI, PreviousLink, VecOp, CondOp,
9606+
Kind, FMFs, CurrentLinkI, PreviousLink, VecOp, CondOp,
96039607
CM.useOrderedReductions(RdxDesc), CurrentLinkI->getDebugLoc());
96049608
// Append the recipe to the end of the VPBasicBlock because we need to
96059609
// ensure that it comes after all of it's inputs, including CondOp.

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2655,22 +2655,19 @@ class VPInterleaveRecipe : public VPRecipeBase {
26552655
/// a vector operand into a scalar value, and adding the result to a chain.
26562656
/// The Operands are {ChainOp, VecOp, [Condition]}.
26572657
class VPReductionRecipe : public VPRecipeWithIRFlags {
2658-
/// The recurrence decriptor for the reduction in question.
2659-
const RecurrenceDescriptor &RdxDesc;
2658+
/// The recurrence kind for the reduction in question.
2659+
RecurKind RdxKind;
26602660
bool IsOrdered;
26612661
/// Whether the reduction is conditional.
26622662
bool IsConditional = false;
26632663

26642664
protected:
2665-
VPReductionRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
2666-
Instruction *I, ArrayRef<VPValue *> Operands,
2667-
VPValue *CondOp, bool IsOrdered, DebugLoc DL)
2668-
: VPRecipeWithIRFlags(SC, Operands,
2669-
isa_and_nonnull<FPMathOperator>(I)
2670-
? R.getFastMathFlags()
2671-
: FastMathFlags(),
2672-
DL),
2673-
RdxDesc(R), IsOrdered(IsOrdered) {
2665+
VPReductionRecipe(const unsigned char SC, RecurKind RdxKind,
2666+
FastMathFlags FMFs, Instruction *I,
2667+
ArrayRef<VPValue *> Operands, VPValue *CondOp,
2668+
bool IsOrdered, DebugLoc DL)
2669+
: VPRecipeWithIRFlags(SC, Operands, FMFs, DL), RdxKind(RdxKind),
2670+
IsOrdered(IsOrdered) {
26742671
if (CondOp) {
26752672
IsConditional = true;
26762673
addOperand(CondOp);
@@ -2679,19 +2676,19 @@ class VPReductionRecipe : public VPRecipeWithIRFlags {
26792676
}
26802677

26812678
public:
2682-
VPReductionRecipe(const RecurrenceDescriptor &R, Instruction *I,
2679+
VPReductionRecipe(RecurKind RdxKind, FastMathFlags FMFs, Instruction *I,
26832680
VPValue *ChainOp, VPValue *VecOp, VPValue *CondOp,
26842681
bool IsOrdered, DebugLoc DL = {})
2685-
: VPReductionRecipe(VPDef::VPReductionSC, R, I,
2682+
: VPReductionRecipe(VPDef::VPReductionSC, RdxKind, FMFs, I,
26862683
ArrayRef<VPValue *>({ChainOp, VecOp}), CondOp,
26872684
IsOrdered, DL) {}
26882685

26892686
~VPReductionRecipe() override = default;
26902687

26912688
VPReductionRecipe *clone() override {
2692-
return new VPReductionRecipe(RdxDesc, getUnderlyingInstr(), getChainOp(),
2693-
getVecOp(), getCondOp(), IsOrdered,
2694-
getDebugLoc());
2689+
return new VPReductionRecipe(RdxKind, getFastMathFlags(),
2690+
getUnderlyingInstr(), getChainOp(), getVecOp(),
2691+
getCondOp(), IsOrdered, getDebugLoc());
26952692
}
26962693

26972694
static inline bool classof(const VPRecipeBase *R) {
@@ -2717,10 +2714,8 @@ class VPReductionRecipe : public VPRecipeWithIRFlags {
27172714
VPSlotTracker &SlotTracker) const override;
27182715
#endif
27192716

2720-
/// Return the recurrence decriptor for the in-loop reduction.
2721-
const RecurrenceDescriptor &getRecurrenceDescriptor() const {
2722-
return RdxDesc;
2723-
}
2717+
/// Return the recurrence kind for the in-loop reduction.
2718+
RecurKind getRecurrenceKind() const { return RdxKind; }
27242719
/// Return true if the in-loop reduction is ordered.
27252720
bool isOrdered() const { return IsOrdered; };
27262721
/// Return true if the in-loop reduction is conditional.
@@ -2744,7 +2739,8 @@ class VPReductionEVLRecipe : public VPReductionRecipe {
27442739
VPReductionEVLRecipe(VPReductionRecipe &R, VPValue &EVL, VPValue *CondOp,
27452740
DebugLoc DL = {})
27462741
: VPReductionRecipe(
2747-
VPDef::VPReductionEVLSC, R.getRecurrenceDescriptor(),
2742+
VPDef::VPReductionEVLSC, R.getRecurrenceKind(),
2743+
R.getFastMathFlags(),
27482744
cast_or_null<Instruction>(R.getUnderlyingValue()),
27492745
ArrayRef<VPValue *>({R.getChainOp(), R.getVecOp(), &EVL}), CondOp,
27502746
R.isOrdered(), DL) {}

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 21 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2231,7 +2231,7 @@ void VPBlendRecipe::print(raw_ostream &O, const Twine &Indent,
22312231
void VPReductionRecipe::execute(VPTransformState &State) {
22322232
assert(!State.Lane && "Reduction being replicated.");
22332233
Value *PrevInChain = State.get(getChainOp(), /*IsScalar*/ true);
2234-
RecurKind Kind = RdxDesc.getRecurrenceKind();
2234+
RecurKind Kind = getRecurrenceKind();
22352235
assert(!RecurrenceDescriptor::isAnyOfRecurrenceKind(Kind) &&
22362236
"In-loop AnyOf reductions aren't currently supported");
22372237
// Propagate the fast-math flags carried by the underlying instruction.
@@ -2244,8 +2244,7 @@ void VPReductionRecipe::execute(VPTransformState &State) {
22442244
VectorType *VecTy = dyn_cast<VectorType>(NewVecOp->getType());
22452245
Type *ElementTy = VecTy ? VecTy->getElementType() : NewVecOp->getType();
22462246

2247-
Value *Start =
2248-
getRecurrenceIdentity(Kind, ElementTy, RdxDesc.getFastMathFlags());
2247+
Value *Start = getRecurrenceIdentity(Kind, ElementTy, getFastMathFlags());
22492248
if (State.VF.isVector())
22502249
Start = State.Builder.CreateVectorSplat(VecTy->getElementCount(), Start);
22512250

@@ -2260,18 +2259,19 @@ void VPReductionRecipe::execute(VPTransformState &State) {
22602259
createOrderedReduction(State.Builder, Kind, NewVecOp, PrevInChain);
22612260
else
22622261
NewRed = State.Builder.CreateBinOp(
2263-
(Instruction::BinaryOps)RdxDesc.getOpcode(), PrevInChain, NewVecOp);
2262+
(Instruction::BinaryOps)RecurrenceDescriptor::getOpcode(Kind),
2263+
PrevInChain, NewVecOp);
22642264
PrevInChain = NewRed;
22652265
NextInChain = NewRed;
22662266
} else {
22672267
PrevInChain = State.get(getChainOp(), /*IsScalar*/ true);
22682268
NewRed = createSimpleReduction(State.Builder, NewVecOp, Kind);
22692269
if (RecurrenceDescriptor::isMinMaxRecurrenceKind(Kind))
2270-
NextInChain = createMinMaxOp(State.Builder, RdxDesc.getRecurrenceKind(),
2271-
NewRed, PrevInChain);
2270+
NextInChain = createMinMaxOp(State.Builder, Kind, NewRed, PrevInChain);
22722271
else
22732272
NextInChain = State.Builder.CreateBinOp(
2274-
(Instruction::BinaryOps)RdxDesc.getOpcode(), NewRed, PrevInChain);
2273+
(Instruction::BinaryOps)RecurrenceDescriptor::getOpcode(Kind), NewRed,
2274+
PrevInChain);
22752275
}
22762276
State.set(this, NextInChain, /*IsScalar*/ true);
22772277
}
@@ -2282,10 +2282,9 @@ void VPReductionEVLRecipe::execute(VPTransformState &State) {
22822282
auto &Builder = State.Builder;
22832283
// Propagate the fast-math flags carried by the underlying instruction.
22842284
IRBuilderBase::FastMathFlagGuard FMFGuard(Builder);
2285-
const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor();
22862285
Builder.setFastMathFlags(getFastMathFlags());
22872286

2288-
RecurKind Kind = RdxDesc.getRecurrenceKind();
2287+
RecurKind Kind = getRecurrenceKind();
22892288
Value *Prev = State.get(getChainOp(), /*IsScalar*/ true);
22902289
Value *VecOp = State.get(getVecOp());
22912290
Value *EVL = State.get(getEVL(), VPLane(0));
@@ -2308,18 +2307,19 @@ void VPReductionEVLRecipe::execute(VPTransformState &State) {
23082307
if (RecurrenceDescriptor::isMinMaxRecurrenceKind(Kind))
23092308
NewRed = createMinMaxOp(Builder, Kind, NewRed, Prev);
23102309
else
2311-
NewRed = Builder.CreateBinOp((Instruction::BinaryOps)RdxDesc.getOpcode(),
2312-
NewRed, Prev);
2310+
NewRed = Builder.CreateBinOp(
2311+
(Instruction::BinaryOps)RecurrenceDescriptor::getOpcode(Kind), NewRed,
2312+
Prev);
23132313
}
23142314
State.set(this, NewRed, /*IsScalar*/ true);
23152315
}
23162316

23172317
InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
23182318
VPCostContext &Ctx) const {
2319-
RecurKind RdxKind = RdxDesc.getRecurrenceKind();
2319+
RecurKind RdxKind = getRecurrenceKind();
23202320
Type *ElementTy = Ctx.Types.inferScalarType(this);
23212321
auto *VectorTy = cast<VectorType>(toVectorTy(ElementTy, VF));
2322-
unsigned Opcode = RdxDesc.getOpcode();
2322+
unsigned Opcode = RecurrenceDescriptor::getOpcode(RdxKind);
23232323
FastMathFlags FMFs = getFastMathFlags();
23242324

23252325
// TODO: Support any-of and in-loop reductions.
@@ -2332,9 +2332,6 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
23322332
ForceTargetInstructionCost.getNumOccurrences() > 0) &&
23332333
"In-loop reduction not implemented in VPlan-based cost model currently.");
23342334

2335-
assert(ElementTy->getTypeID() == RdxDesc.getRecurrenceType()->getTypeID() &&
2336-
"Inferred type and recurrence type mismatch.");
2337-
23382335
// Cost = Reduction cost + BinOp cost
23392336
InstructionCost Cost =
23402337
Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, Ctx.CostKind);
@@ -2357,28 +2354,30 @@ void VPReductionRecipe::print(raw_ostream &O, const Twine &Indent,
23572354
getChainOp()->printAsOperand(O, SlotTracker);
23582355
O << " +";
23592356
printFlags(O);
2360-
O << " reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " (";
2357+
O << " reduce."
2358+
<< Instruction::getOpcodeName(
2359+
RecurrenceDescriptor::getOpcode(getRecurrenceKind()))
2360+
<< " (";
23612361
getVecOp()->printAsOperand(O, SlotTracker);
23622362
if (isConditional()) {
23632363
O << ", ";
23642364
getCondOp()->printAsOperand(O, SlotTracker);
23652365
}
23662366
O << ")";
2367-
if (RdxDesc.IntermediateStore)
2368-
O << " (with final reduction value stored in invariant address sank "
2369-
"outside of loop)";
23702367
}
23712368

23722369
void VPReductionEVLRecipe::print(raw_ostream &O, const Twine &Indent,
23732370
VPSlotTracker &SlotTracker) const {
2374-
const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor();
23752371
O << Indent << "REDUCE ";
23762372
printAsOperand(O, SlotTracker);
23772373
O << " = ";
23782374
getChainOp()->printAsOperand(O, SlotTracker);
23792375
O << " +";
23802376
printFlags(O);
2381-
O << " vp.reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " (";
2377+
O << " vp.reduce."
2378+
<< Instruction::getOpcodeName(
2379+
RecurrenceDescriptor::getOpcode(getRecurrenceKind()))
2380+
<< " (";
23822381
getVecOp()->printAsOperand(O, SlotTracker);
23832382
O << ", ";
23842383
getEVL()->printAsOperand(O, SlotTracker);
@@ -2387,9 +2386,6 @@ void VPReductionEVLRecipe::print(raw_ostream &O, const Twine &Indent,
23872386
getCondOp()->printAsOperand(O, SlotTracker);
23882387
}
23892388
O << ")";
2390-
if (RdxDesc.IntermediateStore)
2391-
O << " (with final reduction value stored in invariant address sank "
2392-
"outside of loop)";
23932389
}
23942390
#endif
23952391

llvm/test/Transforms/LoopVectorize/vplan-printing.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ define void @print_reduction_with_invariant_store(i64 %n, ptr noalias %y, ptr no
234234
; CHECK-NEXT: CLONE ir<%arrayidx> = getelementptr inbounds ir<%y>, vp<[[IV]]>
235235
; CHECK-NEXT: vp<[[VEC_PTR:%.+]]> = vector-pointer ir<%arrayidx>
236236
; CHECK-NEXT: WIDEN ir<%lv> = load vp<[[VEC_PTR]]>
237-
; CHECK-NEXT: REDUCE ir<%red.next> = ir<%red> + fast reduce.fadd (ir<%lv>) (with final reduction value stored in invariant address sank outside of loop)
237+
; CHECK-NEXT: REDUCE ir<%red.next> = ir<%red> + fast reduce.fadd (ir<%lv>)
238238
; CHECK-NEXT: EMIT vp<[[CAN_IV_NEXT]]> = add nuw vp<[[CAN_IV]]>, vp<[[VFxUF]]>
239239
; CHECK-NEXT: EMIT branch-on-count vp<[[CAN_IV_NEXT]]>, vp<[[VTC]]>
240240
; CHECK-NEXT: No successors

llvm/unittests/Transforms/Vectorize/VPlanTest.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1125,8 +1125,8 @@ TEST_F(VPRecipeTest, MayHaveSideEffectsAndMayReadWriteMemory) {
11251125
VPValue *ChainOp = Plan.getOrAddLiveIn(ConstantInt::get(Int32, 1));
11261126
VPValue *VecOp = Plan.getOrAddLiveIn(ConstantInt::get(Int32, 2));
11271127
VPValue *CondOp = Plan.getOrAddLiveIn(ConstantInt::get(Int32, 3));
1128-
VPReductionRecipe Recipe(RecurrenceDescriptor(), Add, ChainOp, CondOp,
1129-
VecOp, false);
1128+
VPReductionRecipe Recipe(RecurKind::Add, FastMathFlags(), Add, ChainOp,
1129+
CondOp, VecOp, false);
11301130
EXPECT_FALSE(Recipe.mayHaveSideEffects());
11311131
EXPECT_FALSE(Recipe.mayReadFromMemory());
11321132
EXPECT_FALSE(Recipe.mayWriteToMemory());
@@ -1140,8 +1140,8 @@ TEST_F(VPRecipeTest, MayHaveSideEffectsAndMayReadWriteMemory) {
11401140
VPValue *ChainOp = Plan.getOrAddLiveIn(ConstantInt::get(Int32, 1));
11411141
VPValue *VecOp = Plan.getOrAddLiveIn(ConstantInt::get(Int32, 2));
11421142
VPValue *CondOp = Plan.getOrAddLiveIn(ConstantInt::get(Int32, 3));
1143-
VPReductionRecipe Recipe(RecurrenceDescriptor(), Add, ChainOp, CondOp,
1144-
VecOp, false);
1143+
VPReductionRecipe Recipe(RecurKind::Add, FastMathFlags(), Add, ChainOp,
1144+
CondOp, VecOp, false);
11451145
VPValue *EVL = Plan.getOrAddLiveIn(ConstantInt::get(Int32, 4));
11461146
VPReductionEVLRecipe EVLRecipe(Recipe, *EVL, CondOp);
11471147
EXPECT_FALSE(EVLRecipe.mayHaveSideEffects());
@@ -1495,8 +1495,8 @@ TEST_F(VPRecipeTest, CastVPReductionRecipeToVPUser) {
14951495
VPValue *ChainOp = getPlan().getOrAddLiveIn(ConstantInt::get(Int32, 1));
14961496
VPValue *VecOp = getPlan().getOrAddLiveIn(ConstantInt::get(Int32, 2));
14971497
VPValue *CondOp = getPlan().getOrAddLiveIn(ConstantInt::get(Int32, 3));
1498-
VPReductionRecipe Recipe(RecurrenceDescriptor(), Add, ChainOp, CondOp, VecOp,
1499-
false);
1498+
VPReductionRecipe Recipe(RecurKind::Add, FastMathFlags(), Add, ChainOp,
1499+
CondOp, VecOp, false);
15001500
EXPECT_TRUE(isa<VPUser>(&Recipe));
15011501
VPRecipeBase *BaseR = &Recipe;
15021502
EXPECT_TRUE(isa<VPUser>(BaseR));
@@ -1510,8 +1510,8 @@ TEST_F(VPRecipeTest, CastVPReductionEVLRecipeToVPUser) {
15101510
VPValue *ChainOp = getPlan().getOrAddLiveIn(ConstantInt::get(Int32, 1));
15111511
VPValue *VecOp = getPlan().getOrAddLiveIn(ConstantInt::get(Int32, 2));
15121512
VPValue *CondOp = getPlan().getOrAddLiveIn(ConstantInt::get(Int32, 3));
1513-
VPReductionRecipe Recipe(RecurrenceDescriptor(), Add, ChainOp, CondOp, VecOp,
1514-
false);
1513+
VPReductionRecipe Recipe(RecurKind::Add, FastMathFlags(), Add, ChainOp,
1514+
CondOp, VecOp, false);
15151515
VPValue *EVL = getPlan().getOrAddLiveIn(ConstantInt::get(Int32, 0));
15161516
VPReductionEVLRecipe EVLRecipe(Recipe, *EVL, CondOp);
15171517
EXPECT_TRUE(isa<VPUser>(&EVLRecipe));

0 commit comments

Comments
 (0)