Skip to content

[NVPTX] Auto-Upgrade some nvvm.annotations to attributes #119261

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions llvm/include/llvm/IR/AutoUpgrade.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ namespace llvm {
/// module is modified.
bool UpgradeModuleFlags(Module &M);

/// Convert legacy nvvm.annotations metadata to appropriate function
/// attributes.
void UpgradeNVVMAnnotations(Module &M);

/// Convert calls to ARC runtime functions to intrinsic calls and upgrade the
/// old retain release marker to new module flag format.
void UpgradeARCRuntime(Module &M);
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/AsmParser/LLParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,7 @@ bool LLParser::validateEndOfModule(bool UpgradeDebugInfo) {
llvm::UpgradeDebugInfo(*M);

UpgradeModuleFlags(*M);
UpgradeNVVMAnnotations(*M);
UpgradeSectionAttributes(*M);

if (PreserveInputDbgFormat != cl::boolOrDefault::BOU_TRUE)
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Bitcode/Reader/BitcodeReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7157,6 +7157,8 @@ Error BitcodeReader::materializeModule() {

UpgradeModuleFlags(*TheModule);

UpgradeNVVMAnnotations(*TheModule);

UpgradeARCRuntime(*TheModule);

return Error::success();
Expand Down
67 changes: 67 additions & 0 deletions llvm/lib/IR/AutoUpgrade.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "llvm/ADT/StringSwitch.h"
#include "llvm/BinaryFormat/Dwarf.h"
#include "llvm/IR/AttributeMask.h"
#include "llvm/IR/CallingConv.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DebugInfo.h"
#include "llvm/IR/DebugInfoMetadata.h"
Expand Down Expand Up @@ -5019,6 +5020,72 @@ bool llvm::UpgradeDebugInfo(Module &M) {
return Modified;
}

bool static upgradeSingleNVVMAnnotation(GlobalValue *GV, StringRef K,
const Metadata *V) {
if (K == "kernel") {
if (!mdconst::extract<ConstantInt>(V)->isZero())
cast<Function>(GV)->setCallingConv(CallingConv::PTX_Kernel);
return true;
}
if (K == "align") {
// V is a bitfeild specifying two 16-bit values. The alignment value is
// specfied in low 16-bits, The index is specified in the high bits. For the
// index, 0 indicates the return value while higher values correspond to
// each parameter (idx = param + 1).
const uint64_t AlignIdxValuePair =
mdconst::extract<ConstantInt>(V)->getZExtValue();
const unsigned Idx = (AlignIdxValuePair >> 16);
const Align StackAlign = Align(AlignIdxValuePair & 0xFFFF);
// TODO: Skip adding the stackalign attribute for returns, for now.
if (!Idx)
return false;
cast<Function>(GV)->addAttributeAtIndex(
Idx, Attribute::getWithStackAlignment(GV->getContext(), StackAlign));
return true;
}

return false;
}

void llvm::UpgradeNVVMAnnotations(Module &M) {
NamedMDNode *NamedMD = M.getNamedMetadata("nvvm.annotations");
if (!NamedMD)
return;

SmallVector<MDNode *, 8> NewNodes;
SmallSet<const MDNode *, 8> SeenNodes;
for (MDNode *MD : NamedMD->operands()) {
if (!SeenNodes.insert(MD).second)
continue;

auto *GV = mdconst::dyn_extract_or_null<GlobalValue>(MD->getOperand(0));
if (!GV)
continue;

assert((MD->getNumOperands() % 2) == 1 && "Invalid number of operands");

SmallVector<Metadata *, 8> NewOperands{MD->getOperand(0)};
// Each nvvm.annotations metadata entry will be of the following form:
// !{ ptr @gv, !"key1", value1, !"key2", value2, ... }
// start index = 1, to skip the global variable key
// increment = 2, to skip the value for each property-value pairs
for (unsigned j = 1, je = MD->getNumOperands(); j < je; j += 2) {
MDString *K = cast<MDString>(MD->getOperand(j));
const MDOperand &V = MD->getOperand(j + 1);
bool Upgraded = upgradeSingleNVVMAnnotation(GV, K->getString(), V);
if (!Upgraded)
NewOperands.append({K, V});
}

if (NewOperands.size() > 1)
NewNodes.push_back(MDNode::get(M.getContext(), NewOperands));
}

NamedMD->clearOperands();
for (MDNode *N : NewNodes)
NamedMD->addOperand(N);
}

/// This checks for objc retain release marker which should be upgraded. It
/// returns true if module is modified.
static bool upgradeRetainReleaseMarker(Module &M) {
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Linker/IRMover.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1244,6 +1244,7 @@ Error IRLinker::linkModuleFlagsMetadata() {

// Check for module flag for updates before do anything.
UpgradeModuleFlags(*SrcM);
UpgradeNVVMAnnotations(*SrcM);

// If the destination module doesn't have module flags yet, then just copy
// over the source module's flags.
Expand Down
27 changes: 9 additions & 18 deletions llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -310,30 +310,21 @@ std::optional<unsigned> getMaxNReg(const Function &F) {
return findOneNVVMAnnotation(&F, "maxnreg");
}

bool isKernelFunction(const Function &F) {
if (F.getCallingConv() == CallingConv::PTX_Kernel)
return true;

if (const auto X = findOneNVVMAnnotation(&F, "kernel"))
return (*X == 1);

return false;
}

MaybeAlign getAlign(const Function &F, unsigned Index) {
// First check the alignstack metadata
if (MaybeAlign StackAlign =
F.getAttributes().getAttributes(Index).getStackAlignment())
return StackAlign;

// If that is missing, check the legacy nvvm metadata
std::vector<unsigned> Vs;
bool retval = findAllNVVMAnnotation(&F, "align", Vs);
if (!retval)
return std::nullopt;
for (unsigned V : Vs)
if ((V >> 16) == Index)
return Align(V & 0xFFFF);
// check the legacy nvvm metadata only for the return value since llvm does
// not support stackalign attribute for this.
if (Index == 0) {
std::vector<unsigned> Vs;
if (findAllNVVMAnnotation(&F, "align", Vs))
Copy link
Member

@Artem-B Artem-B Dec 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Offtopic: I think if findAllNVVMAnnotation() returned ArrayRef it would work much nicer than copying data into a temp array. Bonus points for making it plural.

if (Index == 0) {
  for (unsigned V : findAllNVVMAnnotation())
     do stuff;
}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, I agree the NVVM annotation APIs could be cleaned up significantly, hopefully this work will remove the need for them altogether though.

for (unsigned V : Vs)
if ((V >> 16) == Index)
return Align(V & 0xFFFF);
}

return std::nullopt;
}
Expand Down
7 changes: 6 additions & 1 deletion llvm/lib/Target/NVPTX/NVPTXUtilities.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "NVPTX.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/CodeGen/ValueTypes.h"
#include "llvm/IR/CallingConv.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/GlobalVariable.h"
#include "llvm/IR/IntrinsicInst.h"
Expand Down Expand Up @@ -63,7 +64,11 @@ std::optional<unsigned> getClusterDimz(const Function &);
std::optional<unsigned> getMaxClusterRank(const Function &);
std::optional<unsigned> getMinCTASm(const Function &);
std::optional<unsigned> getMaxNReg(const Function &);
bool isKernelFunction(const Function &);

inline bool isKernelFunction(const Function &F) {
return F.getCallingConv() == CallingConv::PTX_Kernel;
}

bool isParamGridConstant(const Value &);

MaybeAlign getAlign(const Function &, unsigned);
Expand Down
28 changes: 4 additions & 24 deletions llvm/lib/Transforms/IPO/OpenMPOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5906,39 +5906,19 @@ bool llvm::omp::isOpenMPKernel(Function &Fn) {
}

KernelSet llvm::omp::getDeviceKernels(Module &M) {
// TODO: Create a more cross-platform way of determining device kernels.
KernelSet Kernels;

DenseSet<const Function *> SeenKernels;
auto ProcessKernel = [&](Function &KF) {
if (SeenKernels.insert(&KF).second) {
for (Function &F : M)
if (F.hasKernelCallingConv()) {
// We are only interested in OpenMP target regions. Others, such as
// kernels generated by CUDA but linked together, are not interesting to
// this pass.
if (isOpenMPKernel(KF)) {
if (isOpenMPKernel(F)) {
++NumOpenMPTargetRegionKernels;
Kernels.insert(&KF);
Kernels.insert(&F);
} else
++NumNonOpenMPTargetRegionKernels;
}
};

if (NamedMDNode *MD = M.getNamedMetadata("nvvm.annotations"))
for (auto *Op : MD->operands()) {
if (Op->getNumOperands() < 2)
continue;
MDString *KindID = dyn_cast<MDString>(Op->getOperand(1));
if (!KindID || KindID->getString() != "kernel")
continue;

if (auto *KernelFn =
mdconst::dyn_extract_or_null<Function>(Op->getOperand(0)))
ProcessKernel(*KernelFn);
}

for (Function &F : M)
if (F.hasKernelCallingConv())
ProcessKernel(F);

return Kernels;
}
Expand Down
28 changes: 28 additions & 0 deletions llvm/test/CodeGen/NVPTX/upgrade-nvvm-annotations.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --check-attributes --check-globals all --version 5
; RUN: opt < %s -mtriple=nvptx64-unknown-unknown -O0 -S | FileCheck %s

define i32 @foo(i32 %a, i32 %b) {
; CHECK-LABEL: define i32 @foo(
; CHECK-SAME: i32 alignstack(8) [[A:%.*]], i32 alignstack(16) [[B:%.*]]) {
; CHECK-NEXT: ret i32 0
;
ret i32 0
}

define i32 @bar(i32 %a, i32 %b) {
; CHECK-LABEL: define ptx_kernel i32 @bar(
; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]]) {
; CHECK-NEXT: ret i32 0
;
ret i32 0
}

!nvvm.annotations = !{!0, !1, !2}

!0 = !{ptr @foo, !"align", i32 u0x00000008, !"align", i32 u0x00010008, !"align", i32 u0x00020010}
!1 = !{null, !"align", i32 u0x00000008, !"align", i32 u0x00010008, !"align", i32 u0x00020008}
!2 = !{ptr @bar, !"kernel", i32 1}

;.
; CHECK: [[META0:![0-9]+]] = !{ptr @foo, !"align", i32 8}
;.