Skip to content

Commit a1ecd07

Browse files
committed
[LSV] Insert casts to vectorize mismatched types
After collecting equivalence classes, loop over each distinct pair of them and check if they could be merged into one. Consider classes A and B such that their leaders differ only by their scalar bitwidths. We do not yet merge them otherwise. Let N be the scalar bitwidth of the leader instruction in A. Iterate over all instructions in B and ensure their total bitwidths match the total bitwidth of the leader instruction of A. Finally, cast each instruction in B with a mismatched type to a pointer, integer or floating point type. Resolve issue #97715 Change-Id: Ib64fd98de5c908262947648ad14dc53b61814642
1 parent 316a76b commit a1ecd07

36 files changed

+2264
-2078
lines changed

llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp

Lines changed: 135 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
#include "llvm/Transforms/Vectorize/LoadStoreVectorizer.h"
6161
#include "llvm/ADT/APInt.h"
6262
#include "llvm/ADT/ArrayRef.h"
63+
#include "llvm/ADT/Bitset.h"
6364
#include "llvm/ADT/DenseMap.h"
6465
#include "llvm/ADT/MapVector.h"
6566
#include "llvm/ADT/PostOrderIterator.h"
@@ -324,6 +325,10 @@ class Vectorizer {
324325
Instruction *ChainElem, Instruction *ChainBegin,
325326
const DenseMap<Instruction *, APInt /*OffsetFromLeader*/> &ChainOffsets);
326327

328+
/// Merge equivalence classes if casts could be inserted in one to match
329+
/// the total bitwidth of the instructions.
330+
void insertCastsToMergeClasses(EquivalenceClassMap &EQClasses);
331+
327332
/// Merges the equivalence classes if they have underlying objects that differ
328333
/// by one level of indirection (i.e., one is a getelementptr and the other is
329334
/// the base pointer in that getelementptr).
@@ -1308,6 +1313,135 @@ std::optional<APInt> Vectorizer::getConstantOffsetSelects(
13081313
return std::nullopt;
13091314
}
13101315

1316+
void Vectorizer::insertCastsToMergeClasses(EquivalenceClassMap &EQClasses) {
1317+
if (EQClasses.size() < 2)
1318+
return;
1319+
1320+
auto CopyMetaDataFromTo = [&](Instruction *Src, Instruction *Dst) {
1321+
SmallVector<std::pair<unsigned, MDNode *>, 4> MD;
1322+
Src->getAllMetadata(MD);
1323+
for (const auto [ID, Node] : MD) {
1324+
Dst->setMetadata(ID, Node);
1325+
}
1326+
};
1327+
1328+
// For each class, determine if all instructions are of type int, FP or ptr.
1329+
// This information will help us determine the type instructions should be
1330+
// casted into.
1331+
MapVector<EqClassKey, Bitset<3>> ClassAllTy;
1332+
for (const auto &C : EQClasses) {
1333+
auto CommonTypeKind = [](Instruction *I) {
1334+
if (I->getType()->isIntOrIntVectorTy())
1335+
return 0;
1336+
if (I->getType()->isFPOrFPVectorTy())
1337+
return 1;
1338+
if (I->getType()->isPtrOrPtrVectorTy())
1339+
return 2;
1340+
return -1; // Invalid type kind
1341+
};
1342+
1343+
int FirstTypeKind = CommonTypeKind(EQClasses[C.first][0]);
1344+
if (FirstTypeKind != -1 && all_of(EQClasses[C.first], [&](Instruction *I) {
1345+
return CommonTypeKind(I) == FirstTypeKind;
1346+
})) {
1347+
ClassAllTy[C.first].set(FirstTypeKind);
1348+
}
1349+
}
1350+
1351+
// Loop over all equivalence classes and try to merge them. Keep track of
1352+
// classes that are merged into others.
1353+
DenseSet<EqClassKey> ClassesToErase;
1354+
for (auto EC1 : EQClasses) {
1355+
for (auto EC2 : EQClasses) {
1356+
// Skip if EC2 was already merged before, EC1 follows EC2 in the
1357+
// collection or EC1 is the same as EC2.
1358+
if (ClassesToErase.contains(EC2.first) || EC1 <= EC2 ||
1359+
EC1.first == EC2.first)
1360+
continue;
1361+
1362+
auto [Ptr1, AS1, TySize1, IsLoad1] = EC1.first;
1363+
auto [Ptr2, AS2, TySize2, IsLoad2] = EC2.first;
1364+
1365+
// Attempt to merge EC2 into EC1. Skip if the pointers, address spaces or
1366+
// whether the leader instruction is a load/store are different. Also skip
1367+
// if the scalar bitwidth of the first equivalence class is smaller than
1368+
// the second one to avoid reconsidering the same equivalence class pair.
1369+
if (Ptr1 != Ptr2 || AS1 != AS2 || IsLoad1 != IsLoad2 || TySize1 < TySize2)
1370+
continue;
1371+
1372+
// An All-FP class should only be merged into another All-FP class.
1373+
if ((ClassAllTy[EC1.first].test(1) && !ClassAllTy[EC2.first].test(1)) ||
1374+
(!ClassAllTy[EC1.first].test(2) && ClassAllTy[EC2.first].test(2)))
1375+
continue;
1376+
1377+
// Ensure all instructions in EC2 can be bitcasted into NewTy.
1378+
/// TODO: NewTyBits is needed as stuctured binded variables cannot be
1379+
/// captured by a lambda until C++20.
1380+
auto NewTyBits = std::get<2>(EC1.first);
1381+
if (any_of(EC2.second, [&](Instruction *I) {
1382+
return DL.getTypeSizeInBits(getLoadStoreType(I)) != NewTyBits;
1383+
}))
1384+
continue;
1385+
1386+
// Create a new type for the equivalence class.
1387+
auto &Ctx = EC2.second[0]->getContext();
1388+
Type *NewTy = Type::getIntNTy(EC2.second[0]->getContext(), NewTyBits);
1389+
if (ClassAllTy[EC1.first].test(1) && ClassAllTy[EC2.first].test(1)) {
1390+
if (NewTyBits == 16)
1391+
NewTy = Type::getHalfTy(Ctx);
1392+
else if (NewTyBits == 32)
1393+
NewTy = Type::getFloatTy(Ctx);
1394+
else if (NewTyBits == 64)
1395+
NewTy = Type::getDoubleTy(Ctx);
1396+
} else if (ClassAllTy[EC1.first].test(2) &&
1397+
ClassAllTy[EC2.first].test(2)) {
1398+
NewTy = PointerType::get(Ctx, AS2);
1399+
}
1400+
1401+
for (auto *Inst : EC2.second) {
1402+
Value *Ptr = getLoadStorePointerOperand(Inst);
1403+
Type *OrigTy = Inst->getType();
1404+
if (OrigTy == NewTy)
1405+
continue;
1406+
if (auto *LI = dyn_cast<LoadInst>(Inst)) {
1407+
Builder.SetInsertPoint(LI->getIterator());
1408+
auto *NewLoad = Builder.CreateLoad(NewTy, Ptr);
1409+
auto *Cast = Builder.CreateBitOrPointerCast(
1410+
NewLoad, OrigTy, NewLoad->getName() + ".cast");
1411+
LI->replaceAllUsesWith(Cast);
1412+
CopyMetaDataFromTo(LI, NewLoad);
1413+
LI->eraseFromParent();
1414+
EQClasses[EC1.first].emplace_back(NewLoad);
1415+
} else {
1416+
auto *SI = cast<StoreInst>(Inst);
1417+
Builder.SetInsertPoint(SI->getIterator());
1418+
auto *Cast = Builder.CreateBitOrPointerCast(
1419+
SI->getValueOperand(), NewTy,
1420+
SI->getValueOperand()->getName() + ".cast");
1421+
auto *NewStore = Builder.CreateStore(
1422+
Cast, getLoadStorePointerOperand(SI), SI->isVolatile());
1423+
CopyMetaDataFromTo(SI, NewStore);
1424+
SI->eraseFromParent();
1425+
EQClasses[EC1.first].emplace_back(NewStore);
1426+
}
1427+
}
1428+
1429+
// Sort the instructions in the equivalence class by their order in the
1430+
// basic block. This is important to ensure that the instructions are
1431+
// vectorized in the correct order.
1432+
std::sort(EQClasses[EC1.first].begin(), EQClasses[EC1.first].end(),
1433+
[](const Instruction *A, const Instruction *B) {
1434+
return A && B && A->comesBefore(B);
1435+
});
1436+
ClassesToErase.insert(EC2.first);
1437+
}
1438+
}
1439+
1440+
// Erase the equivalence classes that were merged into others.
1441+
for (auto Key : ClassesToErase)
1442+
EQClasses.erase(Key);
1443+
}
1444+
13111445
void Vectorizer::mergeEquivalenceClasses(EquivalenceClassMap &EQClasses) const {
13121446
if (EQClasses.size() < 2) // There is nothing to merge.
13131447
return;
@@ -1493,7 +1627,7 @@ Vectorizer::collectEquivalenceClasses(BasicBlock::iterator Begin,
14931627
/*IsLoad=*/LI != nullptr}]
14941628
.emplace_back(&I);
14951629
}
1496-
1630+
insertCastsToMergeClasses(Ret);
14971631
mergeEquivalenceClasses(Ret);
14981632
return Ret;
14991633
}

0 commit comments

Comments
 (0)