Skip to content

Commit 28cdab8

Browse files
Add support for adding 'nvvm.grid_constant' attribute to LLVM function arguments
1 parent 5b4f2b9 commit 28cdab8

File tree

9 files changed

+214
-25
lines changed

9 files changed

+214
-25
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,19 @@ def NVVM_Dialect : Dialect {
5959
/// Get the name of the attribute used to annotate max number of
6060
/// registers that can be allocated per thread.
6161
static StringRef getMaxnregAttrName() { return "nvvm.maxnreg"; }
62+
63+
/// Get the name of the attribute used to annotate kernel arguments that
64+
/// are grid constants.
65+
static StringRef getGridConstantAttrName() { return "nvvm.grid_constant"; }
66+
67+
/// Verify an attribute from this dialect on the argument at 'argIndex' for
68+
/// the region at 'regionIndex' on the given operation. Returns failure if
69+
/// the verification failed, success otherwise. This hook may optionally be
70+
/// invoked from any operation containing a region.
71+
LogicalResult verifyRegionArgAttribute(Operation *,
72+
unsigned regionIndex,
73+
unsigned argIndex,
74+
NamedAttribute) override;
6275
}];
6376

6477
let useDefaultAttributePrinterParser = 1;

mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#ifndef MLIR_TARGET_LLVMIR_LLVMTRANSLATIONINTERFACE_H
1414
#define MLIR_TARGET_LLVMIR_LLVMTRANSLATIONINTERFACE_H
1515

16+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1617
#include "mlir/IR/BuiltinAttributes.h"
1718
#include "mlir/IR/DialectInterface.h"
1819
#include "mlir/Support/LogicalResult.h"
@@ -25,6 +26,7 @@ class IRBuilderBase;
2526
namespace mlir {
2627
namespace LLVM {
2728
class ModuleTranslation;
29+
class LLVMFuncOp;
2830
} // namespace LLVM
2931

3032
/// Base class for dialect interfaces providing translation to LLVM IR.
@@ -58,6 +60,16 @@ class LLVMTranslationDialectInterface
5860
LLVM::ModuleTranslation &moduleTranslation) const {
5961
return success();
6062
}
63+
64+
/// Hook for derived dialect interface to translate or act on a derived
65+
/// dialect attribute that appears on a function parameter. This gets called
66+
/// after the function operation has been translated.
67+
virtual LogicalResult
68+
convertParameterAttr(LLVM::LLVMFuncOp function, int argIdx,
69+
NamedAttribute attr,
70+
LLVM::ModuleTranslation &moduleTranslation) const {
71+
return success();
72+
}
6173
};
6274

6375
/// Interface collection for translation to LLVM IR, dispatches to a concrete
@@ -90,6 +102,20 @@ class LLVMTranslationInterface
90102
}
91103
return success();
92104
}
105+
106+
/// Acts on the given function operation using the interface implemented by
107+
/// the dialect of one of the function parameter attributes.
108+
virtual LogicalResult
109+
convertParameterAttr(LLVM::LLVMFuncOp function, int argIdx,
110+
NamedAttribute attribute,
111+
LLVM::ModuleTranslation &moduleTranslation) const {
112+
if (const LLVMTranslationDialectInterface *iface =
113+
getInterfaceFor(attribute.getNameDialect())) {
114+
return iface->convertParameterAttr(function, argIdx, attribute,
115+
moduleTranslation);
116+
}
117+
return success();
118+
}
93119
};
94120

95121
} // namespace mlir

mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,8 @@ class ModuleTranslation {
327327
ArrayRef<llvm::Instruction *> instructions);
328328

329329
/// Translates parameter attributes and adds them to the returned AttrBuilder.
330-
llvm::AttrBuilder convertParameterAttrs(DictionaryAttr paramAttrs);
330+
FailureOr<llvm::AttrBuilder>
331+
convertParameterAttrs(LLVMFuncOp func, int argIdx, DictionaryAttr paramAttrs);
331332

332333
/// Original and translated module.
333334
Operation *mlirModule;

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1077,6 +1077,34 @@ LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op,
10771077
return success();
10781078
}
10791079

1080+
LogicalResult NVVMDialect::verifyRegionArgAttribute(Operation *op,
1081+
unsigned regionIndex,
1082+
unsigned argIndex,
1083+
NamedAttribute argAttr) {
1084+
auto funcOp = dyn_cast<FunctionOpInterface>(op);
1085+
if (!funcOp)
1086+
return success();
1087+
1088+
bool isKernel = op->hasAttr(NVVMDialect::getKernelFuncAttrName());
1089+
auto attrName = argAttr.getName();
1090+
if (attrName == NVVM::NVVMDialect::getGridConstantAttrName()) {
1091+
if (!isKernel)
1092+
return op->emitError()
1093+
<< "'" << attrName
1094+
<< "' attribute must be present only on kernel arguments.";
1095+
if (!llvm::isa<UnitAttr>(argAttr.getValue()))
1096+
return op->emitError()
1097+
<< "'" << attrName << "' must be a unit attribute.";
1098+
if (!funcOp.getArgAttr(argIndex, LLVM::LLVMDialect::getByValAttrName()))
1099+
return op->emitError()
1100+
<< "'" << attrName
1101+
<< "' attribute requires the argument to also have attribute '"
1102+
<< LLVM::LLVMDialect::getByValAttrName() << "'.";
1103+
}
1104+
1105+
return success();
1106+
}
1107+
10801108
//===----------------------------------------------------------------------===//
10811109
// NVVM target attribute.
10821110
//===----------------------------------------------------------------------===//

mlir/lib/Target/LLVMIR/AttrKindDetail.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,19 @@ getAttrKindToNameMapping() {
5959
return kindNamePairs;
6060
}
6161

62+
static llvm::DenseMap<llvm::StringRef, llvm::Attribute::AttrKind>
63+
getAttrNameToKindMapping() {
64+
static auto attrNameToKindMapping = []() {
65+
static llvm::DenseMap<llvm::StringRef, llvm::Attribute::AttrKind>
66+
nameKindMap;
67+
for (auto kindNamePair : getAttrKindToNameMapping()) {
68+
nameKindMap.insert({kindNamePair.second, kindNamePair.first});
69+
}
70+
return nameKindMap;
71+
}();
72+
return attrNameToKindMapping;
73+
}
74+
6275
} // namespace detail
6376
} // namespace LLVM
6477
} // namespace mlir

mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,63 @@ class NVVMDialectLLVMIRTranslationInterface
201201
}
202202
return success();
203203
}
204+
205+
LogicalResult
206+
convertParameterAttr(LLVMFuncOp funcOp, int argIdx, NamedAttribute attribute,
207+
LLVM::ModuleTranslation &moduleTranslation) const final {
208+
209+
llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
210+
llvm::Function *llvmFunc =
211+
moduleTranslation.lookupFunction(funcOp.getName());
212+
auto nvvmAnnotations =
213+
moduleTranslation.getOrInsertNamedModuleMetadata("nvvm.annotations");
214+
215+
if (attribute.getName() == NVVM::NVVMDialect::getGridConstantAttrName()) {
216+
llvm::MDNode *gridConstantMetaData = nullptr;
217+
218+
// Check if a 'grid_constant' metadata node exists for the given function
219+
for (int i = nvvmAnnotations->getNumOperands() - 1; i >= 0; --i) {
220+
auto opnd = nvvmAnnotations->getOperand(i);
221+
if (opnd->getNumOperands() == 3 &&
222+
opnd->getOperand(0) == llvm::ValueAsMetadata::get(llvmFunc) &&
223+
opnd->getOperand(1) ==
224+
llvm::MDString::get(llvmContext, "grid_constant")) {
225+
gridConstantMetaData = opnd;
226+
break;
227+
}
228+
}
229+
230+
// 'grid_constant' is a function-level meta data node with a list of
231+
// integers, where each integer n denotes that the nth parameter has the
232+
// grid_constant annotation (numbering from 1). This requires aggregating
233+
// the indices of the individual parameters that have this attribute.
234+
llvm::Type *i32 = llvm::IntegerType::get(llvmContext, 32);
235+
if (gridConstantMetaData == nullptr) {
236+
// Create a new 'grid_constant' metadata node
237+
SmallVector<llvm::Metadata *> gridConstMetadata = {
238+
llvm::ValueAsMetadata::getConstant(
239+
llvm::ConstantInt::get(i32, argIdx + 1))};
240+
llvm::Metadata *llvmMetadata[] = {
241+
llvm::ValueAsMetadata::get(llvmFunc),
242+
llvm::MDString::get(llvmContext, "grid_constant"),
243+
llvm::MDNode::get(llvmContext, gridConstMetadata)};
244+
llvm::MDNode *llvmMetadataNode =
245+
llvm::MDNode::get(llvmContext, llvmMetadata);
246+
nvvmAnnotations->addOperand(llvmMetadataNode);
247+
} else {
248+
// Append argIdx + 1 to the 'grid_constant' argument list
249+
if (auto argList =
250+
dyn_cast<llvm::MDTuple>(gridConstantMetaData->getOperand(2))) {
251+
auto clonedArgList = argList->clone();
252+
clonedArgList->push_back((llvm::ValueAsMetadata::getConstant(
253+
llvm::ConstantInt::get(i32, argIdx + 1))));
254+
gridConstantMetaData->replaceOperandWith(
255+
2, llvm::MDNode::replaceWithUniqued(std::move(clonedArgList)));
256+
}
257+
}
258+
}
259+
return success();
260+
}
204261
};
205262
} // namespace
206263

mlir/lib/Target/LLVMIR/ModuleTranslation.cpp

Lines changed: 32 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1174,28 +1174,29 @@ static void convertFunctionAttributes(LLVMFuncOp func,
11741174
llvmFunc->setMemoryEffects(newMemEffects);
11751175
}
11761176

1177-
llvm::AttrBuilder
1178-
ModuleTranslation::convertParameterAttrs(DictionaryAttr paramAttrs) {
1177+
FailureOr<llvm::AttrBuilder>
1178+
ModuleTranslation::convertParameterAttrs(LLVMFuncOp func, int argIdx,
1179+
DictionaryAttr paramAttrs) {
11791180
llvm::AttrBuilder attrBuilder(llvmModule->getContext());
1180-
1181-
for (auto [llvmKind, mlirName] : getAttrKindToNameMapping()) {
1182-
Attribute attr = paramAttrs.get(mlirName);
1183-
// Skip attributes that are not present.
1184-
if (!attr)
1185-
continue;
1186-
1187-
// NOTE: C++17 does not support capturing structured bindings.
1188-
llvm::Attribute::AttrKind llvmKindCap = llvmKind;
1189-
1190-
llvm::TypeSwitch<Attribute>(attr)
1191-
.Case<TypeAttr>([&](auto typeAttr) {
1192-
attrBuilder.addTypeAttr(llvmKindCap,
1193-
convertType(typeAttr.getValue()));
1194-
})
1195-
.Case<IntegerAttr>([&](auto intAttr) {
1196-
attrBuilder.addRawIntAttr(llvmKindCap, intAttr.getInt());
1197-
})
1198-
.Case<UnitAttr>([&](auto) { attrBuilder.addAttribute(llvmKindCap); });
1181+
auto attrNameToKindMapping = getAttrNameToKindMapping();
1182+
1183+
for (auto namedAttr : paramAttrs) {
1184+
auto it = attrNameToKindMapping.find(namedAttr.getName());
1185+
if (it != attrNameToKindMapping.end()) {
1186+
llvm::Attribute::AttrKind llvmKind = it->second;
1187+
1188+
llvm::TypeSwitch<Attribute>(namedAttr.getValue())
1189+
.Case<TypeAttr>([&](auto typeAttr) {
1190+
attrBuilder.addTypeAttr(llvmKind, convertType(typeAttr.getValue()));
1191+
})
1192+
.Case<IntegerAttr>([&](auto intAttr) {
1193+
attrBuilder.addRawIntAttr(llvmKind, intAttr.getInt());
1194+
})
1195+
.Case<UnitAttr>([&](auto) { attrBuilder.addAttribute(llvmKind); });
1196+
} else if (namedAttr.getNameDialect()) {
1197+
if (failed(iface.convertParameterAttr(func, argIdx, namedAttr, *this)))
1198+
return failure();
1199+
}
11991200
}
12001201

12011202
return attrBuilder;
@@ -1224,14 +1225,21 @@ LogicalResult ModuleTranslation::convertFunctionSignatures() {
12241225
// Convert result attributes.
12251226
if (ArrayAttr allResultAttrs = function.getAllResultAttrs()) {
12261227
DictionaryAttr resultAttrs = cast<DictionaryAttr>(allResultAttrs[0]);
1227-
llvmFunc->addRetAttrs(convertParameterAttrs(resultAttrs));
1228+
FailureOr<llvm::AttrBuilder> attrBuilder =
1229+
convertParameterAttrs(function, -1, resultAttrs);
1230+
if (failed(attrBuilder))
1231+
return failure();
1232+
llvmFunc->addRetAttrs(*attrBuilder);
12281233
}
12291234

12301235
// Convert argument attributes.
12311236
for (auto [argIdx, llvmArg] : llvm::enumerate(llvmFunc->args())) {
12321237
if (DictionaryAttr argAttrs = function.getArgAttrDict(argIdx)) {
1233-
llvm::AttrBuilder attrBuilder = convertParameterAttrs(argAttrs);
1234-
llvmArg.addAttrs(attrBuilder);
1238+
FailureOr<llvm::AttrBuilder> attrBuilder =
1239+
convertParameterAttrs(function, argIdx, argAttrs);
1240+
if (failed(attrBuilder))
1241+
return failure();
1242+
llvmArg.addAttrs(*attrBuilder);
12351243
}
12361244
}
12371245

mlir/test/Dialect/LLVMIR/nvvm.mlir

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,3 +472,29 @@ gpu.module @module_1 [#nvvm.target<chip = "sm_90", features = "+ptx70", link = [
472472

473473
gpu.module @module_2 [#nvvm.target<chip = "sm_90">, #nvvm.target<chip = "sm_80">, #nvvm.target<chip = "sm_70">] {
474474
}
475+
476+
// CHECK-LABEL : nvvm.grid_constant
477+
llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant}) attributes {nvvm.kernel} {
478+
llvm.return
479+
}
480+
481+
// -----
482+
483+
// expected-error @below {{'"nvvm.grid_constant"' attribute must be present only on kernel arguments}}
484+
llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant}) {
485+
llvm.return
486+
}
487+
488+
// -----
489+
490+
// expected-error @below {{'"nvvm.grid_constant"' attribute requires the argument to also have attribute 'llvm.byval'}}
491+
llvm.func @kernel_func(%arg0: !llvm.ptr {nvvm.grid_constant}) attributes {nvvm.kernel} {
492+
llvm.return
493+
}
494+
495+
// -----
496+
497+
// expected-error @below {{'"nvvm.grid_constant"' must be a unit attribute}}
498+
llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant = true}) attributes {nvvm.kernel} {
499+
llvm.return
500+
}

mlir/test/Target/LLVMIR/nvvmir.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -518,3 +518,20 @@ llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.maxntid = array<i32: 3, 4
518518
llvm.return
519519
}
520520

521+
// -----
522+
// CHECK: !nvvm.annotations =
523+
// CHECK: !1 = !{ptr @kernel_func, !"grid_constant", !2}
524+
// CHECK: !2 = !{i32 1}
525+
// CHECK: !3 = !{ptr @kernel_func, !"kernel", i32 1}
526+
llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant}) attributes {nvvm.kernel} {
527+
llvm.return
528+
}
529+
530+
// -----
531+
// CHECK: !nvvm.annotations =
532+
// CHECK: !1 = !{ptr @kernel_func, !"grid_constant", !2}
533+
// CHECK: !2 = !{i32 1, i32 3}
534+
// CHECK: !3 = !{ptr @kernel_func, !"kernel", i32 1}
535+
llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant}, %arg1: f32, %arg2: !llvm.ptr {llvm.byval = f32, nvvm.grid_constant}) attributes {nvvm.kernel} {
536+
llvm.return
537+
}

0 commit comments

Comments
 (0)