Skip to content

Commit 5a728c4

Browse files
committed
Add tests for structs, arrays, scalable vectors
1 parent 6eaffa8 commit 5a728c4

File tree

4 files changed

+1065
-413
lines changed

4 files changed

+1065
-413
lines changed

llvm/lib/Target/AMDGPU/AMDGPULowerBufferFatPointers.cpp

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -732,8 +732,8 @@ Type *LegalizeBufferContentTypesVisitor::legalNonAggregateFor(Type *T) {
732732
ElemTy = VT->getElementType();
733733
}
734734
if (isa<PointerType, ScalableVectorType>(ElemTy))
735-
// Pointers are always big enough, and scalable vectors shouldn't crash the
736-
// pass.
735+
// Pointers are always big enough, and we'll let scalable vectors through to
736+
// fail in codegen.
737737
return T;
738738
unsigned ElemSize = DL.getTypeSizeInBits(ElemTy).getFixedValue();
739739
if (isPowerOf2_32(ElemSize) && ElemSize >= 16 && ElemSize <= 128) {
@@ -855,7 +855,10 @@ void LegalizeBufferContentTypesVisitor::getVecSlices(
855855

856856
Value *LegalizeBufferContentTypesVisitor::extractSlice(Value *Vec, VecSlice S,
857857
const Twine &Name) {
858-
if (!isa<FixedVectorType>(Vec->getType()))
858+
auto *VecVT = dyn_cast<FixedVectorType>(Vec->getType());
859+
if (!VecVT)
860+
return Vec;
861+
if (S.Length == VecVT->getNumElements() && S.Index == 0)
859862
return Vec;
860863
if (S.Length == 1)
861864
return IRB.CreateExtractElement(Vec, S.Index,
@@ -868,7 +871,10 @@ Value *LegalizeBufferContentTypesVisitor::extractSlice(Value *Vec, VecSlice S,
868871
Value *LegalizeBufferContentTypesVisitor::insertSlice(Value *Whole, Value *Part,
869872
VecSlice S,
870873
const Twine &Name) {
871-
if (!isa<FixedVectorType>(Whole->getType()))
874+
auto *WholeVT = dyn_cast<FixedVectorType>(Whole->getType());
875+
if (!WholeVT)
876+
return Part;
877+
if (S.Length == WholeVT->getNumElements() && S.Index == 0)
872878
return Part;
873879
if (S.Length == 1) {
874880
return IRB.CreateInsertElement(Whole, Part, S.Index,
@@ -904,23 +910,24 @@ bool LegalizeBufferContentTypesVisitor::visitLoadImpl(
904910
llvm::enumerate(ST->elements(), Layout->getMemberOffsets())) {
905911
AggIdxs.push_back(I);
906912
Changed |= visitLoadImpl(OrigLI, ElemTy, AggIdxs,
907-
AggByteOff + Offset.getKnownMinValue(), Result,
913+
AggByteOff + Offset.getFixedValue(), Result,
908914
Name + "." + Twine(I));
909915
AggIdxs.pop_back();
910916
}
911917
return Changed;
912918
}
913919
if (auto *AT = dyn_cast<ArrayType>(PartType)) {
914920
Type *ElemTy = AT->getElementType();
915-
TypeSize AllocSize = DL.getTypeAllocSizeInBits(ElemTy);
921+
TypeSize AllocSize = DL.getTypeAllocSize(ElemTy);
916922
if (!(ElemTy->isSingleValueType() &&
917-
DL.getTypeSizeInBits(ElemTy) == AllocSize && !ElemTy->isVectorTy())) {
923+
DL.getTypeSizeInBits(ElemTy) == 8 * AllocSize &&
924+
!ElemTy->isVectorTy())) {
918925
bool Changed = false;
919926
for (auto I : llvm::iota_range<uint32_t>(0, AT->getNumElements(),
920927
/*Inclusive=*/false)) {
921928
AggIdxs.push_back(I);
922929
Changed |= visitLoadImpl(OrigLI, ElemTy, AggIdxs,
923-
AggByteOff + I * AllocSize.getKnownMinValue(),
930+
AggByteOff + I * AllocSize.getFixedValue(),
924931
Result, Name + Twine(I));
925932
AggIdxs.pop_back();
926933
}
@@ -1027,25 +1034,26 @@ std::pair<bool, bool> LegalizeBufferContentTypesVisitor::visitStoreImpl(
10271034
for (auto [I, ElemTy, Offset] :
10281035
llvm::enumerate(ST->elements(), Layout->getMemberOffsets())) {
10291036
AggIdxs.push_back(I);
1030-
Changed |= std::get<0>(visitStoreImpl(
1031-
OrigSI, ElemTy, AggIdxs, AggByteOff + Offset.getKnownMinValue(),
1032-
Name + "." + Twine(I)));
1037+
Changed |= std::get<0>(visitStoreImpl(OrigSI, ElemTy, AggIdxs,
1038+
AggByteOff + Offset.getFixedValue(),
1039+
Name + "." + Twine(I)));
10331040
AggIdxs.pop_back();
10341041
}
10351042
return std::make_pair(Changed, /*ModifiedInPlace=*/false);
10361043
}
10371044
if (auto *AT = dyn_cast<ArrayType>(PartType)) {
10381045
Type *ElemTy = AT->getElementType();
1039-
TypeSize AllocSize = DL.getTypeAllocSizeInBits(ElemTy);
1046+
TypeSize AllocSize = DL.getTypeAllocSize(ElemTy);
10401047
if (!(ElemTy->isSingleValueType() &&
1041-
DL.getTypeSizeInBits(ElemTy) == AllocSize && !ElemTy->isVectorTy())) {
1048+
DL.getTypeSizeInBits(ElemTy) == 8 * AllocSize &&
1049+
!ElemTy->isVectorTy())) {
10421050
bool Changed = false;
10431051
for (auto I : llvm::iota_range<uint32_t>(0, AT->getNumElements(),
10441052
/*Inclusive=*/false)) {
10451053
AggIdxs.push_back(I);
10461054
Changed |= std::get<0>(visitStoreImpl(
1047-
OrigSI, ElemTy, AggIdxs,
1048-
AggByteOff + I * AllocSize.getKnownMinValue(), Name + Twine(I)));
1055+
OrigSI, ElemTy, AggIdxs, AggByteOff + I * AllocSize.getFixedValue(),
1056+
Name + Twine(I)));
10491057
AggIdxs.pop_back();
10501058
}
10511059
return std::make_pair(Changed, /*ModifiedInPlace=*/false);

0 commit comments

Comments
 (0)