Skip to content

Commit 3024610

Browse files
committed
[LSV] Insert casts to vectorize mismatched types
After collecting equivalence classes, loop over each distinct pair of them and check if they could be merged into one. Consider classes A and B such that their leaders differ only by their scalar bitwidths. We do not yet merge them otherwise. Let N be the scalar bitwidth of the leader instruction in A. Iterate over all instructions in B and ensure their total bitwidths match the total bitwidth of the leader instruction of A. Finally, cast each instruction in B with a mismatched type to a pointer, integer or floating point type. Resolve issue #97715 Change-Id: Ib64fd98de5c908262947648ad14dc53b61814642
1 parent d3097b7 commit 3024610

36 files changed

+2330
-2102
lines changed

llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp

Lines changed: 135 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
#include "llvm/Transforms/Vectorize/LoadStoreVectorizer.h"
6161
#include "llvm/ADT/APInt.h"
6262
#include "llvm/ADT/ArrayRef.h"
63+
#include "llvm/ADT/Bitset.h"
6364
#include "llvm/ADT/DenseMap.h"
6465
#include "llvm/ADT/MapVector.h"
6566
#include "llvm/ADT/PostOrderIterator.h"
@@ -324,6 +325,10 @@ class Vectorizer {
324325
Instruction *ChainElem, Instruction *ChainBegin,
325326
const DenseMap<Instruction *, APInt /*OffsetFromLeader*/> &ChainOffsets);
326327

328+
/// Merge equivalence classes if casts could be inserted in one to match
329+
/// the total bitwidth of the instructions.
330+
void insertCastsToMergeClasses(EquivalenceClassMap &EQClasses);
331+
327332
/// Merges the equivalence classes if they have underlying objects that differ
328333
/// by one level of indirection (i.e., one is a getelementptr and the other is
329334
/// the base pointer in that getelementptr).
@@ -1310,6 +1315,135 @@ std::optional<APInt> Vectorizer::getConstantOffsetSelects(
13101315
return std::nullopt;
13111316
}
13121317

1318+
void Vectorizer::insertCastsToMergeClasses(EquivalenceClassMap &EQClasses) {
1319+
if (EQClasses.size() < 2)
1320+
return;
1321+
1322+
auto CopyMetaDataFromTo = [&](Instruction *Src, Instruction *Dst) {
1323+
SmallVector<std::pair<unsigned, MDNode *>, 4> MD;
1324+
Src->getAllMetadata(MD);
1325+
for (const auto [ID, Node] : MD) {
1326+
Dst->setMetadata(ID, Node);
1327+
}
1328+
};
1329+
1330+
// For each class, determine if all instructions are of type int, FP or ptr.
1331+
// This information will help us determine the type instructions should be
1332+
// casted into.
1333+
MapVector<EqClassKey, Bitset<3>> ClassAllTy;
1334+
for (const auto &C : EQClasses) {
1335+
auto CommonTypeKind = [](Instruction *I) {
1336+
if (I->getType()->isIntOrIntVectorTy())
1337+
return 0;
1338+
if (I->getType()->isFPOrFPVectorTy())
1339+
return 1;
1340+
if (I->getType()->isPtrOrPtrVectorTy())
1341+
return 2;
1342+
return -1; // Invalid type kind
1343+
};
1344+
1345+
int FirstTypeKind = CommonTypeKind(EQClasses[C.first][0]);
1346+
if (FirstTypeKind != -1 && all_of(EQClasses[C.first], [&](Instruction *I) {
1347+
return CommonTypeKind(I) == FirstTypeKind;
1348+
})) {
1349+
ClassAllTy[C.first].set(FirstTypeKind);
1350+
}
1351+
}
1352+
1353+
// Loop over all equivalence classes and try to merge them. Keep track of
1354+
// classes that are merged into others.
1355+
DenseSet<EqClassKey> ClassesToErase;
1356+
for (auto EC1 : EQClasses) {
1357+
for (auto EC2 : EQClasses) {
1358+
// Skip if EC2 was already merged before, EC1 follows EC2 in the
1359+
// collection or EC1 is the same as EC2.
1360+
if (ClassesToErase.contains(EC2.first) || EC1 <= EC2 ||
1361+
EC1.first == EC2.first)
1362+
continue;
1363+
1364+
auto [Ptr1, AS1, TySize1, IsLoad1] = EC1.first;
1365+
auto [Ptr2, AS2, TySize2, IsLoad2] = EC2.first;
1366+
1367+
// Attempt to merge EC2 into EC1. Skip if the pointers, address spaces or
1368+
// whether the leader instruction is a load/store are different. Also skip
1369+
// if the scalar bitwidth of the first equivalence class is smaller than
1370+
// the second one to avoid reconsidering the same equivalence class pair.
1371+
if (Ptr1 != Ptr2 || AS1 != AS2 || IsLoad1 != IsLoad2 || TySize1 < TySize2)
1372+
continue;
1373+
1374+
// An All-FP class should only be merged into another All-FP class.
1375+
if ((ClassAllTy[EC1.first].test(1) && !ClassAllTy[EC2.first].test(1)) ||
1376+
(!ClassAllTy[EC1.first].test(2) && ClassAllTy[EC2.first].test(2)))
1377+
continue;
1378+
1379+
// Ensure all instructions in EC2 can be bitcasted into NewTy.
1380+
/// TODO: NewTyBits is needed as stuctured binded variables cannot be
1381+
/// captured by a lambda until C++20.
1382+
auto NewTyBits = std::get<2>(EC1.first);
1383+
if (any_of(EC2.second, [&](Instruction *I) {
1384+
return DL.getTypeSizeInBits(getLoadStoreType(I)) != NewTyBits;
1385+
}))
1386+
continue;
1387+
1388+
// Create a new type for the equivalence class.
1389+
auto &Ctx = EC2.second[0]->getContext();
1390+
Type *NewTy = Type::getIntNTy(EC2.second[0]->getContext(), NewTyBits);
1391+
if (ClassAllTy[EC1.first].test(1) && ClassAllTy[EC2.first].test(1)) {
1392+
if (NewTyBits == 16)
1393+
NewTy = Type::getHalfTy(Ctx);
1394+
else if (NewTyBits == 32)
1395+
NewTy = Type::getFloatTy(Ctx);
1396+
else if (NewTyBits == 64)
1397+
NewTy = Type::getDoubleTy(Ctx);
1398+
} else if (ClassAllTy[EC1.first].test(2) &&
1399+
ClassAllTy[EC2.first].test(2)) {
1400+
NewTy = PointerType::get(Ctx, AS2);
1401+
}
1402+
1403+
for (auto *Inst : EC2.second) {
1404+
Value *Ptr = getLoadStorePointerOperand(Inst);
1405+
Type *OrigTy = Inst->getType();
1406+
if (OrigTy == NewTy)
1407+
continue;
1408+
if (auto *LI = dyn_cast<LoadInst>(Inst)) {
1409+
Builder.SetInsertPoint(LI->getIterator());
1410+
auto *NewLoad = Builder.CreateLoad(NewTy, Ptr);
1411+
auto *Cast = Builder.CreateBitOrPointerCast(
1412+
NewLoad, OrigTy, NewLoad->getName() + ".cast");
1413+
LI->replaceAllUsesWith(Cast);
1414+
CopyMetaDataFromTo(LI, NewLoad);
1415+
LI->eraseFromParent();
1416+
EQClasses[EC1.first].emplace_back(NewLoad);
1417+
} else {
1418+
auto *SI = cast<StoreInst>(Inst);
1419+
Builder.SetInsertPoint(SI->getIterator());
1420+
auto *Cast = Builder.CreateBitOrPointerCast(
1421+
SI->getValueOperand(), NewTy,
1422+
SI->getValueOperand()->getName() + ".cast");
1423+
auto *NewStore = Builder.CreateStore(
1424+
Cast, getLoadStorePointerOperand(SI), SI->isVolatile());
1425+
CopyMetaDataFromTo(SI, NewStore);
1426+
SI->eraseFromParent();
1427+
EQClasses[EC1.first].emplace_back(NewStore);
1428+
}
1429+
}
1430+
1431+
// Sort the instructions in the equivalence class by their order in the
1432+
// basic block. This is important to ensure that the instructions are
1433+
// vectorized in the correct order.
1434+
std::sort(EQClasses[EC1.first].begin(), EQClasses[EC1.first].end(),
1435+
[](const Instruction *A, const Instruction *B) {
1436+
return A && B && A->comesBefore(B);
1437+
});
1438+
ClassesToErase.insert(EC2.first);
1439+
}
1440+
}
1441+
1442+
// Erase the equivalence classes that were merged into others.
1443+
for (auto Key : ClassesToErase)
1444+
EQClasses.erase(Key);
1445+
}
1446+
13131447
void Vectorizer::mergeEquivalenceClasses(EquivalenceClassMap &EQClasses) const {
13141448
if (EQClasses.size() < 2) // There is nothing to merge.
13151449
return;
@@ -1495,7 +1629,7 @@ Vectorizer::collectEquivalenceClasses(BasicBlock::iterator Begin,
14951629
/*IsLoad=*/LI != nullptr}]
14961630
.emplace_back(&I);
14971631
}
1498-
1632+
insertCastsToMergeClasses(Ret);
14991633
mergeEquivalenceClasses(Ret);
15001634
return Ret;
15011635
}

0 commit comments

Comments
 (0)