Skip to content

[NVPTX] fixup support for over-aligned parameters #92457

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
#include "llvm/MC/MCStreamer.h"
#include "llvm/MC/MCSymbol.h"
#include "llvm/MC/TargetRegistry.h"
#include "llvm/Support/Alignment.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Endian.h"
Expand Down Expand Up @@ -370,11 +371,10 @@ void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) {
<< " func_retval0";
} else if (ShouldPassAsArray(Ty)) {
unsigned totalsz = DL.getTypeAllocSize(Ty);
unsigned retAlignment = 0;
if (!getAlign(*F, 0, retAlignment))
retAlignment = TLI->getFunctionParamOptimizedAlign(F, Ty, DL).value();
O << ".param .align " << retAlignment << " .b8 func_retval0[" << totalsz
<< "]";
Align RetAlignment = TLI->getFunctionArgumentAlignment(
F, Ty, AttributeList::ReturnIndex, DL);
O << ".param .align " << RetAlignment.value() << " .b8 func_retval0["
<< totalsz << "]";
} else
llvm_unreachable("Unknown return type");
} else {
Expand Down Expand Up @@ -1558,6 +1558,10 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {

auto getOptimalAlignForParam = [TLI, &DL, &PAL, F,
paramIndex](Type *Ty) -> Align {
if (MaybeAlign StackAlign =
getAlign(*F, paramIndex + AttributeList::FirstArgIndex))
return StackAlign.value();

Align TypeAlign = TLI->getFunctionParamOptimizedAlign(F, Ty, DL);
MaybeAlign ParamAlign = PAL.getParamAlignment(paramIndex);
return std::max(TypeAlign, ParamAlign.valueOrOne());
Expand Down
33 changes: 16 additions & 17 deletions llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1434,12 +1434,11 @@ std::string NVPTXTargetLowering::getPrototype(

if (!Outs[OIdx].Flags.isByVal()) {
if (IsTypePassedAsArray(Ty)) {
unsigned ParamAlign = 0;
const CallInst *CallI = cast<CallInst>(&CB);
// +1 because index 0 is reserved for return type alignment
if (!getAlign(*CallI, i + 1, ParamAlign))
ParamAlign = getFunctionParamOptimizedAlign(F, Ty, DL).value();
O << ".param .align " << ParamAlign << " .b8 ";
Align ParamAlign =
getAlign(*CallI, i + AttributeList::FirstArgIndex)
.value_or(getFunctionParamOptimizedAlign(F, Ty, DL));
O << ".param .align " << ParamAlign.value() << " .b8 ";
O << "_";
O << "[" << DL.getTypeAllocSize(Ty) << "]";
// update the index for Outs
Expand Down Expand Up @@ -1489,6 +1488,11 @@ std::string NVPTXTargetLowering::getPrototype(
return Prototype;
}

Align NVPTXTargetLowering::getFunctionArgumentAlignment(
const Function *F, Type *Ty, unsigned Idx, const DataLayout &DL) const {
return getAlign(*F, Idx).value_or(getFunctionParamOptimizedAlign(F, Ty, DL));
}

Align NVPTXTargetLowering::getArgumentAlignment(const CallBase *CB, Type *Ty,
unsigned Idx,
const DataLayout &DL) const {
Expand All @@ -1497,7 +1501,6 @@ Align NVPTXTargetLowering::getArgumentAlignment(const CallBase *CB, Type *Ty,
return DL.getABITypeAlign(Ty);
}

unsigned Alignment = 0;
const Function *DirectCallee = CB->getCalledFunction();

if (!DirectCallee) {
Expand All @@ -1507,21 +1510,16 @@ Align NVPTXTargetLowering::getArgumentAlignment(const CallBase *CB, Type *Ty,
// With bitcast'd call targets, the instruction will be the call
if (const auto *CI = dyn_cast<CallInst>(CB)) {
// Check if we have call alignment metadata
if (getAlign(*CI, Idx, Alignment))
return Align(Alignment);
if (MaybeAlign StackAlign = getAlign(*CI, Idx))
return StackAlign.value();
}
DirectCallee = getMaybeBitcastedCallee(CB);
}

// Check for function alignment information if we found that the
// ultimate target is a Function
if (DirectCallee) {
if (getAlign(*DirectCallee, Idx, Alignment))
return Align(Alignment);
// If alignment information is not available, fall back to the
// default function param optimized type alignment
return getFunctionParamOptimizedAlign(DirectCallee, Ty, DL);
}
if (DirectCallee)
return getFunctionArgumentAlignment(DirectCallee, Ty, Idx, DL);

// Call is indirect, fall back to the ABI type alignment
return DL.getABITypeAlign(Ty);
Expand Down Expand Up @@ -3195,8 +3193,9 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
if (VTs.empty())
report_fatal_error("Empty parameter types are not supported");

auto VectorInfo =
VectorizePTXValueVTs(VTs, Offsets, DL.getABITypeAlign(Ty));
Align ArgAlign = getFunctionArgumentAlignment(
F, Ty, i + AttributeList::FirstArgIndex, DL);
auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, ArgAlign);

SDValue Arg = getParamSymbol(DAG, i, PtrVT);
int VecIdx = -1; // Index of the first element of the current vector.
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTXISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,9 @@ class NVPTXTargetLowering : public TargetLowering {
MachineFunction &MF,
unsigned Intrinsic) const override;

Align getFunctionArgumentAlignment(const Function *F, Type *Ty, unsigned Idx,
const DataLayout &DL) const;

/// getFunctionParamOptimizedAlign - since function arguments are passed via
/// .param space, we may want to increase their alignment in a way that
/// ensures that we can effectively vectorize their loads & stores. We can
Expand Down
47 changes: 28 additions & 19 deletions llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@
#include "llvm/IR/InstIterator.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Operator.h"
#include "llvm/Support/Alignment.h"
#include "llvm/Support/Mutex.h"
#include <algorithm>
#include <cstring>
#include <map>
#include <mutex>
#include <optional>
#include <string>
#include <vector>

Expand Down Expand Up @@ -296,37 +298,44 @@ bool isKernelFunction(const Function &F) {
return (x == 1);
}

bool getAlign(const Function &F, unsigned index, unsigned &align) {
MaybeAlign getAlign(const Function &F, unsigned Index) {
// First check the alignstack metadata
if (MaybeAlign StackAlign =
F.getAttributes().getAttributes(Index).getStackAlignment())
return StackAlign;

// If that is missing, check the legacy nvvm metadata
std::vector<unsigned> Vs;
bool retval = findAllNVVMAnnotation(&F, "align", Vs);
if (!retval)
return false;
for (unsigned v : Vs) {
if ((v >> 16) == index) {
align = v & 0xFFFF;
return true;
}
}
return false;
return std::nullopt;
for (unsigned V : Vs)
if ((V >> 16) == Index)
return Align(V & 0xFFFF);

return std::nullopt;
}

bool getAlign(const CallInst &I, unsigned index, unsigned &align) {
MaybeAlign getAlign(const CallInst &I, unsigned Index) {
// First check the alignstack metadata
if (MaybeAlign StackAlign =
I.getAttributes().getAttributes(Index).getStackAlignment())
return StackAlign;

// If that is missing, check the legacy nvvm metadata
if (MDNode *alignNode = I.getMetadata("callalign")) {
for (int i = 0, n = alignNode->getNumOperands(); i < n; i++) {
if (const ConstantInt *CI =
mdconst::dyn_extract<ConstantInt>(alignNode->getOperand(i))) {
unsigned v = CI->getZExtValue();
if ((v >> 16) == index) {
align = v & 0xFFFF;
return true;
}
if ((v >> 16) > index) {
return false;
}
unsigned V = CI->getZExtValue();
if ((V >> 16) == Index)
return Align(V & 0xFFFF);
if ((V >> 16) > Index)
return std::nullopt;
}
}
}
return false;
return std::nullopt;
}

Function *getMaybeBitcastedCallee(const CallBase *CB) {
Expand Down
5 changes: 3 additions & 2 deletions llvm/lib/Target/NVPTX/NVPTXUtilities.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "llvm/IR/GlobalVariable.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Value.h"
#include "llvm/Support/Alignment.h"
#include <cstdarg>
#include <set>
#include <string>
Expand Down Expand Up @@ -60,8 +61,8 @@ bool getMinCTASm(const Function &, unsigned &);
bool getMaxNReg(const Function &, unsigned &);
bool isKernelFunction(const Function &);

bool getAlign(const Function &, unsigned index, unsigned &);
bool getAlign(const CallInst &, unsigned index, unsigned &);
MaybeAlign getAlign(const Function &, unsigned);
MaybeAlign getAlign(const CallInst &, unsigned);
Function *getMaybeBitcastedCallee(const CallBase *CB);

// PTX ABI requires all scalar argument/return values to have
Expand Down
109 changes: 109 additions & 0 deletions llvm/test/CodeGen/NVPTX/param-overalign.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
; RUN: llc < %s -march=nvptx | FileCheck %s
; RUN: %if ptxas %{ llc < %s -march=nvptx -verify-machineinstrs | %ptxas-verify %}

target triple = "nvptx64-nvidia-cuda"

%struct.float2 = type { float, float }

; CHECK-LABEL: .visible .func (.param .b32 func_retval0) callee_md
; CHECK-NEXT: (
; CHECK-NEXT: .param .align 8 .b8 callee_md_param_0[8]
; CHECK-NEXT: )
; CHECK-NEXT: ;

; CHECK-LABEL: .visible .func (.param .b32 func_retval0) callee
; CHECK-NEXT: (
; CHECK-NEXT: .param .align 8 .b8 callee_param_0[8]
; CHECK-NEXT: )
; CHECK-NEXT: ;

define float @caller_md(float %a, float %b) {
; CHECK-LABEL: .visible .func (.param .b32 func_retval0) caller_md(
; CHECK-NEXT: .param .b32 caller_md_param_0,
; CHECK-NEXT: .param .b32 caller_md_param_1
; CHECK-NEXT: )
; CHECK-NEXT: {

; CHECK: ld.param.f32 %f1, [caller_md_param_0];
; CHECK-NEXT: ld.param.f32 %f2, [caller_md_param_1];
; CHECK-NEXT: {
; CHECK-NEXT: .param .align 8 .b8 param0[8];
; CHECK-NEXT: st.param.v2.f32 [param0+0], {%f1, %f2};
; CHECK-NEXT: .param .b32 retval0;
; CHECK-NEXT: call.uni (retval0),
; CHECK-NEXT: callee_md,
; CHECK-NEXT: (
; CHECK-NEXT: param0
; CHECK-NEXT: );
; CHECK-NEXT: ld.param.f32 %f3, [retval0+0];
; CHECK-NEXT: }
; CHECK-NEXT: st.param.f32 [func_retval0+0], %f3;
; CHECK-NEXT: ret;
%s1 = insertvalue %struct.float2 poison, float %a, 0
%s2 = insertvalue %struct.float2 %s1, float %b, 1
%r = call float @callee_md(%struct.float2 %s2)
ret float %r
}

define float @callee_md(%struct.float2 %a) {
; CHECK-LABEL: .visible .func (.param .b32 func_retval0) callee_md(
; CHECK-NEXT: .param .align 8 .b8 callee_md_param_0[8]
; CHECK-NEXT: )
; CHECK-NEXT: {

; CHECK: ld.param.v2.f32 {%f1, %f2}, [callee_md_param_0];
; CHECK-NEXT: add.rn.f32 %f3, %f1, %f2;
; CHECK-NEXT: st.param.f32 [func_retval0+0], %f3;
; CHECK-NEXT: ret;
%v0 = extractvalue %struct.float2 %a, 0
%v1 = extractvalue %struct.float2 %a, 1
%2 = fadd float %v0, %v1
ret float %2
}

define float @caller(float %a, float %b) {
; CHECK-LABEL: .visible .func (.param .b32 func_retval0) caller(
; CHECK-NEXT: .param .b32 caller_param_0,
; CHECK-NEXT: .param .b32 caller_param_1
; CHECK-NEXT: )
; CHECK-NEXT: {

; CHECK: ld.param.f32 %f1, [caller_param_0];
; CHECK-NEXT: ld.param.f32 %f2, [caller_param_1];
; CHECK-NEXT: {
; CHECK-NEXT: .param .align 8 .b8 param0[8];
; CHECK-NEXT: st.param.v2.f32 [param0+0], {%f1, %f2};
; CHECK-NEXT: .param .b32 retval0;
; CHECK-NEXT: call.uni (retval0),
; CHECK-NEXT: callee,
; CHECK-NEXT: (
; CHECK-NEXT: param0
; CHECK-NEXT: );
; CHECK-NEXT: ld.param.f32 %f3, [retval0+0];
; CHECK-NEXT: }
; CHECK-NEXT: st.param.f32 [func_retval0+0], %f3;
; CHECK-NEXT: ret;
%s1 = insertvalue %struct.float2 poison, float %a, 0
%s2 = insertvalue %struct.float2 %s1, float %b, 1
%r = call float @callee(%struct.float2 %s2)
ret float %r
}

define float @callee(%struct.float2 alignstack(8) %a ) {
; CHECK-LABEL: .visible .func (.param .b32 func_retval0) callee(
; CHECK-NEXT: .param .align 8 .b8 callee_param_0[8]
; CHECK-NEXT: )
; CHECK-NEXT: {

; CHECK: ld.param.v2.f32 {%f1, %f2}, [callee_param_0];
; CHECK-NEXT: add.rn.f32 %f3, %f1, %f2;
; CHECK-NEXT: st.param.f32 [func_retval0+0], %f3;
; CHECK-NEXT: ret;
%v0 = extractvalue %struct.float2 %a, 0
%v1 = extractvalue %struct.float2 %a, 1
%2 = fadd float %v0, %v1
ret float %2
}

!nvvm.annotations = !{!0}
!0 = !{ptr @callee_md, !"align", i32 u0x00010008}
Loading