@@ -238,6 +238,11 @@ void reorder(Instruction *I) {
238
238
}
239
239
240
240
class Vectorizer {
241
+
242
+ enum ClassTyDist {
243
+ Int, Float, Ptr, Other
244
+ };
245
+
241
246
Function &F;
242
247
AliasAnalysis &AA;
243
248
AssumptionCache ∾
@@ -274,6 +279,17 @@ class Vectorizer {
274
279
bool runOnEquivalenceClass (const EqClassKey &EqClassKey,
275
280
ArrayRef<Instruction *> EqClass);
276
281
282
+ static int getTypeKind (Instruction *I) {
283
+ unsigned ID = I->getType ()->getTypeID ();
284
+ switch (ID) {
285
+ case Type::IntegerTyID:
286
+ case Type::FloatTyID:
287
+ case Type::PointerTyID:
288
+ return ID;
289
+ };
290
+ return -1 ;
291
+ }
292
+
277
293
// / Runs the vectorizer on one chain, i.e. a subset of an equivalence class
278
294
// / where all instructions access a known, constant offset from the first
279
295
// / instruction.
@@ -1325,26 +1341,15 @@ void Vectorizer::insertCastsToMergeClasses(EquivalenceClassMap &EQClasses) {
1325
1341
}
1326
1342
};
1327
1343
1328
- // For each class, determine if all instructions are of type int, FP or ptr.
1329
- // This information will help us determine the type instructions should be
1330
- // casted into.
1331
- MapVector<EqClassKey, Bitset<3 >> ClassAllTy;
1344
+ // For each class, determine the most defined type. This information will
1345
+ // help us determine the type instructions should be casted into.
1346
+ MapVector<EqClassKey, unsigned > ClassToNewTyID;
1332
1347
for (const auto &C : EQClasses) {
1333
- auto CommonTypeKind = [](Instruction *I) {
1334
- if (I->getType ()->isIntOrIntVectorTy ())
1335
- return 0 ;
1336
- if (I->getType ()->isFPOrFPVectorTy ())
1337
- return 1 ;
1338
- if (I->getType ()->isPtrOrPtrVectorTy ())
1339
- return 2 ;
1340
- return -1 ; // Invalid type kind
1341
- };
1342
-
1343
- int FirstTypeKind = CommonTypeKind (EQClasses[C.first ][0 ]);
1348
+ int FirstTypeKind = getTypeKind (EQClasses[C.first ][0 ]);
1344
1349
if (FirstTypeKind != -1 && all_of (EQClasses[C.first ], [&](Instruction *I) {
1345
- return CommonTypeKind (I) == FirstTypeKind;
1350
+ return getTypeKind (I) == FirstTypeKind;
1346
1351
})) {
1347
- ClassAllTy [C.first ]. set ( FirstTypeKind) ;
1352
+ ClassToNewTyID [C.first ] = FirstTypeKind;
1348
1353
}
1349
1354
}
1350
1355
@@ -1369,11 +1374,6 @@ void Vectorizer::insertCastsToMergeClasses(EquivalenceClassMap &EQClasses) {
1369
1374
if (Ptr1 != Ptr2 || AS1 != AS2 || IsLoad1 != IsLoad2 || TySize1 < TySize2)
1370
1375
continue ;
1371
1376
1372
- // An All-FP class should only be merged into another All-FP class.
1373
- if ((ClassAllTy[EC1.first ].test (1 ) && !ClassAllTy[EC2.first ].test (1 )) ||
1374
- (!ClassAllTy[EC1.first ].test (2 ) && ClassAllTy[EC2.first ].test (2 )))
1375
- continue ;
1376
-
1377
1377
// Ensure all instructions in EC2 can be bitcasted into NewTy.
1378
1378
// / TODO: NewTyBits is needed as stuctured binded variables cannot be
1379
1379
// / captured by a lambda until C++20.
@@ -1386,19 +1386,15 @@ void Vectorizer::insertCastsToMergeClasses(EquivalenceClassMap &EQClasses) {
1386
1386
// Create a new type for the equivalence class.
1387
1387
auto &Ctx = EC2.second [0 ]->getContext ();
1388
1388
Type *NewTy = Type::getIntNTy (EC2.second [0 ]->getContext (), NewTyBits);
1389
- if (ClassAllTy[EC1.first ].test (1 ) && ClassAllTy[EC2.first ].test (1 )) {
1390
- if (NewTyBits == 16 )
1391
- NewTy = Type::getHalfTy (Ctx);
1392
- else if (NewTyBits == 32 )
1393
- NewTy = Type::getFloatTy (Ctx);
1394
- else if (NewTyBits == 64 )
1395
- NewTy = Type::getDoubleTy (Ctx);
1396
- } else if (ClassAllTy[EC1.first ].test (2 ) &&
1397
- ClassAllTy[EC2.first ].test (2 )) {
1389
+ if (ClassToNewTyID[EC1.first ] == Type::FloatTyID &&
1390
+ ClassToNewTyID[EC2.first ] == Type::FloatTyID) {
1391
+ NewTy = Type::getFloatTy (Ctx);
1392
+ } else if (ClassToNewTyID[EC1.first ] == Type::PointerTyID &&
1393
+ ClassToNewTyID[EC2.first ] == Type::PointerTyID) {
1398
1394
NewTy = PointerType::get (Ctx, AS2);
1399
1395
}
1400
1396
1401
- for (auto *Inst : EC2.second ) {
1397
+ for (Instruction *Inst : EC2.second ) {
1402
1398
Value *Ptr = getLoadStorePointerOperand (Inst);
1403
1399
Type *OrigTy = Inst->getType ();
1404
1400
if (OrigTy == NewTy)
0 commit comments