Skip to content

Commit a2be454

Browse files
authored
[NVPTX] Use v2.u64 to load/store 128-bit values (#136638)
1 parent e112dcc commit a2be454

File tree

8 files changed

+114
-139
lines changed

8 files changed

+114
-139
lines changed

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 29 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1173,7 +1173,7 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
11731173
return true;
11741174
}
11751175

1176-
static bool isVectorElementTypeUpsized(EVT EltVT) {
1176+
static bool isSubVectorPackedInI32(EVT EltVT) {
11771177
// Despite vectors like v8i8, v16i8, v8i16 being within the bit-limit for
11781178
// total load/store size, PTX syntax only supports v2/v4. Thus, we can't use
11791179
// vectorized loads/stores with the actual element type for i8/i16 as that
@@ -1186,60 +1186,54 @@ static bool isVectorElementTypeUpsized(EVT EltVT) {
11861186

11871187
bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
11881188
MemSDNode *MemSD = cast<MemSDNode>(N);
1189-
EVT LoadedVT = MemSD->getMemoryVT();
1190-
if (!LoadedVT.isSimple())
1189+
const EVT MemEVT = MemSD->getMemoryVT();
1190+
if (!MemEVT.isSimple())
11911191
return false;
1192+
const MVT MemVT = MemEVT.getSimpleVT();
11921193

11931194
// Address Space Setting
11941195
unsigned int CodeAddrSpace = getCodeAddrSpace(MemSD);
11951196
if (canLowerToLDG(MemSD, *Subtarget, CodeAddrSpace, MF)) {
11961197
return tryLDGLDU(N);
11971198
}
11981199

1200+
EVT EltVT = N->getValueType(0);
11991201
SDLoc DL(N);
12001202
SDValue Chain = N->getOperand(0);
12011203
auto [Ordering, Scope] = insertMemoryInstructionFence(DL, Chain, MemSD);
12021204

1203-
// Vector Setting
1204-
MVT SimpleVT = LoadedVT.getSimpleVT();
1205-
12061205
// Type Setting: fromType + fromTypeWidth
12071206
//
12081207
// Sign : ISD::SEXTLOAD
12091208
// Unsign : ISD::ZEXTLOAD, ISD::NON_EXTLOAD or ISD::EXTLOAD and the
12101209
// type is integer
12111210
// Float : ISD::NON_EXTLOAD or ISD::EXTLOAD and the type is float
1212-
MVT ScalarVT = SimpleVT.getScalarType();
12131211
// Read at least 8 bits (predicates are stored as 8-bit values)
1214-
unsigned FromTypeWidth = std::max(8U, (unsigned)ScalarVT.getSizeInBits());
1215-
unsigned int FromType;
12161212
// The last operand holds the original LoadSDNode::getExtensionType() value
1217-
unsigned ExtensionType = cast<ConstantSDNode>(
1218-
N->getOperand(N->getNumOperands() - 1))->getZExtValue();
1219-
if (ExtensionType == ISD::SEXTLOAD)
1220-
FromType = NVPTX::PTXLdStInstCode::Signed;
1221-
else
1222-
FromType = getLdStRegType(ScalarVT);
1213+
const unsigned TotalWidth = MemVT.getSizeInBits();
1214+
unsigned ExtensionType = N->getConstantOperandVal(N->getNumOperands() - 1);
1215+
unsigned FromType = (ExtensionType == ISD::SEXTLOAD)
1216+
? NVPTX::PTXLdStInstCode::Signed
1217+
: getLdStRegType(MemVT.getScalarType());
12231218

12241219
unsigned VecType;
1225-
1220+
unsigned FromTypeWidth;
12261221
switch (N->getOpcode()) {
12271222
case NVPTXISD::LoadV2:
1223+
FromTypeWidth = TotalWidth / 2;
12281224
VecType = NVPTX::PTXLdStInstCode::V2;
12291225
break;
12301226
case NVPTXISD::LoadV4:
1227+
FromTypeWidth = TotalWidth / 4;
12311228
VecType = NVPTX::PTXLdStInstCode::V4;
12321229
break;
12331230
default:
12341231
return false;
12351232
}
12361233

1237-
EVT EltVT = N->getValueType(0);
1238-
1239-
if (isVectorElementTypeUpsized(EltVT)) {
1234+
if (isSubVectorPackedInI32(EltVT)) {
12401235
EltVT = MVT::i32;
12411236
FromType = NVPTX::PTXLdStInstCode::Untyped;
1242-
FromTypeWidth = 32;
12431237
}
12441238

12451239
SDValue Offset, Base;
@@ -1289,9 +1283,14 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
12891283
// LDG/LDU SD node (from custom vector handling), then its the second operand
12901284
SDValue Op1 = N->getOperand(N->getOpcode() == ISD::INTRINSIC_W_CHAIN ? 2 : 1);
12911285

1292-
EVT OrigType = N->getValueType(0);
1286+
const EVT OrigType = N->getValueType(0);
12931287
EVT EltVT = Mem->getMemoryVT();
12941288
unsigned NumElts = 1;
1289+
1290+
if (EltVT == MVT::i128 || EltVT == MVT::f128) {
1291+
EltVT = MVT::i64;
1292+
NumElts = 2;
1293+
}
12951294
if (EltVT.isVector()) {
12961295
NumElts = EltVT.getVectorNumElements();
12971296
EltVT = EltVT.getVectorElementType();
@@ -1311,11 +1310,9 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
13111310
// Build the "promoted" result VTList for the load. If we are really loading
13121311
// i8s, then the return type will be promoted to i16 since we do not expose
13131312
// 8-bit registers in NVPTX.
1314-
EVT NodeVT = (EltVT == MVT::i8) ? MVT::i16 : EltVT;
1313+
const EVT NodeVT = (EltVT == MVT::i8) ? MVT::i16 : EltVT;
13151314
SmallVector<EVT, 5> InstVTs;
1316-
for (unsigned i = 0; i != NumElts; ++i) {
1317-
InstVTs.push_back(NodeVT);
1318-
}
1315+
InstVTs.append(NumElts, NodeVT);
13191316
InstVTs.push_back(MVT::Other);
13201317
SDVTList InstVTList = CurDAG->getVTList(InstVTs);
13211318
SDValue Chain = N->getOperand(0);
@@ -1494,6 +1491,7 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
14941491
EVT EltVT = Op1.getValueType();
14951492
MemSDNode *MemSD = cast<MemSDNode>(N);
14961493
EVT StoreVT = MemSD->getMemoryVT();
1494+
assert(StoreVT.isSimple() && "Store value is not simple");
14971495

14981496
// Address Space Setting
14991497
unsigned CodeAddrSpace = getCodeAddrSpace(MemSD);
@@ -1508,35 +1506,35 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
15081506

15091507
// Type Setting: toType + toTypeWidth
15101508
// - for integer type, always use 'u'
1511-
assert(StoreVT.isSimple() && "Store value is not simple");
1512-
MVT ScalarVT = StoreVT.getSimpleVT().getScalarType();
1513-
unsigned ToTypeWidth = ScalarVT.getSizeInBits();
1514-
unsigned ToType = getLdStRegType(ScalarVT);
1509+
const unsigned TotalWidth = StoreVT.getSimpleVT().getSizeInBits();
1510+
unsigned ToType = getLdStRegType(StoreVT.getSimpleVT().getScalarType());
15151511

15161512
SmallVector<SDValue, 12> Ops;
15171513
SDValue N2;
15181514
unsigned VecType;
1515+
unsigned ToTypeWidth;
15191516

15201517
switch (N->getOpcode()) {
15211518
case NVPTXISD::StoreV2:
15221519
VecType = NVPTX::PTXLdStInstCode::V2;
15231520
Ops.append({N->getOperand(1), N->getOperand(2)});
15241521
N2 = N->getOperand(3);
1522+
ToTypeWidth = TotalWidth / 2;
15251523
break;
15261524
case NVPTXISD::StoreV4:
15271525
VecType = NVPTX::PTXLdStInstCode::V4;
15281526
Ops.append({N->getOperand(1), N->getOperand(2), N->getOperand(3),
15291527
N->getOperand(4)});
15301528
N2 = N->getOperand(5);
1529+
ToTypeWidth = TotalWidth / 4;
15311530
break;
15321531
default:
15331532
return false;
15341533
}
15351534

1536-
if (isVectorElementTypeUpsized(EltVT)) {
1535+
if (isSubVectorPackedInI32(EltVT)) {
15371536
EltVT = MVT::i32;
15381537
ToType = NVPTX::PTXLdStInstCode::Untyped;
1539-
ToTypeWidth = 32;
15401538
}
15411539

15421540
SDValue Offset, Base;

0 commit comments

Comments
 (0)