Skip to content

[NVPTX] Generalize and extend upsizing when lowering 8/16-bit-element vector loads/stores #119622

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 20 additions & 17 deletions llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1001,6 +1001,17 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
return true;
}

static bool isVectorElementTypeUpsized(EVT EltVT) {
// Despite vectors like v8i8, v16i8, v8i16 being within the bit-limit for
// total load/store size, PTX syntax only supports v2/v4. Thus, we can't use
// vectorized loads/stores with the actual element type for i8/i16 as that
// would require v8/v16 variants that do not exist.
// In order to load/store such vectors efficiently, in Type Legalization
// we split the vector into word-sized chunks (v2x16/v4i8). Now, we will
// lower to PTX as vectors of b32.
return Isv2x16VT(EltVT) || EltVT == MVT::v4i8;
}

bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
MemSDNode *MemSD = cast<MemSDNode>(N);
EVT LoadedVT = MemSD->getMemoryVT();
Expand Down Expand Up @@ -1055,11 +1066,7 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {

EVT EltVT = N->getValueType(0);

// v8x16 is a special case. PTX doesn't have ld.v8.16
// instruction. Instead, we split the vector into v2x16 chunks and
// load them with ld.v4.b32.
if (Isv2x16VT(EltVT)) {
assert(N->getOpcode() == NVPTXISD::LoadV4 && "Unexpected load opcode.");
if (isVectorElementTypeUpsized(EltVT)) {
EltVT = MVT::i32;
FromType = NVPTX::PTXLdStInstCode::Untyped;
FromTypeWidth = 32;
Expand Down Expand Up @@ -1223,16 +1230,16 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
if (EltVT.isVector()) {
NumElts = EltVT.getVectorNumElements();
EltVT = EltVT.getVectorElementType();
// vectors of 16bits type are loaded/stored as multiples of v2x16 elements.
// vectors of 8/16bits type are loaded/stored as multiples of v4i8/v2x16
// elements.
if ((EltVT == MVT::f16 && OrigType == MVT::v2f16) ||
(EltVT == MVT::bf16 && OrigType == MVT::v2bf16) ||
(EltVT == MVT::i16 && OrigType == MVT::v2i16)) {
assert(NumElts % 2 == 0 && "Vector must have even number of elements");
EltVT = OrigType;
NumElts /= 2;
} else if (OrigType == MVT::v4i8) {
(EltVT == MVT::i16 && OrigType == MVT::v2i16) ||
(EltVT == MVT::i8 && OrigType == MVT::v4i8)) {
assert(NumElts % OrigType.getVectorNumElements() == 0 &&
"NumElts must be divisible by the number of elts in subvectors");
EltVT = OrigType;
NumElts = 1;
NumElts /= OrigType.getVectorNumElements();
}
}

Expand Down Expand Up @@ -1739,11 +1746,7 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
return false;
}

// v8x16 is a special case. PTX doesn't have st.v8.x16
// instruction. Instead, we split the vector into v2x16 chunks and
// store them with st.v4.b32.
if (Isv2x16VT(EltVT)) {
assert(N->getOpcode() == NVPTXISD::StoreV4 && "Unexpected load opcode.");
if (isVectorElementTypeUpsized(EltVT)) {
EltVT = MVT::i32;
ToType = NVPTX::PTXLdStInstCode::Untyped;
ToTypeWidth = 32;
Expand Down
Loading
Loading