Skip to content

Commit 14bdcef

Browse files
committed
[SLP]Model reduction_add(ext(<n x i1>)) as ext(ctpop(bitcast <n x i1> to int n))
Currently sequences reduction_add(ext(<n x i1>)) are modeled as vector extensions + reduction add, but later instcombiner transforms it into ext(ctcpop(bitcast <n x i1> to int n)). Patch adds direct support for this in SLP vectorizer, which enables better cost estimation. AVX512, -O3+LTO CINT2006/445.gobmk - extra vector code Prolangs-C/bison - extra vector code Benchmarks/NPB-serial/is - 16 x + 8 x reductions vectorized as 24 x reduction Reviewers: RKSimon Reviewed By: RKSimon Pull Request: #116875
1 parent c7d5ef4 commit 14bdcef

File tree

3 files changed

+86
-31
lines changed

3 files changed

+86
-31
lines changed

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Lines changed: 80 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1377,6 +1377,18 @@ class BoUpSLP {
13771377
return MinBWs.at(VectorizableTree.front().get()).second;
13781378
}
13791379

1380+
/// Returns reduction bitwidth and signedness, if it does not match the
1381+
/// original requested size.
1382+
std::optional<std::pair<unsigned, bool>> getReductionBitWidthAndSign() const {
1383+
if (ReductionBitWidth == 0 ||
1384+
ReductionBitWidth >=
1385+
DL->getTypeSizeInBits(
1386+
VectorizableTree.front()->Scalars.front()->getType()))
1387+
return std::nullopt;
1388+
return std::make_pair(ReductionBitWidth,
1389+
MinBWs.at(VectorizableTree.front().get()).second);
1390+
}
1391+
13801392
/// Builds external uses of the vectorized scalars, i.e. the list of
13811393
/// vectorized scalars to be extracted, their lanes and their scalar users. \p
13821394
/// ExternallyUsedValues contains additional list of external uses to handle
@@ -17916,24 +17928,37 @@ void BoUpSLP::computeMinimumValueSizes() {
1791617928
// Add reduction ops sizes, if any.
1791717929
if (UserIgnoreList &&
1791817930
isa<IntegerType>(VectorizableTree.front()->Scalars.front()->getType())) {
17919-
for (Value *V : *UserIgnoreList) {
17920-
auto NumSignBits = ComputeNumSignBits(V, *DL, 0, AC, nullptr, DT);
17921-
auto NumTypeBits = DL->getTypeSizeInBits(V->getType());
17922-
unsigned BitWidth1 = NumTypeBits - NumSignBits;
17923-
if (!isKnownNonNegative(V, SimplifyQuery(*DL)))
17924-
++BitWidth1;
17925-
unsigned BitWidth2 = BitWidth1;
17926-
if (!RecurrenceDescriptor::isIntMinMaxRecurrenceKind(::getRdxKind(V))) {
17927-
auto Mask = DB->getDemandedBits(cast<Instruction>(V));
17928-
BitWidth2 = Mask.getBitWidth() - Mask.countl_zero();
17931+
// Convert vector_reduce_add(ZExt(<n x i1>)) to ZExtOrTrunc(ctpop(bitcast <n
17932+
// x i1> to in)).
17933+
if (all_of(*UserIgnoreList,
17934+
[](Value *V) {
17935+
return cast<Instruction>(V)->getOpcode() == Instruction::Add;
17936+
}) &&
17937+
VectorizableTree.front()->State == TreeEntry::Vectorize &&
17938+
VectorizableTree.front()->getOpcode() == Instruction::ZExt &&
17939+
cast<CastInst>(VectorizableTree.front()->getMainOp())->getSrcTy() ==
17940+
Builder.getInt1Ty()) {
17941+
ReductionBitWidth = 1;
17942+
} else {
17943+
for (Value *V : *UserIgnoreList) {
17944+
unsigned NumSignBits = ComputeNumSignBits(V, *DL, 0, AC, nullptr, DT);
17945+
TypeSize NumTypeBits = DL->getTypeSizeInBits(V->getType());
17946+
unsigned BitWidth1 = NumTypeBits - NumSignBits;
17947+
if (!isKnownNonNegative(V, SimplifyQuery(*DL)))
17948+
++BitWidth1;
17949+
unsigned BitWidth2 = BitWidth1;
17950+
if (!RecurrenceDescriptor::isIntMinMaxRecurrenceKind(::getRdxKind(V))) {
17951+
APInt Mask = DB->getDemandedBits(cast<Instruction>(V));
17952+
BitWidth2 = Mask.getBitWidth() - Mask.countl_zero();
17953+
}
17954+
ReductionBitWidth =
17955+
std::max(std::min(BitWidth1, BitWidth2), ReductionBitWidth);
1792917956
}
17930-
ReductionBitWidth =
17931-
std::max(std::min(BitWidth1, BitWidth2), ReductionBitWidth);
17932-
}
17933-
if (ReductionBitWidth < 8 && ReductionBitWidth > 1)
17934-
ReductionBitWidth = 8;
17957+
if (ReductionBitWidth < 8 && ReductionBitWidth > 1)
17958+
ReductionBitWidth = 8;
1793517959

17936-
ReductionBitWidth = bit_ceil(ReductionBitWidth);
17960+
ReductionBitWidth = bit_ceil(ReductionBitWidth);
17961+
}
1793717962
}
1793817963
bool IsTopRoot = NodeIdx == 0;
1793917964
while (NodeIdx < VectorizableTree.size() &&
@@ -19789,8 +19814,8 @@ class HorizontalReduction {
1978919814

1979019815
// Estimate cost.
1979119816
InstructionCost TreeCost = V.getTreeCost(VL);
19792-
InstructionCost ReductionCost =
19793-
getReductionCost(TTI, VL, IsCmpSelMinMax, ReduxWidth, RdxFMF);
19817+
InstructionCost ReductionCost = getReductionCost(
19818+
TTI, VL, IsCmpSelMinMax, RdxFMF, V.getReductionBitWidthAndSign());
1979419819
InstructionCost Cost = TreeCost + ReductionCost;
1979519820
LLVM_DEBUG(dbgs() << "SLP: Found cost = " << Cost
1979619821
<< " for reduction\n");
@@ -19895,10 +19920,12 @@ class HorizontalReduction {
1989519920
createStrideMask(I, ScalarTyNumElements, VL.size());
1989619921
Value *Lane = Builder.CreateShuffleVector(VectorizedRoot, Mask);
1989719922
ReducedSubTree = Builder.CreateInsertElement(
19898-
ReducedSubTree, emitReduction(Lane, Builder, TTI), I);
19923+
ReducedSubTree,
19924+
emitReduction(Lane, Builder, TTI, RdxRootInst->getType()), I);
1989919925
}
1990019926
} else {
19901-
ReducedSubTree = emitReduction(VectorizedRoot, Builder, TTI);
19927+
ReducedSubTree = emitReduction(VectorizedRoot, Builder, TTI,
19928+
RdxRootInst->getType());
1990219929
}
1990319930
if (ReducedSubTree->getType() != VL.front()->getType()) {
1990419931
assert(ReducedSubTree->getType() != VL.front()->getType() &&
@@ -20079,12 +20106,13 @@ class HorizontalReduction {
2007920106

2008020107
private:
2008120108
/// Calculate the cost of a reduction.
20082-
InstructionCost getReductionCost(TargetTransformInfo *TTI,
20083-
ArrayRef<Value *> ReducedVals,
20084-
bool IsCmpSelMinMax, unsigned ReduxWidth,
20085-
FastMathFlags FMF) {
20109+
InstructionCost getReductionCost(
20110+
TargetTransformInfo *TTI, ArrayRef<Value *> ReducedVals,
20111+
bool IsCmpSelMinMax, FastMathFlags FMF,
20112+
const std::optional<std::pair<unsigned, bool>> BitwidthAndSign) {
2008620113
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
2008720114
Type *ScalarTy = ReducedVals.front()->getType();
20115+
unsigned ReduxWidth = ReducedVals.size();
2008820116
FixedVectorType *VectorTy = getWidenedType(ScalarTy, ReduxWidth);
2008920117
InstructionCost VectorCost = 0, ScalarCost;
2009020118
// If all of the reduced values are constant, the vector cost is 0, since
@@ -20143,8 +20171,22 @@ class HorizontalReduction {
2014320171
VecTy, APInt::getAllOnes(ScalarTyNumElements), /*Insert*/ true,
2014420172
/*Extract*/ false, TTI::TCK_RecipThroughput);
2014520173
} else {
20146-
VectorCost = TTI->getArithmeticReductionCost(RdxOpcode, VectorTy, FMF,
20147-
CostKind);
20174+
auto [Bitwidth, IsSigned] =
20175+
BitwidthAndSign.value_or(std::make_pair(0u, false));
20176+
if (RdxKind == RecurKind::Add && Bitwidth == 1) {
20177+
// Represent vector_reduce_add(ZExt(<n x i1>)) to
20178+
// ZExtOrTrunc(ctpop(bitcast <n x i1> to in)).
20179+
auto *IntTy = IntegerType::get(ScalarTy->getContext(), ReduxWidth);
20180+
IntrinsicCostAttributes ICA(Intrinsic::ctpop, IntTy, {IntTy}, FMF);
20181+
VectorCost =
20182+
TTI->getCastInstrCost(Instruction::BitCast, IntTy,
20183+
getWidenedType(ScalarTy, ReduxWidth),
20184+
TTI::CastContextHint::None, CostKind) +
20185+
TTI->getIntrinsicInstrCost(ICA, CostKind);
20186+
} else {
20187+
VectorCost = TTI->getArithmeticReductionCost(RdxOpcode, VectorTy,
20188+
FMF, CostKind);
20189+
}
2014820190
}
2014920191
}
2015020192
ScalarCost = EvaluateScalarCost([&]() {
@@ -20181,11 +20223,22 @@ class HorizontalReduction {
2018120223

2018220224
/// Emit a horizontal reduction of the vectorized value.
2018320225
Value *emitReduction(Value *VectorizedValue, IRBuilderBase &Builder,
20184-
const TargetTransformInfo *TTI) {
20226+
const TargetTransformInfo *TTI, Type *DestTy) {
2018520227
assert(VectorizedValue && "Need to have a vectorized tree node");
2018620228
assert(RdxKind != RecurKind::FMulAdd &&
2018720229
"A call to the llvm.fmuladd intrinsic is not handled yet");
2018820230

20231+
auto *FTy = cast<FixedVectorType>(VectorizedValue->getType());
20232+
if (FTy->getScalarType() == Builder.getInt1Ty() &&
20233+
RdxKind == RecurKind::Add &&
20234+
DestTy->getScalarType() != FTy->getScalarType()) {
20235+
// Convert vector_reduce_add(ZExt(<n x i1>)) to
20236+
// ZExtOrTrunc(ctpop(bitcast <n x i1> to in)).
20237+
Value *V = Builder.CreateBitCast(
20238+
VectorizedValue, Builder.getIntNTy(FTy->getNumElements()));
20239+
++NumVectorInstructions;
20240+
return Builder.CreateUnaryIntrinsic(Intrinsic::ctpop, V);
20241+
}
2018920242
++NumVectorInstructions;
2019020243
return createSimpleReduction(Builder, VectorizedValue, RdxKind);
2019120244
}

llvm/test/Transforms/SLPVectorizer/X86/alternate-cmp-swapped-pred.ll

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@ define i16 @test(i16 %call37) {
1111
; CHECK-NEXT: [[TMP2:%.*]] = icmp slt <8 x i16> [[SHUFFLE]], zeroinitializer
1212
; CHECK-NEXT: [[TMP3:%.*]] = icmp sgt <8 x i16> [[SHUFFLE]], zeroinitializer
1313
; CHECK-NEXT: [[TMP4:%.*]] = shufflevector <8 x i1> [[TMP2]], <8 x i1> [[TMP3]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 12, i32 5, i32 6, i32 7>
14-
; CHECK-NEXT: [[TMP5:%.*]] = zext <8 x i1> [[TMP4]] to <8 x i16>
15-
; CHECK-NEXT: [[TMP6:%.*]] = call i16 @llvm.vector.reduce.add.v8i16(<8 x i16> [[TMP5]])
14+
; CHECK-NEXT: [[TMP8:%.*]] = bitcast <8 x i1> [[TMP4]] to i8
15+
; CHECK-NEXT: [[TMP7:%.*]] = call i8 @llvm.ctpop.i8(i8 [[TMP8]])
16+
; CHECK-NEXT: [[TMP6:%.*]] = zext i8 [[TMP7]] to i16
1617
; CHECK-NEXT: [[OP_RDX:%.*]] = add i16 [[TMP6]], 0
1718
; CHECK-NEXT: ret i16 [[OP_RDX]]
1819
;

llvm/test/Transforms/SLPVectorizer/zext-incoming-for-neg-icmp.ll

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@ define i32 @test(i32 %a, i8 %b, i8 %c) {
1414
; CHECK-NEXT: [[TMP8:%.*]] = zext <4 x i8> [[TMP2]] to <4 x i16>
1515
; CHECK-NEXT: [[TMP9:%.*]] = sext <4 x i8> [[TMP4]] to <4 x i16>
1616
; CHECK-NEXT: [[TMP5:%.*]] = icmp sle <4 x i16> [[TMP8]], [[TMP9]]
17-
; CHECK-NEXT: [[TMP6:%.*]] = zext <4 x i1> [[TMP5]] to <4 x i32>
18-
; CHECK-NEXT: [[TMP7:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP6]])
17+
; CHECK-NEXT: [[TMP10:%.*]] = bitcast <4 x i1> [[TMP5]] to i4
18+
; CHECK-NEXT: [[TMP11:%.*]] = call i4 @llvm.ctpop.i4(i4 [[TMP10]])
19+
; CHECK-NEXT: [[TMP7:%.*]] = zext i4 [[TMP11]] to i32
1920
; CHECK-NEXT: [[OP_RDX:%.*]] = add i32 [[TMP7]], [[A]]
2021
; CHECK-NEXT: ret i32 [[OP_RDX]]
2122
;

0 commit comments

Comments
 (0)