|
60 | 60 | #include "llvm/Transforms/Vectorize/LoadStoreVectorizer.h"
|
61 | 61 | #include "llvm/ADT/APInt.h"
|
62 | 62 | #include "llvm/ADT/ArrayRef.h"
|
| 63 | +#include "llvm/ADT/Bitset.h" |
63 | 64 | #include "llvm/ADT/DenseMap.h"
|
64 | 65 | #include "llvm/ADT/MapVector.h"
|
65 | 66 | #include "llvm/ADT/PostOrderIterator.h"
|
@@ -324,6 +325,10 @@ class Vectorizer {
|
324 | 325 | Instruction *ChainElem, Instruction *ChainBegin,
|
325 | 326 | const DenseMap<Instruction *, APInt /*OffsetFromLeader*/> &ChainOffsets);
|
326 | 327 |
|
| 328 | + /// Merge equivalence classes if casts could be inserted in one to match |
| 329 | + /// the total bitwidth of the instructions. |
| 330 | + void insertCastsToMergeClasses(EquivalenceClassMap &EQClasses); |
| 331 | + |
327 | 332 | /// Merges the equivalence classes if they have underlying objects that differ
|
328 | 333 | /// by one level of indirection (i.e., one is a getelementptr and the other is
|
329 | 334 | /// the base pointer in that getelementptr).
|
@@ -1310,6 +1315,135 @@ std::optional<APInt> Vectorizer::getConstantOffsetSelects(
|
1310 | 1315 | return std::nullopt;
|
1311 | 1316 | }
|
1312 | 1317 |
|
| 1318 | +void Vectorizer::insertCastsToMergeClasses(EquivalenceClassMap &EQClasses) { |
| 1319 | + if (EQClasses.size() < 2) |
| 1320 | + return; |
| 1321 | + |
| 1322 | + auto CopyMetaDataFromTo = [&](Instruction *Src, Instruction *Dst) { |
| 1323 | + SmallVector<std::pair<unsigned, MDNode *>, 4> MD; |
| 1324 | + Src->getAllMetadata(MD); |
| 1325 | + for (const auto [ID, Node] : MD) { |
| 1326 | + Dst->setMetadata(ID, Node); |
| 1327 | + } |
| 1328 | + }; |
| 1329 | + |
| 1330 | + // For each class, determine if all instructions are of type int, FP or ptr. |
| 1331 | + // This information will help us determine the type instructions should be |
| 1332 | + // casted into. |
| 1333 | + MapVector<EqClassKey, Bitset<3>> ClassAllTy; |
| 1334 | + for (const auto &C : EQClasses) { |
| 1335 | + auto CommonTypeKind = [](Instruction *I) { |
| 1336 | + if (I->getType()->isIntOrIntVectorTy()) |
| 1337 | + return 0; |
| 1338 | + if (I->getType()->isFPOrFPVectorTy()) |
| 1339 | + return 1; |
| 1340 | + if (I->getType()->isPtrOrPtrVectorTy()) |
| 1341 | + return 2; |
| 1342 | + return -1; // Invalid type kind |
| 1343 | + }; |
| 1344 | + |
| 1345 | + int FirstTypeKind = CommonTypeKind(EQClasses[C.first][0]); |
| 1346 | + if (FirstTypeKind != -1 && all_of(EQClasses[C.first], [&](Instruction *I) { |
| 1347 | + return CommonTypeKind(I) == FirstTypeKind; |
| 1348 | + })) { |
| 1349 | + ClassAllTy[C.first].set(FirstTypeKind); |
| 1350 | + } |
| 1351 | + } |
| 1352 | + |
| 1353 | + // Loop over all equivalence classes and try to merge them. Keep track of |
| 1354 | + // classes that are merged into others. |
| 1355 | + DenseSet<EqClassKey> ClassesToErase; |
| 1356 | + for (auto EC1 : EQClasses) { |
| 1357 | + for (auto EC2 : EQClasses) { |
| 1358 | + // Skip if EC2 was already merged before, EC1 follows EC2 in the |
| 1359 | + // collection or EC1 is the same as EC2. |
| 1360 | + if (ClassesToErase.contains(EC2.first) || EC1 <= EC2 || |
| 1361 | + EC1.first == EC2.first) |
| 1362 | + continue; |
| 1363 | + |
| 1364 | + auto [Ptr1, AS1, TySize1, IsLoad1] = EC1.first; |
| 1365 | + auto [Ptr2, AS2, TySize2, IsLoad2] = EC2.first; |
| 1366 | + |
| 1367 | + // Attempt to merge EC2 into EC1. Skip if the pointers, address spaces or |
| 1368 | + // whether the leader instruction is a load/store are different. Also skip |
| 1369 | + // if the scalar bitwidth of the first equivalence class is smaller than |
| 1370 | + // the second one to avoid reconsidering the same equivalence class pair. |
| 1371 | + if (Ptr1 != Ptr2 || AS1 != AS2 || IsLoad1 != IsLoad2 || TySize1 < TySize2) |
| 1372 | + continue; |
| 1373 | + |
| 1374 | + // An All-FP class should only be merged into another All-FP class. |
| 1375 | + if ((ClassAllTy[EC1.first].test(1) && !ClassAllTy[EC2.first].test(1)) || |
| 1376 | + (!ClassAllTy[EC1.first].test(2) && ClassAllTy[EC2.first].test(2))) |
| 1377 | + continue; |
| 1378 | + |
| 1379 | + // Ensure all instructions in EC2 can be bitcasted into NewTy. |
| 1380 | + /// TODO: NewTyBits is needed as stuctured binded variables cannot be |
| 1381 | + /// captured by a lambda until C++20. |
| 1382 | + auto NewTyBits = std::get<2>(EC1.first); |
| 1383 | + if (any_of(EC2.second, [&](Instruction *I) { |
| 1384 | + return DL.getTypeSizeInBits(getLoadStoreType(I)) != NewTyBits; |
| 1385 | + })) |
| 1386 | + continue; |
| 1387 | + |
| 1388 | + // Create a new type for the equivalence class. |
| 1389 | + auto &Ctx = EC2.second[0]->getContext(); |
| 1390 | + Type *NewTy = Type::getIntNTy(EC2.second[0]->getContext(), NewTyBits); |
| 1391 | + if (ClassAllTy[EC1.first].test(1) && ClassAllTy[EC2.first].test(1)) { |
| 1392 | + if (NewTyBits == 16) |
| 1393 | + NewTy = Type::getHalfTy(Ctx); |
| 1394 | + else if (NewTyBits == 32) |
| 1395 | + NewTy = Type::getFloatTy(Ctx); |
| 1396 | + else if (NewTyBits == 64) |
| 1397 | + NewTy = Type::getDoubleTy(Ctx); |
| 1398 | + } else if (ClassAllTy[EC1.first].test(2) && |
| 1399 | + ClassAllTy[EC2.first].test(2)) { |
| 1400 | + NewTy = PointerType::get(Ctx, AS2); |
| 1401 | + } |
| 1402 | + |
| 1403 | + for (auto *Inst : EC2.second) { |
| 1404 | + Value *Ptr = getLoadStorePointerOperand(Inst); |
| 1405 | + Type *OrigTy = Inst->getType(); |
| 1406 | + if (OrigTy == NewTy) |
| 1407 | + continue; |
| 1408 | + if (auto *LI = dyn_cast<LoadInst>(Inst)) { |
| 1409 | + Builder.SetInsertPoint(LI->getIterator()); |
| 1410 | + auto *NewLoad = Builder.CreateLoad(NewTy, Ptr); |
| 1411 | + auto *Cast = Builder.CreateBitOrPointerCast( |
| 1412 | + NewLoad, OrigTy, NewLoad->getName() + ".cast"); |
| 1413 | + LI->replaceAllUsesWith(Cast); |
| 1414 | + CopyMetaDataFromTo(LI, NewLoad); |
| 1415 | + LI->eraseFromParent(); |
| 1416 | + EQClasses[EC1.first].emplace_back(NewLoad); |
| 1417 | + } else { |
| 1418 | + auto *SI = cast<StoreInst>(Inst); |
| 1419 | + Builder.SetInsertPoint(SI->getIterator()); |
| 1420 | + auto *Cast = Builder.CreateBitOrPointerCast( |
| 1421 | + SI->getValueOperand(), NewTy, |
| 1422 | + SI->getValueOperand()->getName() + ".cast"); |
| 1423 | + auto *NewStore = Builder.CreateStore( |
| 1424 | + Cast, getLoadStorePointerOperand(SI), SI->isVolatile()); |
| 1425 | + CopyMetaDataFromTo(SI, NewStore); |
| 1426 | + SI->eraseFromParent(); |
| 1427 | + EQClasses[EC1.first].emplace_back(NewStore); |
| 1428 | + } |
| 1429 | + } |
| 1430 | + |
| 1431 | + // Sort the instructions in the equivalence class by their order in the |
| 1432 | + // basic block. This is important to ensure that the instructions are |
| 1433 | + // vectorized in the correct order. |
| 1434 | + std::sort(EQClasses[EC1.first].begin(), EQClasses[EC1.first].end(), |
| 1435 | + [](const Instruction *A, const Instruction *B) { |
| 1436 | + return A && B && A->comesBefore(B); |
| 1437 | + }); |
| 1438 | + ClassesToErase.insert(EC2.first); |
| 1439 | + } |
| 1440 | + } |
| 1441 | + |
| 1442 | + // Erase the equivalence classes that were merged into others. |
| 1443 | + for (auto Key : ClassesToErase) |
| 1444 | + EQClasses.erase(Key); |
| 1445 | +} |
| 1446 | + |
1313 | 1447 | void Vectorizer::mergeEquivalenceClasses(EquivalenceClassMap &EQClasses) const {
|
1314 | 1448 | if (EQClasses.size() < 2) // There is nothing to merge.
|
1315 | 1449 | return;
|
@@ -1495,7 +1629,7 @@ Vectorizer::collectEquivalenceClasses(BasicBlock::iterator Begin,
|
1495 | 1629 | /*IsLoad=*/LI != nullptr}]
|
1496 | 1630 | .emplace_back(&I);
|
1497 | 1631 | }
|
1498 |
| - |
| 1632 | + insertCastsToMergeClasses(Ret); |
1499 | 1633 | mergeEquivalenceClasses(Ret);
|
1500 | 1634 | return Ret;
|
1501 | 1635 | }
|
|
0 commit comments