Skip to content

Commit 369891b

Browse files
authored
[NVPTX] use untyped loads and stores where ever possible (#137698)
In most cases, the type information attached to load and store instructions is meaningless and inconsistently applied. We can usually use ".b" loads and avoid the complexity of trying to assign the correct type. The one expectation is sign-extending load, which will continue to use ".s" to ensure the sign extension into a larger register is done correctly.
1 parent 1c8cc3b commit 369891b

File tree

203 files changed

+8669
-8697
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

203 files changed

+8669
-8697
lines changed

clang/test/CodeGenCUDA/bf16.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
// CHECK: .param .align 2 .b8 _Z8test_argPDF16bDF16b_param_1[2]
1212
//
1313
__device__ void test_arg(__bf16 *out, __bf16 in) {
14-
// CHECK-DAG: ld.param.u64 %[[A:rd[0-9]+]], [_Z8test_argPDF16bDF16b_param_0];
14+
// CHECK-DAG: ld.param.b64 %[[A:rd[0-9]+]], [_Z8test_argPDF16bDF16b_param_0];
1515
// CHECK-DAG: ld.param.b16 %[[R:rs[0-9]+]], [_Z8test_argPDF16bDF16b_param_1];
1616
__bf16 bf16 = in;
1717
*out = bf16;

clang/test/CodeGenCUDA/fp-contract.cu

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -179,26 +179,26 @@
179179
__host__ __device__ float func(float a, float b, float c) { return a + b * c; }
180180
// COMMON-LABEL: _Z4funcfff
181181
// NV-ON: fma.rn.f32
182-
// NV-ON-NEXT: st.param.f32
182+
// NV-ON-NEXT: st.param.b32
183183
// AMD-ON: v_fmac_f32_e64
184184
// AMD-ON-NEXT: s_setpc_b64
185185

186186
// NV-OFF: mul.rn.f32
187187
// NV-OFF-NEXT: add.rn.f32
188-
// NV-OFF-NEXT: st.param.f32
188+
// NV-OFF-NEXT: st.param.b32
189189
// AMD-OFF: v_mul_f32_e64
190190
// AMD-OFF-NEXT: v_add_f32_e64
191191
// AMD-OFF-NEXT: s_setpc_b64
192192

193193
// NV-OPT-FAST: fma.rn.f32
194-
// NV-OPT-FAST-NEXT: st.param.f32
194+
// NV-OPT-FAST-NEXT: st.param.b32
195195
// NV-OPT-FASTSTD: fma.rn.f32
196-
// NV-OPT-FASTSTD-NEXT: st.param.f32
196+
// NV-OPT-FASTSTD-NEXT: st.param.b32
197197
// NV-OPT-ON: fma.rn.f32
198-
// NV-OPT-ON-NEXT: st.param.f32
198+
// NV-OPT-ON-NEXT: st.param.b32
199199
// NV-OPT-OFF: mul.rn.f32
200200
// NV-OPT-OFF-NEXT: add.rn.f32
201-
// NV-OPT-OFF-NEXT: st.param.f32
201+
// NV-OPT-OFF-NEXT: st.param.b32
202202

203203
// AMD-OPT-FAST-IR: fmul contract float
204204
// AMD-OPT-FAST-IR: fadd contract float
@@ -224,15 +224,15 @@ __host__ __device__ float func2(float a, float b, float c) {
224224
}
225225
// COMMON-LABEL: _Z5func2fff
226226
// NV-OPT-FAST: fma.rn.f32
227-
// NV-OPT-FAST-NEXT: st.param.f32
227+
// NV-OPT-FAST-NEXT: st.param.b32
228228
// NV-OPT-FASTSTD: fma.rn.f32
229-
// NV-OPT-FASTSTD-NEXT: st.param.f32
229+
// NV-OPT-FASTSTD-NEXT: st.param.b32
230230
// NV-OPT-ON: mul.rn.f32
231231
// NV-OPT-ON: add.rn.f32
232-
// NV-OPT-ON-NEXT: st.param.f32
232+
// NV-OPT-ON-NEXT: st.param.b32
233233
// NV-OPT-OFF: mul.rn.f32
234234
// NV-OPT-OFF: add.rn.f32
235-
// NV-OPT-OFF-NEXT: st.param.f32
235+
// NV-OPT-OFF-NEXT: st.param.b32
236236

237237
// AMD-OPT-FAST-IR: fmul contract float
238238
// AMD-OPT-FAST-IR: fadd contract float
@@ -267,16 +267,16 @@ __host__ __device__ float func2(float a, float b, float c) {
267267
}
268268
// COMMON-LABEL: _Z5func3fff
269269
// NV-OPT-FAST: fma.rn.f32
270-
// NV-OPT-FAST-NEXT: st.param.f32
270+
// NV-OPT-FAST-NEXT: st.param.b32
271271
// NV-OPT-FASTSTD: mul.rn.f32
272272
// NV-OPT-FASTSTD: add.rn.f32
273-
// NV-OPT-FASTSTD-NEXT: st.param.f32
273+
// NV-OPT-FASTSTD-NEXT: st.param.b32
274274
// NV-OPT-ON: mul.rn.f32
275275
// NV-OPT-ON: add.rn.f32
276-
// NV-OPT-ON-NEXT: st.param.f32
276+
// NV-OPT-ON-NEXT: st.param.b32
277277
// NV-OPT-OFF: mul.rn.f32
278278
// NV-OPT-OFF: add.rn.f32
279-
// NV-OPT-OFF-NEXT: st.param.f32
279+
// NV-OPT-OFF-NEXT: st.param.b32
280280

281281
// AMD-OPT-FAST-IR: fmul float
282282
// AMD-OPT-FAST-IR: fadd float

clang/test/CodeGenCUDA/memcpy-libcall.cu

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,15 @@
1010
// PTX-LABEL: .func _Z12copy_genericPvPKv(
1111
void __device__ copy_generic(void *dest, const void *src) {
1212
__builtin_memcpy(dest, src, 32);
13-
// PTX: ld.u8
14-
// PTX: st.u8
13+
// PTX: ld.b8
14+
// PTX: st.b8
1515
}
1616

1717
// PTX-LABEL: .entry _Z11copy_globalPvS_(
1818
void __global__ copy_global(void *dest, void * src) {
1919
__builtin_memcpy(dest, src, 32);
20-
// PTX: ld.global.u8
21-
// PTX: st.global.u8
20+
// PTX: ld.global.b8
21+
// PTX: st.global.b8
2222
}
2323

2424
struct S {
@@ -28,37 +28,37 @@ struct S {
2828
// PTX-LABEL: .entry _Z20copy_param_to_globalP1SS_(
2929
void __global__ copy_param_to_global(S *global, S param) {
3030
__builtin_memcpy(global, &param, sizeof(S));
31-
// PTX: ld.param.u32
32-
// PTX: st.global.u32
31+
// PTX: ld.param.b32
32+
// PTX: st.global.b32
3333
}
3434

3535
// PTX-LABEL: .entry _Z19copy_param_to_localPU3AS51SS_(
3636
void __global__ copy_param_to_local(__attribute__((address_space(5))) S *local,
3737
S param) {
3838
__builtin_memcpy(local, &param, sizeof(S));
39-
// PTX: ld.param.u32
40-
// PTX: st.local.u32
39+
// PTX: ld.param.b32
40+
// PTX: st.local.b32
4141
}
4242

4343
// PTX-LABEL: .func _Z21copy_local_to_genericP1SPU3AS5S_(
4444
void __device__ copy_local_to_generic(S *generic,
4545
__attribute__((address_space(5))) S *src) {
4646
__builtin_memcpy(generic, src, sizeof(S));
47-
// PTX: ld.local.u32
48-
// PTX: st.u32
47+
// PTX: ld.local.b32
48+
// PTX: st.b32
4949
}
5050

5151
__shared__ S shared;
5252

5353
// PTX-LABEL: .entry _Z20copy_param_to_shared1S(
5454
void __global__ copy_param_to_shared( S param) {
5555
__builtin_memcpy(&shared, &param, sizeof(S));
56-
// PTX: ld.param.u32
57-
// PTX: st.shared.u32
56+
// PTX: ld.param.b32
57+
// PTX: st.shared.b32
5858
}
5959

6060
void __device__ copy_shared_to_generic(S *generic) {
6161
__builtin_memcpy(generic, &shared, sizeof(S));
62-
// PTX: ld.shared.u32
63-
// PTX: st.u32
62+
// PTX: ld.shared.b32
63+
// PTX: st.b32
6464
}

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 13 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1044,21 +1044,6 @@ pickOpcodeForVT(MVT::SimpleValueType VT, unsigned Opcode_i8,
10441044
}
10451045
}
10461046

1047-
static int getLdStRegType(EVT VT) {
1048-
if (VT.isFloatingPoint())
1049-
switch (VT.getSimpleVT().SimpleTy) {
1050-
case MVT::f16:
1051-
case MVT::bf16:
1052-
case MVT::v2f16:
1053-
case MVT::v2bf16:
1054-
return NVPTX::PTXLdStInstCode::Untyped;
1055-
default:
1056-
return NVPTX::PTXLdStInstCode::Float;
1057-
}
1058-
else
1059-
return NVPTX::PTXLdStInstCode::Unsigned;
1060-
}
1061-
10621047
bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
10631048
MemSDNode *LD = cast<MemSDNode>(N);
10641049
assert(LD->readMem() && "Expected load");
@@ -1088,24 +1073,14 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
10881073
// type is integer
10891074
// Float : ISD::NON_EXTLOAD or ISD::EXTLOAD and the type is float
10901075
MVT SimpleVT = LoadedVT.getSimpleVT();
1091-
MVT ScalarVT = SimpleVT.getScalarType();
10921076
// Read at least 8 bits (predicates are stored as 8-bit values)
1093-
unsigned FromTypeWidth = std::max(8U, (unsigned)ScalarVT.getSizeInBits());
1094-
unsigned int FromType;
1077+
unsigned FromTypeWidth = std::max(8U, (unsigned)SimpleVT.getSizeInBits());
10951078

10961079
// Vector Setting
1097-
unsigned VecType = NVPTX::PTXLdStInstCode::Scalar;
1098-
if (SimpleVT.isVector()) {
1099-
assert((Isv2x16VT(LoadedVT) || LoadedVT == MVT::v4i8) &&
1100-
"Unexpected vector type");
1101-
// v2f16/v2bf16/v2i16 is loaded using ld.b32
1102-
FromTypeWidth = 32;
1103-
}
1104-
1105-
if (PlainLoad && (PlainLoad->getExtensionType() == ISD::SEXTLOAD))
1106-
FromType = NVPTX::PTXLdStInstCode::Signed;
1107-
else
1108-
FromType = getLdStRegType(ScalarVT);
1080+
unsigned int FromType =
1081+
(PlainLoad && (PlainLoad->getExtensionType() == ISD::SEXTLOAD))
1082+
? NVPTX::PTXLdStInstCode::Signed
1083+
: NVPTX::PTXLdStInstCode::Untyped;
11091084

11101085
assert(isPowerOf2_32(FromTypeWidth) && FromTypeWidth >= 8 &&
11111086
FromTypeWidth <= 128 && "Invalid width for load");
@@ -1116,7 +1091,7 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
11161091
SDValue Ops[] = {getI32Imm(Ordering, DL),
11171092
getI32Imm(Scope, DL),
11181093
getI32Imm(CodeAddrSpace, DL),
1119-
getI32Imm(VecType, DL),
1094+
getI32Imm(NVPTX::PTXLdStInstCode::Scalar, DL),
11201095
getI32Imm(FromType, DL),
11211096
getI32Imm(FromTypeWidth, DL),
11221097
Base,
@@ -1182,7 +1157,7 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
11821157
unsigned ExtensionType = N->getConstantOperandVal(N->getNumOperands() - 1);
11831158
unsigned FromType = (ExtensionType == ISD::SEXTLOAD)
11841159
? NVPTX::PTXLdStInstCode::Signed
1185-
: getLdStRegType(MemVT.getScalarType());
1160+
: NVPTX::PTXLdStInstCode::Untyped;
11861161

11871162
unsigned VecType;
11881163
unsigned FromTypeWidth;
@@ -1200,8 +1175,8 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
12001175
}
12011176

12021177
if (isSubVectorPackedInI32(EltVT)) {
1178+
assert(ExtensionType == ISD::NON_EXTLOAD);
12031179
EltVT = MVT::i32;
1204-
FromType = NVPTX::PTXLdStInstCode::Untyped;
12051180
}
12061181

12071182
assert(isPowerOf2_32(FromTypeWidth) && FromTypeWidth >= 8 &&
@@ -1405,21 +1380,7 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
14051380
auto [Ordering, Scope] = insertMemoryInstructionFence(DL, Chain, ST);
14061381

14071382
// Vector Setting
1408-
MVT SimpleVT = StoreVT.getSimpleVT();
1409-
unsigned VecType = NVPTX::PTXLdStInstCode::Scalar;
1410-
1411-
// Type Setting: toType + toTypeWidth
1412-
// - for integer type, always use 'u'
1413-
MVT ScalarVT = SimpleVT.getScalarType();
1414-
unsigned ToTypeWidth = ScalarVT.getSizeInBits();
1415-
if (SimpleVT.isVector()) {
1416-
assert((Isv2x16VT(StoreVT) || StoreVT == MVT::v4i8) &&
1417-
"Unexpected vector type");
1418-
// v2x16 is stored using st.b32
1419-
ToTypeWidth = 32;
1420-
}
1421-
1422-
unsigned int ToType = getLdStRegType(ScalarVT);
1383+
const unsigned ToTypeWidth = StoreVT.getSimpleVT().getSizeInBits();
14231384

14241385
// Create the machine instruction DAG
14251386
SDValue Value = PlainStore ? PlainStore->getValue() : AtomicStore->getVal();
@@ -1434,8 +1395,8 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
14341395
getI32Imm(Ordering, DL),
14351396
getI32Imm(Scope, DL),
14361397
getI32Imm(CodeAddrSpace, DL),
1437-
getI32Imm(VecType, DL),
1438-
getI32Imm(ToType, DL),
1398+
getI32Imm(NVPTX::PTXLdStInstCode::Scalar, DL),
1399+
getI32Imm(NVPTX::PTXLdStInstCode::Untyped, DL),
14391400
getI32Imm(ToTypeWidth, DL),
14401401
Base,
14411402
Offset,
@@ -1481,7 +1442,6 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
14811442
// Type Setting: toType + toTypeWidth
14821443
// - for integer type, always use 'u'
14831444
const unsigned TotalWidth = StoreVT.getSimpleVT().getSizeInBits();
1484-
unsigned ToType = getLdStRegType(StoreVT.getSimpleVT().getScalarType());
14851445

14861446
SmallVector<SDValue, 12> Ops;
14871447
SDValue N2;
@@ -1508,7 +1468,6 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
15081468

15091469
if (isSubVectorPackedInI32(EltVT)) {
15101470
EltVT = MVT::i32;
1511-
ToType = NVPTX::PTXLdStInstCode::Untyped;
15121471
}
15131472

15141473
assert(isPowerOf2_32(ToTypeWidth) && ToTypeWidth >= 8 && ToTypeWidth <= 128 &&
@@ -1519,8 +1478,8 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
15191478

15201479
Ops.append({getI32Imm(Ordering, DL), getI32Imm(Scope, DL),
15211480
getI32Imm(CodeAddrSpace, DL), getI32Imm(VecType, DL),
1522-
getI32Imm(ToType, DL), getI32Imm(ToTypeWidth, DL), Base, Offset,
1523-
Chain});
1481+
getI32Imm(NVPTX::PTXLdStInstCode::Untyped, DL),
1482+
getI32Imm(ToTypeWidth, DL), Base, Offset, Chain});
15241483

15251484
std::optional<unsigned> Opcode;
15261485
switch (N->getOpcode()) {

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2249,11 +2249,11 @@ def LoadParamMemV2I8 : LoadParamV2MemInst<Int16Regs, ".b8">;
22492249
def LoadParamMemV4I32 : LoadParamV4MemInst<Int32Regs, ".b32">;
22502250
def LoadParamMemV4I16 : LoadParamV4MemInst<Int16Regs, ".b16">;
22512251
def LoadParamMemV4I8 : LoadParamV4MemInst<Int16Regs, ".b8">;
2252-
def LoadParamMemF32 : LoadParamMemInst<Float32Regs, ".f32">;
2253-
def LoadParamMemF64 : LoadParamMemInst<Float64Regs, ".f64">;
2254-
def LoadParamMemV2F32 : LoadParamV2MemInst<Float32Regs, ".f32">;
2255-
def LoadParamMemV2F64 : LoadParamV2MemInst<Float64Regs, ".f64">;
2256-
def LoadParamMemV4F32 : LoadParamV4MemInst<Float32Regs, ".f32">;
2252+
def LoadParamMemF32 : LoadParamMemInst<Float32Regs, ".b32">;
2253+
def LoadParamMemF64 : LoadParamMemInst<Float64Regs, ".b64">;
2254+
def LoadParamMemV2F32 : LoadParamV2MemInst<Float32Regs, ".b32">;
2255+
def LoadParamMemV2F64 : LoadParamV2MemInst<Float64Regs, ".b64">;
2256+
def LoadParamMemV4F32 : LoadParamV4MemInst<Float32Regs, ".b32">;
22572257

22582258
defm StoreParamI64 : StoreParamInst<Int64Regs, i64imm, ".b64">;
22592259
defm StoreParamI32 : StoreParamInst<Int32Regs, i32imm, ".b32">;
@@ -2272,13 +2272,13 @@ defm StoreParamV4I32 : StoreParamV4Inst<Int32Regs, i32imm, ".b32">;
22722272
defm StoreParamV4I16 : StoreParamV4Inst<Int16Regs, i16imm, ".b16">;
22732273
defm StoreParamV4I8 : StoreParamV4Inst<Int16Regs, i8imm, ".b8">;
22742274

2275-
defm StoreParamF32 : StoreParamInst<Float32Regs, f32imm, ".f32">;
2276-
defm StoreParamF64 : StoreParamInst<Float64Regs, f64imm, ".f64">;
2275+
defm StoreParamF32 : StoreParamInst<Float32Regs, f32imm, ".b32">;
2276+
defm StoreParamF64 : StoreParamInst<Float64Regs, f64imm, ".b64">;
22772277

2278-
defm StoreParamV2F32 : StoreParamV2Inst<Float32Regs, f32imm, ".f32">;
2279-
defm StoreParamV2F64 : StoreParamV2Inst<Float64Regs, f64imm, ".f64">;
2278+
defm StoreParamV2F32 : StoreParamV2Inst<Float32Regs, f32imm, ".b32">;
2279+
defm StoreParamV2F64 : StoreParamV2Inst<Float64Regs, f64imm, ".b64">;
22802280

2281-
defm StoreParamV4F32 : StoreParamV4Inst<Float32Regs, f32imm, ".f32">;
2281+
defm StoreParamV4F32 : StoreParamV4Inst<Float32Regs, f32imm, ".b32">;
22822282

22832283
def StoreRetvalI64 : StoreRetvalInst<Int64Regs, ".b64">;
22842284
def StoreRetvalI32 : StoreRetvalInst<Int32Regs, ".b32">;
@@ -2294,11 +2294,11 @@ def StoreRetvalV4I32 : StoreRetvalV4Inst<Int32Regs, ".b32">;
22942294
def StoreRetvalV4I16 : StoreRetvalV4Inst<Int16Regs, ".b16">;
22952295
def StoreRetvalV4I8 : StoreRetvalV4Inst<Int16Regs, ".b8">;
22962296

2297-
def StoreRetvalF64 : StoreRetvalInst<Float64Regs, ".f64">;
2298-
def StoreRetvalF32 : StoreRetvalInst<Float32Regs, ".f32">;
2299-
def StoreRetvalV2F64 : StoreRetvalV2Inst<Float64Regs, ".f64">;
2300-
def StoreRetvalV2F32 : StoreRetvalV2Inst<Float32Regs, ".f32">;
2301-
def StoreRetvalV4F32 : StoreRetvalV4Inst<Float32Regs, ".f32">;
2297+
def StoreRetvalF64 : StoreRetvalInst<Float64Regs, ".b64">;
2298+
def StoreRetvalF32 : StoreRetvalInst<Float32Regs, ".b32">;
2299+
def StoreRetvalV2F64 : StoreRetvalV2Inst<Float64Regs, ".b64">;
2300+
def StoreRetvalV2F32 : StoreRetvalV2Inst<Float32Regs, ".b32">;
2301+
def StoreRetvalV4F32 : StoreRetvalV4Inst<Float32Regs, ".b32">;
23022302

23032303
def CallArgBeginInst : NVPTXInst<(outs), (ins), "(", [(CallArgBegin)]>;
23042304
def CallArgEndInst1 : NVPTXInst<(outs), (ins), ");", [(CallArgEnd (i32 1))]>;

0 commit comments

Comments
 (0)