Skip to content

Commit 504c119

Browse files
committed
Handle ptr addrspaces in kernel args
Change the asmprinter so that it will not force .const or .shared pointers to be changed to .global ones if they are explicitly annotated with those address-spaces (even though they are not expected to be present). Unify the code-path for printing addrspace annotations for both CUDA and CL, and only coerce generic pointers into .global pointers on CUDA. Emit alignment info for both CL and CUDA, but omit it on CUDA if it is not explicitly supplied. Update tests to have both aligned and unaligned pointers in all relevant addrspaces.
1 parent c8c51a0 commit 504c119

File tree

2 files changed

+47
-36
lines changed

2 files changed

+47
-36
lines changed

llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1603,35 +1603,34 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
16031603
O << "\t.param .u" << PTySizeInBits << " ";
16041604

16051605
int addrSpace = PTy->getAddressSpace();
1606-
if (static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() ==
1607-
NVPTX::CUDA) {
1606+
const bool IsCUDA =
1607+
static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() ==
1608+
NVPTX::CUDA;
1609+
1610+
O << ".ptr ";
1611+
switch (addrSpace) {
1612+
default:
16081613
// Special handling for pointer arguments to kernel
16091614
// CUDA kernels assume that pointers are in global address space
16101615
// See:
16111616
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameter-state-space
1612-
O << ".ptr .global ";
1613-
if (I->getParamAlign().valueOrOne() != 1) {
1614-
Align ParamAlign = I->getParamAlign().value();
1615-
O << ".align " << ParamAlign.value() << " ";
1616-
}
1617-
} else {
1618-
switch (addrSpace) {
1619-
default:
1620-
O << ".ptr ";
1621-
break;
1622-
case ADDRESS_SPACE_CONST:
1623-
O << ".ptr .const ";
1624-
break;
1625-
case ADDRESS_SPACE_SHARED:
1626-
O << ".ptr .shared ";
1627-
break;
1628-
case ADDRESS_SPACE_GLOBAL:
1629-
O << ".ptr .global ";
1630-
break;
1631-
}
1632-
Align ParamAlign = I->getParamAlign().valueOrOne();
1633-
O << ".align " << ParamAlign.value() << " ";
1617+
if (IsCUDA)
1618+
O << " .global ";
1619+
break;
1620+
case ADDRESS_SPACE_CONST:
1621+
O << " .const ";
1622+
break;
1623+
case ADDRESS_SPACE_SHARED:
1624+
O << " .shared ";
1625+
break;
1626+
case ADDRESS_SPACE_GLOBAL:
1627+
O << " .global ";
1628+
break;
16341629
}
1630+
1631+
Align ParamAlign = I->getParamAlign().valueOrOne();
1632+
if (ParamAlign != 1 || !IsCUDA)
1633+
O << ".align " << ParamAlign.value() << " ";
16351634
O << TLI->getParamName(F, paramIndex);
16361635
continue;
16371636
}
Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,19 @@
1-
; RUN: llc < %s -march=nvptx64 -mcpu=sm_72 2>&1 | FileCheck %s
2-
; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_72 | %ptxas-verify %}
1+
; RUN: llc < %s -march=nvptx64 -mcpu=sm_60 | FileCheck %s
2+
; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_60 | %ptxas-verify %}
33

44
%struct.Large = type { [16 x double] }
55

66
; CHECK-LABEL: .entry func_align(
7-
; CHECK: .param .u64 .ptr .global .align 16 func_align_param_0,
8-
; CHECK: .param .u64 .ptr .global func_align_param_1,
9-
; CHECK: .param .u64 .ptr .global func_align_param_2
10-
define void @func_align(ptr nocapture readonly align 16 %input, ptr nocapture %out, ptr addrspace(3) %n) {
7+
; CHECK: .param .u64 .ptr .global .align 16 func_align_param_0
8+
; CHECK: .param .u64 .ptr .global .align 16 func_align_param_1
9+
; CHECK: .param .u64 .ptr .global .align 16 func_align_param_2
10+
; CHECK: .param .u64 .ptr .shared .align 16 func_align_param_3
11+
; CHECK: .param .u64 .ptr .const .align 16 func_align_param_4
12+
define void @func_align(ptr nocapture readonly align 16 %input,
13+
ptr nocapture align 16 %out,
14+
ptr addrspace(1) align 16 %global,
15+
ptr addrspace(3) align 16 %shared,
16+
ptr addrspace(4) align 16 %const) {
1117
entry:
1218
%0 = addrspacecast ptr %out to ptr addrspace(1)
1319
%1 = addrspacecast ptr %input to ptr addrspace(1)
@@ -17,11 +23,17 @@ entry:
1723
ret void
1824
}
1925

20-
; CHECK-LABEL: .entry func(
21-
; CHECK: .param .u64 .ptr .global func_param_0,
22-
; CHECK: .param .u64 .ptr .global func_param_1,
23-
; CHECK: .param .u32 func_param_2
24-
define void @func(ptr nocapture readonly %input, ptr nocapture %out, i32 %n) {
26+
; CHECK-LABEL: .entry func_noalign(
27+
; CHECK: .param .u64 .ptr .global func_noalign_param_0
28+
; CHECK: .param .u64 .ptr .global func_noalign_param_1
29+
; CHECK: .param .u64 .ptr .global func_noalign_param_2
30+
; CHECK: .param .u64 .ptr .shared func_noalign_param_3
31+
; CHECK: .param .u64 .ptr .const func_noalign_param_4
32+
define void @func_noalign(ptr nocapture readonly %input,
33+
ptr nocapture %out,
34+
ptr addrspace(1) %global,
35+
ptr addrspace(3) %shared,
36+
ptr addrspace(4) %const) {
2537
entry:
2638
%0 = addrspacecast ptr %out to ptr addrspace(1)
2739
%1 = addrspacecast ptr %input to ptr addrspace(1)
@@ -33,4 +45,4 @@ entry:
3345

3446
!nvvm.annotations = !{!0, !1}
3547
!0 = !{ptr @func_align, !"kernel", i32 1}
36-
!1 = !{ptr @func, !"kernel", i32 1}
48+
!1 = !{ptr @func_noalign, !"kernel", i32 1}

0 commit comments

Comments
 (0)