Skip to content

[mlir][nvvm]Add support for grid_constant attribute on LLVM function arguments #78228

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions llvm/include/llvm/IR/Metadata.h
Original file line number Diff line number Diff line change
Expand Up @@ -1701,7 +1701,7 @@ class NamedMDNode : public ilist_node<NamedMDNode> {

explicit NamedMDNode(const Twine &N);

template <class T1, class T2> class op_iterator_impl {
template <class T1> class op_iterator_impl {
friend class NamedMDNode;

const NamedMDNode *Node = nullptr;
Expand All @@ -1711,10 +1711,10 @@ class NamedMDNode : public ilist_node<NamedMDNode> {

public:
using iterator_category = std::bidirectional_iterator_tag;
using value_type = T2;
using value_type = T1;
using difference_type = std::ptrdiff_t;
using pointer = value_type *;
using reference = value_type &;
using reference = value_type;

op_iterator_impl() = default;

Expand Down Expand Up @@ -1775,12 +1775,12 @@ class NamedMDNode : public ilist_node<NamedMDNode> {
// ---------------------------------------------------------------------------
// Operand Iterator interface...
//
using op_iterator = op_iterator_impl<MDNode *, MDNode>;
using op_iterator = op_iterator_impl<MDNode *>;

op_iterator op_begin() { return op_iterator(this, 0); }
op_iterator op_end() { return op_iterator(this, getNumOperands()); }

using const_op_iterator = op_iterator_impl<const MDNode *, MDNode>;
using const_op_iterator = op_iterator_impl<const MDNode *>;

const_op_iterator op_begin() const { return const_op_iterator(this, 0); }
const_op_iterator op_end() const { return const_op_iterator(this, getNumOperands()); }
Expand Down
13 changes: 13 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,19 @@ def NVVM_Dialect : Dialect {
/// Get the name of the attribute used to annotate max number of
/// registers that can be allocated per thread.
static StringRef getMaxnregAttrName() { return "nvvm.maxnreg"; }

/// Get the name of the attribute used to annotate kernel arguments that
/// are grid constants.
static StringRef getGridConstantAttrName() { return "nvvm.grid_constant"; }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an attribute for a kernel parameter, while other attributes are for the kernel itself. We might need to split them later


/// Verify an attribute from this dialect on the argument at 'argIndex' for
/// the region at 'regionIndex' on the given operation. Returns failure if
/// the verification failed, success otherwise. This hook may optionally be
/// invoked from any operation containing a region.
LogicalResult verifyRegionArgAttribute(Operation *op,
unsigned regionIndex,
unsigned argIndex,
NamedAttribute argAttr) override;
}];

let useDefaultAttributePrinterParser = 1;
Expand Down
28 changes: 28 additions & 0 deletions mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#ifndef MLIR_TARGET_LLVMIR_LLVMTRANSLATIONINTERFACE_H
#define MLIR_TARGET_LLVMIR_LLVMTRANSLATIONINTERFACE_H

#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/DialectInterface.h"
#include "mlir/Support/LogicalResult.h"
Expand All @@ -25,6 +26,7 @@ class IRBuilderBase;
namespace mlir {
namespace LLVM {
class ModuleTranslation;
class LLVMFuncOp;
} // namespace LLVM

/// Base class for dialect interfaces providing translation to LLVM IR.
Expand Down Expand Up @@ -58,6 +60,16 @@ class LLVMTranslationDialectInterface
LLVM::ModuleTranslation &moduleTranslation) const {
return success();
}

/// Hook for derived dialect interface to translate or act on a derived
/// dialect attribute that appears on a function parameter. This gets called
/// after the function operation has been translated.
virtual LogicalResult
convertParameterAttr(LLVM::LLVMFuncOp function, int argIdx,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most of LLVM uses unsigned for indexes.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are you referring to?
I strongly believe in using int in absence of bit/mask operation.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

$ git grep " int " | wc -l
1161
$ git grep " unsigned " | wc -l
3449

also most of the former is in tests or tools.

That being said, I'm not a proponent of using unsigned (I'd rather use int32/64_t throughout) , but I am a proponent of consistency.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do you suggest to fix this: I see it as a bug and I am consistently using signed int everywhere!
Should I sent patched updating every file for consistency or should we make sure new code use safer arithmetic patterns?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does not seems like we'll be able to make anything a LLVM-wide policy here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MLIR side we could. I'm rather pro not making it worse though (e.g., I'd consider consistency when it's towards direction we'd want to go rather than where we ended up in undesirable position and keeping going down that route)

NamedAttribute attr,
LLVM::ModuleTranslation &moduleTranslation) const {
return success();
}
};

/// Interface collection for translation to LLVM IR, dispatches to a concrete
Expand Down Expand Up @@ -90,6 +102,22 @@ class LLVMTranslationInterface
}
return success();
}

/// Acts on the given function operation using the interface implemented by
/// the dialect of one of the function parameter attributes.
virtual LogicalResult
convertParameterAttr(LLVM::LLVMFuncOp function, int argIdx,
NamedAttribute attribute,
LLVM::ModuleTranslation &moduleTranslation) const {
if (const LLVMTranslationDialectInterface *iface =
getInterfaceFor(attribute.getNameDialect())) {
return iface->convertParameterAttr(function, argIdx, attribute,
moduleTranslation);
}
function.emitWarning("Unhandled parameter attribute '" +
attribute.getName().str() + "'");
return success();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure the default here should be success. This would be dropping any attribute from a dialect without interface on the floor without warning.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will modify it to return failure. I followed what amendOperation was doing for op attributes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Returning failure causes several test failures. There are dialect attributes like 'fir.bindc_name' which doesn't require any handling here. I updated it to emit a warning.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(unresolving, @ftynse should take another look here and acknowledge the solution explicitly)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, the analogy is fair enough.

}
};

} // namespace mlir
Expand Down
4 changes: 3 additions & 1 deletion mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,9 @@ class ModuleTranslation {
ArrayRef<llvm::Instruction *> instructions);

/// Translates parameter attributes and adds them to the returned AttrBuilder.
llvm::AttrBuilder convertParameterAttrs(DictionaryAttr paramAttrs);
/// Returns failure if any of the translations failed.
FailureOr<llvm::AttrBuilder>
convertParameterAttrs(LLVMFuncOp func, int argIdx, DictionaryAttr paramAttrs);

/// Original and translated module.
Operation *mlirModule;
Expand Down
29 changes: 29 additions & 0 deletions mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1077,6 +1077,35 @@ LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op,
return success();
}

LogicalResult NVVMDialect::verifyRegionArgAttribute(Operation *op,
unsigned regionIndex,
unsigned argIndex,
NamedAttribute argAttr) {
auto funcOp = dyn_cast<FunctionOpInterface>(op);
if (!funcOp)
return success();

bool isKernel = op->hasAttr(NVVMDialect::getKernelFuncAttrName());
StringAttr attrName = argAttr.getName();
if (attrName == NVVM::NVVMDialect::getGridConstantAttrName()) {
if (!isKernel) {
return op->emitError()
<< "'" << attrName
<< "' attribute must be present only on kernel arguments";
}
if (!isa<UnitAttr>(argAttr.getValue()))
return op->emitError() << "'" << attrName << "' must be a unit attribute";
if (!funcOp.getArgAttr(argIndex, LLVM::LLVMDialect::getByValAttrName())) {
return op->emitError()
<< "'" << attrName
<< "' attribute requires the argument to also have attribute '"
<< LLVM::LLVMDialect::getByValAttrName() << "'";
}
}

return success();
}

//===----------------------------------------------------------------------===//
// NVVM target attribute.
//===----------------------------------------------------------------------===//
Expand Down
14 changes: 14 additions & 0 deletions mlir/lib/Target/LLVMIR/AttrKindDetail.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,20 @@ getAttrKindToNameMapping() {
return kindNamePairs;
}

/// Returns a dense map from LLVM attribute name to their kind in LLVM IR
/// dialect.
static llvm::DenseMap<llvm::StringRef, llvm::Attribute::AttrKind>
getAttrNameToKindMapping() {
static auto attrNameToKindMapping = []() {
llvm::DenseMap<llvm::StringRef, llvm::Attribute::AttrKind> nameKindMap;
for (auto kindNamePair : getAttrKindToNameMapping()) {
nameKindMap.insert({kindNamePair.second, kindNamePair.first});
}
return nameKindMap;
}();
return attrNameToKindMapping;
}

} // namespace detail
} // namespace LLVM
} // namespace mlir
Expand Down
56 changes: 56 additions & 0 deletions mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,62 @@ class NVVMDialectLLVMIRTranslationInterface
}
return success();
}

LogicalResult
convertParameterAttr(LLVMFuncOp funcOp, int argIdx, NamedAttribute attribute,
LLVM::ModuleTranslation &moduleTranslation) const final {

llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
llvm::Function *llvmFunc =
moduleTranslation.lookupFunction(funcOp.getName());
llvm::NamedMDNode *nvvmAnnotations =
moduleTranslation.getOrInsertNamedModuleMetadata("nvvm.annotations");

if (attribute.getName() == NVVM::NVVMDialect::getGridConstantAttrName()) {
llvm::MDNode *gridConstantMetaData = nullptr;

// Check if a 'grid_constant' metadata node exists for the given function
for (llvm::MDNode *opnd : llvm::reverse(nvvmAnnotations->operands())) {
if (opnd->getNumOperands() == 3 &&
opnd->getOperand(0) == llvm::ValueAsMetadata::get(llvmFunc) &&
opnd->getOperand(1) ==
llvm::MDString::get(llvmContext, "grid_constant")) {
gridConstantMetaData = opnd;
break;
}
}

// 'grid_constant' is a function-level meta data node with a list of
// integers, where each integer n denotes that the nth parameter has the
// grid_constant annotation (numbering from 1). This requires aggregating
// the indices of the individual parameters that have this attribute.
llvm::Type *i32 = llvm::IntegerType::get(llvmContext, 32);
if (gridConstantMetaData == nullptr) {
// Create a new 'grid_constant' metadata node
SmallVector<llvm::Metadata *> gridConstMetadata = {
llvm::ValueAsMetadata::getConstant(
llvm::ConstantInt::get(i32, argIdx + 1))};
llvm::Metadata *llvmMetadata[] = {
llvm::ValueAsMetadata::get(llvmFunc),
llvm::MDString::get(llvmContext, "grid_constant"),
llvm::MDNode::get(llvmContext, gridConstMetadata)};
llvm::MDNode *llvmMetadataNode =
llvm::MDNode::get(llvmContext, llvmMetadata);
nvvmAnnotations->addOperand(llvmMetadataNode);
} else {
// Append argIdx + 1 to the 'grid_constant' argument list
if (auto argList =
dyn_cast<llvm::MDTuple>(gridConstantMetaData->getOperand(2))) {
llvm::TempMDTuple clonedArgList = argList->clone();
clonedArgList->push_back((llvm::ValueAsMetadata::getConstant(
llvm::ConstantInt::get(i32, argIdx + 1))));
gridConstantMetaData->replaceOperandWith(
2, llvm::MDNode::replaceWithUniqued(std::move(clonedArgList)));
}
}
}
return success();
}
};
} // namespace

Expand Down
56 changes: 32 additions & 24 deletions mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1174,28 +1174,29 @@ static void convertFunctionAttributes(LLVMFuncOp func,
llvmFunc->setMemoryEffects(newMemEffects);
}

llvm::AttrBuilder
ModuleTranslation::convertParameterAttrs(DictionaryAttr paramAttrs) {
FailureOr<llvm::AttrBuilder>
ModuleTranslation::convertParameterAttrs(LLVMFuncOp func, int argIdx,
DictionaryAttr paramAttrs) {
llvm::AttrBuilder attrBuilder(llvmModule->getContext());

for (auto [llvmKind, mlirName] : getAttrKindToNameMapping()) {
Attribute attr = paramAttrs.get(mlirName);
// Skip attributes that are not present.
if (!attr)
continue;

// NOTE: C++17 does not support capturing structured bindings.
llvm::Attribute::AttrKind llvmKindCap = llvmKind;

llvm::TypeSwitch<Attribute>(attr)
.Case<TypeAttr>([&](auto typeAttr) {
attrBuilder.addTypeAttr(llvmKindCap,
convertType(typeAttr.getValue()));
})
.Case<IntegerAttr>([&](auto intAttr) {
attrBuilder.addRawIntAttr(llvmKindCap, intAttr.getInt());
})
.Case<UnitAttr>([&](auto) { attrBuilder.addAttribute(llvmKindCap); });
auto attrNameToKindMapping = getAttrNameToKindMapping();

for (auto namedAttr : paramAttrs) {
auto it = attrNameToKindMapping.find(namedAttr.getName());
if (it != attrNameToKindMapping.end()) {
llvm::Attribute::AttrKind llvmKind = it->second;

llvm::TypeSwitch<Attribute>(namedAttr.getValue())
.Case<TypeAttr>([&](auto typeAttr) {
attrBuilder.addTypeAttr(llvmKind, convertType(typeAttr.getValue()));
})
.Case<IntegerAttr>([&](auto intAttr) {
attrBuilder.addRawIntAttr(llvmKind, intAttr.getInt());
})
.Case<UnitAttr>([&](auto) { attrBuilder.addAttribute(llvmKind); });
} else if (namedAttr.getNameDialect()) {
if (failed(iface.convertParameterAttr(func, argIdx, namedAttr, *this)))
return failure();
}
}

return attrBuilder;
Expand Down Expand Up @@ -1224,14 +1225,21 @@ LogicalResult ModuleTranslation::convertFunctionSignatures() {
// Convert result attributes.
if (ArrayAttr allResultAttrs = function.getAllResultAttrs()) {
DictionaryAttr resultAttrs = cast<DictionaryAttr>(allResultAttrs[0]);
llvmFunc->addRetAttrs(convertParameterAttrs(resultAttrs));
FailureOr<llvm::AttrBuilder> attrBuilder =
convertParameterAttrs(function, -1, resultAttrs);
if (failed(attrBuilder))
return failure();
llvmFunc->addRetAttrs(*attrBuilder);
}

// Convert argument attributes.
for (auto [argIdx, llvmArg] : llvm::enumerate(llvmFunc->args())) {
if (DictionaryAttr argAttrs = function.getArgAttrDict(argIdx)) {
llvm::AttrBuilder attrBuilder = convertParameterAttrs(argAttrs);
llvmArg.addAttrs(attrBuilder);
FailureOr<llvm::AttrBuilder> attrBuilder =
convertParameterAttrs(function, argIdx, argAttrs);
if (failed(attrBuilder))
return failure();
llvmArg.addAttrs(*attrBuilder);
}
}

Expand Down
26 changes: 26 additions & 0 deletions mlir/test/Dialect/LLVMIR/nvvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -472,3 +472,29 @@ gpu.module @module_1 [#nvvm.target<chip = "sm_90", features = "+ptx70", link = [

gpu.module @module_2 [#nvvm.target<chip = "sm_90">, #nvvm.target<chip = "sm_80">, #nvvm.target<chip = "sm_70">] {
}

// CHECK-LABEL : nvvm.grid_constant
llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant}) attributes {nvvm.kernel} {
llvm.return
}

// -----

// expected-error @below {{'"nvvm.grid_constant"' attribute must be present only on kernel arguments}}
llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant}) {
llvm.return
}

// -----

// expected-error @below {{'"nvvm.grid_constant"' attribute requires the argument to also have attribute 'llvm.byval'}}
llvm.func @kernel_func(%arg0: !llvm.ptr {nvvm.grid_constant}) attributes {nvvm.kernel} {
llvm.return
}

// -----

// expected-error @below {{'"nvvm.grid_constant"' must be a unit attribute}}
llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant = true}) attributes {nvvm.kernel} {
llvm.return
}
17 changes: 17 additions & 0 deletions mlir/test/Target/LLVMIR/nvvmir.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -518,3 +518,20 @@ llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.maxntid = array<i32: 3, 4
llvm.return
}

// -----
// CHECK: !nvvm.annotations =
// CHECK: !1 = !{ptr @kernel_func, !"grid_constant", !2}
// CHECK: !2 = !{i32 1}
// CHECK: !3 = !{ptr @kernel_func, !"kernel", i32 1}
llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant}) attributes {nvvm.kernel} {
llvm.return
}

// -----
// CHECK: !nvvm.annotations =
// CHECK: !1 = !{ptr @kernel_func, !"grid_constant", !2}
// CHECK: !2 = !{i32 1, i32 3}
// CHECK: !3 = !{ptr @kernel_func, !"kernel", i32 1}
llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant}, %arg1: f32, %arg2: !llvm.ptr {llvm.byval = f32, nvvm.grid_constant}) attributes {nvvm.kernel} {
llvm.return
}