Skip to content

Commit 8f73cc4

Browse files
committed
Added VGPR_16 to GISEL register bank, support uaddsat/usubsat gisel
1 parent b11e1ba commit 8f73cc4

File tree

14 files changed

+982
-491
lines changed

14 files changed

+982
-491
lines changed

llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -782,9 +782,22 @@ bool AMDGPUInstructionSelector::selectG_BUILD_VECTOR(MachineInstr &MI) const {
782782
return true;
783783

784784
// TODO: This should probably be a combine somewhere
785-
// (build_vector $src0, undef) -> copy $src0
786785
MachineInstr *Src1Def = getDefIgnoringCopies(Src1, *MRI);
787786
if (Src1Def->getOpcode() == AMDGPU::G_IMPLICIT_DEF) {
787+
if (Subtarget->useRealTrue16Insts() && IsVector) {
788+
// (vecTy (DivergentBinFrag<build_vector> Ty:$src0, (Ty undef))),
789+
// -> (vecTy (INSERT_SUBREG (IMPLICIT_DEF), VGPR_16:$src0, lo16))
790+
Register Undef = MRI->createVirtualRegister(&AMDGPU::VGPR_32RegClass);
791+
BuildMI(*BB, &MI, DL, TII.get(AMDGPU::IMPLICIT_DEF), Undef);
792+
BuildMI(*BB, &MI, DL, TII.get(TargetOpcode::INSERT_SUBREG), Dst)
793+
.addReg(Undef)
794+
.addReg(Src0)
795+
.addImm(AMDGPU::lo16);
796+
MI.eraseFromParent();
797+
return RBI.constrainGenericRegister(Dst, AMDGPU::VGPR_32RegClass, *MRI) &&
798+
RBI.constrainGenericRegister(Src0, AMDGPU::VGPR_16RegClass, *MRI);
799+
}
800+
// (build_vector $src0, undef) -> copy $src0
788801
MI.setDesc(TII.get(AMDGPU::COPY));
789802
MI.removeOperand(2);
790803
const auto &RC =

llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp

Lines changed: 107 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -223,8 +223,9 @@ static LegalityPredicate numElementsNotEven(unsigned TypeIdx) {
223223
};
224224
}
225225

226-
static bool isRegisterSize(unsigned Size) {
227-
return Size % 32 == 0 && Size <= MaxRegisterSize;
226+
static bool isRegisterSize(const GCNSubtarget &ST, unsigned Size) {
227+
return ((ST.useRealTrue16Insts() && Size == 16) || Size % 32 == 0) &&
228+
Size <= MaxRegisterSize;
228229
}
229230

230231
static bool isRegisterVectorElementType(LLT EltTy) {
@@ -240,8 +241,8 @@ static bool isRegisterVectorType(LLT Ty) {
240241
}
241242

242243
// TODO: replace all uses of isRegisterType with isRegisterClassType
243-
static bool isRegisterType(LLT Ty) {
244-
if (!isRegisterSize(Ty.getSizeInBits()))
244+
static bool isRegisterType(const GCNSubtarget &ST, LLT Ty) {
245+
if (!isRegisterSize(ST, Ty.getSizeInBits()))
245246
return false;
246247

247248
if (Ty.isVector())
@@ -252,19 +253,21 @@ static bool isRegisterType(LLT Ty) {
252253

253254
// Any combination of 32 or 64-bit elements up the maximum register size, and
254255
// multiples of v2s16.
255-
static LegalityPredicate isRegisterType(unsigned TypeIdx) {
256-
return [=](const LegalityQuery &Query) {
257-
return isRegisterType(Query.Types[TypeIdx]);
256+
static LegalityPredicate isRegisterType(const GCNSubtarget &ST,
257+
unsigned TypeIdx) {
258+
return [=, &ST](const LegalityQuery &Query) {
259+
return isRegisterType(ST, Query.Types[TypeIdx]);
258260
};
259261
}
260262

261263
// RegisterType that doesn't have a corresponding RegClass.
262264
// TODO: Once `isRegisterType` is replaced with `isRegisterClassType` this
263265
// should be removed.
264-
static LegalityPredicate isIllegalRegisterType(unsigned TypeIdx) {
265-
return [=](const LegalityQuery &Query) {
266+
static LegalityPredicate isIllegalRegisterType(const GCNSubtarget &ST,
267+
unsigned TypeIdx) {
268+
return [=, &ST](const LegalityQuery &Query) {
266269
LLT Ty = Query.Types[TypeIdx];
267-
return isRegisterType(Ty) &&
270+
return isRegisterType(ST, Ty) &&
268271
!SIRegisterInfo::getSGPRClassForBitWidth(Ty.getSizeInBits());
269272
};
270273
}
@@ -348,17 +351,20 @@ static std::initializer_list<LLT> AllS64Vectors = {V2S64, V3S64, V4S64, V5S64,
348351
V6S64, V7S64, V8S64, V16S64};
349352

350353
// Checks whether a type is in the list of legal register types.
351-
static bool isRegisterClassType(LLT Ty) {
354+
static bool isRegisterClassType(const GCNSubtarget &ST, LLT Ty) {
352355
if (Ty.isPointerOrPointerVector())
353356
Ty = Ty.changeElementType(LLT::scalar(Ty.getScalarSizeInBits()));
354357

355358
return is_contained(AllS32Vectors, Ty) || is_contained(AllS64Vectors, Ty) ||
356-
is_contained(AllScalarTypes, Ty) || is_contained(AllS16Vectors, Ty);
359+
is_contained(AllScalarTypes, Ty) ||
360+
(ST.useRealTrue16Insts() && Ty == S16) ||
361+
is_contained(AllS16Vectors, Ty);
357362
}
358363

359-
static LegalityPredicate isRegisterClassType(unsigned TypeIdx) {
360-
return [TypeIdx](const LegalityQuery &Query) {
361-
return isRegisterClassType(Query.Types[TypeIdx]);
364+
static LegalityPredicate isRegisterClassType(const GCNSubtarget &ST,
365+
unsigned TypeIdx) {
366+
return [&ST, TypeIdx](const LegalityQuery &Query) {
367+
return isRegisterClassType(ST, Query.Types[TypeIdx]);
362368
};
363369
}
364370

@@ -510,7 +516,7 @@ static bool loadStoreBitcastWorkaround(const LLT Ty) {
510516

511517
static bool isLoadStoreLegal(const GCNSubtarget &ST, const LegalityQuery &Query) {
512518
const LLT Ty = Query.Types[0];
513-
return isRegisterType(Ty) && isLoadStoreSizeLegal(ST, Query) &&
519+
return isRegisterType(ST, Ty) && isLoadStoreSizeLegal(ST, Query) &&
514520
!hasBufferRsrcWorkaround(Ty) && !loadStoreBitcastWorkaround(Ty);
515521
}
516522

@@ -523,12 +529,12 @@ static bool shouldBitcastLoadStoreType(const GCNSubtarget &ST, const LLT Ty,
523529
if (Size != MemSizeInBits)
524530
return Size <= 32 && Ty.isVector();
525531

526-
if (loadStoreBitcastWorkaround(Ty) && isRegisterType(Ty))
532+
if (loadStoreBitcastWorkaround(Ty) && isRegisterType(ST, Ty))
527533
return true;
528534

529535
// Don't try to handle bitcasting vector ext loads for now.
530536
return Ty.isVector() && (!MemTy.isVector() || MemTy == Ty) &&
531-
(Size <= 32 || isRegisterSize(Size)) &&
537+
(Size <= 32 || isRegisterSize(ST, Size)) &&
532538
!isRegisterVectorElementType(Ty.getElementType());
533539
}
534540

@@ -875,7 +881,7 @@ AMDGPULegalizerInfo::AMDGPULegalizerInfo(const GCNSubtarget &ST_,
875881

876882
getActionDefinitionsBuilder(G_BITCAST)
877883
// Don't worry about the size constraint.
878-
.legalIf(all(isRegisterClassType(0), isRegisterClassType(1)))
884+
.legalIf(all(isRegisterClassType(ST, 0), isRegisterClassType(ST, 1)))
879885
.lower();
880886

881887
getActionDefinitionsBuilder(G_CONSTANT)
@@ -890,7 +896,7 @@ AMDGPULegalizerInfo::AMDGPULegalizerInfo(const GCNSubtarget &ST_,
890896
.clampScalar(0, S16, S64);
891897

892898
getActionDefinitionsBuilder({G_IMPLICIT_DEF, G_FREEZE})
893-
.legalIf(isRegisterClassType(0))
899+
.legalIf(isRegisterClassType(ST, 0))
894900
// s1 and s16 are special cases because they have legal operations on
895901
// them, but don't really occupy registers in the normal way.
896902
.legalFor({S1, S16})
@@ -1779,7 +1785,7 @@ AMDGPULegalizerInfo::AMDGPULegalizerInfo(const GCNSubtarget &ST_,
17791785
unsigned IdxTypeIdx = 2;
17801786

17811787
getActionDefinitionsBuilder(Op)
1782-
.customIf([=](const LegalityQuery &Query) {
1788+
.customIf([=](const LegalityQuery &Query) {
17831789
const LLT EltTy = Query.Types[EltTypeIdx];
17841790
const LLT VecTy = Query.Types[VecTypeIdx];
17851791
const LLT IdxTy = Query.Types[IdxTypeIdx];
@@ -1800,36 +1806,37 @@ AMDGPULegalizerInfo::AMDGPULegalizerInfo(const GCNSubtarget &ST_,
18001806
IdxTy.getSizeInBits() == 32 &&
18011807
isLegalVecType;
18021808
})
1803-
.bitcastIf(all(sizeIsMultipleOf32(VecTypeIdx), scalarOrEltNarrowerThan(VecTypeIdx, 32)),
1804-
bitcastToVectorElement32(VecTypeIdx))
1805-
//.bitcastIf(vectorSmallerThan(1, 32), bitcastToScalar(1))
1806-
.bitcastIf(
1807-
all(sizeIsMultipleOf32(VecTypeIdx), scalarOrEltWiderThan(VecTypeIdx, 64)),
1808-
[=](const LegalityQuery &Query) {
1809-
// For > 64-bit element types, try to turn this into a 64-bit
1810-
// element vector since we may be able to do better indexing
1811-
// if this is scalar. If not, fall back to 32.
1812-
const LLT EltTy = Query.Types[EltTypeIdx];
1813-
const LLT VecTy = Query.Types[VecTypeIdx];
1814-
const unsigned DstEltSize = EltTy.getSizeInBits();
1815-
const unsigned VecSize = VecTy.getSizeInBits();
1816-
1817-
const unsigned TargetEltSize = DstEltSize % 64 == 0 ? 64 : 32;
1818-
return std::pair(
1819-
VecTypeIdx,
1820-
LLT::fixed_vector(VecSize / TargetEltSize, TargetEltSize));
1821-
})
1822-
.clampScalar(EltTypeIdx, S32, S64)
1823-
.clampScalar(VecTypeIdx, S32, S64)
1824-
.clampScalar(IdxTypeIdx, S32, S32)
1825-
.clampMaxNumElements(VecTypeIdx, S32, 32)
1826-
// TODO: Clamp elements for 64-bit vectors?
1827-
.moreElementsIf(
1828-
isIllegalRegisterType(VecTypeIdx),
1829-
moreElementsToNextExistingRegClass(VecTypeIdx))
1830-
// It should only be necessary with variable indexes.
1831-
// As a last resort, lower to the stack
1832-
.lower();
1809+
.bitcastIf(all(sizeIsMultipleOf32(VecTypeIdx),
1810+
scalarOrEltNarrowerThan(VecTypeIdx, 32)),
1811+
bitcastToVectorElement32(VecTypeIdx))
1812+
//.bitcastIf(vectorSmallerThan(1, 32), bitcastToScalar(1))
1813+
.bitcastIf(all(sizeIsMultipleOf32(VecTypeIdx),
1814+
scalarOrEltWiderThan(VecTypeIdx, 64)),
1815+
[=](const LegalityQuery &Query) {
1816+
// For > 64-bit element types, try to turn this into a
1817+
// 64-bit element vector since we may be able to do better
1818+
// indexing if this is scalar. If not, fall back to 32.
1819+
const LLT EltTy = Query.Types[EltTypeIdx];
1820+
const LLT VecTy = Query.Types[VecTypeIdx];
1821+
const unsigned DstEltSize = EltTy.getSizeInBits();
1822+
const unsigned VecSize = VecTy.getSizeInBits();
1823+
1824+
const unsigned TargetEltSize =
1825+
DstEltSize % 64 == 0 ? 64 : 32;
1826+
return std::pair(VecTypeIdx,
1827+
LLT::fixed_vector(VecSize / TargetEltSize,
1828+
TargetEltSize));
1829+
})
1830+
.clampScalar(EltTypeIdx, S32, S64)
1831+
.clampScalar(VecTypeIdx, S32, S64)
1832+
.clampScalar(IdxTypeIdx, S32, S32)
1833+
.clampMaxNumElements(VecTypeIdx, S32, 32)
1834+
// TODO: Clamp elements for 64-bit vectors?
1835+
.moreElementsIf(isIllegalRegisterType(ST, VecTypeIdx),
1836+
moreElementsToNextExistingRegClass(VecTypeIdx))
1837+
// It should only be necessary with variable indexes.
1838+
// As a last resort, lower to the stack
1839+
.lower();
18331840
}
18341841

18351842
getActionDefinitionsBuilder(G_EXTRACT_VECTOR_ELT)
@@ -1876,15 +1883,15 @@ AMDGPULegalizerInfo::AMDGPULegalizerInfo(const GCNSubtarget &ST_,
18761883

18771884
}
18781885

1879-
auto &BuildVector = getActionDefinitionsBuilder(G_BUILD_VECTOR)
1880-
.legalForCartesianProduct(AllS32Vectors, {S32})
1881-
.legalForCartesianProduct(AllS64Vectors, {S64})
1882-
.clampNumElements(0, V16S32, V32S32)
1883-
.clampNumElements(0, V2S64, V16S64)
1884-
.fewerElementsIf(isWideVec16(0), changeTo(0, V2S16))
1885-
.moreElementsIf(
1886-
isIllegalRegisterType(0),
1887-
moreElementsToNextExistingRegClass(0));
1886+
auto &BuildVector =
1887+
getActionDefinitionsBuilder(G_BUILD_VECTOR)
1888+
.legalForCartesianProduct(AllS32Vectors, {S32})
1889+
.legalForCartesianProduct(AllS64Vectors, {S64})
1890+
.clampNumElements(0, V16S32, V32S32)
1891+
.clampNumElements(0, V2S64, V16S64)
1892+
.fewerElementsIf(isWideVec16(0), changeTo(0, V2S16))
1893+
.moreElementsIf(isIllegalRegisterType(ST, 0),
1894+
moreElementsToNextExistingRegClass(0));
18881895

18891896
if (ST.hasScalarPackInsts()) {
18901897
BuildVector
@@ -1904,14 +1911,14 @@ AMDGPULegalizerInfo::AMDGPULegalizerInfo(const GCNSubtarget &ST_,
19041911
.lower();
19051912
}
19061913

1907-
BuildVector.legalIf(isRegisterType(0));
1914+
BuildVector.legalIf(isRegisterType(ST, 0));
19081915

19091916
// FIXME: Clamp maximum size
19101917
getActionDefinitionsBuilder(G_CONCAT_VECTORS)
1911-
.legalIf(all(isRegisterType(0), isRegisterType(1)))
1912-
.clampMaxNumElements(0, S32, 32)
1913-
.clampMaxNumElements(1, S16, 2) // TODO: Make 4?
1914-
.clampMaxNumElements(0, S16, 64);
1918+
.legalIf(all(isRegisterType(ST, 0), isRegisterType(ST, 1)))
1919+
.clampMaxNumElements(0, S32, 32)
1920+
.clampMaxNumElements(1, S16, 2) // TODO: Make 4?
1921+
.clampMaxNumElements(0, S16, 64);
19151922

19161923
getActionDefinitionsBuilder(G_SHUFFLE_VECTOR).lower();
19171924

@@ -1932,34 +1939,40 @@ AMDGPULegalizerInfo::AMDGPULegalizerInfo(const GCNSubtarget &ST_,
19321939
return false;
19331940
};
19341941

1935-
auto &Builder = getActionDefinitionsBuilder(Op)
1936-
.legalIf(all(isRegisterType(0), isRegisterType(1)))
1937-
.lowerFor({{S16, V2S16}})
1938-
.lowerIf([=](const LegalityQuery &Query) {
1939-
const LLT BigTy = Query.Types[BigTyIdx];
1940-
return BigTy.getSizeInBits() == 32;
1941-
})
1942-
// Try to widen to s16 first for small types.
1943-
// TODO: Only do this on targets with legal s16 shifts
1944-
.minScalarOrEltIf(scalarNarrowerThan(LitTyIdx, 16), LitTyIdx, S16)
1945-
.widenScalarToNextPow2(LitTyIdx, /*Min*/ 16)
1946-
.moreElementsIf(isSmallOddVector(BigTyIdx), oneMoreElement(BigTyIdx))
1947-
.fewerElementsIf(all(typeIs(0, S16), vectorWiderThan(1, 32),
1948-
elementTypeIs(1, S16)),
1949-
changeTo(1, V2S16))
1950-
// Clamp the little scalar to s8-s256 and make it a power of 2. It's not
1951-
// worth considering the multiples of 64 since 2*192 and 2*384 are not
1952-
// valid.
1953-
.clampScalar(LitTyIdx, S32, S512)
1954-
.widenScalarToNextPow2(LitTyIdx, /*Min*/ 32)
1955-
// Break up vectors with weird elements into scalars
1956-
.fewerElementsIf(
1957-
[=](const LegalityQuery &Query) { return notValidElt(Query, LitTyIdx); },
1958-
scalarize(0))
1959-
.fewerElementsIf(
1960-
[=](const LegalityQuery &Query) { return notValidElt(Query, BigTyIdx); },
1961-
scalarize(1))
1962-
.clampScalar(BigTyIdx, S32, MaxScalar);
1942+
auto &Builder =
1943+
getActionDefinitionsBuilder(Op)
1944+
.legalIf(all(isRegisterType(ST, 0), isRegisterType(ST, 1)))
1945+
.lowerFor({{S16, V2S16}})
1946+
.lowerIf([=](const LegalityQuery &Query) {
1947+
const LLT BigTy = Query.Types[BigTyIdx];
1948+
return BigTy.getSizeInBits() == 32;
1949+
})
1950+
// Try to widen to s16 first for small types.
1951+
// TODO: Only do this on targets with legal s16 shifts
1952+
.minScalarOrEltIf(scalarNarrowerThan(LitTyIdx, 16), LitTyIdx, S16)
1953+
.widenScalarToNextPow2(LitTyIdx, /*Min*/ 16)
1954+
.moreElementsIf(isSmallOddVector(BigTyIdx),
1955+
oneMoreElement(BigTyIdx))
1956+
.fewerElementsIf(all(typeIs(0, S16), vectorWiderThan(1, 32),
1957+
elementTypeIs(1, S16)),
1958+
changeTo(1, V2S16))
1959+
// Clamp the little scalar to s8-s256 and make it a power of 2. It's
1960+
// not worth considering the multiples of 64 since 2*192 and 2*384
1961+
// are not valid.
1962+
.clampScalar(LitTyIdx, S32, S512)
1963+
.widenScalarToNextPow2(LitTyIdx, /*Min*/ 32)
1964+
// Break up vectors with weird elements into scalars
1965+
.fewerElementsIf(
1966+
[=](const LegalityQuery &Query) {
1967+
return notValidElt(Query, LitTyIdx);
1968+
},
1969+
scalarize(0))
1970+
.fewerElementsIf(
1971+
[=](const LegalityQuery &Query) {
1972+
return notValidElt(Query, BigTyIdx);
1973+
},
1974+
scalarize(1))
1975+
.clampScalar(BigTyIdx, S32, MaxScalar);
19631976

19641977
if (Op == G_MERGE_VALUES) {
19651978
Builder.widenScalarIf(
@@ -3146,7 +3159,7 @@ bool AMDGPULegalizerInfo::legalizeLoad(LegalizerHelper &Helper,
31463159
} else {
31473160
// Extract the subvector.
31483161

3149-
if (isRegisterType(ValTy)) {
3162+
if (isRegisterType(ST, ValTy)) {
31503163
// If this a case where G_EXTRACT is legal, use it.
31513164
// (e.g. <3 x s32> -> <4 x s32>)
31523165
WideLoad = B.buildLoadFromOffset(WideTy, PtrReg, *MMO, 0).getReg(0);

llvm/lib/Target/AMDGPU/AMDGPURegisterBanks.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def SGPRRegBank : RegisterBank<"SGPR",
1111
>;
1212

1313
def VGPRRegBank : RegisterBank<"VGPR",
14-
[VGPR_32, VReg_64, VReg_96, VReg_128, VReg_160, VReg_192, VReg_224, VReg_256, VReg_288, VReg_320, VReg_352, VReg_384, VReg_512, VReg_1024]
14+
[VGPR_16_Lo128, VGPR_16, VGPR_32, VReg_64, VReg_96, VReg_128, VReg_160, VReg_192, VReg_224, VReg_256, VReg_288, VReg_320, VReg_352, VReg_384, VReg_512, VReg_1024]
1515
>;
1616

1717
// It is helpful to distinguish conditions from ordinary SGPRs.

llvm/lib/Target/AMDGPU/SIRegisterInfo.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ static cl::opt<bool> EnableSpillSGPRToVGPR(
3535
cl::ReallyHidden,
3636
cl::init(true));
3737

38-
std::array<std::vector<int16_t>, 16> SIRegisterInfo::RegSplitParts;
38+
std::array<std::vector<int16_t>, 32> SIRegisterInfo::RegSplitParts;
3939
std::array<std::array<uint16_t, 32>, 9> SIRegisterInfo::SubRegFromChannelTable;
4040

4141
// Map numbers of DWORDs to indexes in SubRegFromChannelTable.
@@ -351,9 +351,9 @@ SIRegisterInfo::SIRegisterInfo(const GCNSubtarget &ST)
351351
static auto InitializeRegSplitPartsOnce = [this]() {
352352
for (unsigned Idx = 1, E = getNumSubRegIndices() - 1; Idx < E; ++Idx) {
353353
unsigned Size = getSubRegIdxSize(Idx);
354-
if (Size & 31)
354+
if (Size & 15)
355355
continue;
356-
std::vector<int16_t> &Vec = RegSplitParts[Size / 32 - 1];
356+
std::vector<int16_t> &Vec = RegSplitParts[Size / 16 - 1];
357357
unsigned Pos = getSubRegIdxOffset(Idx);
358358
if (Pos % Size)
359359
continue;
@@ -3554,14 +3554,14 @@ bool SIRegisterInfo::isUniformReg(const MachineRegisterInfo &MRI,
35543554
ArrayRef<int16_t> SIRegisterInfo::getRegSplitParts(const TargetRegisterClass *RC,
35553555
unsigned EltSize) const {
35563556
const unsigned RegBitWidth = AMDGPU::getRegBitWidth(*RC);
3557-
assert(RegBitWidth >= 32 && RegBitWidth <= 1024);
3557+
assert(RegBitWidth >= 32 && RegBitWidth <= 1024 && EltSize >= 2);
35583558

3559-
const unsigned RegDWORDs = RegBitWidth / 32;
3560-
const unsigned EltDWORDs = EltSize / 4;
3561-
assert(RegSplitParts.size() + 1 >= EltDWORDs);
3559+
const unsigned RegHalves = RegBitWidth / 16;
3560+
const unsigned EltHalves = EltSize / 2;
3561+
assert(RegSplitParts.size() + 1 >= EltHalves);
35623562

3563-
const std::vector<int16_t> &Parts = RegSplitParts[EltDWORDs - 1];
3564-
const unsigned NumParts = RegDWORDs / EltDWORDs;
3563+
const std::vector<int16_t> &Parts = RegSplitParts[EltHalves - 1];
3564+
const unsigned NumParts = RegHalves / EltHalves;
35653565

35663566
return ArrayRef(Parts.data(), NumParts);
35673567
}

0 commit comments

Comments
 (0)