Skip to content

Commit c9c407b

Browse files
committed
[LSV] Refine type definitions
1 parent f32cd6d commit c9c407b

File tree

2 files changed

+38
-40
lines changed

2 files changed

+38
-40
lines changed

llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp

Lines changed: 28 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,11 @@ void reorder(Instruction *I) {
238238
}
239239

240240
class Vectorizer {
241+
242+
enum ClassTyDist {
243+
Int, Float, Ptr, Other
244+
};
245+
241246
Function &F;
242247
AliasAnalysis &AA;
243248
AssumptionCache ∾
@@ -274,6 +279,17 @@ class Vectorizer {
274279
bool runOnEquivalenceClass(const EqClassKey &EqClassKey,
275280
ArrayRef<Instruction *> EqClass);
276281

282+
static int getTypeKind(Instruction *I) {
283+
unsigned ID = I->getType()->getTypeID();
284+
switch(ID) {
285+
case Type::IntegerTyID:
286+
case Type::FloatTyID:
287+
case Type::PointerTyID:
288+
return ID;
289+
};
290+
return -1;
291+
}
292+
277293
/// Runs the vectorizer on one chain, i.e. a subset of an equivalence class
278294
/// where all instructions access a known, constant offset from the first
279295
/// instruction.
@@ -1327,26 +1343,15 @@ void Vectorizer::insertCastsToMergeClasses(EquivalenceClassMap &EQClasses) {
13271343
}
13281344
};
13291345

1330-
// For each class, determine if all instructions are of type int, FP or ptr.
1331-
// This information will help us determine the type instructions should be
1332-
// casted into.
1333-
MapVector<EqClassKey, Bitset<3>> ClassAllTy;
1346+
// For each class, determine the most defined type. This information will
1347+
// help us determine the type instructions should be casted into.
1348+
MapVector<EqClassKey, unsigned> ClassToNewTyID;
13341349
for (const auto &C : EQClasses) {
1335-
auto CommonTypeKind = [](Instruction *I) {
1336-
if (I->getType()->isIntOrIntVectorTy())
1337-
return 0;
1338-
if (I->getType()->isFPOrFPVectorTy())
1339-
return 1;
1340-
if (I->getType()->isPtrOrPtrVectorTy())
1341-
return 2;
1342-
return -1; // Invalid type kind
1343-
};
1344-
1345-
int FirstTypeKind = CommonTypeKind(EQClasses[C.first][0]);
1350+
int FirstTypeKind = getTypeKind(EQClasses[C.first][0]);
13461351
if (FirstTypeKind != -1 && all_of(EQClasses[C.first], [&](Instruction *I) {
1347-
return CommonTypeKind(I) == FirstTypeKind;
1352+
return getTypeKind(I) == FirstTypeKind;
13481353
})) {
1349-
ClassAllTy[C.first].set(FirstTypeKind);
1354+
ClassToNewTyID[C.first] = FirstTypeKind;
13501355
}
13511356
}
13521357

@@ -1371,11 +1376,6 @@ void Vectorizer::insertCastsToMergeClasses(EquivalenceClassMap &EQClasses) {
13711376
if (Ptr1 != Ptr2 || AS1 != AS2 || IsLoad1 != IsLoad2 || TySize1 < TySize2)
13721377
continue;
13731378

1374-
// An All-FP class should only be merged into another All-FP class.
1375-
if ((ClassAllTy[EC1.first].test(1) && !ClassAllTy[EC2.first].test(1)) ||
1376-
(!ClassAllTy[EC1.first].test(2) && ClassAllTy[EC2.first].test(2)))
1377-
continue;
1378-
13791379
// Ensure all instructions in EC2 can be bitcasted into NewTy.
13801380
/// TODO: NewTyBits is needed as stuctured binded variables cannot be
13811381
/// captured by a lambda until C++20.
@@ -1388,19 +1388,15 @@ void Vectorizer::insertCastsToMergeClasses(EquivalenceClassMap &EQClasses) {
13881388
// Create a new type for the equivalence class.
13891389
auto &Ctx = EC2.second[0]->getContext();
13901390
Type *NewTy = Type::getIntNTy(EC2.second[0]->getContext(), NewTyBits);
1391-
if (ClassAllTy[EC1.first].test(1) && ClassAllTy[EC2.first].test(1)) {
1392-
if (NewTyBits == 16)
1393-
NewTy = Type::getHalfTy(Ctx);
1394-
else if (NewTyBits == 32)
1395-
NewTy = Type::getFloatTy(Ctx);
1396-
else if (NewTyBits == 64)
1397-
NewTy = Type::getDoubleTy(Ctx);
1398-
} else if (ClassAllTy[EC1.first].test(2) &&
1399-
ClassAllTy[EC2.first].test(2)) {
1391+
if (ClassToNewTyID[EC1.first] == Type::FloatTyID &&
1392+
ClassToNewTyID[EC2.first] == Type::FloatTyID) {
1393+
NewTy = Type::getFloatTy(Ctx);
1394+
} else if (ClassToNewTyID[EC1.first] == Type::PointerTyID &&
1395+
ClassToNewTyID[EC2.first] == Type::PointerTyID) {
14001396
NewTy = PointerType::get(Ctx, AS2);
14011397
}
14021398

1403-
for (auto *Inst : EC2.second) {
1399+
for (Instruction *Inst : EC2.second) {
14041400
Value *Ptr = getLoadStorePointerOperand(Inst);
14051401
Type *OrigTy = Inst->getType();
14061402
if (OrigTy == NewTy)

llvm/test/Transforms/LoadStoreVectorizer/AMDGPU/merge-vectors.ll

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -229,10 +229,11 @@ define void @merge_fp_v2half_type(ptr addrspace(1) %ptr1, ptr addrspace(2) %ptr2
229229
; CHECK-OOB-RELAXED-LABEL: define void @merge_fp_v2half_type(
230230
; CHECK-OOB-RELAXED-SAME: ptr addrspace(1) [[PTR1:%.*]], ptr addrspace(2) [[PTR2:%.*]]) #[[ATTR1]] {
231231
; CHECK-OOB-RELAXED-NEXT: [[GEP1:%.*]] = getelementptr inbounds float, ptr addrspace(1) [[PTR1]], i64 0
232-
; CHECK-OOB-RELAXED-NEXT: [[TMP1:%.*]] = load <2 x float>, ptr addrspace(1) [[GEP1]], align 4
233-
; CHECK-OOB-RELAXED-NEXT: [[LOAD11:%.*]] = extractelement <2 x float> [[TMP1]], i32 0
234-
; CHECK-OOB-RELAXED-NEXT: [[TMP2:%.*]] = extractelement <2 x float> [[TMP1]], i32 1
235-
; CHECK-OOB-RELAXED-NEXT: [[DOTCAST:%.*]] = bitcast float [[TMP2]] to <2 x half>
232+
; CHECK-OOB-RELAXED-NEXT: [[TMP1:%.*]] = load <2 x i32>, ptr addrspace(1) [[GEP1]], align 4
233+
; CHECK-OOB-RELAXED-NEXT: [[LOAD12:%.*]] = extractelement <2 x i32> [[TMP1]], i32 0
234+
; CHECK-OOB-RELAXED-NEXT: [[LOAD11:%.*]] = bitcast i32 [[LOAD12]] to float
235+
; CHECK-OOB-RELAXED-NEXT: [[TMP6:%.*]] = extractelement <2 x i32> [[TMP1]], i32 1
236+
; CHECK-OOB-RELAXED-NEXT: [[DOTCAST:%.*]] = bitcast i32 [[TMP6]] to <2 x half>
236237
; CHECK-OOB-RELAXED-NEXT: [[STORE_GEP1:%.*]] = getelementptr inbounds i32, ptr addrspace(2) [[PTR2]], i64 0
237238
; CHECK-OOB-RELAXED-NEXT: [[DOTCAST_CAST:%.*]] = bitcast <2 x half> [[DOTCAST]] to i32
238239
; CHECK-OOB-RELAXED-NEXT: [[TMP3:%.*]] = bitcast float [[LOAD11]] to i32
@@ -244,10 +245,11 @@ define void @merge_fp_v2half_type(ptr addrspace(1) %ptr1, ptr addrspace(2) %ptr2
244245
; CHECK-OOB-STRICT-LABEL: define void @merge_fp_v2half_type(
245246
; CHECK-OOB-STRICT-SAME: ptr addrspace(1) [[PTR1:%.*]], ptr addrspace(2) [[PTR2:%.*]]) {
246247
; CHECK-OOB-STRICT-NEXT: [[GEP1:%.*]] = getelementptr inbounds float, ptr addrspace(1) [[PTR1]], i64 0
247-
; CHECK-OOB-STRICT-NEXT: [[TMP1:%.*]] = load <2 x float>, ptr addrspace(1) [[GEP1]], align 4
248-
; CHECK-OOB-STRICT-NEXT: [[LOAD11:%.*]] = extractelement <2 x float> [[TMP1]], i32 0
249-
; CHECK-OOB-STRICT-NEXT: [[TMP2:%.*]] = extractelement <2 x float> [[TMP1]], i32 1
250-
; CHECK-OOB-STRICT-NEXT: [[DOTCAST:%.*]] = bitcast float [[TMP2]] to <2 x half>
248+
; CHECK-OOB-STRICT-NEXT: [[TMP1:%.*]] = load <2 x i32>, ptr addrspace(1) [[GEP1]], align 4
249+
; CHECK-OOB-STRICT-NEXT: [[LOAD12:%.*]] = extractelement <2 x i32> [[TMP1]], i32 0
250+
; CHECK-OOB-STRICT-NEXT: [[LOAD11:%.*]] = bitcast i32 [[LOAD12]] to float
251+
; CHECK-OOB-STRICT-NEXT: [[TMP6:%.*]] = extractelement <2 x i32> [[TMP1]], i32 1
252+
; CHECK-OOB-STRICT-NEXT: [[DOTCAST:%.*]] = bitcast i32 [[TMP6]] to <2 x half>
251253
; CHECK-OOB-STRICT-NEXT: [[STORE_GEP1:%.*]] = getelementptr inbounds i32, ptr addrspace(2) [[PTR2]], i64 0
252254
; CHECK-OOB-STRICT-NEXT: [[DOTCAST_CAST:%.*]] = bitcast <2 x half> [[DOTCAST]] to i32
253255
; CHECK-OOB-STRICT-NEXT: [[TMP3:%.*]] = bitcast float [[LOAD11]] to i32

0 commit comments

Comments
 (0)