Skip to content

[NVPTX] Use v2.u64 to load/store 128-bit values #136638

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 29 additions & 31 deletions llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1155,7 +1155,7 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
return true;
}

static bool isVectorElementTypeUpsized(EVT EltVT) {
static bool isSubVectorPackedInI32(EVT EltVT) {
// Despite vectors like v8i8, v16i8, v8i16 being within the bit-limit for
// total load/store size, PTX syntax only supports v2/v4. Thus, we can't use
// vectorized loads/stores with the actual element type for i8/i16 as that
Expand All @@ -1168,60 +1168,54 @@ static bool isVectorElementTypeUpsized(EVT EltVT) {

bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
MemSDNode *MemSD = cast<MemSDNode>(N);
EVT LoadedVT = MemSD->getMemoryVT();
if (!LoadedVT.isSimple())
const EVT MemEVT = MemSD->getMemoryVT();
if (!MemEVT.isSimple())
return false;
const MVT MemVT = MemEVT.getSimpleVT();

// Address Space Setting
unsigned int CodeAddrSpace = getCodeAddrSpace(MemSD);
if (canLowerToLDG(MemSD, *Subtarget, CodeAddrSpace, MF)) {
return tryLDGLDU(N);
}

EVT EltVT = N->getValueType(0);
SDLoc DL(N);
SDValue Chain = N->getOperand(0);
auto [Ordering, Scope] = insertMemoryInstructionFence(DL, Chain, MemSD);

// Vector Setting
MVT SimpleVT = LoadedVT.getSimpleVT();

// Type Setting: fromType + fromTypeWidth
//
// Sign : ISD::SEXTLOAD
// Unsign : ISD::ZEXTLOAD, ISD::NON_EXTLOAD or ISD::EXTLOAD and the
// type is integer
// Float : ISD::NON_EXTLOAD or ISD::EXTLOAD and the type is float
MVT ScalarVT = SimpleVT.getScalarType();
// Read at least 8 bits (predicates are stored as 8-bit values)
unsigned FromTypeWidth = std::max(8U, (unsigned)ScalarVT.getSizeInBits());
unsigned int FromType;
// The last operand holds the original LoadSDNode::getExtensionType() value
unsigned ExtensionType = cast<ConstantSDNode>(
N->getOperand(N->getNumOperands() - 1))->getZExtValue();
if (ExtensionType == ISD::SEXTLOAD)
FromType = NVPTX::PTXLdStInstCode::Signed;
else
FromType = getLdStRegType(ScalarVT);
const unsigned TotalWidth = MemVT.getSizeInBits();
unsigned ExtensionType = N->getConstantOperandVal(N->getNumOperands() - 1);
unsigned FromType = (ExtensionType == ISD::SEXTLOAD)
? NVPTX::PTXLdStInstCode::Signed
: getLdStRegType(MemVT.getScalarType());

unsigned VecType;

unsigned FromTypeWidth;
switch (N->getOpcode()) {
case NVPTXISD::LoadV2:
FromTypeWidth = TotalWidth / 2;
VecType = NVPTX::PTXLdStInstCode::V2;
break;
case NVPTXISD::LoadV4:
FromTypeWidth = TotalWidth / 4;
VecType = NVPTX::PTXLdStInstCode::V4;
break;
default:
return false;
}

EVT EltVT = N->getValueType(0);

if (isVectorElementTypeUpsized(EltVT)) {
if (isSubVectorPackedInI32(EltVT)) {
EltVT = MVT::i32;
FromType = NVPTX::PTXLdStInstCode::Untyped;
FromTypeWidth = 32;
}

SDValue Offset, Base;
Expand Down Expand Up @@ -1271,9 +1265,14 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
// LDG/LDU SD node (from custom vector handling), then its the second operand
SDValue Op1 = N->getOperand(N->getOpcode() == ISD::INTRINSIC_W_CHAIN ? 2 : 1);

EVT OrigType = N->getValueType(0);
const EVT OrigType = N->getValueType(0);
EVT EltVT = Mem->getMemoryVT();
unsigned NumElts = 1;

if (EltVT == MVT::i128 || EltVT == MVT::f128) {
EltVT = MVT::i64;
NumElts = 2;
}
if (EltVT.isVector()) {
NumElts = EltVT.getVectorNumElements();
EltVT = EltVT.getVectorElementType();
Expand All @@ -1293,11 +1292,9 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
// Build the "promoted" result VTList for the load. If we are really loading
// i8s, then the return type will be promoted to i16 since we do not expose
// 8-bit registers in NVPTX.
EVT NodeVT = (EltVT == MVT::i8) ? MVT::i16 : EltVT;
const EVT NodeVT = (EltVT == MVT::i8) ? MVT::i16 : EltVT;
SmallVector<EVT, 5> InstVTs;
for (unsigned i = 0; i != NumElts; ++i) {
InstVTs.push_back(NodeVT);
}
InstVTs.append(NumElts, NodeVT);
InstVTs.push_back(MVT::Other);
SDVTList InstVTList = CurDAG->getVTList(InstVTs);
SDValue Chain = N->getOperand(0);
Expand Down Expand Up @@ -1476,6 +1473,7 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
EVT EltVT = Op1.getValueType();
MemSDNode *MemSD = cast<MemSDNode>(N);
EVT StoreVT = MemSD->getMemoryVT();
assert(StoreVT.isSimple() && "Store value is not simple");

// Address Space Setting
unsigned CodeAddrSpace = getCodeAddrSpace(MemSD);
Expand All @@ -1490,35 +1488,35 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {

// Type Setting: toType + toTypeWidth
// - for integer type, always use 'u'
assert(StoreVT.isSimple() && "Store value is not simple");
MVT ScalarVT = StoreVT.getSimpleVT().getScalarType();
unsigned ToTypeWidth = ScalarVT.getSizeInBits();
unsigned ToType = getLdStRegType(ScalarVT);
const unsigned TotalWidth = StoreVT.getSimpleVT().getSizeInBits();
unsigned ToType = getLdStRegType(StoreVT.getSimpleVT().getScalarType());

SmallVector<SDValue, 12> Ops;
SDValue N2;
unsigned VecType;
unsigned ToTypeWidth;

switch (N->getOpcode()) {
case NVPTXISD::StoreV2:
VecType = NVPTX::PTXLdStInstCode::V2;
Ops.append({N->getOperand(1), N->getOperand(2)});
N2 = N->getOperand(3);
ToTypeWidth = TotalWidth / 2;
break;
case NVPTXISD::StoreV4:
VecType = NVPTX::PTXLdStInstCode::V4;
Ops.append({N->getOperand(1), N->getOperand(2), N->getOperand(3),
N->getOperand(4)});
N2 = N->getOperand(5);
ToTypeWidth = TotalWidth / 4;
break;
default:
return false;
}

if (isVectorElementTypeUpsized(EltVT)) {
if (isSubVectorPackedInI32(EltVT)) {
EltVT = MVT::i32;
ToType = NVPTX::PTXLdStInstCode::Untyped;
ToTypeWidth = 32;
}

SDValue Offset, Base;
Expand Down
Loading
Loading