Skip to content

Commit 3b81597

Browse files
committed
[NVPTX] Use v2.u64 to load/store 128-bit values
1 parent 209d8c8 commit 3b81597

File tree

8 files changed

+110
-134
lines changed

8 files changed

+110
-134
lines changed

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 27 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1168,60 +1168,54 @@ static bool isVectorElementTypeUpsized(EVT EltVT) {
11681168

11691169
bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
11701170
MemSDNode *MemSD = cast<MemSDNode>(N);
1171-
EVT LoadedVT = MemSD->getMemoryVT();
1172-
if (!LoadedVT.isSimple())
1171+
EVT MemEVT = MemSD->getMemoryVT();
1172+
if (!MemEVT.isSimple())
11731173
return false;
1174+
MVT MemVT = MemEVT.getSimpleVT();
11741175

11751176
// Address Space Setting
11761177
unsigned int CodeAddrSpace = getCodeAddrSpace(MemSD);
11771178
if (canLowerToLDG(MemSD, *Subtarget, CodeAddrSpace, MF)) {
11781179
return tryLDGLDU(N);
11791180
}
11801181

1182+
EVT EltVT = N->getValueType(0);
11811183
SDLoc DL(N);
11821184
SDValue Chain = N->getOperand(0);
11831185
auto [Ordering, Scope] = insertMemoryInstructionFence(DL, Chain, MemSD);
11841186

1185-
// Vector Setting
1186-
MVT SimpleVT = LoadedVT.getSimpleVT();
1187-
11881187
// Type Setting: fromType + fromTypeWidth
11891188
//
11901189
// Sign : ISD::SEXTLOAD
11911190
// Unsign : ISD::ZEXTLOAD, ISD::NON_EXTLOAD or ISD::EXTLOAD and the
11921191
// type is integer
11931192
// Float : ISD::NON_EXTLOAD or ISD::EXTLOAD and the type is float
1194-
MVT ScalarVT = SimpleVT.getScalarType();
11951193
// Read at least 8 bits (predicates are stored as 8-bit values)
1196-
unsigned FromTypeWidth = std::max(8U, (unsigned)ScalarVT.getSizeInBits());
1197-
unsigned int FromType;
11981194
// The last operand holds the original LoadSDNode::getExtensionType() value
1199-
unsigned ExtensionType = cast<ConstantSDNode>(
1200-
N->getOperand(N->getNumOperands() - 1))->getZExtValue();
1201-
if (ExtensionType == ISD::SEXTLOAD)
1202-
FromType = NVPTX::PTXLdStInstCode::Signed;
1203-
else
1204-
FromType = getLdStRegType(ScalarVT);
1195+
const unsigned TotalWidth = MemVT.getSizeInBits();
1196+
unsigned ExtensionType = N->getConstantOperandVal(N->getNumOperands() - 1);
1197+
unsigned FromType = (ExtensionType == ISD::SEXTLOAD)
1198+
? NVPTX::PTXLdStInstCode::Signed
1199+
: getLdStRegType(MemVT.getScalarType());
12051200

12061201
unsigned VecType;
1207-
1202+
unsigned FromTypeWidth;
12081203
switch (N->getOpcode()) {
12091204
case NVPTXISD::LoadV2:
1205+
FromTypeWidth = TotalWidth / 2;
12101206
VecType = NVPTX::PTXLdStInstCode::V2;
12111207
break;
12121208
case NVPTXISD::LoadV4:
1209+
FromTypeWidth = TotalWidth / 4;
12131210
VecType = NVPTX::PTXLdStInstCode::V4;
12141211
break;
12151212
default:
12161213
return false;
12171214
}
12181215

1219-
EVT EltVT = N->getValueType(0);
1220-
12211216
if (isVectorElementTypeUpsized(EltVT)) {
12221217
EltVT = MVT::i32;
12231218
FromType = NVPTX::PTXLdStInstCode::Untyped;
1224-
FromTypeWidth = 32;
12251219
}
12261220

12271221
SDValue Offset, Base;
@@ -1271,9 +1265,14 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
12711265
// LDG/LDU SD node (from custom vector handling), then its the second operand
12721266
SDValue Op1 = N->getOperand(N->getOpcode() == ISD::INTRINSIC_W_CHAIN ? 2 : 1);
12731267

1274-
EVT OrigType = N->getValueType(0);
1268+
const EVT OrigType = N->getValueType(0);
12751269
EVT EltVT = Mem->getMemoryVT();
12761270
unsigned NumElts = 1;
1271+
1272+
if (EltVT == MVT::i128 || EltVT == MVT::f128) {
1273+
EltVT = MVT::i64;
1274+
NumElts = 2;
1275+
}
12771276
if (EltVT.isVector()) {
12781277
NumElts = EltVT.getVectorNumElements();
12791278
EltVT = EltVT.getVectorElementType();
@@ -1293,11 +1292,9 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
12931292
// Build the "promoted" result VTList for the load. If we are really loading
12941293
// i8s, then the return type will be promoted to i16 since we do not expose
12951294
// 8-bit registers in NVPTX.
1296-
EVT NodeVT = (EltVT == MVT::i8) ? MVT::i16 : EltVT;
1295+
const EVT NodeVT = (EltVT == MVT::i8) ? MVT::i16 : EltVT;
12971296
SmallVector<EVT, 5> InstVTs;
1298-
for (unsigned i = 0; i != NumElts; ++i) {
1299-
InstVTs.push_back(NodeVT);
1300-
}
1297+
InstVTs.append(NumElts, NodeVT);
13011298
InstVTs.push_back(MVT::Other);
13021299
SDVTList InstVTList = CurDAG->getVTList(InstVTs);
13031300
SDValue Chain = N->getOperand(0);
@@ -1476,6 +1473,7 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
14761473
EVT EltVT = Op1.getValueType();
14771474
MemSDNode *MemSD = cast<MemSDNode>(N);
14781475
EVT StoreVT = MemSD->getMemoryVT();
1476+
assert(StoreVT.isSimple() && "Store value is not simple");
14791477

14801478
// Address Space Setting
14811479
unsigned CodeAddrSpace = getCodeAddrSpace(MemSD);
@@ -1490,35 +1488,35 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
14901488

14911489
// Type Setting: toType + toTypeWidth
14921490
// - for integer type, always use 'u'
1493-
assert(StoreVT.isSimple() && "Store value is not simple");
1494-
MVT ScalarVT = StoreVT.getSimpleVT().getScalarType();
1495-
unsigned ToTypeWidth = ScalarVT.getSizeInBits();
1496-
unsigned ToType = getLdStRegType(ScalarVT);
1491+
const unsigned TotalWidth = StoreVT.getSimpleVT().getSizeInBits();
1492+
unsigned ToType = getLdStRegType(StoreVT.getSimpleVT().getScalarType());
14971493

14981494
SmallVector<SDValue, 12> Ops;
14991495
SDValue N2;
15001496
unsigned VecType;
1497+
unsigned ToTypeWidth;
15011498

15021499
switch (N->getOpcode()) {
15031500
case NVPTXISD::StoreV2:
15041501
VecType = NVPTX::PTXLdStInstCode::V2;
15051502
Ops.append({N->getOperand(1), N->getOperand(2)});
15061503
N2 = N->getOperand(3);
1504+
ToTypeWidth = TotalWidth / 2;
15071505
break;
15081506
case NVPTXISD::StoreV4:
15091507
VecType = NVPTX::PTXLdStInstCode::V4;
15101508
Ops.append({N->getOperand(1), N->getOperand(2), N->getOperand(3),
15111509
N->getOperand(4)});
15121510
N2 = N->getOperand(5);
1511+
ToTypeWidth = TotalWidth / 4;
15131512
break;
15141513
default:
15151514
return false;
15161515
}
15171516

15181517
if (isVectorElementTypeUpsized(EltVT)) {
1519-
EltVT = MVT::i32;
15201518
ToType = NVPTX::PTXLdStInstCode::Untyped;
1521-
ToTypeWidth = 32;
1519+
EltVT = MVT::i32;
15221520
}
15231521

15241522
SDValue Offset, Base;

0 commit comments

Comments
 (0)