@@ -8799,12 +8799,10 @@ VPReplicateRecipe *VPRecipeBuilder::handleReplication(Instruction *I,
8799
8799
// / are valid so recipes can be formed later.
8800
8800
void VPRecipeBuilder::collectScaledReductions (VFRange &Range) {
8801
8801
// Find all possible partial reductions.
8802
- SmallVector<std::pair<PartialReductionChain, unsigned >, 1 >
8802
+ SmallVector<std::pair<PartialReductionChain, unsigned >>
8803
8803
PartialReductionChains;
8804
8804
for (const auto &[Phi, RdxDesc] : Legal->getReductionVars ())
8805
- if (std::optional<std::pair<PartialReductionChain, unsigned >> Pair =
8806
- getScaledReduction (Phi, RdxDesc, Range))
8807
- PartialReductionChains.push_back (*Pair);
8805
+ PartialReductionChains.append (getScaledReduction (Phi, RdxDesc.getLoopExitInstr (), Range));
8808
8806
8809
8807
// A partial reduction is invalid if any of its extends are used by
8810
8808
// something that isn't another partial reduction. This is because the
@@ -8832,48 +8830,65 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
8832
8830
}
8833
8831
}
8834
8832
8835
- std::optional <std::pair<PartialReductionChain, unsigned >>
8836
- VPRecipeBuilder::getScaledReduction (PHINode *PHI,
8837
- const RecurrenceDescriptor &Rdx ,
8833
+ SmallVector <std::pair<PartialReductionChain, unsigned >>
8834
+ VPRecipeBuilder::getScaledReduction (Instruction *PHI,
8835
+ Instruction *RdxExitInstr ,
8838
8836
VFRange &Range) {
8837
+ SmallVector<std::pair<PartialReductionChain, unsigned >> Chains;
8838
+
8839
+ if (!CM.TheLoop ->contains (RdxExitInstr))
8840
+ return Chains;
8841
+
8839
8842
// TODO: Allow scaling reductions when predicating. The select at
8840
8843
// the end of the loop chooses between the phi value and most recent
8841
8844
// reduction result, both of which have different VFs to the active lane
8842
8845
// mask when scaling.
8843
- if (CM.blockNeedsPredicationForAnyReason (Rdx. getLoopExitInstr () ->getParent ()))
8844
- return std::nullopt ;
8846
+ if (CM.blockNeedsPredicationForAnyReason (RdxExitInstr ->getParent ()))
8847
+ return Chains ;
8845
8848
8846
- auto *Update = dyn_cast<BinaryOperator>(Rdx. getLoopExitInstr () );
8849
+ auto *Update = dyn_cast<BinaryOperator>(RdxExitInstr );
8847
8850
if (!Update)
8848
- return std::nullopt ;
8851
+ return Chains ;
8849
8852
8850
8853
Value *Op = Update->getOperand (0 );
8851
8854
if (Op == PHI)
8852
8855
Op = Update->getOperand (1 );
8853
8856
8857
+ if (auto *OpInst = dyn_cast<Instruction>(Op)) {
8858
+ auto SR0 = getScaledReduction (PHI, OpInst, Range);
8859
+ if (!SR0.empty ()) {
8860
+ Chains.append (SR0);
8861
+ PHI = SR0.rbegin ()->first .Reduction ;
8862
+
8863
+ Op = Update->getOperand (0 );
8864
+ if (Op == PHI)
8865
+ Op = Update->getOperand (1 );
8866
+ }
8867
+ }
8868
+
8854
8869
auto *BinOp = dyn_cast<BinaryOperator>(Op);
8855
8870
if (!BinOp || !BinOp->hasOneUse ())
8856
- return std::nullopt ;
8871
+ return Chains ;
8857
8872
8858
8873
using namespace llvm ::PatternMatch;
8859
8874
Value *A, *B;
8860
8875
if (!match (BinOp->getOperand (0 ), m_ZExtOrSExt (m_Value (A))) ||
8861
8876
!match (BinOp->getOperand (1 ), m_ZExtOrSExt (m_Value (B))))
8862
- return std::nullopt ;
8877
+ return Chains ;
8863
8878
8864
8879
Instruction *ExtA = cast<Instruction>(BinOp->getOperand (0 ));
8865
8880
Instruction *ExtB = cast<Instruction>(BinOp->getOperand (1 ));
8866
8881
8867
8882
// Check that the extends extend from the same type.
8868
8883
if (A->getType () != B->getType ())
8869
- return std::nullopt ;
8884
+ return Chains ;
8870
8885
8871
8886
TTI::PartialReductionExtendKind OpAExtend =
8872
8887
TargetTransformInfo::getPartialReductionExtendKind (ExtA);
8873
8888
TTI::PartialReductionExtendKind OpBExtend =
8874
8889
TargetTransformInfo::getPartialReductionExtendKind (ExtB);
8875
8890
8876
- PartialReductionChain Chain (Rdx. getLoopExitInstr () , ExtA, ExtB, BinOp);
8891
+ PartialReductionChain Chain (RdxExitInstr , ExtA, ExtB, BinOp);
8877
8892
8878
8893
unsigned TargetScaleFactor =
8879
8894
PHI->getType ()->getPrimitiveSizeInBits ().getKnownScalarFactor (
@@ -8887,9 +8902,9 @@ VPRecipeBuilder::getScaledReduction(PHINode *PHI,
8887
8902
return Cost.isValid ();
8888
8903
},
8889
8904
Range))
8890
- return std::make_pair (Chain, TargetScaleFactor);
8905
+ Chains. push_back ( std::make_pair (Chain, TargetScaleFactor) );
8891
8906
8892
- return std::nullopt ;
8907
+ return Chains ;
8893
8908
}
8894
8909
8895
8910
VPRecipeBase *
@@ -8986,7 +9001,8 @@ VPRecipeBuilder::tryToCreatePartialReduction(Instruction *Reduction,
8986
9001
8987
9002
VPValue *BinOp = Operands[0 ];
8988
9003
VPValue *Phi = Operands[1 ];
8989
- if (isa<VPReductionPHIRecipe>(BinOp->getDefiningRecipe ()))
9004
+ VPRecipeBase *BinOpRecipe = BinOp->getDefiningRecipe ();
9005
+ if (isa<VPReductionPHIRecipe>(BinOpRecipe) || isa<VPPartialReductionRecipe>(BinOpRecipe))
8990
9006
std::swap (BinOp, Phi);
8991
9007
8992
9008
return new VPPartialReductionRecipe (Reduction->getOpcode (), BinOp, Phi,
0 commit comments