Skip to content

Commit f00bc10

Browse files
committed
[LSV] Insert casts to vectorize mismatched types
After collecting equivalence classes, loop over each distinct pair of them and check if they could be merged into one. Consider classes A and B such that their leaders differ only by their scalar bitwidth. (We do not merge them otherwise.) Let N be the scalar bitwidth of the leader instruction in A. Iterate over all instructions in B and ensure their total bitwidths match the total bitwidth of the leader instruction of A. Finally, cast each instruction in B with a mismatched type to an intN type.
1 parent d3097b7 commit f00bc10

File tree

3 files changed

+87
-101
lines changed

3 files changed

+87
-101
lines changed

llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,10 @@ class Vectorizer {
324324
Instruction *ChainElem, Instruction *ChainBegin,
325325
const DenseMap<Instruction *, APInt /*OffsetFromLeader*/> &ChainOffsets);
326326

327+
/// Merge the equivalence classes if casts could be inserted in one to match
328+
/// the scalar bitwidth of the instructions in the other class.
329+
void insertCastsToMergeClasses(EquivalenceClassMap &EQClasses);
330+
327331
/// Merges the equivalence classes if they have underlying objects that differ
328332
/// by one level of indirection (i.e., one is a getelementptr and the other is
329333
/// the base pointer in that getelementptr).
@@ -1310,6 +1314,82 @@ std::optional<APInt> Vectorizer::getConstantOffsetSelects(
13101314
return std::nullopt;
13111315
}
13121316

1317+
void Vectorizer::insertCastsToMergeClasses(EquivalenceClassMap &EQClasses) {
1318+
if (EQClasses.size() < 2)
1319+
return;
1320+
1321+
// Loop over all equivalence classes and try to merge them. Keep track of
1322+
// classes that are merged into others.
1323+
DenseSet<EqClassKey> ClassesToErase;
1324+
for (auto EC1 : EQClasses) {
1325+
for (auto EC2 : EQClasses) {
1326+
if (ClassesToErase.contains(EC2.first) || EC1 <= EC2)
1327+
continue;
1328+
1329+
auto [Ptr1, AS1, TySize1, IsLoad1] = EC1.first;
1330+
auto [Ptr2, AS2, TySize2, IsLoad2] = EC2.first;
1331+
1332+
// Attempt to merge EC2 into EC1. Skip if the pointers, address spaces or
1333+
// whether the leader instruction is a load/store are different. Also skip
1334+
// if the scalar bitwidth of the first equivalence class is smaller than
1335+
// the second one to avoid reconsidering the same equivalence class pair.
1336+
if (Ptr1 != Ptr2 || AS1 != AS2 || IsLoad1 != IsLoad2 || TySize1 < TySize2)
1337+
continue;
1338+
1339+
// Ensure all instructions in EC2 can be bitcasted into NewTy.
1340+
/// TODO: NewTyBits is needed as stuctured binded variables cannot be
1341+
/// captured by a lambda until C++20.
1342+
auto NewTyBits = std::get<2>(EC1.first);
1343+
if (any_of(EC2.second, [&](Instruction *I) {
1344+
return DL.getTypeSizeInBits(getLoadStoreType(I)) != NewTyBits;
1345+
}))
1346+
continue;
1347+
1348+
// Create a new type for the equivalence class.
1349+
/// TODO: NewTy should be an FP type for an all-FP equivalence class.
1350+
auto *NewTy = Type::getIntNTy(EC2.second[0]->getContext(), NewTyBits);
1351+
for (auto *Inst : EC2.second) {
1352+
auto *Ptr = getLoadStorePointerOperand(Inst);
1353+
auto *OrigTy = Inst->getType();
1354+
if (OrigTy == NewTy)
1355+
continue;
1356+
if (auto *LI = dyn_cast<LoadInst>(Inst)) {
1357+
Builder.SetInsertPoint(LI->getIterator());
1358+
auto *NewLoad = Builder.CreateLoad(NewTy, Ptr);
1359+
auto *Cast = Builder.CreateBitOrPointerCast(
1360+
NewLoad, OrigTy, NewLoad->getName() + ".cast");
1361+
LI->replaceAllUsesWith(Cast);
1362+
LI->eraseFromParent();
1363+
EQClasses[EC1.first].emplace_back(NewLoad);
1364+
} else {
1365+
auto *SI = cast<StoreInst>(Inst);
1366+
Builder.SetInsertPoint(SI->getIterator());
1367+
auto *Cast = Builder.CreateBitOrPointerCast(
1368+
SI->getValueOperand(), NewTy,
1369+
SI->getValueOperand()->getName() + ".cast");
1370+
auto *NewStore = Builder.CreateStore(
1371+
Cast, getLoadStorePointerOperand(SI), SI->isVolatile());
1372+
SI->eraseFromParent();
1373+
EQClasses[EC1.first].emplace_back(NewStore);
1374+
}
1375+
}
1376+
1377+
// Sort the instructions in the equivalence class by their order in the
1378+
// basic block. This is important to ensure that the instructions are
1379+
// vectorized in the correct order.
1380+
std::sort(EQClasses[EC1.first].begin(), EQClasses[EC1.first].end(),
1381+
[](Instruction *A, Instruction *B) {
1382+
return A && B && A->comesBefore(B);
1383+
});
1384+
ClassesToErase.insert(EC2.first);
1385+
}
1386+
}
1387+
1388+
// Erase the equivalence classes that were merged into others.
1389+
for (auto Key : ClassesToErase)
1390+
EQClasses.erase(Key);
1391+
}
1392+
13131393
void Vectorizer::mergeEquivalenceClasses(EquivalenceClassMap &EQClasses) const {
13141394
if (EQClasses.size() < 2) // There is nothing to merge.
13151395
return;
@@ -1495,7 +1575,7 @@ Vectorizer::collectEquivalenceClasses(BasicBlock::iterator Begin,
14951575
/*IsLoad=*/LI != nullptr}]
14961576
.emplace_back(&I);
14971577
}
1498-
1578+
insertCastsToMergeClasses(Ret);
14991579
mergeEquivalenceClasses(Ret);
15001580
return Ret;
15011581
}

llvm/test/Transforms/LoadStoreVectorizer/AMDGPU/insert-casts-vectorize.ll

Lines changed: 0 additions & 89 deletions
This file was deleted.

llvm/test/Transforms/LoadStoreVectorizer/AMDGPU/merge-vectors.ll

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -95,10 +95,10 @@ entry:
9595
ret void
9696
}
9797

98-
; Ideally this would be merged
9998
; CHECK-LABEL: @merge_load_i32_v2i16(
100-
; CHECK: load i32,
101-
; CHECK: load <2 x i16>
99+
; CHECK: load <2 x i32>
100+
; CHECK: extractelement <2 x i32> %0, i32 0
101+
; CHECK: extractelement <2 x i32> %0, i32 1
102102
define amdgpu_kernel void @merge_load_i32_v2i16(ptr addrspace(1) nocapture %a) #0 {
103103
entry:
104104
%a.1 = getelementptr inbounds i32, ptr addrspace(1) %a, i32 1
@@ -113,14 +113,9 @@ attributes #0 = { nounwind }
113113
attributes #1 = { nounwind readnone }
114114

115115
; CHECK-LABEL: @merge_i32_2i16_float_4i8(
116-
; CHECK: load i32
117-
; CHECK: load <2 x i16>
118-
; CHECK: load float
119-
; CHECK: load <4 x i8>
120-
; CHECK: store i32
121-
; CHECK: store <2 x i16>
122-
; CHECK: store float
123-
; CHECK: store <4 x i8>
116+
; CHECK: load <4 x i32>
117+
; CHECK: store <2 x i32>
118+
; CHECK: store <2 x i32>
124119
define void @merge_i32_2i16_float_4i8(ptr addrspace(1) %ptr1, ptr addrspace(2) %ptr2) {
125120
%gep1 = getelementptr inbounds i32, ptr addrspace(1) %ptr1, i64 0
126121
%load1 = load i32, ptr addrspace(1) %gep1, align 4

0 commit comments

Comments
 (0)