Skip to content

[AMDGPU][True16][CodeGen] uaddsat/usubsat true16 selection in gisel #128233

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 107 additions & 94 deletions llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,9 @@ static LegalityPredicate numElementsNotEven(unsigned TypeIdx) {
};
}

static bool isRegisterSize(unsigned Size) {
return Size % 32 == 0 && Size <= MaxRegisterSize;
static bool isRegisterSize(const GCNSubtarget &ST, unsigned Size) {
return ((ST.useRealTrue16Insts() && Size == 16) || Size % 32 == 0) &&
Size <= MaxRegisterSize;
}

static bool isRegisterVectorElementType(LLT EltTy) {
Expand All @@ -240,8 +241,8 @@ static bool isRegisterVectorType(LLT Ty) {
}

// TODO: replace all uses of isRegisterType with isRegisterClassType
static bool isRegisterType(LLT Ty) {
if (!isRegisterSize(Ty.getSizeInBits()))
static bool isRegisterType(const GCNSubtarget &ST, LLT Ty) {
if (!isRegisterSize(ST, Ty.getSizeInBits()))
return false;

if (Ty.isVector())
Expand All @@ -252,19 +253,21 @@ static bool isRegisterType(LLT Ty) {

// Any combination of 32 or 64-bit elements up the maximum register size, and
// multiples of v2s16.
static LegalityPredicate isRegisterType(unsigned TypeIdx) {
return [=](const LegalityQuery &Query) {
return isRegisterType(Query.Types[TypeIdx]);
static LegalityPredicate isRegisterType(const GCNSubtarget &ST,
unsigned TypeIdx) {
return [=, &ST](const LegalityQuery &Query) {
return isRegisterType(ST, Query.Types[TypeIdx]);
};
}

// RegisterType that doesn't have a corresponding RegClass.
// TODO: Once `isRegisterType` is replaced with `isRegisterClassType` this
// should be removed.
static LegalityPredicate isIllegalRegisterType(unsigned TypeIdx) {
return [=](const LegalityQuery &Query) {
static LegalityPredicate isIllegalRegisterType(const GCNSubtarget &ST,
unsigned TypeIdx) {
return [=, &ST](const LegalityQuery &Query) {
LLT Ty = Query.Types[TypeIdx];
return isRegisterType(Ty) &&
return isRegisterType(ST, Ty) &&
!SIRegisterInfo::getSGPRClassForBitWidth(Ty.getSizeInBits());
};
}
Expand Down Expand Up @@ -348,17 +351,20 @@ static std::initializer_list<LLT> AllS64Vectors = {V2S64, V3S64, V4S64, V5S64,
V6S64, V7S64, V8S64, V16S64};

// Checks whether a type is in the list of legal register types.
static bool isRegisterClassType(LLT Ty) {
static bool isRegisterClassType(const GCNSubtarget &ST, LLT Ty) {
if (Ty.isPointerOrPointerVector())
Ty = Ty.changeElementType(LLT::scalar(Ty.getScalarSizeInBits()));

return is_contained(AllS32Vectors, Ty) || is_contained(AllS64Vectors, Ty) ||
is_contained(AllScalarTypes, Ty) || is_contained(AllS16Vectors, Ty);
is_contained(AllScalarTypes, Ty) ||
(ST.useRealTrue16Insts() && Ty == S16) ||
is_contained(AllS16Vectors, Ty);
}

static LegalityPredicate isRegisterClassType(unsigned TypeIdx) {
return [TypeIdx](const LegalityQuery &Query) {
return isRegisterClassType(Query.Types[TypeIdx]);
static LegalityPredicate isRegisterClassType(const GCNSubtarget &ST,
unsigned TypeIdx) {
return [&ST, TypeIdx](const LegalityQuery &Query) {
return isRegisterClassType(ST, Query.Types[TypeIdx]);
};
}

Expand Down Expand Up @@ -510,7 +516,7 @@ static bool loadStoreBitcastWorkaround(const LLT Ty) {

static bool isLoadStoreLegal(const GCNSubtarget &ST, const LegalityQuery &Query) {
const LLT Ty = Query.Types[0];
return isRegisterType(Ty) && isLoadStoreSizeLegal(ST, Query) &&
return isRegisterType(ST, Ty) && isLoadStoreSizeLegal(ST, Query) &&
!hasBufferRsrcWorkaround(Ty) && !loadStoreBitcastWorkaround(Ty);
}

Expand All @@ -523,12 +529,12 @@ static bool shouldBitcastLoadStoreType(const GCNSubtarget &ST, const LLT Ty,
if (Size != MemSizeInBits)
return Size <= 32 && Ty.isVector();

if (loadStoreBitcastWorkaround(Ty) && isRegisterType(Ty))
if (loadStoreBitcastWorkaround(Ty) && isRegisterType(ST, Ty))
return true;

// Don't try to handle bitcasting vector ext loads for now.
return Ty.isVector() && (!MemTy.isVector() || MemTy == Ty) &&
(Size <= 32 || isRegisterSize(Size)) &&
(Size <= 32 || isRegisterSize(ST, Size)) &&
!isRegisterVectorElementType(Ty.getElementType());
}

Expand Down Expand Up @@ -875,7 +881,7 @@ AMDGPULegalizerInfo::AMDGPULegalizerInfo(const GCNSubtarget &ST_,

getActionDefinitionsBuilder(G_BITCAST)
// Don't worry about the size constraint.
.legalIf(all(isRegisterClassType(0), isRegisterClassType(1)))
.legalIf(all(isRegisterClassType(ST, 0), isRegisterClassType(ST, 1)))
.lower();

getActionDefinitionsBuilder(G_CONSTANT)
Expand All @@ -890,7 +896,7 @@ AMDGPULegalizerInfo::AMDGPULegalizerInfo(const GCNSubtarget &ST_,
.clampScalar(0, S16, S64);

getActionDefinitionsBuilder({G_IMPLICIT_DEF, G_FREEZE})
.legalIf(isRegisterClassType(0))
.legalIf(isRegisterClassType(ST, 0))
// s1 and s16 are special cases because they have legal operations on
// them, but don't really occupy registers in the normal way.
.legalFor({S1, S16})
Expand Down Expand Up @@ -1779,7 +1785,7 @@ AMDGPULegalizerInfo::AMDGPULegalizerInfo(const GCNSubtarget &ST_,
unsigned IdxTypeIdx = 2;

getActionDefinitionsBuilder(Op)
.customIf([=](const LegalityQuery &Query) {
.customIf([=](const LegalityQuery &Query) {
const LLT EltTy = Query.Types[EltTypeIdx];
const LLT VecTy = Query.Types[VecTypeIdx];
const LLT IdxTy = Query.Types[IdxTypeIdx];
Expand All @@ -1800,36 +1806,37 @@ AMDGPULegalizerInfo::AMDGPULegalizerInfo(const GCNSubtarget &ST_,
IdxTy.getSizeInBits() == 32 &&
isLegalVecType;
})
.bitcastIf(all(sizeIsMultipleOf32(VecTypeIdx), scalarOrEltNarrowerThan(VecTypeIdx, 32)),
Copy link
Contributor Author

@broxigarchen broxigarchen Feb 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some of the changes in this patch are pure clang-format change

bitcastToVectorElement32(VecTypeIdx))
//.bitcastIf(vectorSmallerThan(1, 32), bitcastToScalar(1))
.bitcastIf(
all(sizeIsMultipleOf32(VecTypeIdx), scalarOrEltWiderThan(VecTypeIdx, 64)),
[=](const LegalityQuery &Query) {
// For > 64-bit element types, try to turn this into a 64-bit
// element vector since we may be able to do better indexing
// if this is scalar. If not, fall back to 32.
const LLT EltTy = Query.Types[EltTypeIdx];
const LLT VecTy = Query.Types[VecTypeIdx];
const unsigned DstEltSize = EltTy.getSizeInBits();
const unsigned VecSize = VecTy.getSizeInBits();

const unsigned TargetEltSize = DstEltSize % 64 == 0 ? 64 : 32;
return std::pair(
VecTypeIdx,
LLT::fixed_vector(VecSize / TargetEltSize, TargetEltSize));
})
.clampScalar(EltTypeIdx, S32, S64)
.clampScalar(VecTypeIdx, S32, S64)
.clampScalar(IdxTypeIdx, S32, S32)
.clampMaxNumElements(VecTypeIdx, S32, 32)
// TODO: Clamp elements for 64-bit vectors?
.moreElementsIf(
isIllegalRegisterType(VecTypeIdx),
moreElementsToNextExistingRegClass(VecTypeIdx))
// It should only be necessary with variable indexes.
// As a last resort, lower to the stack
.lower();
.bitcastIf(all(sizeIsMultipleOf32(VecTypeIdx),
scalarOrEltNarrowerThan(VecTypeIdx, 32)),
bitcastToVectorElement32(VecTypeIdx))
//.bitcastIf(vectorSmallerThan(1, 32), bitcastToScalar(1))
.bitcastIf(all(sizeIsMultipleOf32(VecTypeIdx),
scalarOrEltWiderThan(VecTypeIdx, 64)),
[=](const LegalityQuery &Query) {
// For > 64-bit element types, try to turn this into a
// 64-bit element vector since we may be able to do better
// indexing if this is scalar. If not, fall back to 32.
const LLT EltTy = Query.Types[EltTypeIdx];
const LLT VecTy = Query.Types[VecTypeIdx];
const unsigned DstEltSize = EltTy.getSizeInBits();
const unsigned VecSize = VecTy.getSizeInBits();

const unsigned TargetEltSize =
DstEltSize % 64 == 0 ? 64 : 32;
return std::pair(VecTypeIdx,
LLT::fixed_vector(VecSize / TargetEltSize,
TargetEltSize));
})
.clampScalar(EltTypeIdx, S32, S64)
.clampScalar(VecTypeIdx, S32, S64)
.clampScalar(IdxTypeIdx, S32, S32)
.clampMaxNumElements(VecTypeIdx, S32, 32)
// TODO: Clamp elements for 64-bit vectors?
.moreElementsIf(isIllegalRegisterType(ST, VecTypeIdx),
moreElementsToNextExistingRegClass(VecTypeIdx))
// It should only be necessary with variable indexes.
// As a last resort, lower to the stack
.lower();
}

getActionDefinitionsBuilder(G_EXTRACT_VECTOR_ELT)
Expand Down Expand Up @@ -1876,15 +1883,15 @@ AMDGPULegalizerInfo::AMDGPULegalizerInfo(const GCNSubtarget &ST_,

}

auto &BuildVector = getActionDefinitionsBuilder(G_BUILD_VECTOR)
.legalForCartesianProduct(AllS32Vectors, {S32})
.legalForCartesianProduct(AllS64Vectors, {S64})
.clampNumElements(0, V16S32, V32S32)
.clampNumElements(0, V2S64, V16S64)
.fewerElementsIf(isWideVec16(0), changeTo(0, V2S16))
.moreElementsIf(
isIllegalRegisterType(0),
moreElementsToNextExistingRegClass(0));
auto &BuildVector =
getActionDefinitionsBuilder(G_BUILD_VECTOR)
.legalForCartesianProduct(AllS32Vectors, {S32})
.legalForCartesianProduct(AllS64Vectors, {S64})
.clampNumElements(0, V16S32, V32S32)
.clampNumElements(0, V2S64, V16S64)
.fewerElementsIf(isWideVec16(0), changeTo(0, V2S16))
.moreElementsIf(isIllegalRegisterType(ST, 0),
moreElementsToNextExistingRegClass(0));

if (ST.hasScalarPackInsts()) {
BuildVector
Expand All @@ -1904,14 +1911,14 @@ AMDGPULegalizerInfo::AMDGPULegalizerInfo(const GCNSubtarget &ST_,
.lower();
}

BuildVector.legalIf(isRegisterType(0));
BuildVector.legalIf(isRegisterType(ST, 0));

// FIXME: Clamp maximum size
getActionDefinitionsBuilder(G_CONCAT_VECTORS)
.legalIf(all(isRegisterType(0), isRegisterType(1)))
.clampMaxNumElements(0, S32, 32)
.clampMaxNumElements(1, S16, 2) // TODO: Make 4?
.clampMaxNumElements(0, S16, 64);
.legalIf(all(isRegisterType(ST, 0), isRegisterType(ST, 1)))
.clampMaxNumElements(0, S32, 32)
.clampMaxNumElements(1, S16, 2) // TODO: Make 4?
.clampMaxNumElements(0, S16, 64);

getActionDefinitionsBuilder(G_SHUFFLE_VECTOR).lower();

Expand All @@ -1932,34 +1939,40 @@ AMDGPULegalizerInfo::AMDGPULegalizerInfo(const GCNSubtarget &ST_,
return false;
};

auto &Builder = getActionDefinitionsBuilder(Op)
.legalIf(all(isRegisterType(0), isRegisterType(1)))
.lowerFor({{S16, V2S16}})
.lowerIf([=](const LegalityQuery &Query) {
const LLT BigTy = Query.Types[BigTyIdx];
return BigTy.getSizeInBits() == 32;
})
// Try to widen to s16 first for small types.
// TODO: Only do this on targets with legal s16 shifts
.minScalarOrEltIf(scalarNarrowerThan(LitTyIdx, 16), LitTyIdx, S16)
.widenScalarToNextPow2(LitTyIdx, /*Min*/ 16)
.moreElementsIf(isSmallOddVector(BigTyIdx), oneMoreElement(BigTyIdx))
.fewerElementsIf(all(typeIs(0, S16), vectorWiderThan(1, 32),
elementTypeIs(1, S16)),
changeTo(1, V2S16))
// Clamp the little scalar to s8-s256 and make it a power of 2. It's not
// worth considering the multiples of 64 since 2*192 and 2*384 are not
// valid.
.clampScalar(LitTyIdx, S32, S512)
.widenScalarToNextPow2(LitTyIdx, /*Min*/ 32)
// Break up vectors with weird elements into scalars
.fewerElementsIf(
[=](const LegalityQuery &Query) { return notValidElt(Query, LitTyIdx); },
scalarize(0))
.fewerElementsIf(
[=](const LegalityQuery &Query) { return notValidElt(Query, BigTyIdx); },
scalarize(1))
.clampScalar(BigTyIdx, S32, MaxScalar);
auto &Builder =
getActionDefinitionsBuilder(Op)
.legalIf(all(isRegisterType(ST, 0), isRegisterType(ST, 1)))
.lowerFor({{S16, V2S16}})
.lowerIf([=](const LegalityQuery &Query) {
const LLT BigTy = Query.Types[BigTyIdx];
return BigTy.getSizeInBits() == 32;
})
// Try to widen to s16 first for small types.
// TODO: Only do this on targets with legal s16 shifts
.minScalarOrEltIf(scalarNarrowerThan(LitTyIdx, 16), LitTyIdx, S16)
.widenScalarToNextPow2(LitTyIdx, /*Min*/ 16)
.moreElementsIf(isSmallOddVector(BigTyIdx),
oneMoreElement(BigTyIdx))
.fewerElementsIf(all(typeIs(0, S16), vectorWiderThan(1, 32),
elementTypeIs(1, S16)),
changeTo(1, V2S16))
// Clamp the little scalar to s8-s256 and make it a power of 2. It's
// not worth considering the multiples of 64 since 2*192 and 2*384
// are not valid.
.clampScalar(LitTyIdx, S32, S512)
.widenScalarToNextPow2(LitTyIdx, /*Min*/ 32)
// Break up vectors with weird elements into scalars
.fewerElementsIf(
[=](const LegalityQuery &Query) {
return notValidElt(Query, LitTyIdx);
},
scalarize(0))
.fewerElementsIf(
[=](const LegalityQuery &Query) {
return notValidElt(Query, BigTyIdx);
},
scalarize(1))
.clampScalar(BigTyIdx, S32, MaxScalar);

if (Op == G_MERGE_VALUES) {
Builder.widenScalarIf(
Expand Down Expand Up @@ -3146,7 +3159,7 @@ bool AMDGPULegalizerInfo::legalizeLoad(LegalizerHelper &Helper,
} else {
// Extract the subvector.

if (isRegisterType(ValTy)) {
if (isRegisterType(ST, ValTy)) {
// If this a case where G_EXTRACT is legal, use it.
// (e.g. <3 x s32> -> <4 x s32>)
WideLoad = B.buildLoadFromOffset(WideTy, PtrReg, *MMO, 0).getReg(0);
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/AMDGPU/AMDGPURegisterBanks.td
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def SGPRRegBank : RegisterBank<"SGPR",
>;

def VGPRRegBank : RegisterBank<"VGPR",
[VGPR_32, VReg_64, VReg_96, VReg_128, VReg_160, VReg_192, VReg_224, VReg_256, VReg_288, VReg_320, VReg_352, VReg_384, VReg_512, VReg_1024]
[VGPR_16_Lo128, VGPR_16, VGPR_32, VReg_64, VReg_96, VReg_128, VReg_160, VReg_192, VReg_224, VReg_256, VReg_288, VReg_320, VReg_352, VReg_384, VReg_512, VReg_1024]
>;

// It is helpful to distinguish conditions from ordinary SGPRs.
Expand Down
18 changes: 9 additions & 9 deletions llvm/lib/Target/AMDGPU/SIRegisterInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ static cl::opt<bool> EnableSpillSGPRToVGPR(
cl::ReallyHidden,
cl::init(true));

std::array<std::vector<int16_t>, 16> SIRegisterInfo::RegSplitParts;
std::array<std::vector<int16_t>, 32> SIRegisterInfo::RegSplitParts;
std::array<std::array<uint16_t, 32>, 9> SIRegisterInfo::SubRegFromChannelTable;

// Map numbers of DWORDs to indexes in SubRegFromChannelTable.
Expand Down Expand Up @@ -351,9 +351,9 @@ SIRegisterInfo::SIRegisterInfo(const GCNSubtarget &ST)
static auto InitializeRegSplitPartsOnce = [this]() {
for (unsigned Idx = 1, E = getNumSubRegIndices() - 1; Idx < E; ++Idx) {
unsigned Size = getSubRegIdxSize(Idx);
if (Size & 31)
if (Size & 15)
continue;
std::vector<int16_t> &Vec = RegSplitParts[Size / 32 - 1];
std::vector<int16_t> &Vec = RegSplitParts[Size / 16 - 1];
unsigned Pos = getSubRegIdxOffset(Idx);
if (Pos % Size)
continue;
Expand Down Expand Up @@ -3554,14 +3554,14 @@ bool SIRegisterInfo::isUniformReg(const MachineRegisterInfo &MRI,
ArrayRef<int16_t> SIRegisterInfo::getRegSplitParts(const TargetRegisterClass *RC,
unsigned EltSize) const {
const unsigned RegBitWidth = AMDGPU::getRegBitWidth(*RC);
assert(RegBitWidth >= 32 && RegBitWidth <= 1024);
assert(RegBitWidth >= 32 && RegBitWidth <= 1024 && EltSize >= 2);

const unsigned RegDWORDs = RegBitWidth / 32;
const unsigned EltDWORDs = EltSize / 4;
assert(RegSplitParts.size() + 1 >= EltDWORDs);
const unsigned RegHalves = RegBitWidth / 16;
const unsigned EltHalves = EltSize / 2;
assert(RegSplitParts.size() + 1 >= EltHalves);

const std::vector<int16_t> &Parts = RegSplitParts[EltDWORDs - 1];
const unsigned NumParts = RegDWORDs / EltDWORDs;
const std::vector<int16_t> &Parts = RegSplitParts[EltHalves - 1];
const unsigned NumParts = RegHalves / EltHalves;

return ArrayRef(Parts.data(), NumParts);
}
Expand Down
4 changes: 2 additions & 2 deletions llvm/lib/Target/AMDGPU/SIRegisterInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ class SIRegisterInfo final : public AMDGPUGenRegisterInfo {
BitVector RegPressureIgnoredUnits;

/// Sub reg indexes for getRegSplitParts.
/// First index represents subreg size from 1 to 16 DWORDs.
/// First index represents subreg size from 1 to 32 Half DWORDS.
/// The inner vector is sorted by bit offset.
/// Provided a register can be fully split with given subregs,
/// all elements of the inner vector combined give a full lane mask.
static std::array<std::vector<int16_t>, 16> RegSplitParts;
static std::array<std::vector<int16_t>, 32> RegSplitParts;

// Table representing sub reg of given width and offset.
// First index is subreg size: 32, 64, 96, 128, 160, 192, 224, 256, 512.
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2483,6 +2483,8 @@ bool isSISrcInlinableOperand(const MCInstrDesc &Desc, unsigned OpNo) {
// (move from MC* level to Target* level). Return size in bits.
unsigned getRegBitWidth(unsigned RCID) {
switch (RCID) {
case AMDGPU::VGPR_16RegClassID:
case AMDGPU::VGPR_16_Lo128RegClassID:
case AMDGPU::SGPR_LO16RegClassID:
case AMDGPU::AGPR_LO16RegClassID:
return 16;
Expand Down
Loading