Skip to content

[NVPTX] use untyped loads and stores where ever possible #137698

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion clang/test/CodeGenCUDA/bf16.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
// CHECK: .param .align 2 .b8 _Z8test_argPDF16bDF16b_param_1[2]
//
__device__ void test_arg(__bf16 *out, __bf16 in) {
// CHECK-DAG: ld.param.u64 %[[A:rd[0-9]+]], [_Z8test_argPDF16bDF16b_param_0];
// CHECK-DAG: ld.param.b64 %[[A:rd[0-9]+]], [_Z8test_argPDF16bDF16b_param_0];
// CHECK-DAG: ld.param.b16 %[[R:rs[0-9]+]], [_Z8test_argPDF16bDF16b_param_1];
__bf16 bf16 = in;
*out = bf16;
Expand Down
28 changes: 14 additions & 14 deletions clang/test/CodeGenCUDA/fp-contract.cu
Original file line number Diff line number Diff line change
Expand Up @@ -179,26 +179,26 @@
__host__ __device__ float func(float a, float b, float c) { return a + b * c; }
// COMMON-LABEL: _Z4funcfff
// NV-ON: fma.rn.f32
// NV-ON-NEXT: st.param.f32
// NV-ON-NEXT: st.param.b32
// AMD-ON: v_fmac_f32_e64
// AMD-ON-NEXT: s_setpc_b64

// NV-OFF: mul.rn.f32
// NV-OFF-NEXT: add.rn.f32
// NV-OFF-NEXT: st.param.f32
// NV-OFF-NEXT: st.param.b32
// AMD-OFF: v_mul_f32_e64
// AMD-OFF-NEXT: v_add_f32_e64
// AMD-OFF-NEXT: s_setpc_b64

// NV-OPT-FAST: fma.rn.f32
// NV-OPT-FAST-NEXT: st.param.f32
// NV-OPT-FAST-NEXT: st.param.b32
// NV-OPT-FASTSTD: fma.rn.f32
// NV-OPT-FASTSTD-NEXT: st.param.f32
// NV-OPT-FASTSTD-NEXT: st.param.b32
// NV-OPT-ON: fma.rn.f32
// NV-OPT-ON-NEXT: st.param.f32
// NV-OPT-ON-NEXT: st.param.b32
// NV-OPT-OFF: mul.rn.f32
// NV-OPT-OFF-NEXT: add.rn.f32
// NV-OPT-OFF-NEXT: st.param.f32
// NV-OPT-OFF-NEXT: st.param.b32

// AMD-OPT-FAST-IR: fmul contract float
// AMD-OPT-FAST-IR: fadd contract float
Expand All @@ -224,15 +224,15 @@ __host__ __device__ float func2(float a, float b, float c) {
}
// COMMON-LABEL: _Z5func2fff
// NV-OPT-FAST: fma.rn.f32
// NV-OPT-FAST-NEXT: st.param.f32
// NV-OPT-FAST-NEXT: st.param.b32
// NV-OPT-FASTSTD: fma.rn.f32
// NV-OPT-FASTSTD-NEXT: st.param.f32
// NV-OPT-FASTSTD-NEXT: st.param.b32
// NV-OPT-ON: mul.rn.f32
// NV-OPT-ON: add.rn.f32
// NV-OPT-ON-NEXT: st.param.f32
// NV-OPT-ON-NEXT: st.param.b32
// NV-OPT-OFF: mul.rn.f32
// NV-OPT-OFF: add.rn.f32
// NV-OPT-OFF-NEXT: st.param.f32
// NV-OPT-OFF-NEXT: st.param.b32

// AMD-OPT-FAST-IR: fmul contract float
// AMD-OPT-FAST-IR: fadd contract float
Expand Down Expand Up @@ -267,16 +267,16 @@ __host__ __device__ float func2(float a, float b, float c) {
}
// COMMON-LABEL: _Z5func3fff
// NV-OPT-FAST: fma.rn.f32
// NV-OPT-FAST-NEXT: st.param.f32
// NV-OPT-FAST-NEXT: st.param.b32
// NV-OPT-FASTSTD: mul.rn.f32
// NV-OPT-FASTSTD: add.rn.f32
// NV-OPT-FASTSTD-NEXT: st.param.f32
// NV-OPT-FASTSTD-NEXT: st.param.b32
// NV-OPT-ON: mul.rn.f32
// NV-OPT-ON: add.rn.f32
// NV-OPT-ON-NEXT: st.param.f32
// NV-OPT-ON-NEXT: st.param.b32
// NV-OPT-OFF: mul.rn.f32
// NV-OPT-OFF: add.rn.f32
// NV-OPT-OFF-NEXT: st.param.f32
// NV-OPT-OFF-NEXT: st.param.b32

// AMD-OPT-FAST-IR: fmul float
// AMD-OPT-FAST-IR: fadd float
Expand Down
28 changes: 14 additions & 14 deletions clang/test/CodeGenCUDA/memcpy-libcall.cu
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@
// PTX-LABEL: .func _Z12copy_genericPvPKv(
void __device__ copy_generic(void *dest, const void *src) {
__builtin_memcpy(dest, src, 32);
// PTX: ld.u8
// PTX: st.u8
// PTX: ld.b8
// PTX: st.b8
}

// PTX-LABEL: .entry _Z11copy_globalPvS_(
void __global__ copy_global(void *dest, void * src) {
__builtin_memcpy(dest, src, 32);
// PTX: ld.global.u8
// PTX: st.global.u8
// PTX: ld.global.b8
// PTX: st.global.b8
}

struct S {
Expand All @@ -28,37 +28,37 @@ struct S {
// PTX-LABEL: .entry _Z20copy_param_to_globalP1SS_(
void __global__ copy_param_to_global(S *global, S param) {
__builtin_memcpy(global, &param, sizeof(S));
// PTX: ld.param.u32
// PTX: st.global.u32
// PTX: ld.param.b32
// PTX: st.global.b32
}

// PTX-LABEL: .entry _Z19copy_param_to_localPU3AS51SS_(
void __global__ copy_param_to_local(__attribute__((address_space(5))) S *local,
S param) {
__builtin_memcpy(local, &param, sizeof(S));
// PTX: ld.param.u32
// PTX: st.local.u32
// PTX: ld.param.b32
// PTX: st.local.b32
}

// PTX-LABEL: .func _Z21copy_local_to_genericP1SPU3AS5S_(
void __device__ copy_local_to_generic(S *generic,
__attribute__((address_space(5))) S *src) {
__builtin_memcpy(generic, src, sizeof(S));
// PTX: ld.local.u32
// PTX: st.u32
// PTX: ld.local.b32
// PTX: st.b32
}

__shared__ S shared;

// PTX-LABEL: .entry _Z20copy_param_to_shared1S(
void __global__ copy_param_to_shared( S param) {
__builtin_memcpy(&shared, &param, sizeof(S));
// PTX: ld.param.u32
// PTX: st.shared.u32
// PTX: ld.param.b32
// PTX: st.shared.b32
}

void __device__ copy_shared_to_generic(S *generic) {
__builtin_memcpy(generic, &shared, sizeof(S));
// PTX: ld.shared.u32
// PTX: st.u32
// PTX: ld.shared.b32
// PTX: st.b32
}
67 changes: 13 additions & 54 deletions llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1044,21 +1044,6 @@ pickOpcodeForVT(MVT::SimpleValueType VT, unsigned Opcode_i8,
}
}

static int getLdStRegType(EVT VT) {
if (VT.isFloatingPoint())
switch (VT.getSimpleVT().SimpleTy) {
case MVT::f16:
case MVT::bf16:
case MVT::v2f16:
case MVT::v2bf16:
return NVPTX::PTXLdStInstCode::Untyped;
default:
return NVPTX::PTXLdStInstCode::Float;
}
else
return NVPTX::PTXLdStInstCode::Unsigned;
}

bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
MemSDNode *LD = cast<MemSDNode>(N);
assert(LD->readMem() && "Expected load");
Expand Down Expand Up @@ -1088,24 +1073,14 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
// type is integer
// Float : ISD::NON_EXTLOAD or ISD::EXTLOAD and the type is float
MVT SimpleVT = LoadedVT.getSimpleVT();
MVT ScalarVT = SimpleVT.getScalarType();
// Read at least 8 bits (predicates are stored as 8-bit values)
unsigned FromTypeWidth = std::max(8U, (unsigned)ScalarVT.getSizeInBits());
unsigned int FromType;
unsigned FromTypeWidth = std::max(8U, (unsigned)SimpleVT.getSizeInBits());

// Vector Setting
unsigned VecType = NVPTX::PTXLdStInstCode::Scalar;
if (SimpleVT.isVector()) {
assert((Isv2x16VT(LoadedVT) || LoadedVT == MVT::v4i8) &&
"Unexpected vector type");
// v2f16/v2bf16/v2i16 is loaded using ld.b32
FromTypeWidth = 32;
}

if (PlainLoad && (PlainLoad->getExtensionType() == ISD::SEXTLOAD))
FromType = NVPTX::PTXLdStInstCode::Signed;
else
FromType = getLdStRegType(ScalarVT);
unsigned int FromType =
(PlainLoad && (PlainLoad->getExtensionType() == ISD::SEXTLOAD))
? NVPTX::PTXLdStInstCode::Signed
: NVPTX::PTXLdStInstCode::Untyped;

assert(isPowerOf2_32(FromTypeWidth) && FromTypeWidth >= 8 &&
FromTypeWidth <= 128 && "Invalid width for load");
Expand All @@ -1116,7 +1091,7 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
SDValue Ops[] = {getI32Imm(Ordering, DL),
getI32Imm(Scope, DL),
getI32Imm(CodeAddrSpace, DL),
getI32Imm(VecType, DL),
getI32Imm(NVPTX::PTXLdStInstCode::Scalar, DL),
getI32Imm(FromType, DL),
getI32Imm(FromTypeWidth, DL),
Base,
Expand Down Expand Up @@ -1182,7 +1157,7 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
unsigned ExtensionType = N->getConstantOperandVal(N->getNumOperands() - 1);
unsigned FromType = (ExtensionType == ISD::SEXTLOAD)
? NVPTX::PTXLdStInstCode::Signed
: getLdStRegType(MemVT.getScalarType());
: NVPTX::PTXLdStInstCode::Untyped;

unsigned VecType;
unsigned FromTypeWidth;
Expand All @@ -1200,8 +1175,8 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
}

if (isSubVectorPackedInI32(EltVT)) {
assert(ExtensionType == ISD::NON_EXTLOAD);
EltVT = MVT::i32;
FromType = NVPTX::PTXLdStInstCode::Untyped;
}

assert(isPowerOf2_32(FromTypeWidth) && FromTypeWidth >= 8 &&
Expand Down Expand Up @@ -1405,21 +1380,7 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
auto [Ordering, Scope] = insertMemoryInstructionFence(DL, Chain, ST);

// Vector Setting
MVT SimpleVT = StoreVT.getSimpleVT();
unsigned VecType = NVPTX::PTXLdStInstCode::Scalar;

// Type Setting: toType + toTypeWidth
// - for integer type, always use 'u'
MVT ScalarVT = SimpleVT.getScalarType();
unsigned ToTypeWidth = ScalarVT.getSizeInBits();
if (SimpleVT.isVector()) {
assert((Isv2x16VT(StoreVT) || StoreVT == MVT::v4i8) &&
"Unexpected vector type");
// v2x16 is stored using st.b32
ToTypeWidth = 32;
}

unsigned int ToType = getLdStRegType(ScalarVT);
const unsigned ToTypeWidth = StoreVT.getSimpleVT().getSizeInBits();

// Create the machine instruction DAG
SDValue Value = PlainStore ? PlainStore->getValue() : AtomicStore->getVal();
Expand All @@ -1434,8 +1395,8 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
getI32Imm(Ordering, DL),
getI32Imm(Scope, DL),
getI32Imm(CodeAddrSpace, DL),
getI32Imm(VecType, DL),
getI32Imm(ToType, DL),
getI32Imm(NVPTX::PTXLdStInstCode::Scalar, DL),
getI32Imm(NVPTX::PTXLdStInstCode::Untyped, DL),
getI32Imm(ToTypeWidth, DL),
Base,
Offset,
Expand Down Expand Up @@ -1481,7 +1442,6 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
// Type Setting: toType + toTypeWidth
// - for integer type, always use 'u'
const unsigned TotalWidth = StoreVT.getSimpleVT().getSizeInBits();
unsigned ToType = getLdStRegType(StoreVT.getSimpleVT().getScalarType());

SmallVector<SDValue, 12> Ops;
SDValue N2;
Expand All @@ -1508,7 +1468,6 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {

if (isSubVectorPackedInI32(EltVT)) {
EltVT = MVT::i32;
ToType = NVPTX::PTXLdStInstCode::Untyped;
}

assert(isPowerOf2_32(ToTypeWidth) && ToTypeWidth >= 8 && ToTypeWidth <= 128 &&
Expand All @@ -1519,8 +1478,8 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {

Ops.append({getI32Imm(Ordering, DL), getI32Imm(Scope, DL),
getI32Imm(CodeAddrSpace, DL), getI32Imm(VecType, DL),
getI32Imm(ToType, DL), getI32Imm(ToTypeWidth, DL), Base, Offset,
Chain});
getI32Imm(NVPTX::PTXLdStInstCode::Untyped, DL),
getI32Imm(ToTypeWidth, DL), Base, Offset, Chain});

std::optional<unsigned> Opcode;
switch (N->getOpcode()) {
Expand Down
30 changes: 15 additions & 15 deletions llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -2249,11 +2249,11 @@ def LoadParamMemV2I8 : LoadParamV2MemInst<Int16Regs, ".b8">;
def LoadParamMemV4I32 : LoadParamV4MemInst<Int32Regs, ".b32">;
def LoadParamMemV4I16 : LoadParamV4MemInst<Int16Regs, ".b16">;
def LoadParamMemV4I8 : LoadParamV4MemInst<Int16Regs, ".b8">;
def LoadParamMemF32 : LoadParamMemInst<Float32Regs, ".f32">;
def LoadParamMemF64 : LoadParamMemInst<Float64Regs, ".f64">;
def LoadParamMemV2F32 : LoadParamV2MemInst<Float32Regs, ".f32">;
def LoadParamMemV2F64 : LoadParamV2MemInst<Float64Regs, ".f64">;
def LoadParamMemV4F32 : LoadParamV4MemInst<Float32Regs, ".f32">;
def LoadParamMemF32 : LoadParamMemInst<Float32Regs, ".b32">;
def LoadParamMemF64 : LoadParamMemInst<Float64Regs, ".b64">;
def LoadParamMemV2F32 : LoadParamV2MemInst<Float32Regs, ".b32">;
def LoadParamMemV2F64 : LoadParamV2MemInst<Float64Regs, ".b64">;
def LoadParamMemV4F32 : LoadParamV4MemInst<Float32Regs, ".b32">;

defm StoreParamI64 : StoreParamInst<Int64Regs, i64imm, ".b64">;
defm StoreParamI32 : StoreParamInst<Int32Regs, i32imm, ".b32">;
Expand All @@ -2272,13 +2272,13 @@ defm StoreParamV4I32 : StoreParamV4Inst<Int32Regs, i32imm, ".b32">;
defm StoreParamV4I16 : StoreParamV4Inst<Int16Regs, i16imm, ".b16">;
defm StoreParamV4I8 : StoreParamV4Inst<Int16Regs, i8imm, ".b8">;

defm StoreParamF32 : StoreParamInst<Float32Regs, f32imm, ".f32">;
defm StoreParamF64 : StoreParamInst<Float64Regs, f64imm, ".f64">;
defm StoreParamF32 : StoreParamInst<Float32Regs, f32imm, ".b32">;
defm StoreParamF64 : StoreParamInst<Float64Regs, f64imm, ".b64">;

defm StoreParamV2F32 : StoreParamV2Inst<Float32Regs, f32imm, ".f32">;
defm StoreParamV2F64 : StoreParamV2Inst<Float64Regs, f64imm, ".f64">;
defm StoreParamV2F32 : StoreParamV2Inst<Float32Regs, f32imm, ".b32">;
defm StoreParamV2F64 : StoreParamV2Inst<Float64Regs, f64imm, ".b64">;

defm StoreParamV4F32 : StoreParamV4Inst<Float32Regs, f32imm, ".f32">;
defm StoreParamV4F32 : StoreParamV4Inst<Float32Regs, f32imm, ".b32">;

def StoreRetvalI64 : StoreRetvalInst<Int64Regs, ".b64">;
def StoreRetvalI32 : StoreRetvalInst<Int32Regs, ".b32">;
Expand All @@ -2294,11 +2294,11 @@ def StoreRetvalV4I32 : StoreRetvalV4Inst<Int32Regs, ".b32">;
def StoreRetvalV4I16 : StoreRetvalV4Inst<Int16Regs, ".b16">;
def StoreRetvalV4I8 : StoreRetvalV4Inst<Int16Regs, ".b8">;

def StoreRetvalF64 : StoreRetvalInst<Float64Regs, ".f64">;
def StoreRetvalF32 : StoreRetvalInst<Float32Regs, ".f32">;
def StoreRetvalV2F64 : StoreRetvalV2Inst<Float64Regs, ".f64">;
def StoreRetvalV2F32 : StoreRetvalV2Inst<Float32Regs, ".f32">;
def StoreRetvalV4F32 : StoreRetvalV4Inst<Float32Regs, ".f32">;
def StoreRetvalF64 : StoreRetvalInst<Float64Regs, ".b64">;
def StoreRetvalF32 : StoreRetvalInst<Float32Regs, ".b32">;
def StoreRetvalV2F64 : StoreRetvalV2Inst<Float64Regs, ".b64">;
def StoreRetvalV2F32 : StoreRetvalV2Inst<Float32Regs, ".b32">;
def StoreRetvalV4F32 : StoreRetvalV4Inst<Float32Regs, ".b32">;

def CallArgBeginInst : NVPTXInst<(outs), (ins), "(", [(CallArgBegin)]>;
def CallArgEndInst1 : NVPTXInst<(outs), (ins), ");", [(CallArgEnd (i32 1))]>;
Expand Down
Loading