Skip to content

Commit 932d9c1

Browse files
authored
[NVPTX] Generalize and extend upsizing when lowering 8/16-bit-element vector loads/stores (#119622)
This addresses the following issue I opened: #118851. This change generalizes the Type Legalization mechanism that currently handles `v8[i/f/bf]16` upsizing to include loads _and_ stores of `v8i8` + `v16i8`, allowing all of the mentioned vectors to be lowered to ptx as vectors of `b32`. This extension also allows us to remove the DagCombine that only handled exactly `load v16i8`, thus centralizing all the upsizing logic into one place. Test changes include adding v8i8, v16i8, and v8i16 cases to load-store.ll, and updating the CHECKs for other tests to match the improved codegen.
1 parent 8a62104 commit 932d9c1

File tree

9 files changed

+3705
-2071
lines changed

9 files changed

+3705
-2071
lines changed

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1001,6 +1001,17 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
10011001
return true;
10021002
}
10031003

1004+
static bool isVectorElementTypeUpsized(EVT EltVT) {
1005+
// Despite vectors like v8i8, v16i8, v8i16 being within the bit-limit for
1006+
// total load/store size, PTX syntax only supports v2/v4. Thus, we can't use
1007+
// vectorized loads/stores with the actual element type for i8/i16 as that
1008+
// would require v8/v16 variants that do not exist.
1009+
// In order to load/store such vectors efficiently, in Type Legalization
1010+
// we split the vector into word-sized chunks (v2x16/v4i8). Now, we will
1011+
// lower to PTX as vectors of b32.
1012+
return Isv2x16VT(EltVT) || EltVT == MVT::v4i8;
1013+
}
1014+
10041015
bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
10051016
MemSDNode *MemSD = cast<MemSDNode>(N);
10061017
EVT LoadedVT = MemSD->getMemoryVT();
@@ -1055,11 +1066,7 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
10551066

10561067
EVT EltVT = N->getValueType(0);
10571068

1058-
// v8x16 is a special case. PTX doesn't have ld.v8.16
1059-
// instruction. Instead, we split the vector into v2x16 chunks and
1060-
// load them with ld.v4.b32.
1061-
if (Isv2x16VT(EltVT)) {
1062-
assert(N->getOpcode() == NVPTXISD::LoadV4 && "Unexpected load opcode.");
1069+
if (isVectorElementTypeUpsized(EltVT)) {
10631070
EltVT = MVT::i32;
10641071
FromType = NVPTX::PTXLdStInstCode::Untyped;
10651072
FromTypeWidth = 32;
@@ -1223,16 +1230,16 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
12231230
if (EltVT.isVector()) {
12241231
NumElts = EltVT.getVectorNumElements();
12251232
EltVT = EltVT.getVectorElementType();
1226-
// vectors of 16bits type are loaded/stored as multiples of v2x16 elements.
1233+
// vectors of 8/16bits type are loaded/stored as multiples of v4i8/v2x16
1234+
// elements.
12271235
if ((EltVT == MVT::f16 && OrigType == MVT::v2f16) ||
12281236
(EltVT == MVT::bf16 && OrigType == MVT::v2bf16) ||
1229-
(EltVT == MVT::i16 && OrigType == MVT::v2i16)) {
1230-
assert(NumElts % 2 == 0 && "Vector must have even number of elements");
1231-
EltVT = OrigType;
1232-
NumElts /= 2;
1233-
} else if (OrigType == MVT::v4i8) {
1237+
(EltVT == MVT::i16 && OrigType == MVT::v2i16) ||
1238+
(EltVT == MVT::i8 && OrigType == MVT::v4i8)) {
1239+
assert(NumElts % OrigType.getVectorNumElements() == 0 &&
1240+
"NumElts must be divisible by the number of elts in subvectors");
12341241
EltVT = OrigType;
1235-
NumElts = 1;
1242+
NumElts /= OrigType.getVectorNumElements();
12361243
}
12371244
}
12381245

@@ -1739,11 +1746,7 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
17391746
return false;
17401747
}
17411748

1742-
// v8x16 is a special case. PTX doesn't have st.v8.x16
1743-
// instruction. Instead, we split the vector into v2x16 chunks and
1744-
// store them with st.v4.b32.
1745-
if (Isv2x16VT(EltVT)) {
1746-
assert(N->getOpcode() == NVPTXISD::StoreV4 && "Unexpected load opcode.");
1749+
if (isVectorElementTypeUpsized(EltVT)) {
17471750
EltVT = MVT::i32;
17481751
ToType = NVPTX::PTXLdStInstCode::Untyped;
17491752
ToTypeWidth = 32;

0 commit comments

Comments
 (0)