Skip to content

Commit 3df5fa5

Browse files
committed
[LSV] Refine type definitions
1 parent a1ecd07 commit 3df5fa5

File tree

2 files changed

+38
-40
lines changed

2 files changed

+38
-40
lines changed

llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp

Lines changed: 28 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,11 @@ void reorder(Instruction *I) {
238238
}
239239

240240
class Vectorizer {
241+
242+
enum ClassTyDist {
243+
Int, Float, Ptr, Other
244+
};
245+
241246
Function &F;
242247
AliasAnalysis &AA;
243248
AssumptionCache ∾
@@ -274,6 +279,17 @@ class Vectorizer {
274279
bool runOnEquivalenceClass(const EqClassKey &EqClassKey,
275280
ArrayRef<Instruction *> EqClass);
276281

282+
static int getTypeKind(Instruction *I) {
283+
unsigned ID = I->getType()->getTypeID();
284+
switch(ID) {
285+
case Type::IntegerTyID:
286+
case Type::FloatTyID:
287+
case Type::PointerTyID:
288+
return ID;
289+
};
290+
return -1;
291+
}
292+
277293
/// Runs the vectorizer on one chain, i.e. a subset of an equivalence class
278294
/// where all instructions access a known, constant offset from the first
279295
/// instruction.
@@ -1325,26 +1341,15 @@ void Vectorizer::insertCastsToMergeClasses(EquivalenceClassMap &EQClasses) {
13251341
}
13261342
};
13271343

1328-
// For each class, determine if all instructions are of type int, FP or ptr.
1329-
// This information will help us determine the type instructions should be
1330-
// casted into.
1331-
MapVector<EqClassKey, Bitset<3>> ClassAllTy;
1344+
// For each class, determine the most defined type. This information will
1345+
// help us determine the type instructions should be casted into.
1346+
MapVector<EqClassKey, unsigned> ClassToNewTyID;
13321347
for (const auto &C : EQClasses) {
1333-
auto CommonTypeKind = [](Instruction *I) {
1334-
if (I->getType()->isIntOrIntVectorTy())
1335-
return 0;
1336-
if (I->getType()->isFPOrFPVectorTy())
1337-
return 1;
1338-
if (I->getType()->isPtrOrPtrVectorTy())
1339-
return 2;
1340-
return -1; // Invalid type kind
1341-
};
1342-
1343-
int FirstTypeKind = CommonTypeKind(EQClasses[C.first][0]);
1348+
int FirstTypeKind = getTypeKind(EQClasses[C.first][0]);
13441349
if (FirstTypeKind != -1 && all_of(EQClasses[C.first], [&](Instruction *I) {
1345-
return CommonTypeKind(I) == FirstTypeKind;
1350+
return getTypeKind(I) == FirstTypeKind;
13461351
})) {
1347-
ClassAllTy[C.first].set(FirstTypeKind);
1352+
ClassToNewTyID[C.first] = FirstTypeKind;
13481353
}
13491354
}
13501355

@@ -1369,11 +1374,6 @@ void Vectorizer::insertCastsToMergeClasses(EquivalenceClassMap &EQClasses) {
13691374
if (Ptr1 != Ptr2 || AS1 != AS2 || IsLoad1 != IsLoad2 || TySize1 < TySize2)
13701375
continue;
13711376

1372-
// An All-FP class should only be merged into another All-FP class.
1373-
if ((ClassAllTy[EC1.first].test(1) && !ClassAllTy[EC2.first].test(1)) ||
1374-
(!ClassAllTy[EC1.first].test(2) && ClassAllTy[EC2.first].test(2)))
1375-
continue;
1376-
13771377
// Ensure all instructions in EC2 can be bitcasted into NewTy.
13781378
/// TODO: NewTyBits is needed as stuctured binded variables cannot be
13791379
/// captured by a lambda until C++20.
@@ -1386,19 +1386,15 @@ void Vectorizer::insertCastsToMergeClasses(EquivalenceClassMap &EQClasses) {
13861386
// Create a new type for the equivalence class.
13871387
auto &Ctx = EC2.second[0]->getContext();
13881388
Type *NewTy = Type::getIntNTy(EC2.second[0]->getContext(), NewTyBits);
1389-
if (ClassAllTy[EC1.first].test(1) && ClassAllTy[EC2.first].test(1)) {
1390-
if (NewTyBits == 16)
1391-
NewTy = Type::getHalfTy(Ctx);
1392-
else if (NewTyBits == 32)
1393-
NewTy = Type::getFloatTy(Ctx);
1394-
else if (NewTyBits == 64)
1395-
NewTy = Type::getDoubleTy(Ctx);
1396-
} else if (ClassAllTy[EC1.first].test(2) &&
1397-
ClassAllTy[EC2.first].test(2)) {
1389+
if (ClassToNewTyID[EC1.first] == Type::FloatTyID &&
1390+
ClassToNewTyID[EC2.first] == Type::FloatTyID) {
1391+
NewTy = Type::getFloatTy(Ctx);
1392+
} else if (ClassToNewTyID[EC1.first] == Type::PointerTyID &&
1393+
ClassToNewTyID[EC2.first] == Type::PointerTyID) {
13981394
NewTy = PointerType::get(Ctx, AS2);
13991395
}
14001396

1401-
for (auto *Inst : EC2.second) {
1397+
for (Instruction *Inst : EC2.second) {
14021398
Value *Ptr = getLoadStorePointerOperand(Inst);
14031399
Type *OrigTy = Inst->getType();
14041400
if (OrigTy == NewTy)

llvm/test/Transforms/LoadStoreVectorizer/AMDGPU/merge-vectors.ll

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -228,10 +228,11 @@ define void @merge_fp_v2half_type(ptr addrspace(1) %ptr1, ptr addrspace(2) %ptr2
228228
; CHECK-OOB-RELAXED-LABEL: define void @merge_fp_v2half_type(
229229
; CHECK-OOB-RELAXED-SAME: ptr addrspace(1) [[PTR1:%.*]], ptr addrspace(2) [[PTR2:%.*]]) #[[ATTR1]] {
230230
; CHECK-OOB-RELAXED-NEXT: [[GEP1:%.*]] = getelementptr inbounds float, ptr addrspace(1) [[PTR1]], i64 0
231-
; CHECK-OOB-RELAXED-NEXT: [[TMP1:%.*]] = load <2 x float>, ptr addrspace(1) [[GEP1]], align 4
232-
; CHECK-OOB-RELAXED-NEXT: [[LOAD11:%.*]] = extractelement <2 x float> [[TMP1]], i32 0
233-
; CHECK-OOB-RELAXED-NEXT: [[TMP2:%.*]] = extractelement <2 x float> [[TMP1]], i32 1
234-
; CHECK-OOB-RELAXED-NEXT: [[DOTCAST:%.*]] = bitcast float [[TMP2]] to <2 x half>
231+
; CHECK-OOB-RELAXED-NEXT: [[TMP1:%.*]] = load <2 x i32>, ptr addrspace(1) [[GEP1]], align 4
232+
; CHECK-OOB-RELAXED-NEXT: [[LOAD12:%.*]] = extractelement <2 x i32> [[TMP1]], i32 0
233+
; CHECK-OOB-RELAXED-NEXT: [[LOAD11:%.*]] = bitcast i32 [[LOAD12]] to float
234+
; CHECK-OOB-RELAXED-NEXT: [[TMP6:%.*]] = extractelement <2 x i32> [[TMP1]], i32 1
235+
; CHECK-OOB-RELAXED-NEXT: [[DOTCAST:%.*]] = bitcast i32 [[TMP6]] to <2 x half>
235236
; CHECK-OOB-RELAXED-NEXT: [[STORE_GEP1:%.*]] = getelementptr inbounds i32, ptr addrspace(2) [[PTR2]], i64 0
236237
; CHECK-OOB-RELAXED-NEXT: [[DOTCAST_CAST:%.*]] = bitcast <2 x half> [[DOTCAST]] to i32
237238
; CHECK-OOB-RELAXED-NEXT: [[TMP3:%.*]] = bitcast float [[LOAD11]] to i32
@@ -243,10 +244,11 @@ define void @merge_fp_v2half_type(ptr addrspace(1) %ptr1, ptr addrspace(2) %ptr2
243244
; CHECK-OOB-STRICT-LABEL: define void @merge_fp_v2half_type(
244245
; CHECK-OOB-STRICT-SAME: ptr addrspace(1) [[PTR1:%.*]], ptr addrspace(2) [[PTR2:%.*]]) {
245246
; CHECK-OOB-STRICT-NEXT: [[GEP1:%.*]] = getelementptr inbounds float, ptr addrspace(1) [[PTR1]], i64 0
246-
; CHECK-OOB-STRICT-NEXT: [[TMP1:%.*]] = load <2 x float>, ptr addrspace(1) [[GEP1]], align 4
247-
; CHECK-OOB-STRICT-NEXT: [[LOAD11:%.*]] = extractelement <2 x float> [[TMP1]], i32 0
248-
; CHECK-OOB-STRICT-NEXT: [[TMP2:%.*]] = extractelement <2 x float> [[TMP1]], i32 1
249-
; CHECK-OOB-STRICT-NEXT: [[DOTCAST:%.*]] = bitcast float [[TMP2]] to <2 x half>
247+
; CHECK-OOB-STRICT-NEXT: [[TMP1:%.*]] = load <2 x i32>, ptr addrspace(1) [[GEP1]], align 4
248+
; CHECK-OOB-STRICT-NEXT: [[LOAD12:%.*]] = extractelement <2 x i32> [[TMP1]], i32 0
249+
; CHECK-OOB-STRICT-NEXT: [[LOAD11:%.*]] = bitcast i32 [[LOAD12]] to float
250+
; CHECK-OOB-STRICT-NEXT: [[TMP6:%.*]] = extractelement <2 x i32> [[TMP1]], i32 1
251+
; CHECK-OOB-STRICT-NEXT: [[DOTCAST:%.*]] = bitcast i32 [[TMP6]] to <2 x half>
250252
; CHECK-OOB-STRICT-NEXT: [[STORE_GEP1:%.*]] = getelementptr inbounds i32, ptr addrspace(2) [[PTR2]], i64 0
251253
; CHECK-OOB-STRICT-NEXT: [[DOTCAST_CAST:%.*]] = bitcast <2 x half> [[DOTCAST]] to i32
252254
; CHECK-OOB-STRICT-NEXT: [[TMP3:%.*]] = bitcast float [[LOAD11]] to i32

0 commit comments

Comments
 (0)