@@ -1377,6 +1377,18 @@ class BoUpSLP {
1377
1377
return MinBWs.at(VectorizableTree.front().get()).second;
1378
1378
}
1379
1379
1380
+ /// Returns reduction bitwidth and signedness, if it does not match the
1381
+ /// original requested size.
1382
+ std::optional<std::pair<unsigned, bool>> getReductionBitWidthAndSign() const {
1383
+ if (ReductionBitWidth == 0 ||
1384
+ ReductionBitWidth >=
1385
+ DL->getTypeSizeInBits(
1386
+ VectorizableTree.front()->Scalars.front()->getType()))
1387
+ return std::nullopt;
1388
+ return std::make_pair(ReductionBitWidth,
1389
+ MinBWs.at(VectorizableTree.front().get()).second);
1390
+ }
1391
+
1380
1392
/// Builds external uses of the vectorized scalars, i.e. the list of
1381
1393
/// vectorized scalars to be extracted, their lanes and their scalar users. \p
1382
1394
/// ExternallyUsedValues contains additional list of external uses to handle
@@ -17916,24 +17928,37 @@ void BoUpSLP::computeMinimumValueSizes() {
17916
17928
// Add reduction ops sizes, if any.
17917
17929
if (UserIgnoreList &&
17918
17930
isa<IntegerType>(VectorizableTree.front()->Scalars.front()->getType())) {
17919
- for (Value *V : *UserIgnoreList) {
17920
- auto NumSignBits = ComputeNumSignBits(V, *DL, 0, AC, nullptr, DT);
17921
- auto NumTypeBits = DL->getTypeSizeInBits(V->getType());
17922
- unsigned BitWidth1 = NumTypeBits - NumSignBits;
17923
- if (!isKnownNonNegative(V, SimplifyQuery(*DL)))
17924
- ++BitWidth1;
17925
- unsigned BitWidth2 = BitWidth1;
17926
- if (!RecurrenceDescriptor::isIntMinMaxRecurrenceKind(::getRdxKind(V))) {
17927
- auto Mask = DB->getDemandedBits(cast<Instruction>(V));
17928
- BitWidth2 = Mask.getBitWidth() - Mask.countl_zero();
17931
+ // Convert vector_reduce_add(ZExt(<n x i1>)) to ZExtOrTrunc(ctpop(bitcast <n
17932
+ // x i1> to in)).
17933
+ if (all_of(*UserIgnoreList,
17934
+ [](Value *V) {
17935
+ return cast<Instruction>(V)->getOpcode() == Instruction::Add;
17936
+ }) &&
17937
+ VectorizableTree.front()->State == TreeEntry::Vectorize &&
17938
+ VectorizableTree.front()->getOpcode() == Instruction::ZExt &&
17939
+ cast<CastInst>(VectorizableTree.front()->getMainOp())->getSrcTy() ==
17940
+ Builder.getInt1Ty()) {
17941
+ ReductionBitWidth = 1;
17942
+ } else {
17943
+ for (Value *V : *UserIgnoreList) {
17944
+ unsigned NumSignBits = ComputeNumSignBits(V, *DL, 0, AC, nullptr, DT);
17945
+ TypeSize NumTypeBits = DL->getTypeSizeInBits(V->getType());
17946
+ unsigned BitWidth1 = NumTypeBits - NumSignBits;
17947
+ if (!isKnownNonNegative(V, SimplifyQuery(*DL)))
17948
+ ++BitWidth1;
17949
+ unsigned BitWidth2 = BitWidth1;
17950
+ if (!RecurrenceDescriptor::isIntMinMaxRecurrenceKind(::getRdxKind(V))) {
17951
+ APInt Mask = DB->getDemandedBits(cast<Instruction>(V));
17952
+ BitWidth2 = Mask.getBitWidth() - Mask.countl_zero();
17953
+ }
17954
+ ReductionBitWidth =
17955
+ std::max(std::min(BitWidth1, BitWidth2), ReductionBitWidth);
17929
17956
}
17930
- ReductionBitWidth =
17931
- std::max(std::min(BitWidth1, BitWidth2), ReductionBitWidth);
17932
- }
17933
- if (ReductionBitWidth < 8 && ReductionBitWidth > 1)
17934
- ReductionBitWidth = 8;
17957
+ if (ReductionBitWidth < 8 && ReductionBitWidth > 1)
17958
+ ReductionBitWidth = 8;
17935
17959
17936
- ReductionBitWidth = bit_ceil(ReductionBitWidth);
17960
+ ReductionBitWidth = bit_ceil(ReductionBitWidth);
17961
+ }
17937
17962
}
17938
17963
bool IsTopRoot = NodeIdx == 0;
17939
17964
while (NodeIdx < VectorizableTree.size() &&
@@ -19789,8 +19814,8 @@ class HorizontalReduction {
19789
19814
19790
19815
// Estimate cost.
19791
19816
InstructionCost TreeCost = V.getTreeCost(VL);
19792
- InstructionCost ReductionCost =
19793
- getReductionCost( TTI, VL, IsCmpSelMinMax, ReduxWidth, RdxFMF );
19817
+ InstructionCost ReductionCost = getReductionCost(
19818
+ TTI, VL, IsCmpSelMinMax, RdxFMF, V.getReductionBitWidthAndSign() );
19794
19819
InstructionCost Cost = TreeCost + ReductionCost;
19795
19820
LLVM_DEBUG(dbgs() << "SLP: Found cost = " << Cost
19796
19821
<< " for reduction\n");
@@ -19895,10 +19920,12 @@ class HorizontalReduction {
19895
19920
createStrideMask(I, ScalarTyNumElements, VL.size());
19896
19921
Value *Lane = Builder.CreateShuffleVector(VectorizedRoot, Mask);
19897
19922
ReducedSubTree = Builder.CreateInsertElement(
19898
- ReducedSubTree, emitReduction(Lane, Builder, TTI), I);
19923
+ ReducedSubTree,
19924
+ emitReduction(Lane, Builder, TTI, RdxRootInst->getType()), I);
19899
19925
}
19900
19926
} else {
19901
- ReducedSubTree = emitReduction(VectorizedRoot, Builder, TTI);
19927
+ ReducedSubTree = emitReduction(VectorizedRoot, Builder, TTI,
19928
+ RdxRootInst->getType());
19902
19929
}
19903
19930
if (ReducedSubTree->getType() != VL.front()->getType()) {
19904
19931
assert(ReducedSubTree->getType() != VL.front()->getType() &&
@@ -20079,12 +20106,13 @@ class HorizontalReduction {
20079
20106
20080
20107
private:
20081
20108
/// Calculate the cost of a reduction.
20082
- InstructionCost getReductionCost(TargetTransformInfo *TTI,
20083
- ArrayRef<Value *> ReducedVals,
20084
- bool IsCmpSelMinMax, unsigned ReduxWidth ,
20085
- FastMathFlags FMF ) {
20109
+ InstructionCost getReductionCost(
20110
+ TargetTransformInfo *TTI, ArrayRef<Value *> ReducedVals,
20111
+ bool IsCmpSelMinMax, FastMathFlags FMF ,
20112
+ const std::optional<std::pair<unsigned, bool>> BitwidthAndSign ) {
20086
20113
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
20087
20114
Type *ScalarTy = ReducedVals.front()->getType();
20115
+ unsigned ReduxWidth = ReducedVals.size();
20088
20116
FixedVectorType *VectorTy = getWidenedType(ScalarTy, ReduxWidth);
20089
20117
InstructionCost VectorCost = 0, ScalarCost;
20090
20118
// If all of the reduced values are constant, the vector cost is 0, since
@@ -20143,8 +20171,22 @@ class HorizontalReduction {
20143
20171
VecTy, APInt::getAllOnes(ScalarTyNumElements), /*Insert*/ true,
20144
20172
/*Extract*/ false, TTI::TCK_RecipThroughput);
20145
20173
} else {
20146
- VectorCost = TTI->getArithmeticReductionCost(RdxOpcode, VectorTy, FMF,
20147
- CostKind);
20174
+ auto [Bitwidth, IsSigned] =
20175
+ BitwidthAndSign.value_or(std::make_pair(0u, false));
20176
+ if (RdxKind == RecurKind::Add && Bitwidth == 1) {
20177
+ // Represent vector_reduce_add(ZExt(<n x i1>)) to
20178
+ // ZExtOrTrunc(ctpop(bitcast <n x i1> to in)).
20179
+ auto *IntTy = IntegerType::get(ScalarTy->getContext(), ReduxWidth);
20180
+ IntrinsicCostAttributes ICA(Intrinsic::ctpop, IntTy, {IntTy}, FMF);
20181
+ VectorCost =
20182
+ TTI->getCastInstrCost(Instruction::BitCast, IntTy,
20183
+ getWidenedType(ScalarTy, ReduxWidth),
20184
+ TTI::CastContextHint::None, CostKind) +
20185
+ TTI->getIntrinsicInstrCost(ICA, CostKind);
20186
+ } else {
20187
+ VectorCost = TTI->getArithmeticReductionCost(RdxOpcode, VectorTy,
20188
+ FMF, CostKind);
20189
+ }
20148
20190
}
20149
20191
}
20150
20192
ScalarCost = EvaluateScalarCost([&]() {
@@ -20181,11 +20223,22 @@ class HorizontalReduction {
20181
20223
20182
20224
/// Emit a horizontal reduction of the vectorized value.
20183
20225
Value *emitReduction(Value *VectorizedValue, IRBuilderBase &Builder,
20184
- const TargetTransformInfo *TTI) {
20226
+ const TargetTransformInfo *TTI, Type *DestTy ) {
20185
20227
assert(VectorizedValue && "Need to have a vectorized tree node");
20186
20228
assert(RdxKind != RecurKind::FMulAdd &&
20187
20229
"A call to the llvm.fmuladd intrinsic is not handled yet");
20188
20230
20231
+ auto *FTy = cast<FixedVectorType>(VectorizedValue->getType());
20232
+ if (FTy->getScalarType() == Builder.getInt1Ty() &&
20233
+ RdxKind == RecurKind::Add &&
20234
+ DestTy->getScalarType() != FTy->getScalarType()) {
20235
+ // Convert vector_reduce_add(ZExt(<n x i1>)) to
20236
+ // ZExtOrTrunc(ctpop(bitcast <n x i1> to in)).
20237
+ Value *V = Builder.CreateBitCast(
20238
+ VectorizedValue, Builder.getIntNTy(FTy->getNumElements()));
20239
+ ++NumVectorInstructions;
20240
+ return Builder.CreateUnaryIntrinsic(Intrinsic::ctpop, V);
20241
+ }
20189
20242
++NumVectorInstructions;
20190
20243
return createSimpleReduction(Builder, VectorizedValue, RdxKind);
20191
20244
}
0 commit comments