Skip to content

Commit 6c2f5d6

Browse files
author
Hugh Delaney
authored
[NVPTX] Don't use underlying alignment to align param (#96793)
Previously, if a ptr had align N, then the NVPTX lowering was taking this align N to refer to the alignment of the pointer type itself, as opposed to the alignment of the memory that it points to. As such, if a kernel with signature ``` define void @foo(ptr align 4 %_arg_ptr) ``` takes align 4 to be the alignment of the parameter, this would result in breaking the ld.param into two separate loads like so: ``` ld.param.u32 %rd1, [foo_param_0+4]; shl.b64 %rd2, %rd1, 32; ld.param.u32 %rd3, [foo_param_0]; or.b64 %rd4, %rd2, %rd3; ``` It isn't necessary as far as I can tell from the PTX ISA documents to specify the alignment of params, nor to break the loading of params into smaller loads based on their alignment. So this patch changes the codegen to the better: ``` ld.param.u64 %rd1, [foo_param_0]; ```
1 parent ba60d8a commit 6c2f5d6

File tree

2 files changed

+54
-3
lines changed

2 files changed

+54
-3
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3232,9 +3232,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
32323232
if (NumElts != 1)
32333233
return std::nullopt;
32343234
Align PartAlign =
3235-
(Offsets[parti] == 0 && PAL.getParamAlignment(i))
3236-
? PAL.getParamAlignment(i).value()
3237-
: DL.getABITypeAlign(EltVT.getTypeForEVT(F->getContext()));
3235+
DL.getABITypeAlign(EltVT.getTypeForEVT(F->getContext()));
32383236
return commonAlignment(PartAlign, Offsets[parti]);
32393237
}();
32403238
SDValue P = DAG.getLoad(VecVT, dl, Root, VecAddr,

llvm/test/CodeGen/NVPTX/param-align.ll

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,56 @@ define ptx_device void @t6() {
6969
call void %fp(ptr byval(i8) null);
7070
ret void
7171
}
72+
73+
; CHECK-LABEL: .func check_ptr_align1(
74+
; CHECK: ld.param.u64 %rd1, [check_ptr_align1_param_0];
75+
; CHECK-NOT: ld.param.u8
76+
; CHECK: mov.b32 %r1, 0;
77+
; CHECK: st.u8 [%rd1+3], %r1;
78+
; CHECK: st.u8 [%rd1+2], %r1;
79+
; CHECK: st.u8 [%rd1+1], %r1;
80+
; CHECK: mov.b32 %r2, 1;
81+
; CHECK: st.u8 [%rd1], %r2;
82+
; CHECK: ret;
83+
define void @check_ptr_align1(ptr align 1 %_arg_ptr) {
84+
entry:
85+
store i32 1, ptr %_arg_ptr, align 1
86+
ret void
87+
}
88+
89+
; CHECK-LABEL: .func check_ptr_align2(
90+
; CHECK: ld.param.u64 %rd1, [check_ptr_align2_param_0];
91+
; CHECK-NOT: ld.param.u16
92+
; CHECK: mov.b32 %r1, 0;
93+
; CHECK: st.u16 [%rd1+2], %r1;
94+
; CHECK: mov.b32 %r2, 2;
95+
; CHECK: st.u16 [%rd1], %r2;
96+
; CHECK: ret;
97+
define void @check_ptr_align2(ptr align 2 %_arg_ptr) {
98+
entry:
99+
store i32 2, ptr %_arg_ptr, align 2
100+
ret void
101+
}
102+
103+
; CHECK-LABEL: .func check_ptr_align4(
104+
; CHECK: ld.param.u64 %rd1, [check_ptr_align4_param_0];
105+
; CHECK-NOT: ld.param.u32
106+
; CHECK: mov.b32 %r1, 4;
107+
; CHECK: st.u32 [%rd1], %r1;
108+
; CHECK: ret;
109+
define void @check_ptr_align4(ptr align 4 %_arg_ptr) {
110+
entry:
111+
store i32 4, ptr %_arg_ptr, align 4
112+
ret void
113+
}
114+
115+
; CHECK-LABEL: .func check_ptr_align8(
116+
; CHECK: ld.param.u64 %rd1, [check_ptr_align8_param_0];
117+
; CHECK: mov.b32 %r1, 8;
118+
; CHECK: st.u32 [%rd1], %r1;
119+
; CHECK: ret;
120+
define void @check_ptr_align8(ptr align 8 %_arg_ptr) {
121+
entry:
122+
store i32 8, ptr %_arg_ptr, align 8
123+
ret void
124+
}

0 commit comments

Comments
 (0)