-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR] Add ODS support for generating helpers for dialect (discardable) attributes #77024
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-gpu @llvm/pr-subscribers-mlir-llvm Author: Mehdi Amini (joker-eph) ChangesWIP (missing docs) See #75118 for context Full diff: https://github.com/llvm/llvm-project/pull/77024.diff 9 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 48b830ae34f292..6abcfdf60e0fd0 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -28,16 +28,6 @@ def ROCDL_Dialect : Dialect {
let hasOperationAttrVerify = 1;
let extraClassDeclaration = [{
- /// Get the name of the attribute used to annotate external kernel
- /// functions.
- static StringRef getKernelFuncAttrName() { return "rocdl.kernel"; }
- static constexpr ::llvm::StringLiteral getFlatWorkGroupSizeAttrName() {
- return ::llvm::StringLiteral("rocdl.flat_work_group_size");
- }
- static constexpr ::llvm::StringLiteral getReqdWorkGroupSizeAttrName() {
- return ::llvm::StringLiteral("rocdl.reqd_work_group_size");
- }
-
/// The address space value that represents global memory.
static constexpr unsigned kGlobalMemoryAddressSpace = 1;
/// The address space value that represents shared memory.
@@ -46,6 +36,14 @@ def ROCDL_Dialect : Dialect {
static constexpr unsigned kPrivateMemoryAddressSpace = 5;
}];
+ let discardableAttrs = (ins
+ "::mlir::UnitAttr":$kernel,
+ "::mlir::DenseI32ArrayAttr":$reqd_work_group_size,
+ "::mlir::StringAttr":$flat_work_group_size,
+ "::mlir::IntegerAttr":$max_flat_work_group_size,
+ "::mlir::IntegerAttr":$waves_per_eu
+ );
+
let useDefaultAttributePrinterParser = 1;
}
diff --git a/mlir/include/mlir/IR/DialectBase.td b/mlir/include/mlir/IR/DialectBase.td
index 5afa23933ea1f7..16750dc7e4d320 100644
--- a/mlir/include/mlir/IR/DialectBase.td
+++ b/mlir/include/mlir/IR/DialectBase.td
@@ -34,6 +34,8 @@ class Dialect {
// pattern or interfaces.
list<string> dependentDialects = [];
+ dag discardableAttrs = (ins);
+
// The C++ namespace that ops of this dialect should be placed into.
//
// By default, uses the name of the dialect as the only namespace. To avoid
diff --git a/mlir/include/mlir/TableGen/Dialect.h b/mlir/include/mlir/TableGen/Dialect.h
index 5337bd3beb5f9d..3530d240c976c6 100644
--- a/mlir/include/mlir/TableGen/Dialect.h
+++ b/mlir/include/mlir/TableGen/Dialect.h
@@ -14,6 +14,8 @@
#define MLIR_TABLEGEN_DIALECT_H_
#include "mlir/Support/LLVM.h"
+#include "llvm/TableGen/Record.h"
+
#include <string>
#include <vector>
@@ -90,6 +92,10 @@ class Dialect {
/// dialect.
bool usePropertiesForAttributes() const;
+ llvm::DagInit *getDiscardableAttributes() const;
+
+ const llvm::Record *getDef() const { return def; }
+
// Returns whether two dialects are equal by checking the equality of the
// underlying record.
bool operator==(const Dialect &other) const;
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index 599bb13190f12d..abd2733ba4fbd1 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -285,14 +285,17 @@ struct LowerGpuOpsToROCDLOpsPass
configureGpuToROCDLConversionLegality(target);
if (failed(applyPartialConversion(m, target, std::move(llvmPatterns))))
signalPassFailure();
-
+ auto *rocdlDialect = getContext().getLoadedDialect<ROCDL::ROCDLDialect>();
+ auto reqdWorkGroupSizeAttrHelper =
+ rocdlDialect->getReqdWorkGroupSizeAttrHelper();
+ auto flatWorkGroupSizeAttrHelper =
+ rocdlDialect->getFlatWorkGroupSizeAttrHelper();
// Manually rewrite known block size attributes so the LLVMIR translation
// infrastructure can pick them up.
- m.walk([ctx](LLVM::LLVMFuncOp op) {
+ m.walk([&](LLVM::LLVMFuncOp op) {
if (auto blockSizes = dyn_cast_or_null<DenseI32ArrayAttr>(
op->removeAttr(gpu::GPUFuncOp::getKnownBlockSizeAttrName()))) {
- op->setAttr(ROCDL::ROCDLDialect::getReqdWorkGroupSizeAttrName(),
- blockSizes);
+ reqdWorkGroupSizeAttrHelper.setAttr(op, blockSizes);
// Also set up the rocdl.flat_work_group_size attribute to prevent
// conflicting metadata.
uint32_t flatSize = 1;
@@ -301,8 +304,7 @@ struct LowerGpuOpsToROCDLOpsPass
}
StringAttr flatSizeAttr =
StringAttr::get(ctx, Twine(flatSize) + "," + Twine(flatSize));
- op->setAttr(ROCDL::ROCDLDialect::getFlatWorkGroupSizeAttrName(),
- flatSizeAttr);
+ flatWorkGroupSizeAttrHelper.setAttr(op, flatSizeAttr);
}
});
}
@@ -355,8 +357,7 @@ void mlir::populateGpuToROCDLConversionPatterns(
converter,
/*allocaAddrSpace=*/ROCDL::ROCDLDialect::kPrivateMemoryAddressSpace,
/*workgroupAddrSpace=*/ROCDL::ROCDLDialect::kSharedMemoryAddressSpace,
- StringAttr::get(&converter.getContext(),
- ROCDL::ROCDLDialect::getKernelFuncAttrName()));
+ ROCDL::ROCDLDialect::KernelAttrHelper(&converter.getContext()).getName());
if (Runtime::HIP == runtime) {
patterns.add<GPUPrintfOpToHIPLowering>(converter);
} else if (Runtime::OpenCL == runtime) {
diff --git a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
index 26e46b31ddc018..0f2e75cd7e8bc7 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
@@ -253,9 +253,9 @@ void ROCDLDialect::initialize() {
LogicalResult ROCDLDialect::verifyOperationAttribute(Operation *op,
NamedAttribute attr) {
// Kernel function attribute should be attached to functions.
- if (attr.getName() == ROCDLDialect::getKernelFuncAttrName()) {
+ if (kernelAttrName.getName() == attr.getName()) {
if (!isa<LLVM::LLVMFuncOp>(op)) {
- return op->emitError() << "'" << ROCDLDialect::getKernelFuncAttrName()
+ return op->emitError() << "'" << kernelAttrName.getName()
<< "' attribute attached to unexpected op";
}
}
diff --git a/mlir/lib/TableGen/Dialect.cpp b/mlir/lib/TableGen/Dialect.cpp
index 6924a2862eef07..081f6e56f9ded4 100644
--- a/mlir/lib/TableGen/Dialect.cpp
+++ b/mlir/lib/TableGen/Dialect.cpp
@@ -106,6 +106,10 @@ bool Dialect::usePropertiesForAttributes() const {
return def->getValueAsBit("usePropertiesForAttributes");
}
+llvm::DagInit *Dialect::getDiscardableAttributes() const {
+ return def->getValueAsDag("discardableAttrs");
+}
+
bool Dialect::operator==(const Dialect &other) const {
return def == other.def;
}
diff --git a/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp
index 55a6285ec87eb4..6783ffcde6d531 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp
@@ -84,7 +84,8 @@ class ROCDLDialectLLVMIRTranslationInterface
amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
NamedAttribute attribute,
LLVM::ModuleTranslation &moduleTranslation) const final {
- if (attribute.getName() == ROCDL::ROCDLDialect::getKernelFuncAttrName()) {
+ auto *dialect = dyn_cast<ROCDL::ROCDLDialect>(attribute.getNameDialect());
+ if (dialect->getKernelAttrHelper().getName() == attribute.getName()) {
auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
if (!func)
return failure();
@@ -106,7 +107,8 @@ class ROCDLDialectLLVMIRTranslationInterface
// Override flat-work-group-size
// TODO: update clients to rocdl.flat_work_group_size instead,
// then remove this half of the branch
- if ("rocdl.max_flat_work_group_size" == attribute.getName()) {
+ if (dialect->getMaxFlatWorkGroupSizeAttrHelper().getName() ==
+ attribute.getName()) {
auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
if (!func)
return failure();
@@ -121,7 +123,7 @@ class ROCDLDialectLLVMIRTranslationInterface
attrValueStream << "1," << value.getInt();
llvmFunc->addFnAttr("amdgpu-flat-work-group-size", llvmAttrValue);
}
- if (ROCDL::ROCDLDialect::getFlatWorkGroupSizeAttrName() ==
+ if (dialect->getFlatWorkGroupSizeAttrHelper().getName() ==
attribute.getName()) {
auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
if (!func)
@@ -138,7 +140,7 @@ class ROCDLDialectLLVMIRTranslationInterface
}
// Set reqd_work_group_size metadata
- if (ROCDL::ROCDLDialect::getReqdWorkGroupSizeAttrName() ==
+ if (dialect->getReqdWorkGroupSizeAttrHelper().getName() ==
attribute.getName()) {
auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
if (!func)
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.td b/mlir/test/lib/Dialect/Test/TestDialect.td
index 8524d5b1458447..2b5491fc0c6a02 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.td
+++ b/mlir/test/lib/Dialect/Test/TestDialect.td
@@ -25,6 +25,10 @@ def Test_Dialect : Dialect {
let useDefaultAttributePrinterParser = 1;
let isExtensible = 1;
let dependentDialects = ["::mlir::DLTIDialect"];
+ let discardableAttrs = (ins
+ "mlir::IntegerAttr":$discardable_attr_key,
+ "SimpleAAttr":$other_discardable_attr_key
+ );
let extraClassDeclaration = [{
void registerAttributes();
diff --git a/mlir/tools/mlir-tblgen/DialectGen.cpp b/mlir/tools/mlir-tblgen/DialectGen.cpp
index f22434f755abe3..a0ebd9ce8f29e3 100644
--- a/mlir/tools/mlir-tblgen/DialectGen.cpp
+++ b/mlir/tools/mlir-tblgen/DialectGen.cpp
@@ -43,6 +43,21 @@ using DialectFilterIterator =
std::function<bool(const llvm::Record *)>>;
} // namespace
+static void populateDiscardableAttributes(
+ Dialect &dialect, llvm::DagInit *discardableAttrDag,
+ SmallVector<std::pair<std::string, std::string>> &discardableAttributes) {
+ for (int i : llvm::seq<int>(0, discardableAttrDag->getNumArgs())) {
+ llvm::Init *arg = discardableAttrDag->getArg(i);
+
+ StringRef givenName = discardableAttrDag->getArgNameStr(i);
+ if (givenName.empty())
+ PrintFatalError(dialect.getDef()->getLoc(),
+ "discardable attributes must be named");
+ discardableAttributes.push_back(
+ {givenName.str(), arg->getAsUnquotedString()});
+ }
+}
+
/// Given a set of records for a T, filter the ones that correspond to
/// the given dialect.
template <typename T>
@@ -181,6 +196,37 @@ static const char *const operationInterfaceFallbackDecl = R"(
mlir::OperationName opName) override;
)";
+/// The code block for the discardable attribute helper.
+static const char *const discardableAttrHelperDecl = R"(
+ /// Helper to manage the discardable attribute `{1}`.
+ class {0}AttrHelper {{
+ mlir::StringAttr name;
+ public:
+ static constexpr llvm::StringLiteral getNameStr() {{
+ return "{4}.{1}";
+ }
+ constexpr mlir::StringAttr getName() {{
+ return name;
+ }
+
+ {0}AttrHelper(mlir::MLIRContext *ctx)
+ : name(mlir::StringAttr::get(ctx, getNameStr())) {{}
+
+ {2} getAttr(::mlir::Operation *op) {{
+ return op->getAttrOfType<{2}>(getName());
+ }
+ void setAttr(::mlir::Operation *op, {2} val) {{
+ op->setAttr(getName(), val);
+ }
+ };
+ {0}AttrHelper get{0}AttrHelper() {
+ return {3}AttrName;
+ }
+ private:
+ {0}AttrHelper {3}AttrName;
+ public:
+)";
+
/// Generate the declaration for the given dialect class.
static void emitDialectDecl(Dialect &dialect, raw_ostream &os) {
// Emit all nested namespaces.
@@ -216,6 +262,22 @@ static void emitDialectDecl(Dialect &dialect, raw_ostream &os) {
os << regionResultAttrVerifierDecl;
if (dialect.hasOperationInterfaceFallback())
os << operationInterfaceFallbackDecl;
+
+ llvm::DagInit *discardableAttrDag = dialect.getDiscardableAttributes();
+ SmallVector<std::pair<std::string, std::string>> discardableAttributes;
+ populateDiscardableAttributes(dialect, discardableAttrDag,
+ discardableAttributes);
+
+ for (const auto &attrPair : discardableAttributes) {
+ std::string camelNameUpper = llvm::convertToCamelFromSnakeCase(
+ attrPair.first, /*capitalizeFirst=*/true);
+ std::string camelName = llvm::convertToCamelFromSnakeCase(
+ attrPair.first, /*capitalizeFirst=*/false);
+ os << llvm::formatv(discardableAttrHelperDecl, camelNameUpper,
+ attrPair.first, attrPair.second, camelName,
+ dialect.getName());
+ }
+
if (std::optional<StringRef> extraDecl = dialect.getExtraClassDeclaration())
os << *extraDecl;
@@ -253,9 +315,12 @@ static bool emitDialectDecls(const llvm::RecordKeeper &recordKeeper,
/// {1}: initialization code that is emitted in the ctor body before calling
/// initialize().
/// {2}: The dialect parent class.
+/// {3}: Extra members to initialize
static const char *const dialectConstructorStr = R"(
{0}::{0}(::mlir::MLIRContext *context)
- : ::mlir::{2}(getDialectNamespace(), context, ::mlir::TypeID::get<{0}>()) {{
+ : ::mlir::{2}(getDialectNamespace(), context, ::mlir::TypeID::get<{0}>())
+ {3}
+ {{
{1}
initialize();
}
@@ -269,7 +334,9 @@ static const char *const dialectDestructorStr = R"(
)";
-static void emitDialectDef(Dialect &dialect, raw_ostream &os) {
+static void emitDialectDef(Dialect &dialect,
+ const llvm::RecordKeeper &recordKeeper,
+ raw_ostream &os) {
std::string cppClassName = dialect.getCppClassName();
// Emit the TypeID explicit specializations to have a single symbol def.
@@ -292,8 +359,22 @@ static void emitDialectDef(Dialect &dialect, raw_ostream &os) {
// Emit the constructor and destructor.
StringRef superClassName =
dialect.isExtensible() ? "ExtensibleDialect" : "Dialect";
+
+ llvm::DagInit *discardableAttrDag = dialect.getDiscardableAttributes();
+ SmallVector<std::pair<std::string, std::string>> discardableAttributes;
+ populateDiscardableAttributes(dialect, discardableAttrDag,
+ discardableAttributes);
+ std::string discardableAttributesInit;
+ for (const auto &attrPair : discardableAttributes) {
+ std::string camelName = llvm::convertToCamelFromSnakeCase(
+ attrPair.first, /*capitalizeFirst=*/false);
+ llvm::raw_string_ostream os(discardableAttributesInit);
+ os << ", " << camelName << "AttrName(context)";
+ }
+
os << llvm::formatv(dialectConstructorStr, cppClassName,
- dependentDialectRegistrations, superClassName);
+ dependentDialectRegistrations, superClassName,
+ discardableAttributesInit);
if (!dialect.hasNonDefaultDestructor())
os << llvm::formatv(dialectDestructorStr, cppClassName);
}
@@ -310,7 +391,7 @@ static bool emitDialectDefs(const llvm::RecordKeeper &recordKeeper,
std::optional<Dialect> dialect = findDialectToGenerate(dialects);
if (!dialect)
return true;
- emitDialectDef(*dialect, os);
+ emitDialectDef(*dialect, recordKeeper, os);
return false;
}
|
@llvm/pr-subscribers-mlir-core Author: Mehdi Amini (joker-eph) ChangesWIP (missing docs) See #75118 for context Full diff: https://github.com/llvm/llvm-project/pull/77024.diff 9 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 48b830ae34f292..6abcfdf60e0fd0 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -28,16 +28,6 @@ def ROCDL_Dialect : Dialect {
let hasOperationAttrVerify = 1;
let extraClassDeclaration = [{
- /// Get the name of the attribute used to annotate external kernel
- /// functions.
- static StringRef getKernelFuncAttrName() { return "rocdl.kernel"; }
- static constexpr ::llvm::StringLiteral getFlatWorkGroupSizeAttrName() {
- return ::llvm::StringLiteral("rocdl.flat_work_group_size");
- }
- static constexpr ::llvm::StringLiteral getReqdWorkGroupSizeAttrName() {
- return ::llvm::StringLiteral("rocdl.reqd_work_group_size");
- }
-
/// The address space value that represents global memory.
static constexpr unsigned kGlobalMemoryAddressSpace = 1;
/// The address space value that represents shared memory.
@@ -46,6 +36,14 @@ def ROCDL_Dialect : Dialect {
static constexpr unsigned kPrivateMemoryAddressSpace = 5;
}];
+ let discardableAttrs = (ins
+ "::mlir::UnitAttr":$kernel,
+ "::mlir::DenseI32ArrayAttr":$reqd_work_group_size,
+ "::mlir::StringAttr":$flat_work_group_size,
+ "::mlir::IntegerAttr":$max_flat_work_group_size,
+ "::mlir::IntegerAttr":$waves_per_eu
+ );
+
let useDefaultAttributePrinterParser = 1;
}
diff --git a/mlir/include/mlir/IR/DialectBase.td b/mlir/include/mlir/IR/DialectBase.td
index 5afa23933ea1f7..16750dc7e4d320 100644
--- a/mlir/include/mlir/IR/DialectBase.td
+++ b/mlir/include/mlir/IR/DialectBase.td
@@ -34,6 +34,8 @@ class Dialect {
// pattern or interfaces.
list<string> dependentDialects = [];
+ dag discardableAttrs = (ins);
+
// The C++ namespace that ops of this dialect should be placed into.
//
// By default, uses the name of the dialect as the only namespace. To avoid
diff --git a/mlir/include/mlir/TableGen/Dialect.h b/mlir/include/mlir/TableGen/Dialect.h
index 5337bd3beb5f9d..3530d240c976c6 100644
--- a/mlir/include/mlir/TableGen/Dialect.h
+++ b/mlir/include/mlir/TableGen/Dialect.h
@@ -14,6 +14,8 @@
#define MLIR_TABLEGEN_DIALECT_H_
#include "mlir/Support/LLVM.h"
+#include "llvm/TableGen/Record.h"
+
#include <string>
#include <vector>
@@ -90,6 +92,10 @@ class Dialect {
/// dialect.
bool usePropertiesForAttributes() const;
+ llvm::DagInit *getDiscardableAttributes() const;
+
+ const llvm::Record *getDef() const { return def; }
+
// Returns whether two dialects are equal by checking the equality of the
// underlying record.
bool operator==(const Dialect &other) const;
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index 599bb13190f12d..abd2733ba4fbd1 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -285,14 +285,17 @@ struct LowerGpuOpsToROCDLOpsPass
configureGpuToROCDLConversionLegality(target);
if (failed(applyPartialConversion(m, target, std::move(llvmPatterns))))
signalPassFailure();
-
+ auto *rocdlDialect = getContext().getLoadedDialect<ROCDL::ROCDLDialect>();
+ auto reqdWorkGroupSizeAttrHelper =
+ rocdlDialect->getReqdWorkGroupSizeAttrHelper();
+ auto flatWorkGroupSizeAttrHelper =
+ rocdlDialect->getFlatWorkGroupSizeAttrHelper();
// Manually rewrite known block size attributes so the LLVMIR translation
// infrastructure can pick them up.
- m.walk([ctx](LLVM::LLVMFuncOp op) {
+ m.walk([&](LLVM::LLVMFuncOp op) {
if (auto blockSizes = dyn_cast_or_null<DenseI32ArrayAttr>(
op->removeAttr(gpu::GPUFuncOp::getKnownBlockSizeAttrName()))) {
- op->setAttr(ROCDL::ROCDLDialect::getReqdWorkGroupSizeAttrName(),
- blockSizes);
+ reqdWorkGroupSizeAttrHelper.setAttr(op, blockSizes);
// Also set up the rocdl.flat_work_group_size attribute to prevent
// conflicting metadata.
uint32_t flatSize = 1;
@@ -301,8 +304,7 @@ struct LowerGpuOpsToROCDLOpsPass
}
StringAttr flatSizeAttr =
StringAttr::get(ctx, Twine(flatSize) + "," + Twine(flatSize));
- op->setAttr(ROCDL::ROCDLDialect::getFlatWorkGroupSizeAttrName(),
- flatSizeAttr);
+ flatWorkGroupSizeAttrHelper.setAttr(op, flatSizeAttr);
}
});
}
@@ -355,8 +357,7 @@ void mlir::populateGpuToROCDLConversionPatterns(
converter,
/*allocaAddrSpace=*/ROCDL::ROCDLDialect::kPrivateMemoryAddressSpace,
/*workgroupAddrSpace=*/ROCDL::ROCDLDialect::kSharedMemoryAddressSpace,
- StringAttr::get(&converter.getContext(),
- ROCDL::ROCDLDialect::getKernelFuncAttrName()));
+ ROCDL::ROCDLDialect::KernelAttrHelper(&converter.getContext()).getName());
if (Runtime::HIP == runtime) {
patterns.add<GPUPrintfOpToHIPLowering>(converter);
} else if (Runtime::OpenCL == runtime) {
diff --git a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
index 26e46b31ddc018..0f2e75cd7e8bc7 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
@@ -253,9 +253,9 @@ void ROCDLDialect::initialize() {
LogicalResult ROCDLDialect::verifyOperationAttribute(Operation *op,
NamedAttribute attr) {
// Kernel function attribute should be attached to functions.
- if (attr.getName() == ROCDLDialect::getKernelFuncAttrName()) {
+ if (kernelAttrName.getName() == attr.getName()) {
if (!isa<LLVM::LLVMFuncOp>(op)) {
- return op->emitError() << "'" << ROCDLDialect::getKernelFuncAttrName()
+ return op->emitError() << "'" << kernelAttrName.getName()
<< "' attribute attached to unexpected op";
}
}
diff --git a/mlir/lib/TableGen/Dialect.cpp b/mlir/lib/TableGen/Dialect.cpp
index 6924a2862eef07..081f6e56f9ded4 100644
--- a/mlir/lib/TableGen/Dialect.cpp
+++ b/mlir/lib/TableGen/Dialect.cpp
@@ -106,6 +106,10 @@ bool Dialect::usePropertiesForAttributes() const {
return def->getValueAsBit("usePropertiesForAttributes");
}
+llvm::DagInit *Dialect::getDiscardableAttributes() const {
+ return def->getValueAsDag("discardableAttrs");
+}
+
bool Dialect::operator==(const Dialect &other) const {
return def == other.def;
}
diff --git a/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp
index 55a6285ec87eb4..6783ffcde6d531 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp
@@ -84,7 +84,8 @@ class ROCDLDialectLLVMIRTranslationInterface
amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
NamedAttribute attribute,
LLVM::ModuleTranslation &moduleTranslation) const final {
- if (attribute.getName() == ROCDL::ROCDLDialect::getKernelFuncAttrName()) {
+ auto *dialect = dyn_cast<ROCDL::ROCDLDialect>(attribute.getNameDialect());
+ if (dialect->getKernelAttrHelper().getName() == attribute.getName()) {
auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
if (!func)
return failure();
@@ -106,7 +107,8 @@ class ROCDLDialectLLVMIRTranslationInterface
// Override flat-work-group-size
// TODO: update clients to rocdl.flat_work_group_size instead,
// then remove this half of the branch
- if ("rocdl.max_flat_work_group_size" == attribute.getName()) {
+ if (dialect->getMaxFlatWorkGroupSizeAttrHelper().getName() ==
+ attribute.getName()) {
auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
if (!func)
return failure();
@@ -121,7 +123,7 @@ class ROCDLDialectLLVMIRTranslationInterface
attrValueStream << "1," << value.getInt();
llvmFunc->addFnAttr("amdgpu-flat-work-group-size", llvmAttrValue);
}
- if (ROCDL::ROCDLDialect::getFlatWorkGroupSizeAttrName() ==
+ if (dialect->getFlatWorkGroupSizeAttrHelper().getName() ==
attribute.getName()) {
auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
if (!func)
@@ -138,7 +140,7 @@ class ROCDLDialectLLVMIRTranslationInterface
}
// Set reqd_work_group_size metadata
- if (ROCDL::ROCDLDialect::getReqdWorkGroupSizeAttrName() ==
+ if (dialect->getReqdWorkGroupSizeAttrHelper().getName() ==
attribute.getName()) {
auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
if (!func)
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.td b/mlir/test/lib/Dialect/Test/TestDialect.td
index 8524d5b1458447..2b5491fc0c6a02 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.td
+++ b/mlir/test/lib/Dialect/Test/TestDialect.td
@@ -25,6 +25,10 @@ def Test_Dialect : Dialect {
let useDefaultAttributePrinterParser = 1;
let isExtensible = 1;
let dependentDialects = ["::mlir::DLTIDialect"];
+ let discardableAttrs = (ins
+ "mlir::IntegerAttr":$discardable_attr_key,
+ "SimpleAAttr":$other_discardable_attr_key
+ );
let extraClassDeclaration = [{
void registerAttributes();
diff --git a/mlir/tools/mlir-tblgen/DialectGen.cpp b/mlir/tools/mlir-tblgen/DialectGen.cpp
index f22434f755abe3..a0ebd9ce8f29e3 100644
--- a/mlir/tools/mlir-tblgen/DialectGen.cpp
+++ b/mlir/tools/mlir-tblgen/DialectGen.cpp
@@ -43,6 +43,21 @@ using DialectFilterIterator =
std::function<bool(const llvm::Record *)>>;
} // namespace
+static void populateDiscardableAttributes(
+ Dialect &dialect, llvm::DagInit *discardableAttrDag,
+ SmallVector<std::pair<std::string, std::string>> &discardableAttributes) {
+ for (int i : llvm::seq<int>(0, discardableAttrDag->getNumArgs())) {
+ llvm::Init *arg = discardableAttrDag->getArg(i);
+
+ StringRef givenName = discardableAttrDag->getArgNameStr(i);
+ if (givenName.empty())
+ PrintFatalError(dialect.getDef()->getLoc(),
+ "discardable attributes must be named");
+ discardableAttributes.push_back(
+ {givenName.str(), arg->getAsUnquotedString()});
+ }
+}
+
/// Given a set of records for a T, filter the ones that correspond to
/// the given dialect.
template <typename T>
@@ -181,6 +196,37 @@ static const char *const operationInterfaceFallbackDecl = R"(
mlir::OperationName opName) override;
)";
+/// The code block for the discardable attribute helper.
+static const char *const discardableAttrHelperDecl = R"(
+ /// Helper to manage the discardable attribute `{1}`.
+ class {0}AttrHelper {{
+ mlir::StringAttr name;
+ public:
+ static constexpr llvm::StringLiteral getNameStr() {{
+ return "{4}.{1}";
+ }
+ constexpr mlir::StringAttr getName() {{
+ return name;
+ }
+
+ {0}AttrHelper(mlir::MLIRContext *ctx)
+ : name(mlir::StringAttr::get(ctx, getNameStr())) {{}
+
+ {2} getAttr(::mlir::Operation *op) {{
+ return op->getAttrOfType<{2}>(getName());
+ }
+ void setAttr(::mlir::Operation *op, {2} val) {{
+ op->setAttr(getName(), val);
+ }
+ };
+ {0}AttrHelper get{0}AttrHelper() {
+ return {3}AttrName;
+ }
+ private:
+ {0}AttrHelper {3}AttrName;
+ public:
+)";
+
/// Generate the declaration for the given dialect class.
static void emitDialectDecl(Dialect &dialect, raw_ostream &os) {
// Emit all nested namespaces.
@@ -216,6 +262,22 @@ static void emitDialectDecl(Dialect &dialect, raw_ostream &os) {
os << regionResultAttrVerifierDecl;
if (dialect.hasOperationInterfaceFallback())
os << operationInterfaceFallbackDecl;
+
+ llvm::DagInit *discardableAttrDag = dialect.getDiscardableAttributes();
+ SmallVector<std::pair<std::string, std::string>> discardableAttributes;
+ populateDiscardableAttributes(dialect, discardableAttrDag,
+ discardableAttributes);
+
+ for (const auto &attrPair : discardableAttributes) {
+ std::string camelNameUpper = llvm::convertToCamelFromSnakeCase(
+ attrPair.first, /*capitalizeFirst=*/true);
+ std::string camelName = llvm::convertToCamelFromSnakeCase(
+ attrPair.first, /*capitalizeFirst=*/false);
+ os << llvm::formatv(discardableAttrHelperDecl, camelNameUpper,
+ attrPair.first, attrPair.second, camelName,
+ dialect.getName());
+ }
+
if (std::optional<StringRef> extraDecl = dialect.getExtraClassDeclaration())
os << *extraDecl;
@@ -253,9 +315,12 @@ static bool emitDialectDecls(const llvm::RecordKeeper &recordKeeper,
/// {1}: initialization code that is emitted in the ctor body before calling
/// initialize().
/// {2}: The dialect parent class.
+/// {3}: Extra members to initialize
static const char *const dialectConstructorStr = R"(
{0}::{0}(::mlir::MLIRContext *context)
- : ::mlir::{2}(getDialectNamespace(), context, ::mlir::TypeID::get<{0}>()) {{
+ : ::mlir::{2}(getDialectNamespace(), context, ::mlir::TypeID::get<{0}>())
+ {3}
+ {{
{1}
initialize();
}
@@ -269,7 +334,9 @@ static const char *const dialectDestructorStr = R"(
)";
-static void emitDialectDef(Dialect &dialect, raw_ostream &os) {
+static void emitDialectDef(Dialect &dialect,
+ const llvm::RecordKeeper &recordKeeper,
+ raw_ostream &os) {
std::string cppClassName = dialect.getCppClassName();
// Emit the TypeID explicit specializations to have a single symbol def.
@@ -292,8 +359,22 @@ static void emitDialectDef(Dialect &dialect, raw_ostream &os) {
// Emit the constructor and destructor.
StringRef superClassName =
dialect.isExtensible() ? "ExtensibleDialect" : "Dialect";
+
+ llvm::DagInit *discardableAttrDag = dialect.getDiscardableAttributes();
+ SmallVector<std::pair<std::string, std::string>> discardableAttributes;
+ populateDiscardableAttributes(dialect, discardableAttrDag,
+ discardableAttributes);
+ std::string discardableAttributesInit;
+ for (const auto &attrPair : discardableAttributes) {
+ std::string camelName = llvm::convertToCamelFromSnakeCase(
+ attrPair.first, /*capitalizeFirst=*/false);
+ llvm::raw_string_ostream os(discardableAttributesInit);
+ os << ", " << camelName << "AttrName(context)";
+ }
+
os << llvm::formatv(dialectConstructorStr, cppClassName,
- dependentDialectRegistrations, superClassName);
+ dependentDialectRegistrations, superClassName,
+ discardableAttributesInit);
if (!dialect.hasNonDefaultDestructor())
os << llvm::formatv(dialectDestructorStr, cppClassName);
}
@@ -310,7 +391,7 @@ static bool emitDialectDefs(const llvm::RecordKeeper &recordKeeper,
std::optional<Dialect> dialect = findDialectToGenerate(dialects);
if (!dialect)
return true;
- emitDialectDef(*dialect, os);
+ emitDialectDef(*dialect, recordKeeper, os);
return false;
}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's it @joker-eph, Thanks! Appreciate it!
Just a minor ask.
@@ -34,6 +34,8 @@ class Dialect { | |||
// pattern or interfaces. | |||
list<string> dependentDialects = []; | |||
|
|||
dag discardableAttrs = (ins); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is my mental model correct that these are all effectively optional attributes for which there are helpers generated? But they are different than the others as not verified, most likely empty? (e.g., I'm not sure if doing this vs making these optional attributes when one would select which ...)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know if I followed correctly what you're asking about, but these are helpers to manage discardable attributes. The idea is that for example the ROCDL dialect involves a discardable attribute with the key rocdl.reqd_work_group_size
, and the associated Attribute must be a DenseI32ArrayAttr
.
This feature will generate a helper for managing this discardable attribute on operations. The operations don't know anything about the attribute, it's a property of the dialect itself.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That makes sense to me. These being dialect known discardable rather than op known optional.
Looks like a spurious windows failure or just needs updated to TIP? |
Sorry, I need to address the comments, I'm just temporarily out of a workstation! Feel free to ping next week |
No worries. I am suffering the same fate.. |
…e) attributes This is a new ODS feature that allows dialects to define a list of key/value pair representing an attribute type and a name. This will generate helper classes on the dialect to be able to manage discardable attributes on operations in a type safe way. For example the `test` dialect can define: ``` let discardableAttrs = (ins "mlir::IntegerAttr":$discardable_attr_key, ); ``` And the following will be generated in the TestDialect class: ``` /// Helper to manage the discardable attribute `discardable_attr_key`. class DiscardableAttrKeyAttrHelper { ::mlir::StringAttr name; public: static constexpr ::llvm::StringLiteral getNameStr() { return "test.discardable_attr_key"; } constexpr ::mlir::StringAttr getName() { return name; } DiscardableAttrKeyAttrHelper(::mlir::MLIRContext *ctx) : name(::mlir::StringAttr::get(ctx, getNameStr())) {} mlir::IntegerAttr getAttr(::mlir::Operation *op) { return op->getAttrOfType<mlir::IntegerAttr>(name); } void setAttr(::mlir::Operation *op, mlir::IntegerAttr val) { op->setAttr(name, val); } bool isAttrPresent(::mlir::Operation *op) { return op->hasAttrOfType<mlir::IntegerAttr>(name); } void removeAttr(::mlir::Operation *op) { assert(op->hasAttrOfType<mlir::IntegerAttr>(name)); op->removeAttr(name); } }; DiscardableAttrKeyAttrHelper getDiscardableAttrKeyAttrHelper() { return discardableAttrKeyAttrName; } ``` User code having an instance of the TestDialect can then manipulate this attribute on operation using: ``` auto helper = testDialect.getDiscardableAttrKeyAttrHelper(); helper.setAttr(op, value); helper.isAttrPresent(op); ... ```
f64d744
to
40a814a
Compare
This is a new ODS feature that allows dialects to define a list of
key/value pair representing an attribute type and a name.
This will generate helper classes on the dialect to be able to
manage discardable attributes on operations in a type safe way.
For example the
test
dialect can define:And the following will be generated in the TestDialect class:
User code having an instance of the TestDialect can then manipulate this
attribute on operation using: