Skip to content

[SLP]Fix perfect diamond match with extractelements in scalars #132466

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 52 additions & 14 deletions llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5310,12 +5310,11 @@ getShuffleCost(const TargetTransformInfo &TTI, TTI::ShuffleKind Kind,
/// This is similar to TargetTransformInfo::getScalarizationOverhead, but if
/// ScalarTy is a FixedVectorType, a vector will be inserted or extracted
/// instead of a scalar.
static InstructionCost getScalarizationOverhead(const TargetTransformInfo &TTI,
Type *ScalarTy, VectorType *Ty,
const APInt &DemandedElts,
bool Insert, bool Extract,
TTI::TargetCostKind CostKind,
ArrayRef<Value *> VL = {}) {
static InstructionCost
getScalarizationOverhead(const TargetTransformInfo &TTI, Type *ScalarTy,
VectorType *Ty, const APInt &DemandedElts, bool Insert,
bool Extract, TTI::TargetCostKind CostKind,
bool ForPoisonSrc = true, ArrayRef<Value *> VL = {}) {
assert(!isa<ScalableVectorType>(Ty) &&
"ScalableVectorType is not supported.");
assert(getNumElements(ScalarTy) * DemandedElts.getBitWidth() ==
Expand All @@ -5339,8 +5338,26 @@ static InstructionCost getScalarizationOverhead(const TargetTransformInfo &TTI,
}
return Cost;
}
return TTI.getScalarizationOverhead(Ty, DemandedElts, Insert, Extract,
CostKind, VL);
APInt NewDemandedElts = DemandedElts;
InstructionCost Cost = 0;
if (!ForPoisonSrc && Insert) {
// Handle insert into non-poison vector.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not certain this is entirely correct, afaict I can tell you are trying to cost inserting a number of elements into an existing (non-poison) vector, while the TTI::getScalarizationOverhead assumes Insert is for a ISD::BUILD_VECTOR style pattern. Would a SK_Select shuffle not be closer to what you're after? The isZero checks below suggest you've had to write this mainly for the single element diamond case, so maybe that need to be handled by a DemandedElts.isOneBitSet() special case instead?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use a conservative solution for now. Currently, codegen just generates insertelement for each unique scalar here, so I just replicated this for non-poison input value.
Need to teach ScalarizationOverhead of the non-poison input vector. Also, not sure the current generic implementation is fully correct. It currently supposes that all insertvector instructions are inserted into poison vector, but it is true only for the first instruction.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK - add a TODO comment for now please

// TODO: Need to teach getScalarizationOverhead about insert elements into
// non-poison input vector to better handle such cases. Currently, it is
// very conservative and may "pessimize" the vectorization.
for (unsigned I : seq(DemandedElts.getBitWidth())) {
if (!DemandedElts[I])
continue;
Cost += TTI.getVectorInstrCost(Instruction::InsertElement, Ty, CostKind,
I, Constant::getNullValue(Ty),
VL.empty() ? nullptr : VL[I]);
}
NewDemandedElts.clearAllBits();
} else if (!NewDemandedElts.isZero()) {
Cost += TTI.getScalarizationOverhead(Ty, NewDemandedElts, Insert, Extract,
CostKind, VL);
}
return Cost;
}

/// Correctly creates insert_subvector, checking that the index is multiple of
Expand Down Expand Up @@ -11684,6 +11701,15 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
// No need to delay the cost estimation during analysis.
return std::nullopt;
}
/// Reset the builder to handle perfect diamond match.
void resetForSameNode() {
IsFinalized = false;
CommonMask.clear();
InVectors.clear();
Cost = 0;
VectorizedVals.clear();
SameNodesEstimated = true;
}
void add(const TreeEntry &E1, const TreeEntry &E2, ArrayRef<int> Mask) {
if (&E1 == &E2) {
assert(all_of(Mask,
Expand Down Expand Up @@ -14890,15 +14916,18 @@ InstructionCost BoUpSLP::getGatherCost(ArrayRef<Value *> VL, bool ForPoisonSrc,
ShuffledElements.setBit(I);
ShuffleMask[I] = Res.first->second;
}
if (!DemandedElements.isZero())
Cost += getScalarizationOverhead(*TTI, ScalarTy, VecTy, DemandedElements,
/*Insert=*/true,
/*Extract=*/false, CostKind, VL);
if (ForPoisonSrc)
if (ForPoisonSrc) {
Cost = getScalarizationOverhead(*TTI, ScalarTy, VecTy,
/*DemandedElts*/ ~ShuffledElements,
/*Insert*/ true,
/*Extract*/ false, CostKind, VL);
/*Extract*/ false, CostKind,
/*ForPoisonSrc=*/true, VL);
} else if (!DemandedElements.isZero()) {
Cost += getScalarizationOverhead(*TTI, ScalarTy, VecTy, DemandedElements,
/*Insert=*/true,
/*Extract=*/false, CostKind,
/*ForPoisonSrc=*/false, VL);
}
if (DuplicateNonConst)
Cost += ::getShuffleCost(*TTI, TargetTransformInfo::SK_PermuteSingleSrc,
VecTy, ShuffleMask);
Expand Down Expand Up @@ -15556,6 +15585,12 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis {
PoisonValue::get(PointerType::getUnqual(ScalarTy->getContext())),
MaybeAlign());
}
/// Reset the builder to handle perfect diamond match.
void resetForSameNode() {
IsFinalized = false;
CommonMask.clear();
InVectors.clear();
}
/// Adds 2 input vectors (in form of tree entries) and the mask for their
/// shuffling.
void add(const TreeEntry &E1, const TreeEntry &E2, ArrayRef<int> Mask) {
Expand Down Expand Up @@ -16111,6 +16146,9 @@ ResTy BoUpSLP::processBuildVector(const TreeEntry *E, Type *ScalarTy,
Mask[I] = FrontTE->findLaneForValue(V);
}
}
// Reset the builder(s) to correctly handle perfect diamond matched
// nodes.
ShuffleBuilder.resetForSameNode();
ShuffleBuilder.add(*FrontTE, Mask);
// Full matched entry found, no need to insert subvectors.
Res = ShuffleBuilder.finalize(E->getCommonMask(), {}, {});
Expand Down
21 changes: 9 additions & 12 deletions llvm/test/Transforms/SLPVectorizer/X86/buildvector-with-reuses.ll
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,15 @@ define <4 x double> @test(ptr %ia, ptr %ib, ptr %ic, ptr %id, ptr %ie, ptr %x) {
; CHECK-NEXT: [[I4275:%.*]] = load double, ptr [[ID]], align 8
; CHECK-NEXT: [[I4277:%.*]] = load double, ptr [[IE]], align 8
; CHECK-NEXT: [[I4326:%.*]] = load <4 x double>, ptr [[X]], align 8
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <4 x double> [[I4326]], <4 x double> poison, <2 x i32> zeroinitializer
; CHECK-NEXT: [[TMP2:%.*]] = insertelement <2 x double> poison, double [[I4238]], i32 0
; CHECK-NEXT: [[TMP3:%.*]] = insertelement <2 x double> [[TMP2]], double [[I4252]], i32 1
; CHECK-NEXT: [[TMP4:%.*]] = fmul fast <2 x double> [[TMP1]], [[TMP3]]
; CHECK-NEXT: [[TMP5:%.*]] = insertelement <2 x double> [[TMP1]], double [[I4275]], i32 1
; CHECK-NEXT: [[TMP6:%.*]] = insertelement <2 x double> poison, double [[I4264]], i32 0
; CHECK-NEXT: [[TMP7:%.*]] = insertelement <2 x double> [[TMP6]], double [[I4277]], i32 1
; CHECK-NEXT: [[TMP8:%.*]] = fmul fast <2 x double> [[TMP5]], [[TMP7]]
; CHECK-NEXT: [[TMP9:%.*]] = shufflevector <2 x double> [[TMP4]], <2 x double> poison, <4 x i32> <i32 0, i32 1, i32 poison, i32 poison>
; CHECK-NEXT: [[TMP10:%.*]] = shufflevector <2 x double> [[TMP8]], <2 x double> poison, <4 x i32> <i32 0, i32 1, i32 poison, i32 poison>
; CHECK-NEXT: [[I44281:%.*]] = shufflevector <4 x double> [[TMP9]], <4 x double> [[TMP10]], <4 x i32> <i32 0, i32 1, i32 4, i32 5>
; CHECK-NEXT: ret <4 x double> [[I44281]]
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <4 x double> [[I4326]], <4 x double> poison, <2 x i32> <i32 0, i32 poison>
; CHECK-NEXT: [[TMP2:%.*]] = insertelement <2 x double> [[TMP1]], double [[I4275]], i32 1
; CHECK-NEXT: [[TMP3:%.*]] = shufflevector <2 x double> [[TMP2]], <2 x double> poison, <4 x i32> <i32 0, i32 0, i32 0, i32 1>
; CHECK-NEXT: [[TMP4:%.*]] = insertelement <4 x double> poison, double [[I4238]], i32 0
; CHECK-NEXT: [[TMP5:%.*]] = insertelement <4 x double> [[TMP4]], double [[I4252]], i32 1
; CHECK-NEXT: [[TMP6:%.*]] = insertelement <4 x double> [[TMP5]], double [[I4264]], i32 2
; CHECK-NEXT: [[TMP7:%.*]] = insertelement <4 x double> [[TMP6]], double [[I4277]], i32 3
; CHECK-NEXT: [[TMP8:%.*]] = fmul fast <4 x double> [[TMP3]], [[TMP7]]
; CHECK-NEXT: ret <4 x double> [[TMP8]]
;
%i4238 = load double, ptr %ia, align 8
%i4252 = load double, ptr %ib, align 8
Expand Down
44 changes: 8 additions & 36 deletions llvm/test/Transforms/SLPVectorizer/X86/reduction-transpose.ll
Original file line number Diff line number Diff line change
Expand Up @@ -49,24 +49,10 @@ define i32 @reduce_and4(i32 %acc, <4 x i32> %v1, <4 x i32> %v2, <4 x i32> %v3, <
;
; AVX512-LABEL: @reduce_and4(
; AVX512-NEXT: entry:
; AVX512-NEXT: [[VECEXT:%.*]] = extractelement <4 x i32> [[V1:%.*]], i64 0
; AVX512-NEXT: [[VECEXT1:%.*]] = extractelement <4 x i32> [[V1]], i64 1
; AVX512-NEXT: [[VECEXT2:%.*]] = extractelement <4 x i32> [[V1]], i64 2
; AVX512-NEXT: [[VECEXT4:%.*]] = extractelement <4 x i32> [[V1]], i64 3
; AVX512-NEXT: [[VECEXT7:%.*]] = extractelement <4 x i32> [[V2:%.*]], i64 0
; AVX512-NEXT: [[VECEXT8:%.*]] = extractelement <4 x i32> [[V2]], i64 1
; AVX512-NEXT: [[VECEXT10:%.*]] = extractelement <4 x i32> [[V2]], i64 2
; AVX512-NEXT: [[VECEXT12:%.*]] = extractelement <4 x i32> [[V2]], i64 3
; AVX512-NEXT: [[TMP0:%.*]] = shufflevector <4 x i32> [[V4:%.*]], <4 x i32> [[V3:%.*]], <16 x i32> <i32 1, i32 0, i32 2, i32 3, i32 5, i32 4, i32 6, i32 7, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
; AVX512-NEXT: [[TMP1:%.*]] = insertelement <16 x i32> [[TMP0]], i32 [[VECEXT8]], i32 8
; AVX512-NEXT: [[TMP2:%.*]] = insertelement <16 x i32> [[TMP1]], i32 [[VECEXT7]], i32 9
; AVX512-NEXT: [[TMP3:%.*]] = insertelement <16 x i32> [[TMP2]], i32 [[VECEXT10]], i32 10
; AVX512-NEXT: [[TMP4:%.*]] = insertelement <16 x i32> [[TMP3]], i32 [[VECEXT12]], i32 11
; AVX512-NEXT: [[TMP5:%.*]] = insertelement <16 x i32> [[TMP4]], i32 [[VECEXT1]], i32 12
; AVX512-NEXT: [[TMP6:%.*]] = insertelement <16 x i32> [[TMP5]], i32 [[VECEXT]], i32 13
; AVX512-NEXT: [[TMP7:%.*]] = insertelement <16 x i32> [[TMP6]], i32 [[VECEXT2]], i32 14
; AVX512-NEXT: [[TMP8:%.*]] = insertelement <16 x i32> [[TMP7]], i32 [[VECEXT4]], i32 15
; AVX512-NEXT: [[OP_RDX:%.*]] = call i32 @llvm.vector.reduce.and.v16i32(<16 x i32> [[TMP8]])
; AVX512-NEXT: [[TMP0:%.*]] = shufflevector <4 x i32> [[V4:%.*]], <4 x i32> [[V3:%.*]], <8 x i32> <i32 1, i32 0, i32 2, i32 3, i32 5, i32 4, i32 6, i32 7>
; AVX512-NEXT: [[TMP1:%.*]] = shufflevector <4 x i32> [[V2:%.*]], <4 x i32> [[V1:%.*]], <8 x i32> <i32 1, i32 0, i32 2, i32 3, i32 5, i32 4, i32 6, i32 7>
; AVX512-NEXT: [[RDX_OP:%.*]] = and <8 x i32> [[TMP0]], [[TMP1]]
; AVX512-NEXT: [[OP_RDX:%.*]] = call i32 @llvm.vector.reduce.and.v8i32(<8 x i32> [[RDX_OP]])
; AVX512-NEXT: [[OP_RDX1:%.*]] = and i32 [[OP_RDX]], [[ACC:%.*]]
; AVX512-NEXT: ret i32 [[OP_RDX1]]
;
Expand Down Expand Up @@ -144,24 +130,10 @@ define i32 @reduce_and4_transpose(i32 %acc, <4 x i32> %v1, <4 x i32> %v2, <4 x i
; AVX2-NEXT: ret i32 [[OP_RDX]]
;
; AVX512-LABEL: @reduce_and4_transpose(
; AVX512-NEXT: [[VECEXT:%.*]] = extractelement <4 x i32> [[V1:%.*]], i64 0
; AVX512-NEXT: [[VECEXT1:%.*]] = extractelement <4 x i32> [[V2:%.*]], i64 0
; AVX512-NEXT: [[VECEXT7:%.*]] = extractelement <4 x i32> [[V1]], i64 1
; AVX512-NEXT: [[VECEXT8:%.*]] = extractelement <4 x i32> [[V2]], i64 1
; AVX512-NEXT: [[VECEXT15:%.*]] = extractelement <4 x i32> [[V1]], i64 2
; AVX512-NEXT: [[VECEXT16:%.*]] = extractelement <4 x i32> [[V2]], i64 2
; AVX512-NEXT: [[VECEXT23:%.*]] = extractelement <4 x i32> [[V1]], i64 3
; AVX512-NEXT: [[VECEXT24:%.*]] = extractelement <4 x i32> [[V2]], i64 3
; AVX512-NEXT: [[TMP1:%.*]] = shufflevector <4 x i32> [[V4:%.*]], <4 x i32> [[V3:%.*]], <16 x i32> <i32 3, i32 2, i32 1, i32 0, i32 7, i32 6, i32 5, i32 4, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
; AVX512-NEXT: [[TMP2:%.*]] = insertelement <16 x i32> [[TMP1]], i32 [[VECEXT24]], i32 8
; AVX512-NEXT: [[TMP3:%.*]] = insertelement <16 x i32> [[TMP2]], i32 [[VECEXT16]], i32 9
; AVX512-NEXT: [[TMP4:%.*]] = insertelement <16 x i32> [[TMP3]], i32 [[VECEXT8]], i32 10
; AVX512-NEXT: [[TMP5:%.*]] = insertelement <16 x i32> [[TMP4]], i32 [[VECEXT1]], i32 11
; AVX512-NEXT: [[TMP6:%.*]] = insertelement <16 x i32> [[TMP5]], i32 [[VECEXT23]], i32 12
; AVX512-NEXT: [[TMP7:%.*]] = insertelement <16 x i32> [[TMP6]], i32 [[VECEXT15]], i32 13
; AVX512-NEXT: [[TMP8:%.*]] = insertelement <16 x i32> [[TMP7]], i32 [[VECEXT7]], i32 14
; AVX512-NEXT: [[TMP9:%.*]] = insertelement <16 x i32> [[TMP8]], i32 [[VECEXT]], i32 15
; AVX512-NEXT: [[OP_RDX:%.*]] = call i32 @llvm.vector.reduce.and.v16i32(<16 x i32> [[TMP9]])
; AVX512-NEXT: [[TMP1:%.*]] = shufflevector <4 x i32> [[V4:%.*]], <4 x i32> [[V3:%.*]], <8 x i32> <i32 3, i32 2, i32 1, i32 0, i32 7, i32 6, i32 5, i32 4>
; AVX512-NEXT: [[TMP2:%.*]] = shufflevector <4 x i32> [[V2:%.*]], <4 x i32> [[V1:%.*]], <8 x i32> <i32 3, i32 2, i32 1, i32 0, i32 7, i32 6, i32 5, i32 4>
; AVX512-NEXT: [[RDX_OP:%.*]] = and <8 x i32> [[TMP1]], [[TMP2]]
; AVX512-NEXT: [[OP_RDX:%.*]] = call i32 @llvm.vector.reduce.and.v8i32(<8 x i32> [[RDX_OP]])
; AVX512-NEXT: [[OP_RDX1:%.*]] = and i32 [[OP_RDX]], [[ACC:%.*]]
; AVX512-NEXT: ret i32 [[OP_RDX1]]
;
Expand Down
Loading