Skip to content

Commit 8da3a8f

Browse files
authored
[NVPTX] fixup support for over-aligned parameters (#92457)
This extends the NVPTX support for over-aligned parameters and return values in a few related ways: - Support for `alignstack` attribute, as an alternative to legacy nvvm `!"align"` metadata entries. While we still maintain the legacy support, long term it might be nice to auto-upgrade to `alignstack`. - Check the alignment info when emitting the parameter list to prevent a mismatch between alignment of caller and callee, which would previously cause a fatal error for `ptxas`. - Check the alignment info when emitting loads for parameters, potentially enabling better vectorization.
1 parent e2db08f commit 8da3a8f

File tree

6 files changed

+168
-43
lines changed

6 files changed

+168
-43
lines changed

llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
#include "llvm/MC/MCStreamer.h"
7373
#include "llvm/MC/MCSymbol.h"
7474
#include "llvm/MC/TargetRegistry.h"
75+
#include "llvm/Support/Alignment.h"
7576
#include "llvm/Support/Casting.h"
7677
#include "llvm/Support/CommandLine.h"
7778
#include "llvm/Support/Endian.h"
@@ -370,11 +371,10 @@ void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) {
370371
<< " func_retval0";
371372
} else if (ShouldPassAsArray(Ty)) {
372373
unsigned totalsz = DL.getTypeAllocSize(Ty);
373-
unsigned retAlignment = 0;
374-
if (!getAlign(*F, 0, retAlignment))
375-
retAlignment = TLI->getFunctionParamOptimizedAlign(F, Ty, DL).value();
376-
O << ".param .align " << retAlignment << " .b8 func_retval0[" << totalsz
377-
<< "]";
374+
Align RetAlignment = TLI->getFunctionArgumentAlignment(
375+
F, Ty, AttributeList::ReturnIndex, DL);
376+
O << ".param .align " << RetAlignment.value() << " .b8 func_retval0["
377+
<< totalsz << "]";
378378
} else
379379
llvm_unreachable("Unknown return type");
380380
} else {
@@ -1558,6 +1558,10 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
15581558

15591559
auto getOptimalAlignForParam = [TLI, &DL, &PAL, F,
15601560
paramIndex](Type *Ty) -> Align {
1561+
if (MaybeAlign StackAlign =
1562+
getAlign(*F, paramIndex + AttributeList::FirstArgIndex))
1563+
return StackAlign.value();
1564+
15611565
Align TypeAlign = TLI->getFunctionParamOptimizedAlign(F, Ty, DL);
15621566
MaybeAlign ParamAlign = PAL.getParamAlignment(paramIndex);
15631567
return std::max(TypeAlign, ParamAlign.valueOrOne());

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1434,12 +1434,11 @@ std::string NVPTXTargetLowering::getPrototype(
14341434

14351435
if (!Outs[OIdx].Flags.isByVal()) {
14361436
if (IsTypePassedAsArray(Ty)) {
1437-
unsigned ParamAlign = 0;
14381437
const CallInst *CallI = cast<CallInst>(&CB);
1439-
// +1 because index 0 is reserved for return type alignment
1440-
if (!getAlign(*CallI, i + 1, ParamAlign))
1441-
ParamAlign = getFunctionParamOptimizedAlign(F, Ty, DL).value();
1442-
O << ".param .align " << ParamAlign << " .b8 ";
1438+
Align ParamAlign =
1439+
getAlign(*CallI, i + AttributeList::FirstArgIndex)
1440+
.value_or(getFunctionParamOptimizedAlign(F, Ty, DL));
1441+
O << ".param .align " << ParamAlign.value() << " .b8 ";
14431442
O << "_";
14441443
O << "[" << DL.getTypeAllocSize(Ty) << "]";
14451444
// update the index for Outs
@@ -1489,6 +1488,11 @@ std::string NVPTXTargetLowering::getPrototype(
14891488
return Prototype;
14901489
}
14911490

1491+
Align NVPTXTargetLowering::getFunctionArgumentAlignment(
1492+
const Function *F, Type *Ty, unsigned Idx, const DataLayout &DL) const {
1493+
return getAlign(*F, Idx).value_or(getFunctionParamOptimizedAlign(F, Ty, DL));
1494+
}
1495+
14921496
Align NVPTXTargetLowering::getArgumentAlignment(const CallBase *CB, Type *Ty,
14931497
unsigned Idx,
14941498
const DataLayout &DL) const {
@@ -1497,7 +1501,6 @@ Align NVPTXTargetLowering::getArgumentAlignment(const CallBase *CB, Type *Ty,
14971501
return DL.getABITypeAlign(Ty);
14981502
}
14991503

1500-
unsigned Alignment = 0;
15011504
const Function *DirectCallee = CB->getCalledFunction();
15021505

15031506
if (!DirectCallee) {
@@ -1507,21 +1510,16 @@ Align NVPTXTargetLowering::getArgumentAlignment(const CallBase *CB, Type *Ty,
15071510
// With bitcast'd call targets, the instruction will be the call
15081511
if (const auto *CI = dyn_cast<CallInst>(CB)) {
15091512
// Check if we have call alignment metadata
1510-
if (getAlign(*CI, Idx, Alignment))
1511-
return Align(Alignment);
1513+
if (MaybeAlign StackAlign = getAlign(*CI, Idx))
1514+
return StackAlign.value();
15121515
}
15131516
DirectCallee = getMaybeBitcastedCallee(CB);
15141517
}
15151518

15161519
// Check for function alignment information if we found that the
15171520
// ultimate target is a Function
1518-
if (DirectCallee) {
1519-
if (getAlign(*DirectCallee, Idx, Alignment))
1520-
return Align(Alignment);
1521-
// If alignment information is not available, fall back to the
1522-
// default function param optimized type alignment
1523-
return getFunctionParamOptimizedAlign(DirectCallee, Ty, DL);
1524-
}
1521+
if (DirectCallee)
1522+
return getFunctionArgumentAlignment(DirectCallee, Ty, Idx, DL);
15251523

15261524
// Call is indirect, fall back to the ABI type alignment
15271525
return DL.getABITypeAlign(Ty);
@@ -3195,8 +3193,9 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
31953193
if (VTs.empty())
31963194
report_fatal_error("Empty parameter types are not supported");
31973195

3198-
auto VectorInfo =
3199-
VectorizePTXValueVTs(VTs, Offsets, DL.getABITypeAlign(Ty));
3196+
Align ArgAlign = getFunctionArgumentAlignment(
3197+
F, Ty, i + AttributeList::FirstArgIndex, DL);
3198+
auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, ArgAlign);
32003199

32013200
SDValue Arg = getParamSymbol(DAG, i, PtrVT);
32023201
int VecIdx = -1; // Index of the first element of the current vector.

llvm/lib/Target/NVPTX/NVPTXISelLowering.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,9 @@ class NVPTXTargetLowering : public TargetLowering {
462462
MachineFunction &MF,
463463
unsigned Intrinsic) const override;
464464

465+
Align getFunctionArgumentAlignment(const Function *F, Type *Ty, unsigned Idx,
466+
const DataLayout &DL) const;
467+
465468
/// getFunctionParamOptimizedAlign - since function arguments are passed via
466469
/// .param space, we may want to increase their alignment in a way that
467470
/// ensures that we can effectively vectorize their loads & stores. We can

llvm/lib/Target/NVPTX/NVPTXUtilities.cpp

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,13 @@
1919
#include "llvm/IR/InstIterator.h"
2020
#include "llvm/IR/Module.h"
2121
#include "llvm/IR/Operator.h"
22+
#include "llvm/Support/Alignment.h"
2223
#include "llvm/Support/Mutex.h"
2324
#include <algorithm>
2425
#include <cstring>
2526
#include <map>
2627
#include <mutex>
28+
#include <optional>
2729
#include <string>
2830
#include <vector>
2931

@@ -296,37 +298,44 @@ bool isKernelFunction(const Function &F) {
296298
return (x == 1);
297299
}
298300

299-
bool getAlign(const Function &F, unsigned index, unsigned &align) {
301+
MaybeAlign getAlign(const Function &F, unsigned Index) {
302+
// First check the alignstack metadata
303+
if (MaybeAlign StackAlign =
304+
F.getAttributes().getAttributes(Index).getStackAlignment())
305+
return StackAlign;
306+
307+
// If that is missing, check the legacy nvvm metadata
300308
std::vector<unsigned> Vs;
301309
bool retval = findAllNVVMAnnotation(&F, "align", Vs);
302310
if (!retval)
303-
return false;
304-
for (unsigned v : Vs) {
305-
if ((v >> 16) == index) {
306-
align = v & 0xFFFF;
307-
return true;
308-
}
309-
}
310-
return false;
311+
return std::nullopt;
312+
for (unsigned V : Vs)
313+
if ((V >> 16) == Index)
314+
return Align(V & 0xFFFF);
315+
316+
return std::nullopt;
311317
}
312318

313-
bool getAlign(const CallInst &I, unsigned index, unsigned &align) {
319+
MaybeAlign getAlign(const CallInst &I, unsigned Index) {
320+
// First check the alignstack metadata
321+
if (MaybeAlign StackAlign =
322+
I.getAttributes().getAttributes(Index).getStackAlignment())
323+
return StackAlign;
324+
325+
// If that is missing, check the legacy nvvm metadata
314326
if (MDNode *alignNode = I.getMetadata("callalign")) {
315327
for (int i = 0, n = alignNode->getNumOperands(); i < n; i++) {
316328
if (const ConstantInt *CI =
317329
mdconst::dyn_extract<ConstantInt>(alignNode->getOperand(i))) {
318-
unsigned v = CI->getZExtValue();
319-
if ((v >> 16) == index) {
320-
align = v & 0xFFFF;
321-
return true;
322-
}
323-
if ((v >> 16) > index) {
324-
return false;
325-
}
330+
unsigned V = CI->getZExtValue();
331+
if ((V >> 16) == Index)
332+
return Align(V & 0xFFFF);
333+
if ((V >> 16) > Index)
334+
return std::nullopt;
326335
}
327336
}
328337
}
329-
return false;
338+
return std::nullopt;
330339
}
331340

332341
Function *getMaybeBitcastedCallee(const CallBase *CB) {

llvm/lib/Target/NVPTX/NVPTXUtilities.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "llvm/IR/GlobalVariable.h"
1919
#include "llvm/IR/IntrinsicInst.h"
2020
#include "llvm/IR/Value.h"
21+
#include "llvm/Support/Alignment.h"
2122
#include <cstdarg>
2223
#include <set>
2324
#include <string>
@@ -60,8 +61,8 @@ bool getMinCTASm(const Function &, unsigned &);
6061
bool getMaxNReg(const Function &, unsigned &);
6162
bool isKernelFunction(const Function &);
6263

63-
bool getAlign(const Function &, unsigned index, unsigned &);
64-
bool getAlign(const CallInst &, unsigned index, unsigned &);
64+
MaybeAlign getAlign(const Function &, unsigned);
65+
MaybeAlign getAlign(const CallInst &, unsigned);
6566
Function *getMaybeBitcastedCallee(const CallBase *CB);
6667

6768
// PTX ABI requires all scalar argument/return values to have
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
; RUN: llc < %s -march=nvptx | FileCheck %s
2+
; RUN: %if ptxas %{ llc < %s -march=nvptx -verify-machineinstrs | %ptxas-verify %}
3+
4+
target triple = "nvptx64-nvidia-cuda"
5+
6+
%struct.float2 = type { float, float }
7+
8+
; CHECK-LABEL: .visible .func (.param .b32 func_retval0) callee_md
9+
; CHECK-NEXT: (
10+
; CHECK-NEXT: .param .align 8 .b8 callee_md_param_0[8]
11+
; CHECK-NEXT: )
12+
; CHECK-NEXT: ;
13+
14+
; CHECK-LABEL: .visible .func (.param .b32 func_retval0) callee
15+
; CHECK-NEXT: (
16+
; CHECK-NEXT: .param .align 8 .b8 callee_param_0[8]
17+
; CHECK-NEXT: )
18+
; CHECK-NEXT: ;
19+
20+
define float @caller_md(float %a, float %b) {
21+
; CHECK-LABEL: .visible .func (.param .b32 func_retval0) caller_md(
22+
; CHECK-NEXT: .param .b32 caller_md_param_0,
23+
; CHECK-NEXT: .param .b32 caller_md_param_1
24+
; CHECK-NEXT: )
25+
; CHECK-NEXT: {
26+
27+
; CHECK: ld.param.f32 %f1, [caller_md_param_0];
28+
; CHECK-NEXT: ld.param.f32 %f2, [caller_md_param_1];
29+
; CHECK-NEXT: {
30+
; CHECK-NEXT: .param .align 8 .b8 param0[8];
31+
; CHECK-NEXT: st.param.v2.f32 [param0+0], {%f1, %f2};
32+
; CHECK-NEXT: .param .b32 retval0;
33+
; CHECK-NEXT: call.uni (retval0),
34+
; CHECK-NEXT: callee_md,
35+
; CHECK-NEXT: (
36+
; CHECK-NEXT: param0
37+
; CHECK-NEXT: );
38+
; CHECK-NEXT: ld.param.f32 %f3, [retval0+0];
39+
; CHECK-NEXT: }
40+
; CHECK-NEXT: st.param.f32 [func_retval0+0], %f3;
41+
; CHECK-NEXT: ret;
42+
%s1 = insertvalue %struct.float2 poison, float %a, 0
43+
%s2 = insertvalue %struct.float2 %s1, float %b, 1
44+
%r = call float @callee_md(%struct.float2 %s2)
45+
ret float %r
46+
}
47+
48+
define float @callee_md(%struct.float2 %a) {
49+
; CHECK-LABEL: .visible .func (.param .b32 func_retval0) callee_md(
50+
; CHECK-NEXT: .param .align 8 .b8 callee_md_param_0[8]
51+
; CHECK-NEXT: )
52+
; CHECK-NEXT: {
53+
54+
; CHECK: ld.param.v2.f32 {%f1, %f2}, [callee_md_param_0];
55+
; CHECK-NEXT: add.rn.f32 %f3, %f1, %f2;
56+
; CHECK-NEXT: st.param.f32 [func_retval0+0], %f3;
57+
; CHECK-NEXT: ret;
58+
%v0 = extractvalue %struct.float2 %a, 0
59+
%v1 = extractvalue %struct.float2 %a, 1
60+
%2 = fadd float %v0, %v1
61+
ret float %2
62+
}
63+
64+
define float @caller(float %a, float %b) {
65+
; CHECK-LABEL: .visible .func (.param .b32 func_retval0) caller(
66+
; CHECK-NEXT: .param .b32 caller_param_0,
67+
; CHECK-NEXT: .param .b32 caller_param_1
68+
; CHECK-NEXT: )
69+
; CHECK-NEXT: {
70+
71+
; CHECK: ld.param.f32 %f1, [caller_param_0];
72+
; CHECK-NEXT: ld.param.f32 %f2, [caller_param_1];
73+
; CHECK-NEXT: {
74+
; CHECK-NEXT: .param .align 8 .b8 param0[8];
75+
; CHECK-NEXT: st.param.v2.f32 [param0+0], {%f1, %f2};
76+
; CHECK-NEXT: .param .b32 retval0;
77+
; CHECK-NEXT: call.uni (retval0),
78+
; CHECK-NEXT: callee,
79+
; CHECK-NEXT: (
80+
; CHECK-NEXT: param0
81+
; CHECK-NEXT: );
82+
; CHECK-NEXT: ld.param.f32 %f3, [retval0+0];
83+
; CHECK-NEXT: }
84+
; CHECK-NEXT: st.param.f32 [func_retval0+0], %f3;
85+
; CHECK-NEXT: ret;
86+
%s1 = insertvalue %struct.float2 poison, float %a, 0
87+
%s2 = insertvalue %struct.float2 %s1, float %b, 1
88+
%r = call float @callee(%struct.float2 %s2)
89+
ret float %r
90+
}
91+
92+
define float @callee(%struct.float2 alignstack(8) %a ) {
93+
; CHECK-LABEL: .visible .func (.param .b32 func_retval0) callee(
94+
; CHECK-NEXT: .param .align 8 .b8 callee_param_0[8]
95+
; CHECK-NEXT: )
96+
; CHECK-NEXT: {
97+
98+
; CHECK: ld.param.v2.f32 {%f1, %f2}, [callee_param_0];
99+
; CHECK-NEXT: add.rn.f32 %f3, %f1, %f2;
100+
; CHECK-NEXT: st.param.f32 [func_retval0+0], %f3;
101+
; CHECK-NEXT: ret;
102+
%v0 = extractvalue %struct.float2 %a, 0
103+
%v1 = extractvalue %struct.float2 %a, 1
104+
%2 = fadd float %v0, %v1
105+
ret float %2
106+
}
107+
108+
!nvvm.annotations = !{!0}
109+
!0 = !{ptr @callee_md, !"align", i32 u0x00010008}

0 commit comments

Comments
 (0)