@@ -8790,12 +8790,12 @@ VPReplicateRecipe *VPRecipeBuilder::handleReplication(Instruction *I,
8790
8790
// / are valid so recipes can be formed later.
8791
8791
void VPRecipeBuilder::collectScaledReductions (VFRange &Range) {
8792
8792
// Find all possible partial reductions.
8793
- SmallVector<std::pair<PartialReductionChain, unsigned >, 1 >
8793
+ SmallVector<std::pair<PartialReductionChain, unsigned >>
8794
8794
PartialReductionChains;
8795
- for (const auto &[Phi, RdxDesc] : Legal->getReductionVars ())
8796
- if (std::optional<std::pair<PartialReductionChain, unsigned >> Pair =
8797
- getScaledReduction (Phi, RdxDesc, Range))
8798
- PartialReductionChains. push_back (*Pair);
8795
+ for (const auto &[Phi, RdxDesc] : Legal->getReductionVars ()) {
8796
+ if (auto SR = getScaledReduction (Phi, RdxDesc. getLoopExitInstr (), Range))
8797
+ PartialReductionChains. append (*SR);
8798
+ }
8799
8799
8800
8800
// A partial reduction is invalid if any of its extends are used by
8801
8801
// something that isn't another partial reduction. This is because the
@@ -8823,26 +8823,42 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
8823
8823
}
8824
8824
}
8825
8825
8826
- std::optional<std::pair<PartialReductionChain, unsigned >>
8827
- VPRecipeBuilder::getScaledReduction (PHINode *PHI,
8828
- const RecurrenceDescriptor &Rdx ,
8826
+ std::optional<SmallVector< std::pair<PartialReductionChain, unsigned > >>
8827
+ VPRecipeBuilder::getScaledReduction (Instruction *PHI,
8828
+ Instruction *RdxExitInstr ,
8829
8829
VFRange &Range) {
8830
+
8831
+ if (!CM.TheLoop ->contains (RdxExitInstr))
8832
+ return std::nullopt;
8833
+
8830
8834
// TODO: Allow scaling reductions when predicating. The select at
8831
8835
// the end of the loop chooses between the phi value and most recent
8832
8836
// reduction result, both of which have different VFs to the active lane
8833
8837
// mask when scaling.
8834
- if (CM.blockNeedsPredicationForAnyReason (Rdx. getLoopExitInstr () ->getParent ()))
8838
+ if (CM.blockNeedsPredicationForAnyReason (RdxExitInstr ->getParent ()))
8835
8839
return std::nullopt;
8836
8840
8837
- auto *Update = dyn_cast<BinaryOperator>(Rdx. getLoopExitInstr () );
8841
+ auto *Update = dyn_cast<BinaryOperator>(RdxExitInstr );
8838
8842
if (!Update)
8839
8843
return std::nullopt;
8840
8844
8841
8845
Value *Op = Update->getOperand (0 );
8842
8846
Value *PhiOp = Update->getOperand (1 );
8843
- if (Op == PHI) {
8844
- Op = Update->getOperand (1 );
8845
- PhiOp = Update->getOperand (0 );
8847
+ if (Op == PHI)
8848
+ std::swap (Op, PhiOp);
8849
+
8850
+ SmallVector<std::pair<PartialReductionChain, unsigned >> Chains;
8851
+
8852
+ if (auto *OpInst = dyn_cast<Instruction>(Op)) {
8853
+ if (auto SR0 = getScaledReduction (PHI, OpInst, Range)) {
8854
+ Chains.append (*SR0);
8855
+ PHI = SR0->rbegin ()->first .Reduction ;
8856
+
8857
+ Op = Update->getOperand (0 );
8858
+ PhiOp = Update->getOperand (1 );
8859
+ if (Op == PHI)
8860
+ std::swap (Op, PhiOp);
8861
+ }
8846
8862
}
8847
8863
if (PhiOp != PHI)
8848
8864
return std::nullopt;
@@ -8860,12 +8876,16 @@ VPRecipeBuilder::getScaledReduction(PHINode *PHI,
8860
8876
Instruction *ExtA = cast<Instruction>(BinOp->getOperand (0 ));
8861
8877
Instruction *ExtB = cast<Instruction>(BinOp->getOperand (1 ));
8862
8878
8879
+ // Check that the extends extend from the same type.
8880
+ if (A->getType () != B->getType ())
8881
+ return std::nullopt;
8882
+
8863
8883
TTI::PartialReductionExtendKind OpAExtend =
8864
8884
TargetTransformInfo::getPartialReductionExtendKind (ExtA);
8865
8885
TTI::PartialReductionExtendKind OpBExtend =
8866
8886
TargetTransformInfo::getPartialReductionExtendKind (ExtB);
8867
8887
8868
- PartialReductionChain Chain (Rdx. getLoopExitInstr () , ExtA, ExtB, BinOp);
8888
+ PartialReductionChain Chain (RdxExitInstr , ExtA, ExtB, BinOp);
8869
8889
8870
8890
unsigned TargetScaleFactor =
8871
8891
PHI->getType ()->getPrimitiveSizeInBits ().getKnownScalarFactor (
@@ -8880,9 +8900,9 @@ VPRecipeBuilder::getScaledReduction(PHINode *PHI,
8880
8900
return Cost.isValid ();
8881
8901
},
8882
8902
Range))
8883
- return std::make_pair (Chain, TargetScaleFactor);
8903
+ Chains. push_back ( std::make_pair (Chain, TargetScaleFactor) );
8884
8904
8885
- return std::nullopt ;
8905
+ return Chains ;
8886
8906
}
8887
8907
8888
8908
VPRecipeBase *
@@ -8979,7 +8999,8 @@ VPRecipeBuilder::tryToCreatePartialReduction(Instruction *Reduction,
8979
8999
8980
9000
VPValue *BinOp = Operands[0 ];
8981
9001
VPValue *Phi = Operands[1 ];
8982
- if (isa<VPReductionPHIRecipe>(BinOp->getDefiningRecipe ()))
9002
+ VPRecipeBase *BinOpRecipe = BinOp->getDefiningRecipe ();
9003
+ if (isa<VPReductionPHIRecipe>(BinOpRecipe) || isa<VPPartialReductionRecipe>(BinOpRecipe))
8983
9004
std::swap (BinOp, Phi);
8984
9005
8985
9006
return new VPPartialReductionRecipe (Reduction->getOpcode (), BinOp, Phi,
0 commit comments