Skip to content

Commit fa6850a

Browse files
[mlir][nvvm]Add support for grid_constant attribute on LLVM function arguments (#78228)
Add support for attribute nvvm.grid_constant on LLVM function arguments. The attribute can be attached only to arguments of type llvm.ptr that have llvm.byval attribute. Generate LLVM metadata for functions with nvvm.grid_constant arguments. The metadata node is a list of integers, where each integer n denotes that the nth parameter has the grid_constant annotation (numbering from 1). The generated metadata node will be handled by NVVM compiler. See https://docs.nvidia.com/cuda/nvvm-ir-spec/index.html#supported-properties for documentation on grid_constant property. This patch also adds convertParameterAttr to LLVMTranslationDialectInterface for supporting the translation of derived dialect attributes on function parameters 
1 parent 8799d71 commit fa6850a

File tree

10 files changed

+223
-30
lines changed

10 files changed

+223
-30
lines changed

llvm/include/llvm/IR/Metadata.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1735,7 +1735,7 @@ class NamedMDNode : public ilist_node<NamedMDNode> {
17351735

17361736
explicit NamedMDNode(const Twine &N);
17371737

1738-
template <class T1, class T2> class op_iterator_impl {
1738+
template <class T1> class op_iterator_impl {
17391739
friend class NamedMDNode;
17401740

17411741
const NamedMDNode *Node = nullptr;
@@ -1745,10 +1745,10 @@ class NamedMDNode : public ilist_node<NamedMDNode> {
17451745

17461746
public:
17471747
using iterator_category = std::bidirectional_iterator_tag;
1748-
using value_type = T2;
1748+
using value_type = T1;
17491749
using difference_type = std::ptrdiff_t;
17501750
using pointer = value_type *;
1751-
using reference = value_type &;
1751+
using reference = value_type;
17521752

17531753
op_iterator_impl() = default;
17541754

@@ -1809,12 +1809,12 @@ class NamedMDNode : public ilist_node<NamedMDNode> {
18091809
// ---------------------------------------------------------------------------
18101810
// Operand Iterator interface...
18111811
//
1812-
using op_iterator = op_iterator_impl<MDNode *, MDNode>;
1812+
using op_iterator = op_iterator_impl<MDNode *>;
18131813

18141814
op_iterator op_begin() { return op_iterator(this, 0); }
18151815
op_iterator op_end() { return op_iterator(this, getNumOperands()); }
18161816

1817-
using const_op_iterator = op_iterator_impl<const MDNode *, MDNode>;
1817+
using const_op_iterator = op_iterator_impl<const MDNode *>;
18181818

18191819
const_op_iterator op_begin() const { return const_op_iterator(this, 0); }
18201820
const_op_iterator op_end() const { return const_op_iterator(this, getNumOperands()); }

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,19 @@ def NVVM_Dialect : Dialect {
5959
/// Get the name of the attribute used to annotate max number of
6060
/// registers that can be allocated per thread.
6161
static StringRef getMaxnregAttrName() { return "nvvm.maxnreg"; }
62+
63+
/// Get the name of the attribute used to annotate kernel arguments that
64+
/// are grid constants.
65+
static StringRef getGridConstantAttrName() { return "nvvm.grid_constant"; }
66+
67+
/// Verify an attribute from this dialect on the argument at 'argIndex' for
68+
/// the region at 'regionIndex' on the given operation. Returns failure if
69+
/// the verification failed, success otherwise. This hook may optionally be
70+
/// invoked from any operation containing a region.
71+
LogicalResult verifyRegionArgAttribute(Operation *op,
72+
unsigned regionIndex,
73+
unsigned argIndex,
74+
NamedAttribute argAttr) override;
6275
}];
6376

6477
let useDefaultAttributePrinterParser = 1;

mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#ifndef MLIR_TARGET_LLVMIR_LLVMTRANSLATIONINTERFACE_H
1414
#define MLIR_TARGET_LLVMIR_LLVMTRANSLATIONINTERFACE_H
1515

16+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1617
#include "mlir/IR/BuiltinAttributes.h"
1718
#include "mlir/IR/DialectInterface.h"
1819
#include "mlir/Support/LogicalResult.h"
@@ -25,6 +26,7 @@ class IRBuilderBase;
2526
namespace mlir {
2627
namespace LLVM {
2728
class ModuleTranslation;
29+
class LLVMFuncOp;
2830
} // namespace LLVM
2931

3032
/// Base class for dialect interfaces providing translation to LLVM IR.
@@ -58,6 +60,16 @@ class LLVMTranslationDialectInterface
5860
LLVM::ModuleTranslation &moduleTranslation) const {
5961
return success();
6062
}
63+
64+
/// Hook for derived dialect interface to translate or act on a derived
65+
/// dialect attribute that appears on a function parameter. This gets called
66+
/// after the function operation has been translated.
67+
virtual LogicalResult
68+
convertParameterAttr(LLVM::LLVMFuncOp function, int argIdx,
69+
NamedAttribute attr,
70+
LLVM::ModuleTranslation &moduleTranslation) const {
71+
return success();
72+
}
6173
};
6274

6375
/// Interface collection for translation to LLVM IR, dispatches to a concrete
@@ -90,6 +102,22 @@ class LLVMTranslationInterface
90102
}
91103
return success();
92104
}
105+
106+
/// Acts on the given function operation using the interface implemented by
107+
/// the dialect of one of the function parameter attributes.
108+
virtual LogicalResult
109+
convertParameterAttr(LLVM::LLVMFuncOp function, int argIdx,
110+
NamedAttribute attribute,
111+
LLVM::ModuleTranslation &moduleTranslation) const {
112+
if (const LLVMTranslationDialectInterface *iface =
113+
getInterfaceFor(attribute.getNameDialect())) {
114+
return iface->convertParameterAttr(function, argIdx, attribute,
115+
moduleTranslation);
116+
}
117+
function.emitWarning("Unhandled parameter attribute '" +
118+
attribute.getName().str() + "'");
119+
return success();
120+
}
93121
};
94122

95123
} // namespace mlir

mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,9 @@ class ModuleTranslation {
327327
ArrayRef<llvm::Instruction *> instructions);
328328

329329
/// Translates parameter attributes and adds them to the returned AttrBuilder.
330-
llvm::AttrBuilder convertParameterAttrs(DictionaryAttr paramAttrs);
330+
/// Returns failure if any of the translations failed.
331+
FailureOr<llvm::AttrBuilder>
332+
convertParameterAttrs(LLVMFuncOp func, int argIdx, DictionaryAttr paramAttrs);
331333

332334
/// Original and translated module.
333335
Operation *mlirModule;

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1074,6 +1074,35 @@ LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op,
10741074
return success();
10751075
}
10761076

1077+
LogicalResult NVVMDialect::verifyRegionArgAttribute(Operation *op,
1078+
unsigned regionIndex,
1079+
unsigned argIndex,
1080+
NamedAttribute argAttr) {
1081+
auto funcOp = dyn_cast<FunctionOpInterface>(op);
1082+
if (!funcOp)
1083+
return success();
1084+
1085+
bool isKernel = op->hasAttr(NVVMDialect::getKernelFuncAttrName());
1086+
StringAttr attrName = argAttr.getName();
1087+
if (attrName == NVVM::NVVMDialect::getGridConstantAttrName()) {
1088+
if (!isKernel) {
1089+
return op->emitError()
1090+
<< "'" << attrName
1091+
<< "' attribute must be present only on kernel arguments";
1092+
}
1093+
if (!isa<UnitAttr>(argAttr.getValue()))
1094+
return op->emitError() << "'" << attrName << "' must be a unit attribute";
1095+
if (!funcOp.getArgAttr(argIndex, LLVM::LLVMDialect::getByValAttrName())) {
1096+
return op->emitError()
1097+
<< "'" << attrName
1098+
<< "' attribute requires the argument to also have attribute '"
1099+
<< LLVM::LLVMDialect::getByValAttrName() << "'";
1100+
}
1101+
}
1102+
1103+
return success();
1104+
}
1105+
10771106
//===----------------------------------------------------------------------===//
10781107
// NVVM target attribute.
10791108
//===----------------------------------------------------------------------===//

mlir/lib/Target/LLVMIR/AttrKindDetail.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,20 @@ getAttrKindToNameMapping() {
5959
return kindNamePairs;
6060
}
6161

62+
/// Returns a dense map from LLVM attribute name to their kind in LLVM IR
63+
/// dialect.
64+
static llvm::DenseMap<llvm::StringRef, llvm::Attribute::AttrKind>
65+
getAttrNameToKindMapping() {
66+
static auto attrNameToKindMapping = []() {
67+
llvm::DenseMap<llvm::StringRef, llvm::Attribute::AttrKind> nameKindMap;
68+
for (auto kindNamePair : getAttrKindToNameMapping()) {
69+
nameKindMap.insert({kindNamePair.second, kindNamePair.first});
70+
}
71+
return nameKindMap;
72+
}();
73+
return attrNameToKindMapping;
74+
}
75+
6276
} // namespace detail
6377
} // namespace LLVM
6478
} // namespace mlir

mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,62 @@ class NVVMDialectLLVMIRTranslationInterface
201201
}
202202
return success();
203203
}
204+
205+
LogicalResult
206+
convertParameterAttr(LLVMFuncOp funcOp, int argIdx, NamedAttribute attribute,
207+
LLVM::ModuleTranslation &moduleTranslation) const final {
208+
209+
llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
210+
llvm::Function *llvmFunc =
211+
moduleTranslation.lookupFunction(funcOp.getName());
212+
llvm::NamedMDNode *nvvmAnnotations =
213+
moduleTranslation.getOrInsertNamedModuleMetadata("nvvm.annotations");
214+
215+
if (attribute.getName() == NVVM::NVVMDialect::getGridConstantAttrName()) {
216+
llvm::MDNode *gridConstantMetaData = nullptr;
217+
218+
// Check if a 'grid_constant' metadata node exists for the given function
219+
for (llvm::MDNode *opnd : llvm::reverse(nvvmAnnotations->operands())) {
220+
if (opnd->getNumOperands() == 3 &&
221+
opnd->getOperand(0) == llvm::ValueAsMetadata::get(llvmFunc) &&
222+
opnd->getOperand(1) ==
223+
llvm::MDString::get(llvmContext, "grid_constant")) {
224+
gridConstantMetaData = opnd;
225+
break;
226+
}
227+
}
228+
229+
// 'grid_constant' is a function-level meta data node with a list of
230+
// integers, where each integer n denotes that the nth parameter has the
231+
// grid_constant annotation (numbering from 1). This requires aggregating
232+
// the indices of the individual parameters that have this attribute.
233+
llvm::Type *i32 = llvm::IntegerType::get(llvmContext, 32);
234+
if (gridConstantMetaData == nullptr) {
235+
// Create a new 'grid_constant' metadata node
236+
SmallVector<llvm::Metadata *> gridConstMetadata = {
237+
llvm::ValueAsMetadata::getConstant(
238+
llvm::ConstantInt::get(i32, argIdx + 1))};
239+
llvm::Metadata *llvmMetadata[] = {
240+
llvm::ValueAsMetadata::get(llvmFunc),
241+
llvm::MDString::get(llvmContext, "grid_constant"),
242+
llvm::MDNode::get(llvmContext, gridConstMetadata)};
243+
llvm::MDNode *llvmMetadataNode =
244+
llvm::MDNode::get(llvmContext, llvmMetadata);
245+
nvvmAnnotations->addOperand(llvmMetadataNode);
246+
} else {
247+
// Append argIdx + 1 to the 'grid_constant' argument list
248+
if (auto argList =
249+
dyn_cast<llvm::MDTuple>(gridConstantMetaData->getOperand(2))) {
250+
llvm::TempMDTuple clonedArgList = argList->clone();
251+
clonedArgList->push_back((llvm::ValueAsMetadata::getConstant(
252+
llvm::ConstantInt::get(i32, argIdx + 1))));
253+
gridConstantMetaData->replaceOperandWith(
254+
2, llvm::MDNode::replaceWithUniqued(std::move(clonedArgList)));
255+
}
256+
}
257+
}
258+
return success();
259+
}
204260
};
205261
} // namespace
206262

mlir/lib/Target/LLVMIR/ModuleTranslation.cpp

Lines changed: 32 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1298,28 +1298,29 @@ static void convertFunctionAttributes(LLVMFuncOp func,
12981298
llvmFunc->setMemoryEffects(newMemEffects);
12991299
}
13001300

1301-
llvm::AttrBuilder
1302-
ModuleTranslation::convertParameterAttrs(DictionaryAttr paramAttrs) {
1301+
FailureOr<llvm::AttrBuilder>
1302+
ModuleTranslation::convertParameterAttrs(LLVMFuncOp func, int argIdx,
1303+
DictionaryAttr paramAttrs) {
13031304
llvm::AttrBuilder attrBuilder(llvmModule->getContext());
1304-
1305-
for (auto [llvmKind, mlirName] : getAttrKindToNameMapping()) {
1306-
Attribute attr = paramAttrs.get(mlirName);
1307-
// Skip attributes that are not present.
1308-
if (!attr)
1309-
continue;
1310-
1311-
// NOTE: C++17 does not support capturing structured bindings.
1312-
llvm::Attribute::AttrKind llvmKindCap = llvmKind;
1313-
1314-
llvm::TypeSwitch<Attribute>(attr)
1315-
.Case<TypeAttr>([&](auto typeAttr) {
1316-
attrBuilder.addTypeAttr(llvmKindCap,
1317-
convertType(typeAttr.getValue()));
1318-
})
1319-
.Case<IntegerAttr>([&](auto intAttr) {
1320-
attrBuilder.addRawIntAttr(llvmKindCap, intAttr.getInt());
1321-
})
1322-
.Case<UnitAttr>([&](auto) { attrBuilder.addAttribute(llvmKindCap); });
1305+
auto attrNameToKindMapping = getAttrNameToKindMapping();
1306+
1307+
for (auto namedAttr : paramAttrs) {
1308+
auto it = attrNameToKindMapping.find(namedAttr.getName());
1309+
if (it != attrNameToKindMapping.end()) {
1310+
llvm::Attribute::AttrKind llvmKind = it->second;
1311+
1312+
llvm::TypeSwitch<Attribute>(namedAttr.getValue())
1313+
.Case<TypeAttr>([&](auto typeAttr) {
1314+
attrBuilder.addTypeAttr(llvmKind, convertType(typeAttr.getValue()));
1315+
})
1316+
.Case<IntegerAttr>([&](auto intAttr) {
1317+
attrBuilder.addRawIntAttr(llvmKind, intAttr.getInt());
1318+
})
1319+
.Case<UnitAttr>([&](auto) { attrBuilder.addAttribute(llvmKind); });
1320+
} else if (namedAttr.getNameDialect()) {
1321+
if (failed(iface.convertParameterAttr(func, argIdx, namedAttr, *this)))
1322+
return failure();
1323+
}
13231324
}
13241325

13251326
return attrBuilder;
@@ -1348,14 +1349,21 @@ LogicalResult ModuleTranslation::convertFunctionSignatures() {
13481349
// Convert result attributes.
13491350
if (ArrayAttr allResultAttrs = function.getAllResultAttrs()) {
13501351
DictionaryAttr resultAttrs = cast<DictionaryAttr>(allResultAttrs[0]);
1351-
llvmFunc->addRetAttrs(convertParameterAttrs(resultAttrs));
1352+
FailureOr<llvm::AttrBuilder> attrBuilder =
1353+
convertParameterAttrs(function, -1, resultAttrs);
1354+
if (failed(attrBuilder))
1355+
return failure();
1356+
llvmFunc->addRetAttrs(*attrBuilder);
13521357
}
13531358

13541359
// Convert argument attributes.
13551360
for (auto [argIdx, llvmArg] : llvm::enumerate(llvmFunc->args())) {
13561361
if (DictionaryAttr argAttrs = function.getArgAttrDict(argIdx)) {
1357-
llvm::AttrBuilder attrBuilder = convertParameterAttrs(argAttrs);
1358-
llvmArg.addAttrs(attrBuilder);
1362+
FailureOr<llvm::AttrBuilder> attrBuilder =
1363+
convertParameterAttrs(function, argIdx, argAttrs);
1364+
if (failed(attrBuilder))
1365+
return failure();
1366+
llvmArg.addAttrs(*attrBuilder);
13591367
}
13601368
}
13611369

mlir/test/Dialect/LLVMIR/nvvm.mlir

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,3 +472,29 @@ gpu.module @module_1 [#nvvm.target<chip = "sm_90", features = "+ptx70", link = [
472472

473473
gpu.module @module_2 [#nvvm.target<chip = "sm_90">, #nvvm.target<chip = "sm_80">, #nvvm.target<chip = "sm_70">] {
474474
}
475+
476+
// CHECK-LABEL : nvvm.grid_constant
477+
llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant}) attributes {nvvm.kernel} {
478+
llvm.return
479+
}
480+
481+
// -----
482+
483+
// expected-error @below {{'"nvvm.grid_constant"' attribute must be present only on kernel arguments}}
484+
llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant}) {
485+
llvm.return
486+
}
487+
488+
// -----
489+
490+
// expected-error @below {{'"nvvm.grid_constant"' attribute requires the argument to also have attribute 'llvm.byval'}}
491+
llvm.func @kernel_func(%arg0: !llvm.ptr {nvvm.grid_constant}) attributes {nvvm.kernel} {
492+
llvm.return
493+
}
494+
495+
// -----
496+
497+
// expected-error @below {{'"nvvm.grid_constant"' must be a unit attribute}}
498+
llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant = true}) attributes {nvvm.kernel} {
499+
llvm.return
500+
}

mlir/test/Target/LLVMIR/nvvmir.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -538,3 +538,20 @@ llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.maxntid = array<i32: 3, 4
538538
llvm.return
539539
}
540540

541+
// -----
542+
// CHECK: !nvvm.annotations =
543+
// CHECK: !1 = !{ptr @kernel_func, !"grid_constant", !2}
544+
// CHECK: !2 = !{i32 1}
545+
// CHECK: !3 = !{ptr @kernel_func, !"kernel", i32 1}
546+
llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant}) attributes {nvvm.kernel} {
547+
llvm.return
548+
}
549+
550+
// -----
551+
// CHECK: !nvvm.annotations =
552+
// CHECK: !1 = !{ptr @kernel_func, !"grid_constant", !2}
553+
// CHECK: !2 = !{i32 1, i32 3}
554+
// CHECK: !3 = !{ptr @kernel_func, !"kernel", i32 1}
555+
llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant}, %arg1: f32, %arg2: !llvm.ptr {llvm.byval = f32, nvvm.grid_constant}) attributes {nvvm.kernel} {
556+
llvm.return
557+
}

0 commit comments

Comments
 (0)