@@ -7532,6 +7532,10 @@ static bool planContainsAdditionalSimplifications(VPlan &Plan,
7532
7532
}
7533
7533
continue ;
7534
7534
}
7535
+ // The VPlan-based cost model is more accurate for partial reduction and
7536
+ // comparing against the legacy cost isn't desirable.
7537
+ if (isa<VPPartialReductionRecipe>(&R))
7538
+ return true ;
7535
7539
if (Instruction *UI = GetInstructionForCost (&R))
7536
7540
SeenInstrs.insert (UI);
7537
7541
}
@@ -8746,6 +8750,103 @@ VPReplicateRecipe *VPRecipeBuilder::handleReplication(Instruction *I,
8746
8750
return Recipe;
8747
8751
}
8748
8752
8753
+ // / Find all possible partial reductions in the loop and track all of those that
8754
+ // / are valid so recipes can be formed later.
8755
+ void VPRecipeBuilder::collectScaledReductions (VFRange &Range) {
8756
+ // Find all possible partial reductions.
8757
+ SmallVector<std::pair<PartialReductionChain, unsigned >, 1 >
8758
+ PartialReductionChains;
8759
+ for (const auto &[Phi, RdxDesc] : Legal->getReductionVars ())
8760
+ if (std::optional<std::pair<PartialReductionChain, unsigned >> Pair =
8761
+ getScaledReduction (Phi, RdxDesc, Range))
8762
+ PartialReductionChains.push_back (*Pair);
8763
+
8764
+ // A partial reduction is invalid if any of its extends are used by
8765
+ // something that isn't another partial reduction. This is because the
8766
+ // extends are intended to be lowered along with the reduction itself.
8767
+
8768
+ // Build up a set of partial reduction bin ops for efficient use checking.
8769
+ SmallSet<User *, 4 > PartialReductionBinOps;
8770
+ for (const auto &[PartialRdx, _] : PartialReductionChains)
8771
+ PartialReductionBinOps.insert (PartialRdx.BinOp );
8772
+
8773
+ auto ExtendIsOnlyUsedByPartialReductions =
8774
+ [&PartialReductionBinOps](Instruction *Extend) {
8775
+ return all_of (Extend->users (), [&](const User *U) {
8776
+ return PartialReductionBinOps.contains (U);
8777
+ });
8778
+ };
8779
+
8780
+ // Check if each use of a chain's two extends is a partial reduction
8781
+ // and only add those that don't have non-partial reduction users.
8782
+ for (auto Pair : PartialReductionChains) {
8783
+ PartialReductionChain Chain = Pair.first ;
8784
+ if (ExtendIsOnlyUsedByPartialReductions (Chain.ExtendA ) &&
8785
+ ExtendIsOnlyUsedByPartialReductions (Chain.ExtendB ))
8786
+ ScaledReductionExitInstrs.insert (std::make_pair (Chain.Reduction , Pair));
8787
+ }
8788
+ }
8789
+
8790
+ std::optional<std::pair<PartialReductionChain, unsigned >>
8791
+ VPRecipeBuilder::getScaledReduction (PHINode *PHI,
8792
+ const RecurrenceDescriptor &Rdx,
8793
+ VFRange &Range) {
8794
+ // TODO: Allow scaling reductions when predicating. The select at
8795
+ // the end of the loop chooses between the phi value and most recent
8796
+ // reduction result, both of which have different VFs to the active lane
8797
+ // mask when scaling.
8798
+ if (CM.blockNeedsPredicationForAnyReason (Rdx.getLoopExitInstr ()->getParent ()))
8799
+ return std::nullopt;
8800
+
8801
+ auto *Update = dyn_cast<BinaryOperator>(Rdx.getLoopExitInstr ());
8802
+ if (!Update)
8803
+ return std::nullopt;
8804
+
8805
+ Value *Op = Update->getOperand (0 );
8806
+ if (Op == PHI)
8807
+ Op = Update->getOperand (1 );
8808
+
8809
+ auto *BinOp = dyn_cast<BinaryOperator>(Op);
8810
+ if (!BinOp || !BinOp->hasOneUse ())
8811
+ return std::nullopt;
8812
+
8813
+ using namespace llvm ::PatternMatch;
8814
+ Value *A, *B;
8815
+ if (!match (BinOp->getOperand (0 ), m_ZExtOrSExt (m_Value (A))) ||
8816
+ !match (BinOp->getOperand (1 ), m_ZExtOrSExt (m_Value (B))))
8817
+ return std::nullopt;
8818
+
8819
+ Instruction *ExtA = cast<Instruction>(BinOp->getOperand (0 ));
8820
+ Instruction *ExtB = cast<Instruction>(BinOp->getOperand (1 ));
8821
+
8822
+ // Check that the extends extend from the same type.
8823
+ if (A->getType () != B->getType ())
8824
+ return std::nullopt;
8825
+
8826
+ TTI::PartialReductionExtendKind OpAExtend =
8827
+ TargetTransformInfo::getPartialReductionExtendKind (ExtA);
8828
+ TTI::PartialReductionExtendKind OpBExtend =
8829
+ TargetTransformInfo::getPartialReductionExtendKind (ExtB);
8830
+
8831
+ PartialReductionChain Chain (Rdx.getLoopExitInstr (), ExtA, ExtB, BinOp);
8832
+
8833
+ unsigned TargetScaleFactor =
8834
+ PHI->getType ()->getPrimitiveSizeInBits ().getKnownScalarFactor (
8835
+ A->getType ()->getPrimitiveSizeInBits ());
8836
+
8837
+ if (LoopVectorizationPlanner::getDecisionAndClampRange (
8838
+ [&](ElementCount VF) {
8839
+ InstructionCost Cost = TTI->getPartialReductionCost (
8840
+ Update->getOpcode (), A->getType (), PHI->getType (), VF,
8841
+ OpAExtend, OpBExtend, std::make_optional (BinOp->getOpcode ()));
8842
+ return Cost.isValid ();
8843
+ },
8844
+ Range))
8845
+ return std::make_pair (Chain, TargetScaleFactor);
8846
+
8847
+ return std::nullopt;
8848
+ }
8849
+
8749
8850
VPRecipeBase *
8750
8851
VPRecipeBuilder::tryToCreateWidenRecipe (Instruction *Instr,
8751
8852
ArrayRef<VPValue *> Operands,
@@ -8770,9 +8871,14 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
8770
8871
Legal->getReductionVars ().find (Phi)->second ;
8771
8872
assert (RdxDesc.getRecurrenceStartValue () ==
8772
8873
Phi->getIncomingValueForBlock (OrigLoop->getLoopPreheader ()));
8773
- PhiRecipe = new VPReductionPHIRecipe (Phi, RdxDesc, *StartV,
8774
- CM.isInLoopReduction (Phi),
8775
- CM.useOrderedReductions (RdxDesc));
8874
+
8875
+ // If the PHI is used by a partial reduction, set the scale factor.
8876
+ std::optional<std::pair<PartialReductionChain, unsigned >> Pair =
8877
+ getScaledReductionForInstr (RdxDesc.getLoopExitInstr ());
8878
+ unsigned ScaleFactor = Pair ? Pair->second : 1 ;
8879
+ PhiRecipe = new VPReductionPHIRecipe (
8880
+ Phi, RdxDesc, *StartV, CM.isInLoopReduction (Phi),
8881
+ CM.useOrderedReductions (RdxDesc), ScaleFactor);
8776
8882
} else {
8777
8883
// TODO: Currently fixed-order recurrences are modeled as chains of
8778
8884
// first-order recurrences. If there are no users of the intermediate
@@ -8804,6 +8910,9 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
8804
8910
if (isa<LoadInst>(Instr) || isa<StoreInst>(Instr))
8805
8911
return tryToWidenMemory (Instr, Operands, Range);
8806
8912
8913
+ if (getScaledReductionForInstr (Instr))
8914
+ return tryToCreatePartialReduction (Instr, Operands);
8915
+
8807
8916
if (!shouldWiden (Instr, Range))
8808
8917
return nullptr ;
8809
8918
@@ -8824,6 +8933,21 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
8824
8933
return tryToWiden (Instr, Operands, VPBB);
8825
8934
}
8826
8935
8936
+ VPRecipeBase *
8937
+ VPRecipeBuilder::tryToCreatePartialReduction (Instruction *Reduction,
8938
+ ArrayRef<VPValue *> Operands) {
8939
+ assert (Operands.size () == 2 &&
8940
+ " Unexpected number of operands for partial reduction" );
8941
+
8942
+ VPValue *BinOp = Operands[0 ];
8943
+ VPValue *Phi = Operands[1 ];
8944
+ if (isa<VPReductionPHIRecipe>(BinOp->getDefiningRecipe ()))
8945
+ std::swap (BinOp, Phi);
8946
+
8947
+ return new VPPartialReductionRecipe (Reduction->getOpcode (), BinOp, Phi,
8948
+ Reduction);
8949
+ }
8950
+
8827
8951
void LoopVectorizationPlanner::buildVPlansWithVPRecipes (ElementCount MinVF,
8828
8952
ElementCount MaxVF) {
8829
8953
assert (OrigLoop->isInnermost () && " Inner loop expected." );
@@ -9247,7 +9371,8 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
9247
9371
bool HasNUW = !IVUpdateMayOverflow || Style == TailFoldingStyle::None;
9248
9372
addCanonicalIVRecipes (*Plan, Legal->getWidestInductionType (), HasNUW, DL);
9249
9373
9250
- VPRecipeBuilder RecipeBuilder (*Plan, OrigLoop, TLI, Legal, CM, PSE, Builder);
9374
+ VPRecipeBuilder RecipeBuilder (*Plan, OrigLoop, TLI, &TTI, Legal, CM, PSE,
9375
+ Builder);
9251
9376
9252
9377
// ---------------------------------------------------------------------------
9253
9378
// Pre-construction: record ingredients whose recipes we'll need to further
@@ -9293,6 +9418,9 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
9293
9418
bool NeedsBlends = BB != HeaderBB && !BB->phis ().empty ();
9294
9419
return Legal->blockNeedsPredication (BB) || NeedsBlends;
9295
9420
});
9421
+
9422
+ RecipeBuilder.collectScaledReductions (Range);
9423
+
9296
9424
auto *MiddleVPBB = Plan->getMiddleBlock ();
9297
9425
VPBasicBlock::iterator MBIP = MiddleVPBB->getFirstNonPhi ();
9298
9426
for (BasicBlock *BB : make_range (DFS.beginRPO (), DFS.endRPO ())) {
0 commit comments