@@ -1168,60 +1168,54 @@ static bool isVectorElementTypeUpsized(EVT EltVT) {
1168
1168
1169
1169
bool NVPTXDAGToDAGISel::tryLoadVector (SDNode *N) {
1170
1170
MemSDNode *MemSD = cast<MemSDNode>(N);
1171
- EVT LoadedVT = MemSD->getMemoryVT ();
1172
- if (!LoadedVT .isSimple ())
1171
+ EVT MemEVT = MemSD->getMemoryVT ();
1172
+ if (!MemEVT .isSimple ())
1173
1173
return false ;
1174
+ MVT MemVT = MemEVT.getSimpleVT ();
1174
1175
1175
1176
// Address Space Setting
1176
1177
unsigned int CodeAddrSpace = getCodeAddrSpace (MemSD);
1177
1178
if (canLowerToLDG (MemSD, *Subtarget, CodeAddrSpace, MF)) {
1178
1179
return tryLDGLDU (N);
1179
1180
}
1180
1181
1182
+ EVT EltVT = N->getValueType (0 );
1181
1183
SDLoc DL (N);
1182
1184
SDValue Chain = N->getOperand (0 );
1183
1185
auto [Ordering, Scope] = insertMemoryInstructionFence (DL, Chain, MemSD);
1184
1186
1185
- // Vector Setting
1186
- MVT SimpleVT = LoadedVT.getSimpleVT ();
1187
-
1188
1187
// Type Setting: fromType + fromTypeWidth
1189
1188
//
1190
1189
// Sign : ISD::SEXTLOAD
1191
1190
// Unsign : ISD::ZEXTLOAD, ISD::NON_EXTLOAD or ISD::EXTLOAD and the
1192
1191
// type is integer
1193
1192
// Float : ISD::NON_EXTLOAD or ISD::EXTLOAD and the type is float
1194
- MVT ScalarVT = SimpleVT.getScalarType ();
1195
1193
// Read at least 8 bits (predicates are stored as 8-bit values)
1196
- unsigned FromTypeWidth = std::max (8U , (unsigned )ScalarVT.getSizeInBits ());
1197
- unsigned int FromType;
1198
1194
// The last operand holds the original LoadSDNode::getExtensionType() value
1199
- unsigned ExtensionType = cast<ConstantSDNode>(
1200
- N->getOperand (N->getNumOperands () - 1 ))->getZExtValue ();
1201
- if (ExtensionType == ISD::SEXTLOAD)
1202
- FromType = NVPTX::PTXLdStInstCode::Signed;
1203
- else
1204
- FromType = getLdStRegType (ScalarVT);
1195
+ const unsigned TotalWidth = MemVT.getSizeInBits ();
1196
+ unsigned ExtensionType = N->getConstantOperandVal (N->getNumOperands () - 1 );
1197
+ unsigned FromType = (ExtensionType == ISD::SEXTLOAD)
1198
+ ? NVPTX::PTXLdStInstCode::Signed
1199
+ : getLdStRegType (MemVT.getScalarType ());
1205
1200
1206
1201
unsigned VecType;
1207
-
1202
+ unsigned FromTypeWidth;
1208
1203
switch (N->getOpcode ()) {
1209
1204
case NVPTXISD::LoadV2:
1205
+ FromTypeWidth = TotalWidth / 2 ;
1210
1206
VecType = NVPTX::PTXLdStInstCode::V2;
1211
1207
break ;
1212
1208
case NVPTXISD::LoadV4:
1209
+ FromTypeWidth = TotalWidth / 4 ;
1213
1210
VecType = NVPTX::PTXLdStInstCode::V4;
1214
1211
break ;
1215
1212
default :
1216
1213
return false ;
1217
1214
}
1218
1215
1219
- EVT EltVT = N->getValueType (0 );
1220
-
1221
1216
if (isVectorElementTypeUpsized (EltVT)) {
1222
1217
EltVT = MVT::i32 ;
1223
1218
FromType = NVPTX::PTXLdStInstCode::Untyped;
1224
- FromTypeWidth = 32 ;
1225
1219
}
1226
1220
1227
1221
SDValue Offset, Base;
@@ -1271,9 +1265,14 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
1271
1265
// LDG/LDU SD node (from custom vector handling), then its the second operand
1272
1266
SDValue Op1 = N->getOperand (N->getOpcode () == ISD::INTRINSIC_W_CHAIN ? 2 : 1 );
1273
1267
1274
- EVT OrigType = N->getValueType (0 );
1268
+ const EVT OrigType = N->getValueType (0 );
1275
1269
EVT EltVT = Mem->getMemoryVT ();
1276
1270
unsigned NumElts = 1 ;
1271
+
1272
+ if (EltVT == MVT::i128 || EltVT == MVT::f128 ) {
1273
+ EltVT = MVT::i64 ;
1274
+ NumElts = 2 ;
1275
+ }
1277
1276
if (EltVT.isVector ()) {
1278
1277
NumElts = EltVT.getVectorNumElements ();
1279
1278
EltVT = EltVT.getVectorElementType ();
@@ -1293,11 +1292,9 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
1293
1292
// Build the "promoted" result VTList for the load. If we are really loading
1294
1293
// i8s, then the return type will be promoted to i16 since we do not expose
1295
1294
// 8-bit registers in NVPTX.
1296
- EVT NodeVT = (EltVT == MVT::i8 ) ? MVT::i16 : EltVT;
1295
+ const EVT NodeVT = (EltVT == MVT::i8 ) ? MVT::i16 : EltVT;
1297
1296
SmallVector<EVT, 5 > InstVTs;
1298
- for (unsigned i = 0 ; i != NumElts; ++i) {
1299
- InstVTs.push_back (NodeVT);
1300
- }
1297
+ InstVTs.append (NumElts, NodeVT);
1301
1298
InstVTs.push_back (MVT::Other);
1302
1299
SDVTList InstVTList = CurDAG->getVTList (InstVTs);
1303
1300
SDValue Chain = N->getOperand (0 );
@@ -1476,6 +1473,7 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
1476
1473
EVT EltVT = Op1.getValueType ();
1477
1474
MemSDNode *MemSD = cast<MemSDNode>(N);
1478
1475
EVT StoreVT = MemSD->getMemoryVT ();
1476
+ assert (StoreVT.isSimple () && " Store value is not simple" );
1479
1477
1480
1478
// Address Space Setting
1481
1479
unsigned CodeAddrSpace = getCodeAddrSpace (MemSD);
@@ -1490,35 +1488,35 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
1490
1488
1491
1489
// Type Setting: toType + toTypeWidth
1492
1490
// - for integer type, always use 'u'
1493
- assert (StoreVT.isSimple () && " Store value is not simple" );
1494
- MVT ScalarVT = StoreVT.getSimpleVT ().getScalarType ();
1495
- unsigned ToTypeWidth = ScalarVT.getSizeInBits ();
1496
- unsigned ToType = getLdStRegType (ScalarVT);
1491
+ const unsigned TotalWidth = StoreVT.getSimpleVT ().getSizeInBits ();
1492
+ unsigned ToType = getLdStRegType (StoreVT.getSimpleVT ().getScalarType ());
1497
1493
1498
1494
SmallVector<SDValue, 12 > Ops;
1499
1495
SDValue N2;
1500
1496
unsigned VecType;
1497
+ unsigned ToTypeWidth;
1501
1498
1502
1499
switch (N->getOpcode ()) {
1503
1500
case NVPTXISD::StoreV2:
1504
1501
VecType = NVPTX::PTXLdStInstCode::V2;
1505
1502
Ops.append ({N->getOperand (1 ), N->getOperand (2 )});
1506
1503
N2 = N->getOperand (3 );
1504
+ ToTypeWidth = TotalWidth / 2 ;
1507
1505
break ;
1508
1506
case NVPTXISD::StoreV4:
1509
1507
VecType = NVPTX::PTXLdStInstCode::V4;
1510
1508
Ops.append ({N->getOperand (1 ), N->getOperand (2 ), N->getOperand (3 ),
1511
1509
N->getOperand (4 )});
1512
1510
N2 = N->getOperand (5 );
1511
+ ToTypeWidth = TotalWidth / 4 ;
1513
1512
break ;
1514
1513
default :
1515
1514
return false ;
1516
1515
}
1517
1516
1518
1517
if (isVectorElementTypeUpsized (EltVT)) {
1519
- EltVT = MVT::i32 ;
1520
1518
ToType = NVPTX::PTXLdStInstCode::Untyped;
1521
- ToTypeWidth = 32 ;
1519
+ EltVT = MVT:: i32 ;
1522
1520
}
1523
1521
1524
1522
SDValue Offset, Base;
0 commit comments