@@ -22144,53 +22144,16 @@ class HorizontalReduction {
22144
22144
}
22145
22145
22146
22146
Type *ScalarTy = VL.front()->getType();
22147
- if (isa<FixedVectorType>(ScalarTy)) {
22148
- assert(SLPReVec && "FixedVectorType is not expected.");
22149
- unsigned ScalarTyNumElements = getNumElements(ScalarTy);
22150
- Value *ReducedSubTree = PoisonValue::get(
22151
- getWidenedType(ScalarTy->getScalarType(), ScalarTyNumElements));
22152
- for (unsigned I : seq<unsigned>(ScalarTyNumElements)) {
22153
- // Do reduction for each lane.
22154
- // e.g., do reduce add for
22155
- // VL[0] = <4 x Ty> <a, b, c, d>
22156
- // VL[1] = <4 x Ty> <e, f, g, h>
22157
- // Lane[0] = <2 x Ty> <a, e>
22158
- // Lane[1] = <2 x Ty> <b, f>
22159
- // Lane[2] = <2 x Ty> <c, g>
22160
- // Lane[3] = <2 x Ty> <d, h>
22161
- // result[0] = reduce add Lane[0]
22162
- // result[1] = reduce add Lane[1]
22163
- // result[2] = reduce add Lane[2]
22164
- // result[3] = reduce add Lane[3]
22165
- SmallVector<int, 16> Mask =
22166
- createStrideMask(I, ScalarTyNumElements, VL.size());
22167
- Value *Lane = Builder.CreateShuffleVector(VectorizedRoot, Mask);
22168
- Value *Val =
22169
- createSingleOp(Builder, *TTI, Lane,
22170
- OptReusedScalars && SameScaleFactor
22171
- ? SameValuesCounter.front().second
22172
- : 1,
22173
- Lane->getType()->getScalarType() !=
22174
- VL.front()->getType()->getScalarType()
22175
- ? V.isSignedMinBitwidthRootNode()
22176
- : true,
22177
- RdxRootInst->getType());
22178
- ReducedSubTree =
22179
- Builder.CreateInsertElement(ReducedSubTree, Val, I);
22180
- }
22181
- VectorizedTree = GetNewVectorizedTree(VectorizedTree, ReducedSubTree);
22182
- } else {
22183
- Type *VecTy = VectorizedRoot->getType();
22184
- Type *RedScalarTy = VecTy->getScalarType();
22185
- VectorValuesAndScales.emplace_back(
22186
- VectorizedRoot,
22187
- OptReusedScalars && SameScaleFactor
22188
- ? SameValuesCounter.front().second
22189
- : 1,
22190
- RedScalarTy != ScalarTy->getScalarType()
22191
- ? V.isSignedMinBitwidthRootNode()
22192
- : true);
22193
- }
22147
+ Type *VecTy = VectorizedRoot->getType();
22148
+ Type *RedScalarTy = VecTy->getScalarType();
22149
+ VectorValuesAndScales.emplace_back(
22150
+ VectorizedRoot,
22151
+ OptReusedScalars && SameScaleFactor
22152
+ ? SameValuesCounter.front().second
22153
+ : 1,
22154
+ RedScalarTy != ScalarTy->getScalarType()
22155
+ ? V.isSignedMinBitwidthRootNode()
22156
+ : true);
22194
22157
22195
22158
// Count vectorized reduced values to exclude them from final reduction.
22196
22159
for (Value *RdxVal : VL) {
@@ -22363,9 +22326,35 @@ class HorizontalReduction {
22363
22326
Value *createSingleOp(IRBuilderBase &Builder, const TargetTransformInfo &TTI,
22364
22327
Value *Vec, unsigned Scale, bool IsSigned,
22365
22328
Type *DestTy) {
22366
- Value *Rdx = emitReduction(Vec, Builder, &TTI, DestTy);
22367
- if (Rdx->getType() != DestTy->getScalarType())
22368
- Rdx = Builder.CreateIntCast(Rdx, DestTy->getScalarType(), IsSigned);
22329
+ Value *Rdx;
22330
+ if (auto *VecTy = dyn_cast<FixedVectorType>(DestTy)) {
22331
+ unsigned DestTyNumElements = getNumElements(VecTy);
22332
+ unsigned VF = getNumElements(Vec->getType()) / DestTyNumElements;
22333
+ Rdx = PoisonValue::get(
22334
+ getWidenedType(Vec->getType()->getScalarType(), DestTyNumElements));
22335
+ for (unsigned I : seq<unsigned>(DestTyNumElements)) {
22336
+ // Do reduction for each lane.
22337
+ // e.g., do reduce add for
22338
+ // VL[0] = <4 x Ty> <a, b, c, d>
22339
+ // VL[1] = <4 x Ty> <e, f, g, h>
22340
+ // Lane[0] = <2 x Ty> <a, e>
22341
+ // Lane[1] = <2 x Ty> <b, f>
22342
+ // Lane[2] = <2 x Ty> <c, g>
22343
+ // Lane[3] = <2 x Ty> <d, h>
22344
+ // result[0] = reduce add Lane[0]
22345
+ // result[1] = reduce add Lane[1]
22346
+ // result[2] = reduce add Lane[2]
22347
+ // result[3] = reduce add Lane[3]
22348
+ SmallVector<int, 16> Mask = createStrideMask(I, DestTyNumElements, VF);
22349
+ Value *Lane = Builder.CreateShuffleVector(Vec, Mask);
22350
+ Rdx = Builder.CreateInsertElement(
22351
+ Rdx, emitReduction(Lane, Builder, &TTI, DestTy), I);
22352
+ }
22353
+ } else {
22354
+ Rdx = emitReduction(Vec, Builder, &TTI, DestTy);
22355
+ }
22356
+ if (Rdx->getType() != DestTy)
22357
+ Rdx = Builder.CreateIntCast(Rdx, DestTy, IsSigned);
22369
22358
// Improved analysis for add/fadd/xor reductions with same scale
22370
22359
// factor for all operands of reductions. We can emit scalar ops for
22371
22360
// them instead.
@@ -22432,30 +22421,32 @@ class HorizontalReduction {
22432
22421
case RecurKind::FMul: {
22433
22422
unsigned RdxOpcode = RecurrenceDescriptor::getOpcode(RdxKind);
22434
22423
if (!AllConsts) {
22435
- if (auto *VecTy = dyn_cast<FixedVectorType>(ScalarTy)) {
22436
- assert(SLPReVec && "FixedVectorType is not expected.");
22437
- unsigned ScalarTyNumElements = VecTy->getNumElements();
22438
- for (unsigned I : seq<unsigned>(ReducedVals.size())) {
22439
- VectorCost += TTI->getShuffleCost(
22440
- TTI::SK_PermuteSingleSrc, VectorTy,
22441
- createStrideMask(I, ScalarTyNumElements, ReducedVals.size()));
22442
- VectorCost += TTI->getArithmeticReductionCost(RdxOpcode, VecTy, FMF,
22443
- CostKind);
22444
- }
22445
- VectorCost += TTI->getScalarizationOverhead(
22446
- VecTy, APInt::getAllOnes(ScalarTyNumElements), /*Insert*/ true,
22447
- /*Extract*/ false, TTI::TCK_RecipThroughput);
22448
- } else if (DoesRequireReductionOp) {
22449
- Type *RedTy = VectorTy->getElementType();
22450
- auto [RType, IsSigned] = R.getRootNodeTypeWithNoCast().value_or(
22451
- std::make_pair(RedTy, true));
22452
- if (RType == RedTy) {
22453
- VectorCost = TTI->getArithmeticReductionCost(RdxOpcode, VectorTy,
22454
- FMF, CostKind);
22424
+ if (DoesRequireReductionOp) {
22425
+ if (auto *VecTy = dyn_cast<FixedVectorType>(ScalarTy)) {
22426
+ assert(SLPReVec && "FixedVectorType is not expected.");
22427
+ unsigned ScalarTyNumElements = VecTy->getNumElements();
22428
+ for (unsigned I : seq<unsigned>(ReducedVals.size())) {
22429
+ VectorCost += TTI->getShuffleCost(
22430
+ TTI::SK_PermuteSingleSrc, VectorTy,
22431
+ createStrideMask(I, ScalarTyNumElements, ReducedVals.size()));
22432
+ VectorCost += TTI->getArithmeticReductionCost(RdxOpcode, VecTy,
22433
+ FMF, CostKind);
22434
+ }
22435
+ VectorCost += TTI->getScalarizationOverhead(
22436
+ VecTy, APInt::getAllOnes(ScalarTyNumElements), /*Insert*/ true,
22437
+ /*Extract*/ false, TTI::TCK_RecipThroughput);
22455
22438
} else {
22456
- VectorCost = TTI->getExtendedReductionCost(
22457
- RdxOpcode, !IsSigned, RedTy, getWidenedType(RType, ReduxWidth),
22458
- FMF, CostKind);
22439
+ Type *RedTy = VectorTy->getElementType();
22440
+ auto [RType, IsSigned] = R.getRootNodeTypeWithNoCast().value_or(
22441
+ std::make_pair(RedTy, true));
22442
+ if (RType == RedTy) {
22443
+ VectorCost = TTI->getArithmeticReductionCost(RdxOpcode, VectorTy,
22444
+ FMF, CostKind);
22445
+ } else {
22446
+ VectorCost = TTI->getExtendedReductionCost(
22447
+ RdxOpcode, !IsSigned, RedTy,
22448
+ getWidenedType(RType, ReduxWidth), FMF, CostKind);
22449
+ }
22459
22450
}
22460
22451
} else {
22461
22452
Type *RedTy = VectorTy->getElementType();
0 commit comments