@@ -8758,13 +8758,6 @@ bool VPRecipeBuilder::getScaledReductions(
8758
8758
if (!CM.TheLoop ->contains (RdxExitInstr))
8759
8759
return false ;
8760
8760
8761
- // TODO: Allow scaling reductions when predicating. The select at
8762
- // the end of the loop chooses between the phi value and most recent
8763
- // reduction result, both of which have different VFs to the active lane
8764
- // mask when scaling.
8765
- if (CM.blockNeedsPredicationForAnyReason (RdxExitInstr->getParent ()))
8766
- return false ;
8767
-
8768
8761
auto *Update = dyn_cast<BinaryOperator>(RdxExitInstr);
8769
8762
if (!Update)
8770
8763
return false ;
@@ -8926,8 +8919,19 @@ VPRecipeBuilder::tryToCreatePartialReduction(Instruction *Reduction,
8926
8919
isa<VPPartialReductionRecipe>(BinOpRecipe))
8927
8920
std::swap (BinOp, Accumulator);
8928
8921
8929
- return new VPPartialReductionRecipe (Reduction->getOpcode (), BinOp,
8930
- Accumulator, Reduction);
8922
+ unsigned ReductionOpcode = Reduction->getOpcode ();
8923
+ if (CM.blockNeedsPredicationForAnyReason (Reduction->getParent ())) {
8924
+ assert ((ReductionOpcode == Instruction::Add ||
8925
+ ReductionOpcode == Instruction::Sub) &&
8926
+ " Expected an ADD or SUB operation for predicated partial "
8927
+ " reductions (because the neutral element in the mask is zero)!" );
8928
+ VPValue *Mask = getBlockInMask (Reduction->getParent ());
8929
+ VPValue *Zero =
8930
+ Plan.getOrAddLiveIn (ConstantInt::get (Reduction->getType (), 0 ));
8931
+ BinOp = Builder.createSelect (Mask, BinOp, Zero, Reduction->getDebugLoc ());
8932
+ }
8933
+ return new VPPartialReductionRecipe (ReductionOpcode, BinOp, Accumulator,
8934
+ Reduction);
8931
8935
}
8932
8936
8933
8937
void LoopVectorizationPlanner::buildVPlansWithVPRecipes (ElementCount MinVF,
@@ -9735,7 +9739,11 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
9735
9739
// beginning of the dedicated latch block.
9736
9740
auto *OrigExitingVPV = PhiR->getBackedgeValue ();
9737
9741
auto *NewExitingVPV = PhiR->getBackedgeValue ();
9738
- if (!PhiR->isInLoop () && CM.foldTailByMasking ()) {
9742
+ // Don't output selects for partial reductions because they have an output
9743
+ // with fewer lanes than the VF. So the operands of the select would have
9744
+ // different numbers of lanes. Partial reductions mask the input instead.
9745
+ if (!PhiR->isInLoop () && CM.foldTailByMasking () &&
9746
+ !isa<VPPartialReductionRecipe>(OrigExitingVPV->getDefiningRecipe ())) {
9739
9747
VPValue *Cond = RecipeBuilder.getBlockInMask (OrigLoop->getHeader ());
9740
9748
assert (OrigExitingVPV->getDefiningRecipe ()->getParent () != LatchVPBB &&
9741
9749
" reduction recipe must be defined before latch" );
0 commit comments