@@ -238,6 +238,11 @@ void reorder(Instruction *I) {
238
238
}
239
239
240
240
class Vectorizer {
241
+
242
+ enum ClassTyDist {
243
+ Int, Float, Ptr, Other
244
+ };
245
+
241
246
Function &F;
242
247
AliasAnalysis &AA;
243
248
AssumptionCache ∾
@@ -274,6 +279,17 @@ class Vectorizer {
274
279
bool runOnEquivalenceClass (const EqClassKey &EqClassKey,
275
280
ArrayRef<Instruction *> EqClass);
276
281
282
+ static int getTypeKind (Instruction *I) {
283
+ unsigned ID = I->getType ()->getTypeID ();
284
+ switch (ID) {
285
+ case Type::IntegerTyID:
286
+ case Type::FloatTyID:
287
+ case Type::PointerTyID:
288
+ return ID;
289
+ };
290
+ return -1 ;
291
+ }
292
+
277
293
// / Runs the vectorizer on one chain, i.e. a subset of an equivalence class
278
294
// / where all instructions access a known, constant offset from the first
279
295
// / instruction.
@@ -1327,26 +1343,15 @@ void Vectorizer::insertCastsToMergeClasses(EquivalenceClassMap &EQClasses) {
1327
1343
}
1328
1344
};
1329
1345
1330
- // For each class, determine if all instructions are of type int, FP or ptr.
1331
- // This information will help us determine the type instructions should be
1332
- // casted into.
1333
- MapVector<EqClassKey, Bitset<3 >> ClassAllTy;
1346
+ // For each class, determine the most defined type. This information will
1347
+ // help us determine the type instructions should be casted into.
1348
+ MapVector<EqClassKey, unsigned > ClassToNewTyID;
1334
1349
for (const auto &C : EQClasses) {
1335
- auto CommonTypeKind = [](Instruction *I) {
1336
- if (I->getType ()->isIntOrIntVectorTy ())
1337
- return 0 ;
1338
- if (I->getType ()->isFPOrFPVectorTy ())
1339
- return 1 ;
1340
- if (I->getType ()->isPtrOrPtrVectorTy ())
1341
- return 2 ;
1342
- return -1 ; // Invalid type kind
1343
- };
1344
-
1345
- int FirstTypeKind = CommonTypeKind (EQClasses[C.first ][0 ]);
1350
+ int FirstTypeKind = getTypeKind (EQClasses[C.first ][0 ]);
1346
1351
if (FirstTypeKind != -1 && all_of (EQClasses[C.first ], [&](Instruction *I) {
1347
- return CommonTypeKind (I) == FirstTypeKind;
1352
+ return getTypeKind (I) == FirstTypeKind;
1348
1353
})) {
1349
- ClassAllTy [C.first ]. set ( FirstTypeKind) ;
1354
+ ClassToNewTyID [C.first ] = FirstTypeKind;
1350
1355
}
1351
1356
}
1352
1357
@@ -1371,11 +1376,6 @@ void Vectorizer::insertCastsToMergeClasses(EquivalenceClassMap &EQClasses) {
1371
1376
if (Ptr1 != Ptr2 || AS1 != AS2 || IsLoad1 != IsLoad2 || TySize1 < TySize2)
1372
1377
continue ;
1373
1378
1374
- // An All-FP class should only be merged into another All-FP class.
1375
- if ((ClassAllTy[EC1.first ].test (1 ) && !ClassAllTy[EC2.first ].test (1 )) ||
1376
- (!ClassAllTy[EC1.first ].test (2 ) && ClassAllTy[EC2.first ].test (2 )))
1377
- continue ;
1378
-
1379
1379
// Ensure all instructions in EC2 can be bitcasted into NewTy.
1380
1380
// / TODO: NewTyBits is needed as stuctured binded variables cannot be
1381
1381
// / captured by a lambda until C++20.
@@ -1388,19 +1388,15 @@ void Vectorizer::insertCastsToMergeClasses(EquivalenceClassMap &EQClasses) {
1388
1388
// Create a new type for the equivalence class.
1389
1389
auto &Ctx = EC2.second [0 ]->getContext ();
1390
1390
Type *NewTy = Type::getIntNTy (EC2.second [0 ]->getContext (), NewTyBits);
1391
- if (ClassAllTy[EC1.first ].test (1 ) && ClassAllTy[EC2.first ].test (1 )) {
1392
- if (NewTyBits == 16 )
1393
- NewTy = Type::getHalfTy (Ctx);
1394
- else if (NewTyBits == 32 )
1395
- NewTy = Type::getFloatTy (Ctx);
1396
- else if (NewTyBits == 64 )
1397
- NewTy = Type::getDoubleTy (Ctx);
1398
- } else if (ClassAllTy[EC1.first ].test (2 ) &&
1399
- ClassAllTy[EC2.first ].test (2 )) {
1391
+ if (ClassToNewTyID[EC1.first ] == Type::FloatTyID &&
1392
+ ClassToNewTyID[EC2.first ] == Type::FloatTyID) {
1393
+ NewTy = Type::getFloatTy (Ctx);
1394
+ } else if (ClassToNewTyID[EC1.first ] == Type::PointerTyID &&
1395
+ ClassToNewTyID[EC2.first ] == Type::PointerTyID) {
1400
1396
NewTy = PointerType::get (Ctx, AS2);
1401
1397
}
1402
1398
1403
- for (auto *Inst : EC2.second ) {
1399
+ for (Instruction *Inst : EC2.second ) {
1404
1400
Value *Ptr = getLoadStorePointerOperand (Inst);
1405
1401
Type *OrigTy = Inst->getType ();
1406
1402
if (OrigTy == NewTy)
0 commit comments