Skip to content

Commit eb8e0b9

Browse files
committed
[NVPTX] Remove load/store type
1 parent ae6b4b2 commit eb8e0b9

File tree

189 files changed

+8554
-8582
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

189 files changed

+8554
-8582
lines changed

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 13 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1077,21 +1077,6 @@ pickOpcodeForVT(MVT::SimpleValueType VT, unsigned Opcode_i8,
10771077
}
10781078
}
10791079

1080-
static int getLdStRegType(EVT VT) {
1081-
if (VT.isFloatingPoint())
1082-
switch (VT.getSimpleVT().SimpleTy) {
1083-
case MVT::f16:
1084-
case MVT::bf16:
1085-
case MVT::v2f16:
1086-
case MVT::v2bf16:
1087-
return NVPTX::PTXLdStInstCode::Untyped;
1088-
default:
1089-
return NVPTX::PTXLdStInstCode::Float;
1090-
}
1091-
else
1092-
return NVPTX::PTXLdStInstCode::Unsigned;
1093-
}
1094-
10951080
bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
10961081
MemSDNode *LD = cast<MemSDNode>(N);
10971082
assert(LD->readMem() && "Expected load");
@@ -1122,32 +1107,22 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
11221107
// type is integer
11231108
// Float : ISD::NON_EXTLOAD or ISD::EXTLOAD and the type is float
11241109
MVT SimpleVT = LoadedVT.getSimpleVT();
1125-
MVT ScalarVT = SimpleVT.getScalarType();
11261110
// Read at least 8 bits (predicates are stored as 8-bit values)
1127-
unsigned FromTypeWidth = std::max(8U, (unsigned)ScalarVT.getSizeInBits());
1128-
unsigned int FromType;
1111+
unsigned FromTypeWidth = std::max(8U, (unsigned)SimpleVT.getSizeInBits());
11291112

11301113
// Vector Setting
1131-
unsigned VecType = NVPTX::PTXLdStInstCode::Scalar;
1132-
if (SimpleVT.isVector()) {
1133-
assert((Isv2x16VT(LoadedVT) || LoadedVT == MVT::v4i8) &&
1134-
"Unexpected vector type");
1135-
// v2f16/v2bf16/v2i16 is loaded using ld.b32
1136-
FromTypeWidth = 32;
1137-
}
1138-
1139-
if (PlainLoad && (PlainLoad->getExtensionType() == ISD::SEXTLOAD))
1140-
FromType = NVPTX::PTXLdStInstCode::Signed;
1141-
else
1142-
FromType = getLdStRegType(ScalarVT);
1114+
unsigned int FromType =
1115+
(PlainLoad && (PlainLoad->getExtensionType() == ISD::SEXTLOAD))
1116+
? NVPTX::PTXLdStInstCode::Signed
1117+
: NVPTX::PTXLdStInstCode::Untyped;
11431118

11441119
// Create the machine instruction DAG
11451120
SDValue Offset, Base;
11461121
SelectADDR(N->getOperand(1), Base, Offset);
11471122
SDValue Ops[] = {getI32Imm(Ordering, DL),
11481123
getI32Imm(Scope, DL),
11491124
getI32Imm(CodeAddrSpace, DL),
1150-
getI32Imm(VecType, DL),
1125+
getI32Imm(NVPTX::PTXLdStInstCode::Scalar, DL),
11511126
getI32Imm(FromType, DL),
11521127
getI32Imm(FromTypeWidth, DL),
11531128
Base,
@@ -1214,7 +1189,7 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
12141189
unsigned ExtensionType = N->getConstantOperandVal(N->getNumOperands() - 1);
12151190
unsigned FromType = (ExtensionType == ISD::SEXTLOAD)
12161191
? NVPTX::PTXLdStInstCode::Signed
1217-
: getLdStRegType(MemVT.getScalarType());
1192+
: NVPTX::PTXLdStInstCode::Untyped;
12181193

12191194
unsigned VecType;
12201195
unsigned FromTypeWidth;
@@ -1232,8 +1207,8 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
12321207
}
12331208

12341209
if (isSubVectorPackedInI32(EltVT)) {
1210+
assert(ExtensionType == ISD::NON_EXTLOAD);
12351211
EltVT = MVT::i32;
1236-
FromType = NVPTX::PTXLdStInstCode::Untyped;
12371212
}
12381213

12391214
SDValue Offset, Base;
@@ -1434,21 +1409,7 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
14341409
auto [Ordering, Scope] = insertMemoryInstructionFence(DL, Chain, ST);
14351410

14361411
// Vector Setting
1437-
MVT SimpleVT = StoreVT.getSimpleVT();
1438-
unsigned VecType = NVPTX::PTXLdStInstCode::Scalar;
1439-
1440-
// Type Setting: toType + toTypeWidth
1441-
// - for integer type, always use 'u'
1442-
MVT ScalarVT = SimpleVT.getScalarType();
1443-
unsigned ToTypeWidth = ScalarVT.getSizeInBits();
1444-
if (SimpleVT.isVector()) {
1445-
assert((Isv2x16VT(StoreVT) || StoreVT == MVT::v4i8) &&
1446-
"Unexpected vector type");
1447-
// v2x16 is stored using st.b32
1448-
ToTypeWidth = 32;
1449-
}
1450-
1451-
unsigned int ToType = getLdStRegType(ScalarVT);
1412+
const unsigned ToTypeWidth = StoreVT.getSimpleVT().getSizeInBits();
14521413

14531414
// Create the machine instruction DAG
14541415
SDValue Value = PlainStore ? PlainStore->getValue() : AtomicStore->getVal();
@@ -1460,8 +1421,8 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
14601421
getI32Imm(Ordering, DL),
14611422
getI32Imm(Scope, DL),
14621423
getI32Imm(CodeAddrSpace, DL),
1463-
getI32Imm(VecType, DL),
1464-
getI32Imm(ToType, DL),
1424+
getI32Imm(NVPTX::PTXLdStInstCode::Scalar, DL),
1425+
getI32Imm(NVPTX::PTXLdStInstCode::Untyped, DL),
14651426
getI32Imm(ToTypeWidth, DL),
14661427
Base,
14671428
Offset,
@@ -1507,7 +1468,6 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
15071468
// Type Setting: toType + toTypeWidth
15081469
// - for integer type, always use 'u'
15091470
const unsigned TotalWidth = StoreVT.getSimpleVT().getSizeInBits();
1510-
unsigned ToType = getLdStRegType(StoreVT.getSimpleVT().getScalarType());
15111471

15121472
SmallVector<SDValue, 12> Ops;
15131473
SDValue N2;
@@ -1534,16 +1494,15 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
15341494

15351495
if (isSubVectorPackedInI32(EltVT)) {
15361496
EltVT = MVT::i32;
1537-
ToType = NVPTX::PTXLdStInstCode::Untyped;
15381497
}
15391498

15401499
SDValue Offset, Base;
15411500
SelectADDR(N2, Base, Offset);
15421501

15431502
Ops.append({getI32Imm(Ordering, DL), getI32Imm(Scope, DL),
15441503
getI32Imm(CodeAddrSpace, DL), getI32Imm(VecType, DL),
1545-
getI32Imm(ToType, DL), getI32Imm(ToTypeWidth, DL), Base, Offset,
1546-
Chain});
1504+
getI32Imm(NVPTX::PTXLdStInstCode::Untyped, DL),
1505+
getI32Imm(ToTypeWidth, DL), Base, Offset, Chain});
15471506

15481507
std::optional<unsigned> Opcode;
15491508
switch (N->getOpcode()) {

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2296,11 +2296,11 @@ def LoadParamMemV2I8 : LoadParamV2MemInst<Int16Regs, ".b8">;
22962296
def LoadParamMemV4I32 : LoadParamV4MemInst<Int32Regs, ".b32">;
22972297
def LoadParamMemV4I16 : LoadParamV4MemInst<Int16Regs, ".b16">;
22982298
def LoadParamMemV4I8 : LoadParamV4MemInst<Int16Regs, ".b8">;
2299-
def LoadParamMemF32 : LoadParamMemInst<Float32Regs, ".f32">;
2300-
def LoadParamMemF64 : LoadParamMemInst<Float64Regs, ".f64">;
2301-
def LoadParamMemV2F32 : LoadParamV2MemInst<Float32Regs, ".f32">;
2302-
def LoadParamMemV2F64 : LoadParamV2MemInst<Float64Regs, ".f64">;
2303-
def LoadParamMemV4F32 : LoadParamV4MemInst<Float32Regs, ".f32">;
2299+
def LoadParamMemF32 : LoadParamMemInst<Float32Regs, ".b32">;
2300+
def LoadParamMemF64 : LoadParamMemInst<Float64Regs, ".b64">;
2301+
def LoadParamMemV2F32 : LoadParamV2MemInst<Float32Regs, ".b32">;
2302+
def LoadParamMemV2F64 : LoadParamV2MemInst<Float64Regs, ".b64">;
2303+
def LoadParamMemV4F32 : LoadParamV4MemInst<Float32Regs, ".b32">;
23042304

23052305
defm StoreParamI64 : StoreParamInst<Int64Regs, i64imm, ".b64">;
23062306
defm StoreParamI32 : StoreParamInst<Int32Regs, i32imm, ".b32">;
@@ -2319,13 +2319,13 @@ defm StoreParamV4I32 : StoreParamV4Inst<Int32Regs, i32imm, ".b32">;
23192319
defm StoreParamV4I16 : StoreParamV4Inst<Int16Regs, i16imm, ".b16">;
23202320
defm StoreParamV4I8 : StoreParamV4Inst<Int16Regs, i8imm, ".b8">;
23212321

2322-
defm StoreParamF32 : StoreParamInst<Float32Regs, f32imm, ".f32">;
2323-
defm StoreParamF64 : StoreParamInst<Float64Regs, f64imm, ".f64">;
2322+
defm StoreParamF32 : StoreParamInst<Float32Regs, f32imm, ".b32">;
2323+
defm StoreParamF64 : StoreParamInst<Float64Regs, f64imm, ".b64">;
23242324

2325-
defm StoreParamV2F32 : StoreParamV2Inst<Float32Regs, f32imm, ".f32">;
2326-
defm StoreParamV2F64 : StoreParamV2Inst<Float64Regs, f64imm, ".f64">;
2325+
defm StoreParamV2F32 : StoreParamV2Inst<Float32Regs, f32imm, ".b32">;
2326+
defm StoreParamV2F64 : StoreParamV2Inst<Float64Regs, f64imm, ".b64">;
23272327

2328-
defm StoreParamV4F32 : StoreParamV4Inst<Float32Regs, f32imm, ".f32">;
2328+
defm StoreParamV4F32 : StoreParamV4Inst<Float32Regs, f32imm, ".b32">;
23292329

23302330
def StoreRetvalI64 : StoreRetvalInst<Int64Regs, ".b64">;
23312331
def StoreRetvalI32 : StoreRetvalInst<Int32Regs, ".b32">;
@@ -2341,11 +2341,11 @@ def StoreRetvalV4I32 : StoreRetvalV4Inst<Int32Regs, ".b32">;
23412341
def StoreRetvalV4I16 : StoreRetvalV4Inst<Int16Regs, ".b16">;
23422342
def StoreRetvalV4I8 : StoreRetvalV4Inst<Int16Regs, ".b8">;
23432343

2344-
def StoreRetvalF64 : StoreRetvalInst<Float64Regs, ".f64">;
2345-
def StoreRetvalF32 : StoreRetvalInst<Float32Regs, ".f32">;
2346-
def StoreRetvalV2F64 : StoreRetvalV2Inst<Float64Regs, ".f64">;
2347-
def StoreRetvalV2F32 : StoreRetvalV2Inst<Float32Regs, ".f32">;
2348-
def StoreRetvalV4F32 : StoreRetvalV4Inst<Float32Regs, ".f32">;
2344+
def StoreRetvalF64 : StoreRetvalInst<Float64Regs, ".b64">;
2345+
def StoreRetvalF32 : StoreRetvalInst<Float32Regs, ".b32">;
2346+
def StoreRetvalV2F64 : StoreRetvalV2Inst<Float64Regs, ".b64">;
2347+
def StoreRetvalV2F32 : StoreRetvalV2Inst<Float32Regs, ".b32">;
2348+
def StoreRetvalV4F32 : StoreRetvalV4Inst<Float32Regs, ".b32">;
23492349

23502350
def CallArgBeginInst : NVPTXInst<(outs), (ins), "(", [(CallArgBegin)]>;
23512351
def CallArgEndInst1 : NVPTXInst<(outs), (ins), ");", [(CallArgEnd (i32 1))]>;

llvm/lib/Target/NVPTX/NVPTXIntrinsics.td

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2302,12 +2302,12 @@ class LDU_G<string TyStr, NVPTXRegClass regclass>
23022302
"ldu.global." # TyStr # " \t$result, [$src];",
23032303
[]>, Requires<[hasLDU]>;
23042304

2305-
def INT_PTX_LDU_GLOBAL_i8 : LDU_G<"u8", Int16Regs>;
2306-
def INT_PTX_LDU_GLOBAL_i16 : LDU_G<"u16", Int16Regs>;
2307-
def INT_PTX_LDU_GLOBAL_i32 : LDU_G<"u32", Int32Regs>;
2308-
def INT_PTX_LDU_GLOBAL_i64 : LDU_G<"u64", Int64Regs>;
2309-
def INT_PTX_LDU_GLOBAL_f32 : LDU_G<"f32", Float32Regs>;
2310-
def INT_PTX_LDU_GLOBAL_f64 : LDU_G<"f64", Float64Regs>;
2305+
def INT_PTX_LDU_GLOBAL_i8 : LDU_G<"b8", Int16Regs>;
2306+
def INT_PTX_LDU_GLOBAL_i16 : LDU_G<"b16", Int16Regs>;
2307+
def INT_PTX_LDU_GLOBAL_i32 : LDU_G<"b32", Int32Regs>;
2308+
def INT_PTX_LDU_GLOBAL_i64 : LDU_G<"b64", Int64Regs>;
2309+
def INT_PTX_LDU_GLOBAL_f32 : LDU_G<"b32", Float32Regs>;
2310+
def INT_PTX_LDU_GLOBAL_f64 : LDU_G<"b64", Float64Regs>;
23112311

23122312
// vector
23132313

@@ -2324,19 +2324,19 @@ class VLDU_G_ELE_V4<string TyStr, NVPTXRegClass regclass>
23242324
"ldu.global.v4." # TyStr # " \t{{$dst1, $dst2, $dst3, $dst4}}, [$src];", []>;
23252325

23262326

2327-
def INT_PTX_LDU_G_v2i8_ELE : VLDU_G_ELE_V2<"u8", Int16Regs>;
2328-
def INT_PTX_LDU_G_v2i16_ELE : VLDU_G_ELE_V2<"u16", Int16Regs>;
2329-
def INT_PTX_LDU_G_v2i32_ELE : VLDU_G_ELE_V2<"u32", Int32Regs>;
2330-
def INT_PTX_LDU_G_v2f32_ELE : VLDU_G_ELE_V2<"f32", Float32Regs>;
2331-
def INT_PTX_LDU_G_v2i64_ELE : VLDU_G_ELE_V2<"u64", Int64Regs>;
2332-
def INT_PTX_LDU_G_v2f64_ELE : VLDU_G_ELE_V2<"f64", Float64Regs>;
2327+
def INT_PTX_LDU_G_v2i8_ELE : VLDU_G_ELE_V2<"b8", Int16Regs>;
2328+
def INT_PTX_LDU_G_v2i16_ELE : VLDU_G_ELE_V2<"b16", Int16Regs>;
2329+
def INT_PTX_LDU_G_v2i32_ELE : VLDU_G_ELE_V2<"b32", Int32Regs>;
2330+
def INT_PTX_LDU_G_v2f32_ELE : VLDU_G_ELE_V2<"b32", Float32Regs>;
2331+
def INT_PTX_LDU_G_v2i64_ELE : VLDU_G_ELE_V2<"b64", Int64Regs>;
2332+
def INT_PTX_LDU_G_v2f64_ELE : VLDU_G_ELE_V2<"b64", Float64Regs>;
23332333

2334-
def INT_PTX_LDU_G_v4i8_ELE : VLDU_G_ELE_V4<"u8", Int16Regs>;
2335-
def INT_PTX_LDU_G_v4i16_ELE : VLDU_G_ELE_V4<"u16", Int16Regs>;
2336-
def INT_PTX_LDU_G_v4i32_ELE : VLDU_G_ELE_V4<"u32", Int32Regs>;
2334+
def INT_PTX_LDU_G_v4i8_ELE : VLDU_G_ELE_V4<"b8", Int16Regs>;
2335+
def INT_PTX_LDU_G_v4i16_ELE : VLDU_G_ELE_V4<"b16", Int16Regs>;
2336+
def INT_PTX_LDU_G_v4i32_ELE : VLDU_G_ELE_V4<"b32", Int32Regs>;
23372337
def INT_PTX_LDU_G_v4f16_ELE : VLDU_G_ELE_V4<"b16", Int16Regs>;
23382338
def INT_PTX_LDU_G_v4f16x2_ELE : VLDU_G_ELE_V4<"b32", Int32Regs>;
2339-
def INT_PTX_LDU_G_v4f32_ELE : VLDU_G_ELE_V4<"f32", Float32Regs>;
2339+
def INT_PTX_LDU_G_v4f32_ELE : VLDU_G_ELE_V4<"b32", Float32Regs>;
23402340

23412341

23422342
//-----------------------------------
@@ -2352,12 +2352,12 @@ class LDG_G<string TyStr, NVPTXRegClass regclass>
23522352
"ld.global.nc." # TyStr # " \t$result, [$src];",
23532353
[]>, Requires<[hasLDG]>;
23542354

2355-
def INT_PTX_LDG_GLOBAL_i8 : LDG_G<"u8", Int16Regs>;
2356-
def INT_PTX_LDG_GLOBAL_i16 : LDG_G<"u16", Int16Regs>;
2357-
def INT_PTX_LDG_GLOBAL_i32 : LDG_G<"u32", Int32Regs>;
2358-
def INT_PTX_LDG_GLOBAL_i64 : LDG_G<"u64", Int64Regs>;
2359-
def INT_PTX_LDG_GLOBAL_f32 : LDG_G<"f32", Float32Regs>;
2360-
def INT_PTX_LDG_GLOBAL_f64 : LDG_G<"f64", Float64Regs>;
2355+
def INT_PTX_LDG_GLOBAL_i8 : LDG_G<"b8", Int16Regs>;
2356+
def INT_PTX_LDG_GLOBAL_i16 : LDG_G<"b16", Int16Regs>;
2357+
def INT_PTX_LDG_GLOBAL_i32 : LDG_G<"b32", Int32Regs>;
2358+
def INT_PTX_LDG_GLOBAL_i64 : LDG_G<"b64", Int64Regs>;
2359+
def INT_PTX_LDG_GLOBAL_f32 : LDG_G<"b32", Float32Regs>;
2360+
def INT_PTX_LDG_GLOBAL_f64 : LDG_G<"b64", Float64Regs>;
23612361

23622362
// vector
23632363

@@ -2374,17 +2374,17 @@ class VLDG_G_ELE_V4<string TyStr, NVPTXRegClass regclass> :
23742374
"ld.global.nc.v4." # TyStr # " \t{{$dst1, $dst2, $dst3, $dst4}}, [$src];", []>;
23752375

23762376
// FIXME: 8-bit LDG should be fixed once LDG/LDU nodes are made into proper loads.
2377-
def INT_PTX_LDG_G_v2i8_ELE : VLDG_G_ELE_V2<"u8", Int16Regs>;
2378-
def INT_PTX_LDG_G_v2i16_ELE : VLDG_G_ELE_V2<"u16", Int16Regs>;
2379-
def INT_PTX_LDG_G_v2i32_ELE : VLDG_G_ELE_V2<"u32", Int32Regs>;
2380-
def INT_PTX_LDG_G_v2f32_ELE : VLDG_G_ELE_V2<"f32", Float32Regs>;
2381-
def INT_PTX_LDG_G_v2i64_ELE : VLDG_G_ELE_V2<"u64", Int64Regs>;
2382-
def INT_PTX_LDG_G_v2f64_ELE : VLDG_G_ELE_V2<"f64", Float64Regs>;
2383-
2384-
def INT_PTX_LDG_G_v4i8_ELE : VLDG_G_ELE_V4<"u8", Int16Regs>;
2385-
def INT_PTX_LDG_G_v4i16_ELE : VLDG_G_ELE_V4<"u16", Int16Regs>;
2386-
def INT_PTX_LDG_G_v4i32_ELE : VLDG_G_ELE_V4<"u32", Int32Regs>;
2387-
def INT_PTX_LDG_G_v4f32_ELE : VLDG_G_ELE_V4<"f32", Float32Regs>;
2377+
def INT_PTX_LDG_G_v2i8_ELE : VLDG_G_ELE_V2<"b8", Int16Regs>;
2378+
def INT_PTX_LDG_G_v2i16_ELE : VLDG_G_ELE_V2<"b16", Int16Regs>;
2379+
def INT_PTX_LDG_G_v2i32_ELE : VLDG_G_ELE_V2<"b32", Int32Regs>;
2380+
def INT_PTX_LDG_G_v2f32_ELE : VLDG_G_ELE_V2<"b32", Float32Regs>;
2381+
def INT_PTX_LDG_G_v2i64_ELE : VLDG_G_ELE_V2<"b64", Int64Regs>;
2382+
def INT_PTX_LDG_G_v2f64_ELE : VLDG_G_ELE_V2<"b64", Float64Regs>;
2383+
2384+
def INT_PTX_LDG_G_v4i8_ELE : VLDG_G_ELE_V4<"b8", Int16Regs>;
2385+
def INT_PTX_LDG_G_v4i16_ELE : VLDG_G_ELE_V4<"b16", Int16Regs>;
2386+
def INT_PTX_LDG_G_v4i32_ELE : VLDG_G_ELE_V4<"b32", Int32Regs>;
2387+
def INT_PTX_LDG_G_v4f32_ELE : VLDG_G_ELE_V4<"b32", Float32Regs>;
23882388

23892389

23902390
multiclass NG_TO_G<string Str, bit Supports32 = 1, list<Predicate> Preds = []> {

0 commit comments

Comments
 (0)