Skip to content

[SLP][REVEC] Make MinBWs support vector instructions. #103049

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 19 additions & 7 deletions llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9527,8 +9527,12 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
// that the costs will be accurate.
auto It = MinBWs.find(E);
Type *OrigScalarTy = ScalarTy;
if (It != MinBWs.end())
if (It != MinBWs.end()) {
auto VecTy = dyn_cast<FixedVectorType>(ScalarTy);
ScalarTy = IntegerType::get(F->getContext(), It->second.first);
if (VecTy)
ScalarTy = getWidenedType(ScalarTy, VecTy->getNumElements());
}
auto *VecTy = getWidenedType(ScalarTy, VL.size());
unsigned EntryVF = E->getVectorFactor();
auto *FinalVecTy = getWidenedType(ScalarTy, EntryVF);
Expand Down Expand Up @@ -13127,8 +13131,12 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
else if (auto *IE = dyn_cast<InsertElementInst>(V))
ScalarTy = IE->getOperand(1)->getType();
auto It = MinBWs.find(E);
if (It != MinBWs.end())
if (It != MinBWs.end()) {
auto VecTy = dyn_cast<FixedVectorType>(ScalarTy);
ScalarTy = IntegerType::get(F->getContext(), It->second.first);
if (VecTy)
ScalarTy = getWidenedType(ScalarTy, VecTy->getNumElements());
}
auto *VecTy = getWidenedType(ScalarTy, E->Scalars.size());
if (E->isGather()) {
// Set insert point for non-reduction initial nodes.
Expand Down Expand Up @@ -16003,16 +16011,18 @@ void BoUpSLP::computeMinimumValueSizes() {
}

unsigned VF = E.getVectorFactor();
auto *TreeRootIT =
dyn_cast<IntegerType>(E.Scalars.front()->getType()->getScalarType());
Type *ScalarTy = E.Scalars.front()->getType();
unsigned ScalarTyNumElements = getNumElements(ScalarTy);
auto *TreeRootIT = dyn_cast<IntegerType>(ScalarTy->getScalarType());
if (!TreeRootIT || !Opcode)
return 0u;

if (any_of(E.Scalars,
[&](Value *V) { return AnalyzedMinBWVals.contains(V); }))
return 0u;

unsigned NumParts = TTI->getNumberOfParts(getWidenedType(TreeRootIT, VF));
unsigned NumParts = TTI->getNumberOfParts(
getWidenedType(TreeRootIT, VF * ScalarTyNumElements));

// The maximum bit width required to represent all the values that can be
// demoted without loss of precision. It would be safe to truncate the roots
Expand All @@ -16034,7 +16044,8 @@ void BoUpSLP::computeMinimumValueSizes() {
// we can truncate the roots to this narrower type.
for (Value *Root : E.Scalars) {
unsigned NumSignBits = ComputeNumSignBits(Root, *DL, 0, AC, nullptr, DT);
TypeSize NumTypeBits = DL->getTypeSizeInBits(Root->getType());
TypeSize NumTypeBits =
DL->getTypeSizeInBits(Root->getType()->getScalarType());
unsigned BitWidth1 = NumTypeBits - NumSignBits;
// If we can't prove that the sign bit is zero, we must add one to the
// maximum bit width to account for the unknown sign bit. This preserves
Expand Down Expand Up @@ -16206,7 +16217,8 @@ void BoUpSLP::computeMinimumValueSizes() {
// type, we can proceed with the narrowing. Otherwise, do nothing.
if (MaxBitWidth == 0 ||
MaxBitWidth >=
cast<IntegerType>(TreeRoot.front()->getType())->getBitWidth()) {
cast<IntegerType>(TreeRoot.front()->getType()->getScalarType())
->getBitWidth()) {
if (UserIgnoreList)
AnalyzedMinBWVals.insert(TreeRoot.begin(), TreeRoot.end());
continue;
Expand Down
24 changes: 24 additions & 0 deletions llvm/test/Transforms/SLPVectorizer/revec.ll
Original file line number Diff line number Diff line change
Expand Up @@ -272,3 +272,27 @@ for.body:
%5 = phi <2 x float> [ %5, %for.body ], [ zeroinitializer, %entry ]
br i1 false, label %for0, label %for.body
}

define void @test9() {
; CHECK-LABEL: @test9(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[TMP0:%.*]] = call <8 x i16> @llvm.vector.insert.v8i16.v4i16(<8 x i16> poison, <4 x i16> zeroinitializer, i64 0)
; CHECK-NEXT: [[TMP1:%.*]] = call <8 x i16> @llvm.vector.insert.v8i16.v4i16(<8 x i16> [[TMP0]], <4 x i16> zeroinitializer, i64 4)
; CHECK-NEXT: br label [[FOR_BODY13:%.*]]
; CHECK: for.body13:
; CHECK-NEXT: [[TMP2:%.*]] = trunc <8 x i16> [[TMP1]] to <8 x i1>
; CHECK-NEXT: [[TMP3:%.*]] = zext <8 x i1> [[TMP2]] to <8 x i32>
; CHECK-NEXT: store <8 x i32> [[TMP3]], ptr null, align 4
; CHECK-NEXT: br label [[FOR_BODY13]]
;
entry:
br label %for.body13

for.body13: ; preds = %for.body13, %entry
%vmovl.i111 = sext <4 x i16> zeroinitializer to <4 x i32>
%vmovl.i110 = sext <4 x i16> zeroinitializer to <4 x i32>
store <4 x i32> %vmovl.i111, ptr null, align 4
%add.ptr29 = getelementptr i8, ptr null, i64 16
store <4 x i32> %vmovl.i110, ptr %add.ptr29, align 4
br label %for.body13
}
Loading