@@ -8268,6 +8268,105 @@ VPReplicateRecipe *VPRecipeBuilder::handleReplication(Instruction *I,
8268
8268
return Recipe;
8269
8269
}
8270
8270
8271
+ // / Find all possible partial reductions in the loop and track all of those that
8272
+ // / are valid so recipes can be formed later.
8273
+ void VPRecipeBuilder::collectScaledReductions (VFRange &Range) {
8274
+ // Find all possible partial reductions.
8275
+ SmallVector<std::pair<PartialReductionChain, unsigned >, 1 >
8276
+ PartialReductionChains;
8277
+ for (const auto &[Phi, RdxDesc] : Legal->getReductionVars ())
8278
+ if (std::optional<std::pair<PartialReductionChain, unsigned >> Pair =
8279
+ getScaledReduction (Phi, RdxDesc, Range))
8280
+ PartialReductionChains.push_back (*Pair);
8281
+
8282
+ // A partial reduction is invalid if any of its extends are used by
8283
+ // something that isn't another partial reduction. This is because the
8284
+ // extends are intended to be lowered along with the reduction itself.
8285
+
8286
+ // Build up a set of partial reduction bin ops for efficient use checking.
8287
+ SmallSet<User *, 4 > PartialReductionBinOps;
8288
+ for (const auto &[PartialRdx, _] : PartialReductionChains)
8289
+ PartialReductionBinOps.insert (PartialRdx.BinOp );
8290
+
8291
+ auto ExtendIsOnlyUsedByPartialReductions =
8292
+ [&PartialReductionBinOps](Instruction *Extend) {
8293
+ return all_of (Extend->users (), [&](const User *U) {
8294
+ return PartialReductionBinOps.contains (U);
8295
+ });
8296
+ };
8297
+
8298
+ // Check if each use of a chain's two extends is a partial reduction
8299
+ // and only add those that don't have non-partial reduction users.
8300
+ for (auto Pair : PartialReductionChains) {
8301
+ PartialReductionChain Chain = Pair.first ;
8302
+ if (ExtendIsOnlyUsedByPartialReductions (Chain.ExtendA ) &&
8303
+ ExtendIsOnlyUsedByPartialReductions (Chain.ExtendB ))
8304
+ ScaledReductionExitInstrs.insert (std::make_pair (Chain.Reduction , Pair));
8305
+ }
8306
+ }
8307
+
8308
+ std::optional<std::pair<PartialReductionChain, unsigned >>
8309
+ VPRecipeBuilder::getScaledReduction (PHINode *PHI,
8310
+ const RecurrenceDescriptor &Rdx,
8311
+ VFRange &Range) {
8312
+ // TODO: Allow scaling reductions when predicating. The select at
8313
+ // the end of the loop chooses between the phi value and most recent
8314
+ // reduction result, both of which have different VFs to the active lane
8315
+ // mask when scaling.
8316
+ if (CM.blockNeedsPredicationForAnyReason (Rdx.getLoopExitInstr ()->getParent ()))
8317
+ return std::nullopt;
8318
+
8319
+ auto *Update = dyn_cast<BinaryOperator>(Rdx.getLoopExitInstr ());
8320
+ if (!Update)
8321
+ return std::nullopt;
8322
+
8323
+ Value *Op = Update->getOperand (0 );
8324
+ Value *PhiOp = Update->getOperand (1 );
8325
+ if (Op == PHI) {
8326
+ Op = Update->getOperand (1 );
8327
+ PhiOp = Update->getOperand (0 );
8328
+ }
8329
+ if (PhiOp != PHI)
8330
+ return std::nullopt;
8331
+
8332
+ auto *BinOp = dyn_cast<BinaryOperator>(Op);
8333
+ if (!BinOp || !BinOp->hasOneUse ())
8334
+ return std::nullopt;
8335
+
8336
+ using namespace llvm ::PatternMatch;
8337
+ Value *A, *B;
8338
+ if (!match (BinOp->getOperand (0 ), m_ZExtOrSExt (m_Value (A))) ||
8339
+ !match (BinOp->getOperand (1 ), m_ZExtOrSExt (m_Value (B))))
8340
+ return std::nullopt;
8341
+
8342
+ Instruction *ExtA = cast<Instruction>(BinOp->getOperand (0 ));
8343
+ Instruction *ExtB = cast<Instruction>(BinOp->getOperand (1 ));
8344
+
8345
+ TTI::PartialReductionExtendKind OpAExtend =
8346
+ TargetTransformInfo::getPartialReductionExtendKind (ExtA);
8347
+ TTI::PartialReductionExtendKind OpBExtend =
8348
+ TargetTransformInfo::getPartialReductionExtendKind (ExtB);
8349
+
8350
+ PartialReductionChain Chain (Rdx.getLoopExitInstr (), ExtA, ExtB, BinOp);
8351
+
8352
+ unsigned TargetScaleFactor =
8353
+ PHI->getType ()->getPrimitiveSizeInBits ().getKnownScalarFactor (
8354
+ A->getType ()->getPrimitiveSizeInBits ());
8355
+
8356
+ if (LoopVectorizationPlanner::getDecisionAndClampRange (
8357
+ [&](ElementCount VF) {
8358
+ InstructionCost Cost = TTI->getPartialReductionCost (
8359
+ Update->getOpcode (), A->getType (), B->getType (), PHI->getType (),
8360
+ VF, OpAExtend, OpBExtend,
8361
+ std::make_optional (BinOp->getOpcode ()));
8362
+ return Cost.isValid ();
8363
+ },
8364
+ Range))
8365
+ return std::make_pair (Chain, TargetScaleFactor);
8366
+
8367
+ return std::nullopt;
8368
+ }
8369
+
8271
8370
VPRecipeBase *
8272
8371
VPRecipeBuilder::tryToCreateWidenRecipe (Instruction *Instr,
8273
8372
ArrayRef<VPValue *> Operands,
@@ -8292,9 +8391,14 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
8292
8391
Legal->getReductionVars ().find (Phi)->second ;
8293
8392
assert (RdxDesc.getRecurrenceStartValue () ==
8294
8393
Phi->getIncomingValueForBlock (OrigLoop->getLoopPreheader ()));
8295
- PhiRecipe = new VPReductionPHIRecipe (Phi, RdxDesc, *StartV,
8296
- CM.isInLoopReduction (Phi),
8297
- CM.useOrderedReductions (RdxDesc));
8394
+
8395
+ // If the PHI is used by a partial reduction, set the scale factor.
8396
+ std::optional<std::pair<PartialReductionChain, unsigned >> Pair =
8397
+ getScaledReductionForInstr (RdxDesc.getLoopExitInstr ());
8398
+ unsigned ScaleFactor = Pair ? Pair->second : 1 ;
8399
+ PhiRecipe = new VPReductionPHIRecipe (
8400
+ Phi, RdxDesc, *StartV, CM.isInLoopReduction (Phi),
8401
+ CM.useOrderedReductions (RdxDesc), ScaleFactor);
8298
8402
} else {
8299
8403
// TODO: Currently fixed-order recurrences are modeled as chains of
8300
8404
// first-order recurrences. If there are no users of the intermediate
@@ -8322,6 +8426,9 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
8322
8426
if (isa<LoadInst>(Instr) || isa<StoreInst>(Instr))
8323
8427
return tryToWidenMemory (Instr, Operands, Range);
8324
8428
8429
+ if (getScaledReductionForInstr (Instr))
8430
+ return tryToCreatePartialReduction (Instr, Operands);
8431
+
8325
8432
if (!shouldWiden (Instr, Range))
8326
8433
return nullptr ;
8327
8434
@@ -8342,6 +8449,21 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
8342
8449
return tryToWiden (Instr, Operands, VPBB);
8343
8450
}
8344
8451
8452
+ VPRecipeBase *
8453
+ VPRecipeBuilder::tryToCreatePartialReduction (Instruction *Reduction,
8454
+ ArrayRef<VPValue *> Operands) {
8455
+ assert (Operands.size () == 2 &&
8456
+ " Unexpected number of operands for partial reduction" );
8457
+
8458
+ VPValue *BinOp = Operands[0 ];
8459
+ VPValue *Phi = Operands[1 ];
8460
+ if (isa<VPReductionPHIRecipe>(BinOp->getDefiningRecipe ()))
8461
+ std::swap (BinOp, Phi);
8462
+
8463
+ return new VPPartialReductionRecipe (Reduction->getOpcode (), BinOp, Phi,
8464
+ Reduction);
8465
+ }
8466
+
8345
8467
void LoopVectorizationPlanner::buildVPlansWithVPRecipes (ElementCount MinVF,
8346
8468
ElementCount MaxVF) {
8347
8469
assert (OrigLoop->isInnermost () && " Inner loop expected." );
@@ -8514,7 +8636,8 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
8514
8636
bool HasNUW = Style == TailFoldingStyle::None;
8515
8637
addCanonicalIVRecipes (*Plan, Legal->getWidestInductionType (), HasNUW, DL);
8516
8638
8517
- VPRecipeBuilder RecipeBuilder (*Plan, OrigLoop, TLI, Legal, CM, PSE, Builder);
8639
+ VPRecipeBuilder RecipeBuilder (*Plan, OrigLoop, TLI, &TTI, Legal, CM, PSE,
8640
+ Builder);
8518
8641
8519
8642
// ---------------------------------------------------------------------------
8520
8643
// Pre-construction: record ingredients whose recipes we'll need to further
@@ -8560,6 +8683,9 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
8560
8683
bool NeedsBlends = BB != HeaderBB && !BB->phis ().empty ();
8561
8684
return Legal->blockNeedsPredication (BB) || NeedsBlends;
8562
8685
});
8686
+
8687
+ RecipeBuilder.collectScaledReductions (Range);
8688
+
8563
8689
for (BasicBlock *BB : make_range (DFS.beginRPO (), DFS.endRPO ())) {
8564
8690
// Relevant instructions from basic block BB will be grouped into VPRecipe
8565
8691
// ingredients and fill a new VPBasicBlock.
@@ -8770,7 +8896,7 @@ VPlanPtr LoopVectorizationPlanner::buildVPlan(VFRange &Range) {
8770
8896
bool HasNUW = true ;
8771
8897
addCanonicalIVRecipes (*Plan, Legal->getWidestInductionType (), HasNUW,
8772
8898
DebugLoc ());
8773
- assert (verifyVPlanIsValid (*Plan) && " VPlan is invalid" );
8899
+ assert (verifyVPlanIsValid (*Plan) && " VPlan is invalid" );
8774
8900
return Plan;
8775
8901
}
8776
8902
0 commit comments