|
60 | 60 | #include "llvm/Transforms/Vectorize/LoadStoreVectorizer.h"
|
61 | 61 | #include "llvm/ADT/APInt.h"
|
62 | 62 | #include "llvm/ADT/ArrayRef.h"
|
| 63 | +#include "llvm/ADT/Bitset.h" |
63 | 64 | #include "llvm/ADT/DenseMap.h"
|
64 | 65 | #include "llvm/ADT/MapVector.h"
|
65 | 66 | #include "llvm/ADT/PostOrderIterator.h"
|
@@ -1318,6 +1319,28 @@ void Vectorizer::insertCastsToMergeClasses(EquivalenceClassMap &EQClasses) {
|
1318 | 1319 | if (EQClasses.size() < 2)
|
1319 | 1320 | return;
|
1320 | 1321 |
|
| 1322 | + // For each class, determine if all instructions are of type int, FP or ptr. |
| 1323 | + // This information will help us determine the type instructions should be |
| 1324 | + // casted into. |
| 1325 | + MapVector<EqClassKey, Bitset<3>> ClassAllTy; |
| 1326 | + for (auto C : EQClasses) { |
| 1327 | + if (all_of(EQClasses[C.first], |
| 1328 | + [](Instruction *I) { |
| 1329 | + return I->getType()->isIntOrIntVectorTy(); |
| 1330 | + })) |
| 1331 | + ClassAllTy[C.first].set(0); |
| 1332 | + else if (all_of(EQClasses[C.first], |
| 1333 | + [](Instruction *I) { |
| 1334 | + return I->getType()->isFPOrFPVectorTy(); |
| 1335 | + })) |
| 1336 | + ClassAllTy[C.first].set(1); |
| 1337 | + else if (all_of(EQClasses[C.first], |
| 1338 | + [](Instruction *I) { |
| 1339 | + return I->getType()->isPtrOrPtrVectorTy(); |
| 1340 | + })) |
| 1341 | + ClassAllTy[C.first].set(2); |
| 1342 | + } |
| 1343 | + |
1321 | 1344 | // Loop over all equivalence classes and try to merge them. Keep track of
|
1322 | 1345 | // classes that are merged into others.
|
1323 | 1346 | DenseSet<EqClassKey> ClassesToErase;
|
@@ -1346,8 +1369,19 @@ void Vectorizer::insertCastsToMergeClasses(EquivalenceClassMap &EQClasses) {
|
1346 | 1369 | continue;
|
1347 | 1370 |
|
1348 | 1371 | // Create a new type for the equivalence class.
|
1349 |
| - /// TODO: NewTy should be an FP type for an all-FP equivalence class. |
1350 |
| - auto *NewTy = Type::getIntNTy(EC2.second[0]->getContext(), NewTyBits); |
| 1372 | + auto &Ctx = EC2.second[0]->getContext(); |
| 1373 | + Type *NewTy = Type::getIntNTy(EC2.second[0]->getContext(), NewTyBits); |
| 1374 | + if (ClassAllTy[EC1.first].test(1) && ClassAllTy[EC2.first].test(1)) { |
| 1375 | + if (NewTyBits == 16) |
| 1376 | + NewTy = Type::getHalfTy(Ctx); |
| 1377 | + else if (NewTyBits == 32) |
| 1378 | + NewTy = Type::getFloatTy(Ctx); |
| 1379 | + else if (NewTyBits == 64) |
| 1380 | + NewTy = Type::getDoubleTy(Ctx); |
| 1381 | + } else if (ClassAllTy[EC1.first].test(2) && ClassAllTy[EC2.first].test(2)) { |
| 1382 | + NewTy = PointerType::get(Ctx, AS2); |
| 1383 | + } |
| 1384 | + |
1351 | 1385 | for (auto *Inst : EC2.second) {
|
1352 | 1386 | auto *Ptr = getLoadStorePointerOperand(Inst);
|
1353 | 1387 | auto *OrigTy = Inst->getType();
|
|
0 commit comments