Skip to content

[SLP] no need to generate extract for in-tree uses for original scala… #76077

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Dec 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 33 additions & 54 deletions llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4925,36 +4925,34 @@ void BoUpSLP::buildExternalUses(
LLVM_DEBUG(dbgs() << "SLP: Checking user:" << *U << ".\n");

Instruction *UserInst = dyn_cast<Instruction>(U);
if (!UserInst)
if (!UserInst || isDeleted(UserInst))
continue;

if (isDeleted(UserInst))
// Ignore users in the user ignore list.
if (UserIgnoreList && UserIgnoreList->contains(UserInst))
continue;

// Skip in-tree scalars that become vectors
if (TreeEntry *UseEntry = getTreeEntry(U)) {
Value *UseScalar = UseEntry->Scalars[0];
// Some in-tree scalars will remain as scalar in vectorized
// instructions. If that is the case, the one in Lane 0 will
// instructions. If that is the case, the one in FoundLane will
// be used.
if (UseScalar != U ||
UseEntry->State == TreeEntry::ScatterVectorize ||
if (UseEntry->State == TreeEntry::ScatterVectorize ||
UseEntry->State == TreeEntry::PossibleStridedVectorize ||
!doesInTreeUserNeedToExtract(Scalar, UserInst, TLI)) {
!doesInTreeUserNeedToExtract(
Scalar, cast<Instruction>(UseEntry->Scalars.front()), TLI)) {
LLVM_DEBUG(dbgs() << "SLP: \tInternal user will be removed:" << *U
<< ".\n");
assert(UseEntry->State != TreeEntry::NeedToGather && "Bad state");
continue;
}
U = nullptr;
}

// Ignore users in the user ignore list.
if (UserIgnoreList && UserIgnoreList->contains(UserInst))
continue;

LLVM_DEBUG(dbgs() << "SLP: Need to extract:" << *U << " from lane "
<< Lane << " from " << *Scalar << ".\n");
ExternalUses.push_back(ExternalUser(Scalar, U, FoundLane));
LLVM_DEBUG(dbgs() << "SLP: Need to extract:" << *UserInst
<< " from lane " << Lane << " from " << *Scalar
<< ".\n");
ExternalUses.emplace_back(Scalar, U, FoundLane);
}
}
}
Expand Down Expand Up @@ -11493,17 +11491,6 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
Value *PO = LI->getPointerOperand();
if (E->State == TreeEntry::Vectorize) {
NewLI = Builder.CreateAlignedLoad(VecTy, PO, LI->getAlign());

// The pointer operand uses an in-tree scalar so we add the new
// LoadInst to ExternalUses list to make sure that an extract will
// be generated in the future.
if (isa<Instruction>(PO)) {
if (TreeEntry *Entry = getTreeEntry(PO)) {
// Find which lane we need to extract.
unsigned FoundLane = Entry->findLaneForValue(PO);
ExternalUses.emplace_back(PO, NewLI, FoundLane);
}
}
} else {
assert((E->State == TreeEntry::ScatterVectorize ||
E->State == TreeEntry::PossibleStridedVectorize) &&
Expand Down Expand Up @@ -11539,17 +11526,6 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
StoreInst *ST =
Builder.CreateAlignedStore(VecValue, Ptr, SI->getAlign());

// The pointer operand uses an in-tree scalar, so add the new StoreInst to
// ExternalUses to make sure that an extract will be generated in the
// future.
if (isa<Instruction>(Ptr)) {
if (TreeEntry *Entry = getTreeEntry(Ptr)) {
// Find which lane we need to extract.
unsigned FoundLane = Entry->findLaneForValue(Ptr);
ExternalUses.push_back(ExternalUser(Ptr, ST, FoundLane));
}
}

Value *V = propagateMetadata(ST, E->Scalars);

E->VectorizedValue = V;
Expand Down Expand Up @@ -11654,18 +11630,6 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
CI->getOperandBundlesAsDefs(OpBundles);
Value *V = Builder.CreateCall(CF, OpVecs, OpBundles);

// The scalar argument uses an in-tree scalar so we add the new vectorized
// call to ExternalUses list to make sure that an extract will be
// generated in the future.
if (isa_and_present<Instruction>(ScalarArg)) {
if (TreeEntry *Entry = getTreeEntry(ScalarArg)) {
// Find which lane we need to extract.
unsigned FoundLane = Entry->findLaneForValue(ScalarArg);
ExternalUses.push_back(
ExternalUser(ScalarArg, cast<User>(V), FoundLane));
}
}

propagateIRFlags(V, E->Scalars, VL0);
V = FinalShuffle(V, E, VecTy, IsSigned);

Expand Down Expand Up @@ -11877,6 +11841,7 @@ Value *BoUpSLP::vectorizeTree(
DenseMap<Value *, DenseMap<BasicBlock *, Instruction *>> ScalarToEEs;
SmallDenseSet<Value *, 4> UsedInserts;
DenseMap<Value *, Value *> VectorCasts;
SmallDenseSet<Value *, 4> ScalarsWithNullptrUser;
// Extract all of the elements with the external uses.
for (const auto &ExternalUse : ExternalUses) {
Value *Scalar = ExternalUse.Scalar;
Expand Down Expand Up @@ -11947,13 +11912,27 @@ Value *BoUpSLP::vectorizeTree(
VectorToInsertElement.try_emplace(Vec, IE);
return Vec;
};
// If User == nullptr, the Scalar is used as extra arg. Generate
// ExtractElement instruction and update the record for this scalar in
// ExternallyUsedValues.
// If User == nullptr, the Scalar remains as scalar in vectorized
// instructions or is used as extra arg. Generate ExtractElement instruction
// and update the record for this scalar in ExternallyUsedValues.
if (!User) {
assert(ExternallyUsedValues.count(Scalar) &&
"Scalar with nullptr as an external user must be registered in "
"ExternallyUsedValues map");
if (!ScalarsWithNullptrUser.insert(Scalar).second)
continue;
assert((ExternallyUsedValues.count(Scalar) ||
any_of(Scalar->users(),
[&](llvm::User *U) {
TreeEntry *UseEntry = getTreeEntry(U);
return UseEntry &&
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return UseEntry &&
return UseEntry && UseEntry->State != TreeEntry::Vectorize &&

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose this should be UseEntry->State == TreeEntry::Vectorize ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, looks so, also should be E->State != TreeEntry::Vectorize

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated to UseEntry && UseEntry->State == TreeEntry::Vectorize && E->State == TreeEntry::Vectorize && doesInTreeUserNeedToExtract()

UseEntry->State == TreeEntry::Vectorize &&
E->State == TreeEntry::Vectorize &&
doesInTreeUserNeedToExtract(
Scalar,
cast<Instruction>(UseEntry->Scalars.front()),
TLI);
})) &&
"Scalar with nullptr User must be registered in "
"ExternallyUsedValues map or remain as scalar in vectorized "
"instructions");
if (auto *VecI = dyn_cast<Instruction>(Vec)) {
if (auto *PHI = dyn_cast<PHINode>(VecI))
Builder.SetInsertPoint(PHI->getParent(),
Expand Down
34 changes: 17 additions & 17 deletions llvm/test/Transforms/SLPVectorizer/X86/extract_in_tree_user.ll
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ define i32 @fn1() {
; CHECK-NEXT: entry:
; CHECK-NEXT: [[TMP0:%.*]] = load ptr, ptr @a, align 8
; CHECK-NEXT: [[TMP1:%.*]] = insertelement <2 x ptr> poison, ptr [[TMP0]], i32 0
; CHECK-NEXT: [[SHUFFLE:%.*]] = shufflevector <2 x ptr> [[TMP1]], <2 x ptr> poison, <2 x i32> zeroinitializer
; CHECK-NEXT: [[TMP2:%.*]] = getelementptr i64, <2 x ptr> [[SHUFFLE]], <2 x i64> <i64 11, i64 56>
; CHECK-NEXT: [[TMP3:%.*]] = ptrtoint <2 x ptr> [[TMP2]] to <2 x i64>
; CHECK-NEXT: [[TMP4:%.*]] = extractelement <2 x ptr> [[TMP2]], i32 0
; CHECK-NEXT: store <2 x i64> [[TMP3]], ptr [[TMP4]], align 8
; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <2 x ptr> [[TMP1]], <2 x ptr> poison, <2 x i32> zeroinitializer
; CHECK-NEXT: [[TMP3:%.*]] = getelementptr i64, <2 x ptr> [[TMP2]], <2 x i64> <i64 11, i64 56>
; CHECK-NEXT: [[TMP4:%.*]] = extractelement <2 x ptr> [[TMP3]], i32 0
; CHECK-NEXT: [[TMP5:%.*]] = ptrtoint <2 x ptr> [[TMP3]] to <2 x i64>
; CHECK-NEXT: store <2 x i64> [[TMP5]], ptr [[TMP4]], align 8
; CHECK-NEXT: ret i32 undef
;
entry:
Expand All @@ -34,13 +34,13 @@ declare float @llvm.powi.f32.i32(float, i32)
define void @fn2(ptr %a, ptr %b, ptr %c) {
; CHECK-LABEL: @fn2(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[TMP1:%.*]] = load <4 x i32>, ptr [[A:%.*]], align 4
; CHECK-NEXT: [[TMP3:%.*]] = load <4 x i32>, ptr [[B:%.*]], align 4
; CHECK-NEXT: [[TMP4:%.*]] = add <4 x i32> [[TMP1]], [[TMP3]]
; CHECK-NEXT: [[TMP5:%.*]] = sitofp <4 x i32> [[TMP4]] to <4 x float>
; CHECK-NEXT: [[TMP6:%.*]] = extractelement <4 x i32> [[TMP4]], i32 0
; CHECK-NEXT: [[TMP7:%.*]] = call <4 x float> @llvm.powi.v4f32.i32(<4 x float> [[TMP5]], i32 [[TMP6]])
; CHECK-NEXT: store <4 x float> [[TMP7]], ptr [[C:%.*]], align 4
; CHECK-NEXT: [[TMP0:%.*]] = load <4 x i32>, ptr [[A:%.*]], align 4
; CHECK-NEXT: [[TMP1:%.*]] = load <4 x i32>, ptr [[B:%.*]], align 4
; CHECK-NEXT: [[TMP2:%.*]] = add <4 x i32> [[TMP0]], [[TMP1]]
; CHECK-NEXT: [[TMP3:%.*]] = extractelement <4 x i32> [[TMP2]], i32 0
; CHECK-NEXT: [[TMP4:%.*]] = sitofp <4 x i32> [[TMP2]] to <4 x float>
; CHECK-NEXT: [[TMP5:%.*]] = call <4 x float> @llvm.powi.v4f32.i32(<4 x float> [[TMP4]], i32 [[TMP3]])
; CHECK-NEXT: store <4 x float> [[TMP5]], ptr [[C:%.*]], align 4
; CHECK-NEXT: ret void
;
entry:
Expand Down Expand Up @@ -90,12 +90,12 @@ define void @externally_used_ptrs() {
; CHECK-NEXT: entry:
; CHECK-NEXT: [[TMP0:%.*]] = load ptr, ptr @a, align 8
; CHECK-NEXT: [[TMP1:%.*]] = insertelement <2 x ptr> poison, ptr [[TMP0]], i32 0
; CHECK-NEXT: [[SHUFFLE:%.*]] = shufflevector <2 x ptr> [[TMP1]], <2 x ptr> poison, <2 x i32> zeroinitializer
; CHECK-NEXT: [[TMP2:%.*]] = getelementptr i64, <2 x ptr> [[SHUFFLE]], <2 x i64> <i64 56, i64 11>
; CHECK-NEXT: [[TMP3:%.*]] = ptrtoint <2 x ptr> [[TMP2]] to <2 x i64>
; CHECK-NEXT: [[TMP4:%.*]] = extractelement <2 x ptr> [[TMP2]], i32 1
; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <2 x ptr> [[TMP1]], <2 x ptr> poison, <2 x i32> zeroinitializer
; CHECK-NEXT: [[TMP3:%.*]] = getelementptr i64, <2 x ptr> [[TMP2]], <2 x i64> <i64 56, i64 11>
; CHECK-NEXT: [[TMP4:%.*]] = extractelement <2 x ptr> [[TMP3]], i32 1
; CHECK-NEXT: [[TMP5:%.*]] = ptrtoint <2 x ptr> [[TMP3]] to <2 x i64>
; CHECK-NEXT: [[TMP6:%.*]] = load <2 x i64>, ptr [[TMP4]], align 8
; CHECK-NEXT: [[TMP7:%.*]] = add <2 x i64> [[TMP3]], [[TMP6]]
; CHECK-NEXT: [[TMP7:%.*]] = add <2 x i64> [[TMP5]], [[TMP6]]
; CHECK-NEXT: store <2 x i64> [[TMP7]], ptr [[TMP4]], align 8
; CHECK-NEXT: ret void
;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@ define void @"foo"(ptr addrspace(1) %0, ptr addrspace(1) %1) #0 {
; CHECK-NEXT: [[TMP3:%.*]] = insertelement <4 x ptr addrspace(1)> poison, ptr addrspace(1) [[TMP0:%.*]], i32 0
; CHECK-NEXT: [[TMP4:%.*]] = shufflevector <4 x ptr addrspace(1)> [[TMP3]], <4 x ptr addrspace(1)> poison, <4 x i32> zeroinitializer
; CHECK-NEXT: [[TMP5:%.*]] = getelementptr i8, <4 x ptr addrspace(1)> [[TMP4]], <4 x i64> <i64 8, i64 12, i64 28, i64 24>
; CHECK-NEXT: [[TMP6:%.*]] = getelementptr inbounds i8, ptr addrspace(1) [[TMP1:%.*]], i64 8
; CHECK-NEXT: [[TMP7:%.*]] = call <4 x float> @llvm.masked.gather.v4f32.v4p1(<4 x ptr addrspace(1)> [[TMP5]], i32 4, <4 x i1> <i1 true, i1 true, i1 true, i1 true>, <4 x float> poison)
; CHECK-NEXT: [[TMP8:%.*]] = shufflevector <4 x float> [[TMP7]], <4 x float> poison, <8 x i32> <i32 0, i32 3, i32 0, i32 3, i32 2, i32 1, i32 2, i32 1>
; CHECK-NEXT: [[TMP9:%.*]] = load <8 x float>, ptr addrspace(1) [[TMP6]], align 4
; CHECK-NEXT: [[TMP10:%.*]] = fmul <8 x float> [[TMP8]], [[TMP9]]
; CHECK-NEXT: [[TMP11:%.*]] = fadd <8 x float> [[TMP10]], zeroinitializer
; CHECK-NEXT: [[TMP12:%.*]] = shufflevector <8 x float> [[TMP11]], <8 x float> poison, <8 x i32> <i32 0, i32 5, i32 2, i32 7, i32 4, i32 1, i32 6, i32 3>
; CHECK-NEXT: [[TMP13:%.*]] = extractelement <4 x ptr addrspace(1)> [[TMP5]], i32 0
; CHECK-NEXT: store <8 x float> [[TMP12]], ptr addrspace(1) [[TMP13]], align 4
; CHECK-NEXT: [[TMP6:%.*]] = extractelement <4 x ptr addrspace(1)> [[TMP5]], i32 0
; CHECK-NEXT: [[TMP7:%.*]] = getelementptr inbounds i8, ptr addrspace(1) [[TMP1:%.*]], i64 8
; CHECK-NEXT: [[TMP8:%.*]] = call <4 x float> @llvm.masked.gather.v4f32.v4p1(<4 x ptr addrspace(1)> [[TMP5]], i32 4, <4 x i1> <i1 true, i1 true, i1 true, i1 true>, <4 x float> poison)
; CHECK-NEXT: [[TMP9:%.*]] = shufflevector <4 x float> [[TMP8]], <4 x float> poison, <8 x i32> <i32 0, i32 3, i32 0, i32 3, i32 2, i32 1, i32 2, i32 1>
; CHECK-NEXT: [[TMP10:%.*]] = load <8 x float>, ptr addrspace(1) [[TMP7]], align 4
; CHECK-NEXT: [[TMP11:%.*]] = fmul <8 x float> [[TMP9]], [[TMP10]]
; CHECK-NEXT: [[TMP12:%.*]] = fadd <8 x float> [[TMP11]], zeroinitializer
; CHECK-NEXT: [[TMP13:%.*]] = shufflevector <8 x float> [[TMP12]], <8 x float> poison, <8 x i32> <i32 0, i32 5, i32 2, i32 7, i32 4, i32 1, i32 6, i32 3>
; CHECK-NEXT: store <8 x float> [[TMP13]], ptr addrspace(1) [[TMP6]], align 4
; CHECK-NEXT: ret void
;
%3 = getelementptr inbounds i8, ptr addrspace(1) %0, i64 8
Expand Down