Skip to content

[SLP]Model reduction_add(ext(<n x i1>)) as ext(ctpop(bitcast <n x i1> to int n)) #116875

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 80 additions & 27 deletions llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1371,6 +1371,18 @@ class BoUpSLP {
return MinBWs.at(VectorizableTree.front().get()).second;
}

/// Returns reduction bitwidth and signedness, if it does not match the
/// original requested size.
std::optional<std::pair<unsigned, bool>> getReductionBitWidthAndSign() const {
if (ReductionBitWidth == 0 ||
ReductionBitWidth ==
DL->getTypeSizeInBits(
VectorizableTree.front()->Scalars.front()->getType()))
return std::nullopt;
return std::make_pair(ReductionBitWidth,
MinBWs.at(VectorizableTree.front().get()).second);
}

/// Builds external uses of the vectorized scalars, i.e. the list of
/// vectorized scalars to be extracted, their lanes and their scalar users. \p
/// ExternallyUsedValues contains additional list of external uses to handle
Expand Down Expand Up @@ -17887,24 +17899,37 @@ void BoUpSLP::computeMinimumValueSizes() {
// Add reduction ops sizes, if any.
if (UserIgnoreList &&
isa<IntegerType>(VectorizableTree.front()->Scalars.front()->getType())) {
for (Value *V : *UserIgnoreList) {
auto NumSignBits = ComputeNumSignBits(V, *DL, 0, AC, nullptr, DT);
auto NumTypeBits = DL->getTypeSizeInBits(V->getType());
unsigned BitWidth1 = NumTypeBits - NumSignBits;
if (!isKnownNonNegative(V, SimplifyQuery(*DL)))
++BitWidth1;
unsigned BitWidth2 = BitWidth1;
if (!RecurrenceDescriptor::isIntMinMaxRecurrenceKind(::getRdxKind(V))) {
auto Mask = DB->getDemandedBits(cast<Instruction>(V));
BitWidth2 = Mask.getBitWidth() - Mask.countl_zero();
// Convert vector_reduce_add(ZExt(<n x i1>)) to ZExtOrTrunc(ctpop(bitcast <n
// x i1> to in)).
if (all_of(*UserIgnoreList,
[](Value *V) {
return cast<Instruction>(V)->getOpcode() == Instruction::Add;
}) &&
VectorizableTree.front()->State == TreeEntry::Vectorize &&
VectorizableTree.front()->getOpcode() == Instruction::ZExt &&
cast<CastInst>(VectorizableTree.front()->getMainOp())->getSrcTy() ==
Builder.getInt1Ty()) {
ReductionBitWidth = 1;
} else {
for (Value *V : *UserIgnoreList) {
unsigned NumSignBits = ComputeNumSignBits(V, *DL, 0, AC, nullptr, DT);
TypeSize NumTypeBits = DL->getTypeSizeInBits(V->getType());
unsigned BitWidth1 = NumTypeBits - NumSignBits;
if (!isKnownNonNegative(V, SimplifyQuery(*DL)))
++BitWidth1;
unsigned BitWidth2 = BitWidth1;
if (!RecurrenceDescriptor::isIntMinMaxRecurrenceKind(::getRdxKind(V))) {
APInt Mask = DB->getDemandedBits(cast<Instruction>(V));
BitWidth2 = Mask.getBitWidth() - Mask.countl_zero();
}
ReductionBitWidth =
std::max(std::min(BitWidth1, BitWidth2), ReductionBitWidth);
}
ReductionBitWidth =
std::max(std::min(BitWidth1, BitWidth2), ReductionBitWidth);
}
if (ReductionBitWidth < 8 && ReductionBitWidth > 1)
ReductionBitWidth = 8;
if (ReductionBitWidth < 8 && ReductionBitWidth > 1)
ReductionBitWidth = 8;

ReductionBitWidth = bit_ceil(ReductionBitWidth);
ReductionBitWidth = bit_ceil(ReductionBitWidth);
}
}
bool IsTopRoot = NodeIdx == 0;
while (NodeIdx < VectorizableTree.size() &&
Expand Down Expand Up @@ -19760,8 +19785,8 @@ class HorizontalReduction {

// Estimate cost.
InstructionCost TreeCost = V.getTreeCost(VL);
InstructionCost ReductionCost =
getReductionCost(TTI, VL, IsCmpSelMinMax, ReduxWidth, RdxFMF);
InstructionCost ReductionCost = getReductionCost(
TTI, VL, IsCmpSelMinMax, RdxFMF, V.getReductionBitWidthAndSign());
InstructionCost Cost = TreeCost + ReductionCost;
LLVM_DEBUG(dbgs() << "SLP: Found cost = " << Cost
<< " for reduction\n");
Expand Down Expand Up @@ -19866,10 +19891,12 @@ class HorizontalReduction {
createStrideMask(I, ScalarTyNumElements, VL.size());
Value *Lane = Builder.CreateShuffleVector(VectorizedRoot, Mask);
ReducedSubTree = Builder.CreateInsertElement(
ReducedSubTree, emitReduction(Lane, Builder, TTI), I);
ReducedSubTree,
emitReduction(Lane, Builder, TTI, RdxRootInst->getType()), I);
}
} else {
ReducedSubTree = emitReduction(VectorizedRoot, Builder, TTI);
ReducedSubTree = emitReduction(VectorizedRoot, Builder, TTI,
RdxRootInst->getType());
}
if (ReducedSubTree->getType() != VL.front()->getType()) {
assert(ReducedSubTree->getType() != VL.front()->getType() &&
Expand Down Expand Up @@ -20050,12 +20077,13 @@ class HorizontalReduction {

private:
/// Calculate the cost of a reduction.
InstructionCost getReductionCost(TargetTransformInfo *TTI,
ArrayRef<Value *> ReducedVals,
bool IsCmpSelMinMax, unsigned ReduxWidth,
FastMathFlags FMF) {
InstructionCost getReductionCost(
TargetTransformInfo *TTI, ArrayRef<Value *> ReducedVals,
bool IsCmpSelMinMax, FastMathFlags FMF,
const std::optional<std::pair<unsigned, bool>> BitwidthAndSign) {
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
Type *ScalarTy = ReducedVals.front()->getType();
unsigned ReduxWidth = ReducedVals.size();
FixedVectorType *VectorTy = getWidenedType(ScalarTy, ReduxWidth);
InstructionCost VectorCost = 0, ScalarCost;
// If all of the reduced values are constant, the vector cost is 0, since
Expand Down Expand Up @@ -20114,8 +20142,22 @@ class HorizontalReduction {
VecTy, APInt::getAllOnes(ScalarTyNumElements), /*Insert*/ true,
/*Extract*/ false, TTI::TCK_RecipThroughput);
} else {
VectorCost = TTI->getArithmeticReductionCost(RdxOpcode, VectorTy, FMF,
CostKind);
auto [Bitwidth, IsSigned] =
BitwidthAndSign.value_or(std::make_pair(0u, false));
if (RdxKind == RecurKind::Add && Bitwidth == 1) {
// Represent vector_reduce_add(ZExt(<n x i1>)) to
// ZExtOrTrunc(ctpop(bitcast <n x i1> to in)).
auto *IntTy = IntegerType::get(ScalarTy->getContext(), ReduxWidth);
IntrinsicCostAttributes ICA(Intrinsic::ctpop, IntTy, {IntTy}, FMF);
VectorCost =
TTI->getCastInstrCost(Instruction::BitCast, IntTy,
getWidenedType(ScalarTy, ReduxWidth),
TTI::CastContextHint::None, CostKind) +
TTI->getIntrinsicInstrCost(ICA, CostKind);
} else {
VectorCost = TTI->getArithmeticReductionCost(RdxOpcode, VectorTy,
FMF, CostKind);
}
}
}
ScalarCost = EvaluateScalarCost([&]() {
Expand Down Expand Up @@ -20152,11 +20194,22 @@ class HorizontalReduction {

/// Emit a horizontal reduction of the vectorized value.
Value *emitReduction(Value *VectorizedValue, IRBuilderBase &Builder,
const TargetTransformInfo *TTI) {
const TargetTransformInfo *TTI, Type *DestTy) {
assert(VectorizedValue && "Need to have a vectorized tree node");
assert(RdxKind != RecurKind::FMulAdd &&
"A call to the llvm.fmuladd intrinsic is not handled yet");

auto *FTy = cast<FixedVectorType>(VectorizedValue->getType());
if (FTy->getScalarType() == Builder.getInt1Ty() &&
RdxKind == RecurKind::Add &&
DestTy->getScalarType() != FTy->getScalarType()) {
// Convert vector_reduce_add(ZExt(<n x i1>)) to
// ZExtOrTrunc(ctpop(bitcast <n x i1> to in)).
Value *V = Builder.CreateBitCast(
VectorizedValue, Builder.getIntNTy(FTy->getNumElements()));
++NumVectorInstructions;
return Builder.CreateUnaryIntrinsic(Intrinsic::ctpop, V);
}
++NumVectorInstructions;
return createSimpleReduction(Builder, VectorizedValue, RdxKind);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@ define i16 @test(i16 %call37) {
; CHECK-NEXT: [[TMP2:%.*]] = icmp slt <8 x i16> [[SHUFFLE]], zeroinitializer
; CHECK-NEXT: [[TMP3:%.*]] = icmp sgt <8 x i16> [[SHUFFLE]], zeroinitializer
; CHECK-NEXT: [[TMP4:%.*]] = shufflevector <8 x i1> [[TMP2]], <8 x i1> [[TMP3]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 12, i32 5, i32 6, i32 7>
; CHECK-NEXT: [[TMP5:%.*]] = zext <8 x i1> [[TMP4]] to <8 x i16>
; CHECK-NEXT: [[TMP6:%.*]] = call i16 @llvm.vector.reduce.add.v8i16(<8 x i16> [[TMP5]])
; CHECK-NEXT: [[TMP8:%.*]] = bitcast <8 x i1> [[TMP4]] to i8
; CHECK-NEXT: [[TMP7:%.*]] = call i8 @llvm.ctpop.i8(i8 [[TMP8]])
; CHECK-NEXT: [[TMP6:%.*]] = zext i8 [[TMP7]] to i16
; CHECK-NEXT: [[OP_RDX:%.*]] = add i16 [[TMP6]], 0
; CHECK-NEXT: ret i16 [[OP_RDX]]
;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@ define i32 @test(i32 %a, i8 %b, i8 %c) {
; CHECK-NEXT: [[TMP8:%.*]] = zext <4 x i8> [[TMP2]] to <4 x i16>
; CHECK-NEXT: [[TMP9:%.*]] = sext <4 x i8> [[TMP4]] to <4 x i16>
; CHECK-NEXT: [[TMP5:%.*]] = icmp sle <4 x i16> [[TMP8]], [[TMP9]]
; CHECK-NEXT: [[TMP6:%.*]] = zext <4 x i1> [[TMP5]] to <4 x i32>
; CHECK-NEXT: [[TMP7:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP6]])
; CHECK-NEXT: [[TMP10:%.*]] = bitcast <4 x i1> [[TMP5]] to i4
; CHECK-NEXT: [[TMP11:%.*]] = call i4 @llvm.ctpop.i4(i4 [[TMP10]])
; CHECK-NEXT: [[TMP7:%.*]] = zext i4 [[TMP11]] to i32
; CHECK-NEXT: [[OP_RDX:%.*]] = add i32 [[TMP7]], [[A]]
; CHECK-NEXT: ret i32 [[OP_RDX]]
;
Expand Down
Loading