-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][nvvm]Add support for grid_constant attribute on LLVM function arguments #78228
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write If you have received no comments on your PR for a week, you can request a review If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
@llvm/pr-subscribers-llvm-ir @llvm/pr-subscribers-mlir Author: Rishi Surendran (rishisurendran) ChangesAdd support for attribute nvvm.grid_constant on LLVM function arguments. The attribute can be attached only to arguments of type llvm.ptr that have llvm.byval attribute. This patch also adds convertParameterAttr to LLVMTranslationDialectInterface for supporting the translation of derived dialect attributes on function parameters Full diff: https://github.com/llvm/llvm-project/pull/78228.diff 9 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 7140e614412f986..1fc5ee2c32bd492 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -59,6 +59,19 @@ def NVVM_Dialect : Dialect {
/// Get the name of the attribute used to annotate max number of
/// registers that can be allocated per thread.
static StringRef getMaxnregAttrName() { return "nvvm.maxnreg"; }
+
+ /// Get the name of the attribute used to annotate kernel arguments that
+ /// are grid constants.
+ static StringRef getGridConstantAttrName() { return "nvvm.grid_constant"; }
+
+ /// Verify an attribute from this dialect on the argument at 'argIndex' for
+ /// the region at 'regionIndex' on the given operation. Returns failure if
+ /// the verification failed, success otherwise. This hook may optionally be
+ /// invoked from any operation containing a region.
+ LogicalResult verifyRegionArgAttribute(Operation *,
+ unsigned regionIndex,
+ unsigned argIndex,
+ NamedAttribute) override;
}];
let useDefaultAttributePrinterParser = 1;
diff --git a/mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h b/mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h
index 19991a6f89d80fa..55358ebc6e86efc 100644
--- a/mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h
+++ b/mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h
@@ -13,6 +13,7 @@
#ifndef MLIR_TARGET_LLVMIR_LLVMTRANSLATIONINTERFACE_H
#define MLIR_TARGET_LLVMIR_LLVMTRANSLATIONINTERFACE_H
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/DialectInterface.h"
#include "mlir/Support/LogicalResult.h"
@@ -25,6 +26,7 @@ class IRBuilderBase;
namespace mlir {
namespace LLVM {
class ModuleTranslation;
+class LLVMFuncOp;
} // namespace LLVM
/// Base class for dialect interfaces providing translation to LLVM IR.
@@ -58,6 +60,16 @@ class LLVMTranslationDialectInterface
LLVM::ModuleTranslation &moduleTranslation) const {
return success();
}
+
+ /// Hook for derived dialect interface to translate or act on a derived
+ /// dialect attribute that appears on a function parameter. This gets called
+ /// after the function operation has been translated.
+ virtual LogicalResult
+ convertParameterAttr(LLVM::LLVMFuncOp function, int argIdx,
+ NamedAttribute attr,
+ LLVM::ModuleTranslation &moduleTranslation) const {
+ return success();
+ }
};
/// Interface collection for translation to LLVM IR, dispatches to a concrete
@@ -90,6 +102,20 @@ class LLVMTranslationInterface
}
return success();
}
+
+ /// Acts on the given function operation using the interface implemented by
+ /// the dialect of one of the function parameter attributes.
+ virtual LogicalResult
+ convertParameterAttr(LLVM::LLVMFuncOp function, int argIdx,
+ NamedAttribute attribute,
+ LLVM::ModuleTranslation &moduleTranslation) const {
+ if (const LLVMTranslationDialectInterface *iface =
+ getInterfaceFor(attribute.getNameDialect())) {
+ return iface->convertParameterAttr(function, argIdx, attribute,
+ moduleTranslation);
+ }
+ return success();
+ }
};
} // namespace mlir
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index d6b03aca28d24d5..f0012bf875511ee 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -326,8 +326,8 @@ class ModuleTranslation {
convertDialectAttributes(Operation *op,
ArrayRef<llvm::Instruction *> instructions);
- /// Translates parameter attributes and adds them to the returned AttrBuilder.
- llvm::AttrBuilder convertParameterAttrs(DictionaryAttr paramAttrs);
+ FailureOr<llvm::AttrBuilder>
+ convertParameterAttrs(LLVMFuncOp func, int argIdx, DictionaryAttr paramAttrs);
/// Original and translated module.
Operation *mlirModule;
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index aa49c4dc31fbc02..dc7816318131e41 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1077,6 +1077,34 @@ LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op,
return success();
}
+LogicalResult NVVMDialect::verifyRegionArgAttribute(Operation *op,
+ unsigned regionIndex,
+ unsigned argIndex,
+ NamedAttribute argAttr) {
+ auto funcOp = dyn_cast<FunctionOpInterface>(op);
+ if (!funcOp)
+ return success();
+
+ bool isKernel = op->hasAttr(NVVMDialect::getKernelFuncAttrName());
+ auto attrName = argAttr.getName();
+ if (attrName == NVVM::NVVMDialect::getGridConstantAttrName()) {
+ if (!isKernel)
+ return op->emitError()
+ << "'" << attrName
+ << "' attribute must be present only on kernel arguments.";
+ if (!llvm::isa<UnitAttr>(argAttr.getValue()))
+ return op->emitError()
+ << "'" << attrName << "' must be a unit attribute.";
+ if (!funcOp.getArgAttr(argIndex, LLVM::LLVMDialect::getByValAttrName()))
+ return op->emitError()
+ << "'" << attrName
+ << "' attribute requires the argument to also have attribute '"
+ << LLVM::LLVMDialect::getByValAttrName() << "'.";
+ }
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// NVVM target attribute.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/LLVMIR/AttrKindDetail.h b/mlir/lib/Target/LLVMIR/AttrKindDetail.h
index 7f81777886f56eb..55a364856bd6f99 100644
--- a/mlir/lib/Target/LLVMIR/AttrKindDetail.h
+++ b/mlir/lib/Target/LLVMIR/AttrKindDetail.h
@@ -59,6 +59,19 @@ getAttrKindToNameMapping() {
return kindNamePairs;
}
+static llvm::DenseMap<llvm::StringRef, llvm::Attribute::AttrKind>
+getAttrNameToKindMapping() {
+ static auto attrNameToKindMapping = []() {
+ static llvm::DenseMap<llvm::StringRef, llvm::Attribute::AttrKind>
+ nameKindMap;
+ for (auto kindNamePair : getAttrKindToNameMapping()) {
+ nameKindMap.insert({kindNamePair.second, kindNamePair.first});
+ }
+ return nameKindMap;
+ }();
+ return attrNameToKindMapping;
+}
+
} // namespace detail
} // namespace LLVM
} // namespace mlir
diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
index 45eb8402a7344f4..5e1712527d70151 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
@@ -201,6 +201,63 @@ class NVVMDialectLLVMIRTranslationInterface
}
return success();
}
+
+ LogicalResult
+ convertParameterAttr(LLVMFuncOp funcOp, int argIdx, NamedAttribute attribute,
+ LLVM::ModuleTranslation &moduleTranslation) const final {
+
+ llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
+ llvm::Function *llvmFunc =
+ moduleTranslation.lookupFunction(funcOp.getName());
+ auto nvvmAnnotations =
+ moduleTranslation.getOrInsertNamedModuleMetadata("nvvm.annotations");
+
+ if (attribute.getName() == NVVM::NVVMDialect::getGridConstantAttrName()) {
+ llvm::MDNode *gridConstantMetaData = nullptr;
+
+ // Check if a 'grid_constant' metadata node exists for the given function
+ for (int i = nvvmAnnotations->getNumOperands() - 1; i >= 0; --i) {
+ auto opnd = nvvmAnnotations->getOperand(i);
+ if (opnd->getNumOperands() == 3 &&
+ opnd->getOperand(0) == llvm::ValueAsMetadata::get(llvmFunc) &&
+ opnd->getOperand(1) ==
+ llvm::MDString::get(llvmContext, "grid_constant")) {
+ gridConstantMetaData = opnd;
+ break;
+ }
+ }
+
+ // 'grid_constant' is a function-level meta data node with a list of
+ // integers, where each integer n denotes that the nth parameter has the
+ // grid_constant annotation (numbering from 1). This requires aggregating
+ // the indices of the individual parameters that have this attribute.
+ llvm::Type *i32 = llvm::IntegerType::get(llvmContext, 32);
+ if (gridConstantMetaData == nullptr) {
+ // Create a new 'grid_constant' metadata node
+ SmallVector<llvm::Metadata *> gridConstMetadata = {
+ llvm::ValueAsMetadata::getConstant(
+ llvm::ConstantInt::get(i32, argIdx + 1))};
+ llvm::Metadata *llvmMetadata[] = {
+ llvm::ValueAsMetadata::get(llvmFunc),
+ llvm::MDString::get(llvmContext, "grid_constant"),
+ llvm::MDNode::get(llvmContext, gridConstMetadata)};
+ llvm::MDNode *llvmMetadataNode =
+ llvm::MDNode::get(llvmContext, llvmMetadata);
+ nvvmAnnotations->addOperand(llvmMetadataNode);
+ } else {
+ // Append argIdx + 1 to the 'grid_constant' argument list
+ if (auto argList =
+ dyn_cast<llvm::MDTuple>(gridConstantMetaData->getOperand(2))) {
+ auto clonedArgList = argList->clone();
+ clonedArgList->push_back((llvm::ValueAsMetadata::getConstant(
+ llvm::ConstantInt::get(i32, argIdx + 1))));
+ gridConstantMetaData->replaceOperandWith(
+ 2, llvm::MDNode::replaceWithUniqued(std::move(clonedArgList)));
+ }
+ }
+ }
+ return success();
+ }
};
} // namespace
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 2763a0fdd62aba1..574dbfa177b9bb3 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -1174,28 +1174,29 @@ static void convertFunctionAttributes(LLVMFuncOp func,
llvmFunc->setMemoryEffects(newMemEffects);
}
-llvm::AttrBuilder
-ModuleTranslation::convertParameterAttrs(DictionaryAttr paramAttrs) {
+FailureOr<llvm::AttrBuilder>
+ModuleTranslation::convertParameterAttrs(LLVMFuncOp func, int argIdx,
+ DictionaryAttr paramAttrs) {
llvm::AttrBuilder attrBuilder(llvmModule->getContext());
-
- for (auto [llvmKind, mlirName] : getAttrKindToNameMapping()) {
- Attribute attr = paramAttrs.get(mlirName);
- // Skip attributes that are not present.
- if (!attr)
- continue;
-
- // NOTE: C++17 does not support capturing structured bindings.
- llvm::Attribute::AttrKind llvmKindCap = llvmKind;
-
- llvm::TypeSwitch<Attribute>(attr)
- .Case<TypeAttr>([&](auto typeAttr) {
- attrBuilder.addTypeAttr(llvmKindCap,
- convertType(typeAttr.getValue()));
- })
- .Case<IntegerAttr>([&](auto intAttr) {
- attrBuilder.addRawIntAttr(llvmKindCap, intAttr.getInt());
- })
- .Case<UnitAttr>([&](auto) { attrBuilder.addAttribute(llvmKindCap); });
+ auto attrNameToKindMapping = getAttrNameToKindMapping();
+
+ for (auto namedAttr : paramAttrs) {
+ auto it = attrNameToKindMapping.find(namedAttr.getName());
+ if (it != attrNameToKindMapping.end()) {
+ llvm::Attribute::AttrKind llvmKind = it->second;
+
+ llvm::TypeSwitch<Attribute>(namedAttr.getValue())
+ .Case<TypeAttr>([&](auto typeAttr) {
+ attrBuilder.addTypeAttr(llvmKind, convertType(typeAttr.getValue()));
+ })
+ .Case<IntegerAttr>([&](auto intAttr) {
+ attrBuilder.addRawIntAttr(llvmKind, intAttr.getInt());
+ })
+ .Case<UnitAttr>([&](auto) { attrBuilder.addAttribute(llvmKind); });
+ } else if (namedAttr.getNameDialect()) {
+ if (failed(iface.convertParameterAttr(func, argIdx, namedAttr, *this)))
+ return failure();
+ }
}
return attrBuilder;
@@ -1224,14 +1225,21 @@ LogicalResult ModuleTranslation::convertFunctionSignatures() {
// Convert result attributes.
if (ArrayAttr allResultAttrs = function.getAllResultAttrs()) {
DictionaryAttr resultAttrs = cast<DictionaryAttr>(allResultAttrs[0]);
- llvmFunc->addRetAttrs(convertParameterAttrs(resultAttrs));
+ FailureOr<llvm::AttrBuilder> attrBuilder =
+ convertParameterAttrs(function, -1, resultAttrs);
+ if (failed(attrBuilder))
+ return failure();
+ llvmFunc->addRetAttrs(*attrBuilder);
}
// Convert argument attributes.
for (auto [argIdx, llvmArg] : llvm::enumerate(llvmFunc->args())) {
if (DictionaryAttr argAttrs = function.getArgAttrDict(argIdx)) {
- llvm::AttrBuilder attrBuilder = convertParameterAttrs(argAttrs);
- llvmArg.addAttrs(attrBuilder);
+ FailureOr<llvm::AttrBuilder> attrBuilder =
+ convertParameterAttrs(function, argIdx, argAttrs);
+ if (failed(attrBuilder))
+ return failure();
+ llvmArg.addAttrs(*attrBuilder);
}
}
diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index ce483ddab22a0ee..0369f45ca6a0156 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -472,3 +472,29 @@ gpu.module @module_1 [#nvvm.target<chip = "sm_90", features = "+ptx70", link = [
gpu.module @module_2 [#nvvm.target<chip = "sm_90">, #nvvm.target<chip = "sm_80">, #nvvm.target<chip = "sm_70">] {
}
+
+// CHECK-LABEL : nvvm.grid_constant
+llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant}) attributes {nvvm.kernel} {
+ llvm.return
+}
+
+// -----
+
+// expected-error @below {{'"nvvm.grid_constant"' attribute must be present only on kernel arguments}}
+llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant}) {
+ llvm.return
+}
+
+// -----
+
+// expected-error @below {{'"nvvm.grid_constant"' attribute requires the argument to also have attribute 'llvm.byval'}}
+llvm.func @kernel_func(%arg0: !llvm.ptr {nvvm.grid_constant}) attributes {nvvm.kernel} {
+ llvm.return
+}
+
+// -----
+
+// expected-error @below {{'"nvvm.grid_constant"' must be a unit attribute}}
+llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant = true}) attributes {nvvm.kernel} {
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index 8c5e3524a848f68..6dc47d08fc5c812 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -518,3 +518,20 @@ llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.maxntid = array<i32: 3, 4
llvm.return
}
+// -----
+// CHECK: !nvvm.annotations =
+// CHECK: !1 = !{ptr @kernel_func, !"grid_constant", !2}
+// CHECK: !2 = !{i32 1}
+// CHECK: !3 = !{ptr @kernel_func, !"kernel", i32 1}
+llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant}) attributes {nvvm.kernel} {
+ llvm.return
+}
+
+// -----
+// CHECK: !nvvm.annotations =
+// CHECK: !1 = !{ptr @kernel_func, !"grid_constant", !2}
+// CHECK: !2 = !{i32 1, i32 3}
+// CHECK: !3 = !{ptr @kernel_func, !"kernel", i32 1}
+llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant}, %arg1: f32, %arg2: !llvm.ptr {llvm.byval = f32, nvvm.grid_constant}) attributes {nvvm.kernel} {
+ llvm.return
+}
|
@llvm/pr-subscribers-mlir-llvm Author: Rishi Surendran (rishisurendran) ChangesAdd support for attribute nvvm.grid_constant on LLVM function arguments. The attribute can be attached only to arguments of type llvm.ptr that have llvm.byval attribute. This patch also adds convertParameterAttr to LLVMTranslationDialectInterface for supporting the translation of derived dialect attributes on function parameters Full diff: https://github.com/llvm/llvm-project/pull/78228.diff 9 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 7140e614412f98..1fc5ee2c32bd49 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -59,6 +59,19 @@ def NVVM_Dialect : Dialect {
/// Get the name of the attribute used to annotate max number of
/// registers that can be allocated per thread.
static StringRef getMaxnregAttrName() { return "nvvm.maxnreg"; }
+
+ /// Get the name of the attribute used to annotate kernel arguments that
+ /// are grid constants.
+ static StringRef getGridConstantAttrName() { return "nvvm.grid_constant"; }
+
+ /// Verify an attribute from this dialect on the argument at 'argIndex' for
+ /// the region at 'regionIndex' on the given operation. Returns failure if
+ /// the verification failed, success otherwise. This hook may optionally be
+ /// invoked from any operation containing a region.
+ LogicalResult verifyRegionArgAttribute(Operation *,
+ unsigned regionIndex,
+ unsigned argIndex,
+ NamedAttribute) override;
}];
let useDefaultAttributePrinterParser = 1;
diff --git a/mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h b/mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h
index 19991a6f89d80f..55358ebc6e86ef 100644
--- a/mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h
+++ b/mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h
@@ -13,6 +13,7 @@
#ifndef MLIR_TARGET_LLVMIR_LLVMTRANSLATIONINTERFACE_H
#define MLIR_TARGET_LLVMIR_LLVMTRANSLATIONINTERFACE_H
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/DialectInterface.h"
#include "mlir/Support/LogicalResult.h"
@@ -25,6 +26,7 @@ class IRBuilderBase;
namespace mlir {
namespace LLVM {
class ModuleTranslation;
+class LLVMFuncOp;
} // namespace LLVM
/// Base class for dialect interfaces providing translation to LLVM IR.
@@ -58,6 +60,16 @@ class LLVMTranslationDialectInterface
LLVM::ModuleTranslation &moduleTranslation) const {
return success();
}
+
+ /// Hook for derived dialect interface to translate or act on a derived
+ /// dialect attribute that appears on a function parameter. This gets called
+ /// after the function operation has been translated.
+ virtual LogicalResult
+ convertParameterAttr(LLVM::LLVMFuncOp function, int argIdx,
+ NamedAttribute attr,
+ LLVM::ModuleTranslation &moduleTranslation) const {
+ return success();
+ }
};
/// Interface collection for translation to LLVM IR, dispatches to a concrete
@@ -90,6 +102,20 @@ class LLVMTranslationInterface
}
return success();
}
+
+ /// Acts on the given function operation using the interface implemented by
+ /// the dialect of one of the function parameter attributes.
+ virtual LogicalResult
+ convertParameterAttr(LLVM::LLVMFuncOp function, int argIdx,
+ NamedAttribute attribute,
+ LLVM::ModuleTranslation &moduleTranslation) const {
+ if (const LLVMTranslationDialectInterface *iface =
+ getInterfaceFor(attribute.getNameDialect())) {
+ return iface->convertParameterAttr(function, argIdx, attribute,
+ moduleTranslation);
+ }
+ return success();
+ }
};
} // namespace mlir
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index d6b03aca28d24d..f0012bf875511e 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -326,8 +326,8 @@ class ModuleTranslation {
convertDialectAttributes(Operation *op,
ArrayRef<llvm::Instruction *> instructions);
- /// Translates parameter attributes and adds them to the returned AttrBuilder.
- llvm::AttrBuilder convertParameterAttrs(DictionaryAttr paramAttrs);
+ FailureOr<llvm::AttrBuilder>
+ convertParameterAttrs(LLVMFuncOp func, int argIdx, DictionaryAttr paramAttrs);
/// Original and translated module.
Operation *mlirModule;
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index aa49c4dc31fbc0..dc7816318131e4 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1077,6 +1077,34 @@ LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op,
return success();
}
+LogicalResult NVVMDialect::verifyRegionArgAttribute(Operation *op,
+ unsigned regionIndex,
+ unsigned argIndex,
+ NamedAttribute argAttr) {
+ auto funcOp = dyn_cast<FunctionOpInterface>(op);
+ if (!funcOp)
+ return success();
+
+ bool isKernel = op->hasAttr(NVVMDialect::getKernelFuncAttrName());
+ auto attrName = argAttr.getName();
+ if (attrName == NVVM::NVVMDialect::getGridConstantAttrName()) {
+ if (!isKernel)
+ return op->emitError()
+ << "'" << attrName
+ << "' attribute must be present only on kernel arguments.";
+ if (!llvm::isa<UnitAttr>(argAttr.getValue()))
+ return op->emitError()
+ << "'" << attrName << "' must be a unit attribute.";
+ if (!funcOp.getArgAttr(argIndex, LLVM::LLVMDialect::getByValAttrName()))
+ return op->emitError()
+ << "'" << attrName
+ << "' attribute requires the argument to also have attribute '"
+ << LLVM::LLVMDialect::getByValAttrName() << "'.";
+ }
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// NVVM target attribute.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/LLVMIR/AttrKindDetail.h b/mlir/lib/Target/LLVMIR/AttrKindDetail.h
index 7f81777886f56e..55a364856bd6f9 100644
--- a/mlir/lib/Target/LLVMIR/AttrKindDetail.h
+++ b/mlir/lib/Target/LLVMIR/AttrKindDetail.h
@@ -59,6 +59,19 @@ getAttrKindToNameMapping() {
return kindNamePairs;
}
+static llvm::DenseMap<llvm::StringRef, llvm::Attribute::AttrKind>
+getAttrNameToKindMapping() {
+ static auto attrNameToKindMapping = []() {
+ static llvm::DenseMap<llvm::StringRef, llvm::Attribute::AttrKind>
+ nameKindMap;
+ for (auto kindNamePair : getAttrKindToNameMapping()) {
+ nameKindMap.insert({kindNamePair.second, kindNamePair.first});
+ }
+ return nameKindMap;
+ }();
+ return attrNameToKindMapping;
+}
+
} // namespace detail
} // namespace LLVM
} // namespace mlir
diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
index 45eb8402a7344f..5e1712527d7015 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
@@ -201,6 +201,63 @@ class NVVMDialectLLVMIRTranslationInterface
}
return success();
}
+
+ LogicalResult
+ convertParameterAttr(LLVMFuncOp funcOp, int argIdx, NamedAttribute attribute,
+ LLVM::ModuleTranslation &moduleTranslation) const final {
+
+ llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
+ llvm::Function *llvmFunc =
+ moduleTranslation.lookupFunction(funcOp.getName());
+ auto nvvmAnnotations =
+ moduleTranslation.getOrInsertNamedModuleMetadata("nvvm.annotations");
+
+ if (attribute.getName() == NVVM::NVVMDialect::getGridConstantAttrName()) {
+ llvm::MDNode *gridConstantMetaData = nullptr;
+
+ // Check if a 'grid_constant' metadata node exists for the given function
+ for (int i = nvvmAnnotations->getNumOperands() - 1; i >= 0; --i) {
+ auto opnd = nvvmAnnotations->getOperand(i);
+ if (opnd->getNumOperands() == 3 &&
+ opnd->getOperand(0) == llvm::ValueAsMetadata::get(llvmFunc) &&
+ opnd->getOperand(1) ==
+ llvm::MDString::get(llvmContext, "grid_constant")) {
+ gridConstantMetaData = opnd;
+ break;
+ }
+ }
+
+ // 'grid_constant' is a function-level meta data node with a list of
+ // integers, where each integer n denotes that the nth parameter has the
+ // grid_constant annotation (numbering from 1). This requires aggregating
+ // the indices of the individual parameters that have this attribute.
+ llvm::Type *i32 = llvm::IntegerType::get(llvmContext, 32);
+ if (gridConstantMetaData == nullptr) {
+ // Create a new 'grid_constant' metadata node
+ SmallVector<llvm::Metadata *> gridConstMetadata = {
+ llvm::ValueAsMetadata::getConstant(
+ llvm::ConstantInt::get(i32, argIdx + 1))};
+ llvm::Metadata *llvmMetadata[] = {
+ llvm::ValueAsMetadata::get(llvmFunc),
+ llvm::MDString::get(llvmContext, "grid_constant"),
+ llvm::MDNode::get(llvmContext, gridConstMetadata)};
+ llvm::MDNode *llvmMetadataNode =
+ llvm::MDNode::get(llvmContext, llvmMetadata);
+ nvvmAnnotations->addOperand(llvmMetadataNode);
+ } else {
+ // Append argIdx + 1 to the 'grid_constant' argument list
+ if (auto argList =
+ dyn_cast<llvm::MDTuple>(gridConstantMetaData->getOperand(2))) {
+ auto clonedArgList = argList->clone();
+ clonedArgList->push_back((llvm::ValueAsMetadata::getConstant(
+ llvm::ConstantInt::get(i32, argIdx + 1))));
+ gridConstantMetaData->replaceOperandWith(
+ 2, llvm::MDNode::replaceWithUniqued(std::move(clonedArgList)));
+ }
+ }
+ }
+ return success();
+ }
};
} // namespace
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 2763a0fdd62aba..574dbfa177b9bb 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -1174,28 +1174,29 @@ static void convertFunctionAttributes(LLVMFuncOp func,
llvmFunc->setMemoryEffects(newMemEffects);
}
-llvm::AttrBuilder
-ModuleTranslation::convertParameterAttrs(DictionaryAttr paramAttrs) {
+FailureOr<llvm::AttrBuilder>
+ModuleTranslation::convertParameterAttrs(LLVMFuncOp func, int argIdx,
+ DictionaryAttr paramAttrs) {
llvm::AttrBuilder attrBuilder(llvmModule->getContext());
-
- for (auto [llvmKind, mlirName] : getAttrKindToNameMapping()) {
- Attribute attr = paramAttrs.get(mlirName);
- // Skip attributes that are not present.
- if (!attr)
- continue;
-
- // NOTE: C++17 does not support capturing structured bindings.
- llvm::Attribute::AttrKind llvmKindCap = llvmKind;
-
- llvm::TypeSwitch<Attribute>(attr)
- .Case<TypeAttr>([&](auto typeAttr) {
- attrBuilder.addTypeAttr(llvmKindCap,
- convertType(typeAttr.getValue()));
- })
- .Case<IntegerAttr>([&](auto intAttr) {
- attrBuilder.addRawIntAttr(llvmKindCap, intAttr.getInt());
- })
- .Case<UnitAttr>([&](auto) { attrBuilder.addAttribute(llvmKindCap); });
+ auto attrNameToKindMapping = getAttrNameToKindMapping();
+
+ for (auto namedAttr : paramAttrs) {
+ auto it = attrNameToKindMapping.find(namedAttr.getName());
+ if (it != attrNameToKindMapping.end()) {
+ llvm::Attribute::AttrKind llvmKind = it->second;
+
+ llvm::TypeSwitch<Attribute>(namedAttr.getValue())
+ .Case<TypeAttr>([&](auto typeAttr) {
+ attrBuilder.addTypeAttr(llvmKind, convertType(typeAttr.getValue()));
+ })
+ .Case<IntegerAttr>([&](auto intAttr) {
+ attrBuilder.addRawIntAttr(llvmKind, intAttr.getInt());
+ })
+ .Case<UnitAttr>([&](auto) { attrBuilder.addAttribute(llvmKind); });
+ } else if (namedAttr.getNameDialect()) {
+ if (failed(iface.convertParameterAttr(func, argIdx, namedAttr, *this)))
+ return failure();
+ }
}
return attrBuilder;
@@ -1224,14 +1225,21 @@ LogicalResult ModuleTranslation::convertFunctionSignatures() {
// Convert result attributes.
if (ArrayAttr allResultAttrs = function.getAllResultAttrs()) {
DictionaryAttr resultAttrs = cast<DictionaryAttr>(allResultAttrs[0]);
- llvmFunc->addRetAttrs(convertParameterAttrs(resultAttrs));
+ FailureOr<llvm::AttrBuilder> attrBuilder =
+ convertParameterAttrs(function, -1, resultAttrs);
+ if (failed(attrBuilder))
+ return failure();
+ llvmFunc->addRetAttrs(*attrBuilder);
}
// Convert argument attributes.
for (auto [argIdx, llvmArg] : llvm::enumerate(llvmFunc->args())) {
if (DictionaryAttr argAttrs = function.getArgAttrDict(argIdx)) {
- llvm::AttrBuilder attrBuilder = convertParameterAttrs(argAttrs);
- llvmArg.addAttrs(attrBuilder);
+ FailureOr<llvm::AttrBuilder> attrBuilder =
+ convertParameterAttrs(function, argIdx, argAttrs);
+ if (failed(attrBuilder))
+ return failure();
+ llvmArg.addAttrs(*attrBuilder);
}
}
diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index ce483ddab22a0e..0369f45ca6a015 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -472,3 +472,29 @@ gpu.module @module_1 [#nvvm.target<chip = "sm_90", features = "+ptx70", link = [
gpu.module @module_2 [#nvvm.target<chip = "sm_90">, #nvvm.target<chip = "sm_80">, #nvvm.target<chip = "sm_70">] {
}
+
+// CHECK-LABEL : nvvm.grid_constant
+llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant}) attributes {nvvm.kernel} {
+ llvm.return
+}
+
+// -----
+
+// expected-error @below {{'"nvvm.grid_constant"' attribute must be present only on kernel arguments}}
+llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant}) {
+ llvm.return
+}
+
+// -----
+
+// expected-error @below {{'"nvvm.grid_constant"' attribute requires the argument to also have attribute 'llvm.byval'}}
+llvm.func @kernel_func(%arg0: !llvm.ptr {nvvm.grid_constant}) attributes {nvvm.kernel} {
+ llvm.return
+}
+
+// -----
+
+// expected-error @below {{'"nvvm.grid_constant"' must be a unit attribute}}
+llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant = true}) attributes {nvvm.kernel} {
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index 8c5e3524a848f6..6dc47d08fc5c81 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -518,3 +518,20 @@ llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.maxntid = array<i32: 3, 4
llvm.return
}
+// -----
+// CHECK: !nvvm.annotations =
+// CHECK: !1 = !{ptr @kernel_func, !"grid_constant", !2}
+// CHECK: !2 = !{i32 1}
+// CHECK: !3 = !{ptr @kernel_func, !"kernel", i32 1}
+llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant}) attributes {nvvm.kernel} {
+ llvm.return
+}
+
+// -----
+// CHECK: !nvvm.annotations =
+// CHECK: !1 = !{ptr @kernel_func, !"grid_constant", !2}
+// CHECK: !2 = !{i32 1, i32 3}
+// CHECK: !3 = !{ptr @kernel_func, !"kernel", i32 1}
+llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant}, %arg1: f32, %arg2: !llvm.ptr {llvm.byval = f32, nvvm.grid_constant}) attributes {nvvm.kernel} {
+ llvm.return
+}
|
e947d27
to
28cdab8
Compare
return iface->convertParameterAttr(function, argIdx, attribute, | ||
moduleTranslation); | ||
} | ||
return success(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure the default here should be success. This would be dropping any attribute from a dialect without interface on the floor without warning.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will modify it to return failure. I followed what amendOperation
was doing for op attributes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Returning failure causes several test failures. There are dialect attributes like 'fir.bindc_name' which doesn't require any handling here. I updated it to emit a warning.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(unresolving, @ftynse should take another look here and acknowledge the solution explicitly)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, the analogy is fair enough.
mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
Outdated
Show resolved
Hide resolved
/// dialect attribute that appears on a function parameter. This gets called | ||
/// after the function operation has been translated. | ||
virtual LogicalResult | ||
convertParameterAttr(LLVM::LLVMFuncOp function, int argIdx, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Most of LLVM uses unsigned
for indexes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What are you referring to?
I strongly believe in using int
in absence of bit/mask operation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
$ git grep " int " | wc -l
1161
$ git grep " unsigned " | wc -l
3449
also most of the former is in tests or tools.
That being said, I'm not a proponent of using unsigned
(I'd rather use int32/64_t
throughout) , but I am a proponent of consistency.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How do you suggest to fix this: I see it as a bug and I am consistently using signed int
everywhere!
Should I sent patched updating every file for consistency or should we make sure new code use safer arithmetic patterns?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd like to have some clarity on what's the convention... Has anything came out after https://discourse.llvm.org/t/rfc-coding-standards-prefer-int-for-regular-arithmetic-use-unsigned-only-for-bitmask-and-when-you-intend-to-rely-on-wrapping-behavior/52191 ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It does not seems like we'll be able to make anything a LLVM-wide policy here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MLIR side we could. I'm rather pro not making it worse though (e.g., I'd consider consistency when it's towards direction we'd want to go rather than where we ended up in undesirable position and keeping going down that route)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the contribution. I am not clear how it is used grid_constant
annotation? Is there PR in LLVM?
|
||
/// Get the name of the attribute used to annotate kernel arguments that | ||
/// are grid constants. | ||
static StringRef getGridConstantAttrName() { return "nvvm.grid_constant"; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is an attribute for a kernel parameter, while other attributes are for the kernel itself. We might need to split them later
mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
Outdated
Show resolved
Hide resolved
There are no changes in LLVM. I have updated the description. The generated metadata will be handled by libNVVM. |
Thanks for the explanation. It is useful to have this is supported by libNVVM, and frankly I am okay with this PR. But it would be nice to agree the attribute with open-source llvm folks. |
This is something that LLVM will support eventually as well: it is part of the NVVM specification and it is needed to support the grid_constant CUDA C++ feature. |
We're currently focusing on adding support for MLIR to cover the NVVM IR specification, of course LLVM support should also be added. We haven’t settled on a roadmap for this on our side just yet though, maybe others will pick this up in the meantime. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks good to me even though open-source llvm does not have this metadata. But I would wait @ftynse to review once again
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for me.
return iface->convertParameterAttr(function, argIdx, attribute, | ||
moduleTranslation); | ||
} | ||
return success(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, the analogy is fair enough.
Register the LLVM IR translation interface for FIR to avoid warnings about "Unhandled parameter attribute" after llvm#78228.
Register the LLVM IR translation interface for FIR to avoid warnings about "Unhandled parameter attribute" after #78228.
This prevents crashes in test_core.py due to too many diagnostics emitted in llvm/llvm-project#78228 It should also speed up compile times, as we can use multithreading, and avoid handling diagnostic messages. PiperOrigin-RevId: 607293980
This prevents crashes in test_core.py due to too many diagnostics emitted in llvm/llvm-project#78228 It should also speed up compile times, as we can use multithreading, and avoid handling diagnostic messages. PiperOrigin-RevId: 607293980
This prevents crashes in test_core.py due to too many diagnostics emitted in llvm/llvm-project#78228 It should also speed up compile times, as we can use multithreading, and avoid handling diagnostic messages. PiperOrigin-RevId: 607293980
This prevents crashes in test_core.py due to too many diagnostics emitted in llvm/llvm-project#78228 It should also speed up compile times, as we can use multithreading, and avoid handling diagnostic messages. PiperOrigin-RevId: 607293980
This prevents crashes in test_core.py due to too many diagnostics emitted in llvm/llvm-project#78228 It should also speed up compile times, as we can use multithreading, and avoid handling diagnostic messages. PiperOrigin-RevId: 607293980
This prevents crashes in test_core.py due to too many diagnostics emitted in llvm/llvm-project#78228 It should also speed up compile times, as we can use multithreading, and avoid handling diagnostic messages. PiperOrigin-RevId: 607293980
This prevents crashes in test_core.py due to too many diagnostics emitted in llvm/llvm-project#78228 It should also speed up compile times, as we can use multithreading, and avoid handling diagnostic messages. PiperOrigin-RevId: 607293980
This prevents crashes in test_core.py due to too many diagnostics emitted in llvm/llvm-project#78228 It should also speed up compile times, as we can use multithreading, and avoid handling diagnostic messages. PiperOrigin-RevId: 607293980
This prevents crashes in test_core.py due to too many diagnostics emitted in llvm/llvm-project#78228 It should also speed up compile times, as we can use multithreading, and avoid handling diagnostic messages. PiperOrigin-RevId: 607348584
This prevents crashes in test_core.py due to too many diagnostics emitted in llvm/llvm-project#78228 It should also speed up compile times, as we can use multithreading, and avoid handling diagnostic messages. PiperOrigin-RevId: 607348584
This prevents crashes in test_core.py due to too many diagnostics emitted in llvm/llvm-project#78228 It should also speed up compile times, as we can use multithreading, and avoid handling diagnostic messages. PiperOrigin-RevId: 607348584
This prevents crashes in test_core.py due to too many diagnostics emitted in llvm/llvm-project#78228 It should also speed up compile times, as we can use multithreading, and avoid handling diagnostic messages.
This prevents crashes in test_core.py due to too many diagnostics emitted in llvm/llvm-project#78228 It should also speed up compile times, as we can use multithreading, and avoid handling diagnostic messages.
This prevents crashes in test_core.py due to too many diagnostics emitted in llvm/llvm-project#78228 It should also speed up compile times, as we can use multithreading, and avoid handling diagnostic messages.
This prevents crashes in test_core.py due to too many diagnostics emitted in llvm/llvm-project#78228 It should also speed up compile times, as we can use multithreading, and avoid handling diagnostic messages.
… is set (#3147) This prevents crashes in test_core.py due to too many diagnostics emitted in llvm/llvm-project#78228 It should also speed up compile times, as we can use multithreading, and avoid handling diagnostic messages.
… is set (#3147) This prevents crashes in test_core.py due to too many diagnostics emitted in llvm/llvm-project#78228 It should also speed up compile times, as we can use multithreading, and avoid handling diagnostic messages.
This prevents crashes in test_core.py due to too many diagnostics emitted in llvm/llvm-project#78228 It should also speed up compile times, as we can use multithreading, and avoid handling diagnostic messages. PiperOrigin-RevId: 607348584
… is set (triton-lang#3147) This prevents crashes in test_core.py due to too many diagnostics emitted in llvm/llvm-project#78228 It should also speed up compile times, as we can use multithreading, and avoid handling diagnostic messages.
Add support for attribute nvvm.grid_constant on LLVM function arguments. The attribute can be attached only to arguments of type llvm.ptr that have llvm.byval attribute.
Generate LLVM metadata for functions with nvvm.grid_constant arguments. The metadata node is a list of integers, where each integer n denotes that the nth parameter has the
grid_constant annotation (numbering from 1). The generated metadata node will be handled by NVVM compiler. See https://docs.nvidia.com/cuda/nvvm-ir-spec/index.html#supported-properties for documentation on grid_constant property.
This patch also adds convertParameterAttr to LLVMTranslationDialectInterface for supporting the translation of derived dialect attributes on function parameters